Spaces:
Build error
Build error
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from io import BytesIO | |
| from PIL import Image | |
| import cv2 | |
| import supervision as sv | |
| from inference import get_model | |
| app = FastAPI() | |
| import os | |
| # Set writable directories for matplotlib and Hugging Face cache | |
| os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" | |
| os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics" | |
| # Load the pre-trained model once when the server starts | |
| model = get_model(model_id="yolov8n-640") | |
| def gen_frames(): | |
| # Open the default camera (index 0) | |
| cap = cv2.VideoCapture(0) | |
| if not cap.isOpened(): | |
| raise HTTPException(status_code=500, detail="Could not open video device") | |
| while True: | |
| success, frame = cap.read() | |
| if not success: | |
| break | |
| # Convert the frame from BGR (OpenCV default) to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # Convert the numpy array to a PIL Image | |
| pil_img = Image.fromarray(frame_rgb) | |
| # Run inference on the current frame (model.infer returns a list, so we take the first element) | |
| try: | |
| results = model.infer(pil_img)[0] | |
| except Exception as e: | |
| # If inference fails, skip this frame | |
| print(f"Inference error: {e}") | |
| continue | |
| # Convert inference results to a Supervision detections object | |
| detections = sv.Detections.from_inference(results) | |
| # Create annotators for bounding boxes and labels | |
| bounding_box_annotator = sv.BoxAnnotator() | |
| label_annotator = sv.LabelAnnotator() | |
| # Annotate the frame with bounding boxes and labels | |
| annotated_image = bounding_box_annotator.annotate(scene=pil_img, detections=detections) | |
| annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections) | |
| # Save the annotated image to an in-memory buffer in JPEG format | |
| buf = BytesIO() | |
| annotated_image.save(buf, format="JPEG") | |
| buf.seek(0) | |
| frame_bytes = buf.read() | |
| # Yield the frame in MJPEG format | |
| yield (b'--frame\r\n' | |
| b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n') | |
| # Release the camera when the loop ends | |
| cap.release() | |
| async def live_feed(): | |
| """ | |
| Streams a live feed from the camera with inference annotations. | |
| Access via: http://localhost:8000/live_feed | |
| """ | |
| return StreamingResponse( | |
| gen_frames(), | |
| media_type="multipart/x-mixed-replace; boundary=frame" | |
| ) | |
| async def detect_classes(): | |
| """ | |
| Detects and returns the labels of the detected classes in the current camera frame as JSON. | |
| Updates every 30 frames. | |
| Access via: http://localhost:8000/detect_classes | |
| """ | |
| # Open the default camera (index 0) | |
| cap = cv2.VideoCapture(0) | |
| if not cap.isOpened(): | |
| raise HTTPException(status_code=500, detail="Could not open video device") | |
| frame_count = 0 | |
| all_class_ids = [] | |
| while frame_count < 3: | |
| success, frame = cap.read() | |
| if not success: | |
| cap.release() | |
| raise HTTPException(status_code=500, detail="Failed to read frame from camera") | |
| # Convert the frame from BGR (OpenCV default) to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # Convert the numpy array to a PIL Image | |
| pil_img = Image.fromarray(frame_rgb) | |
| # Run inference on the current frame (model.infer returns a list, so we take the first element) | |
| try: | |
| results = model.infer(pil_img)[0] | |
| except Exception as e: | |
| print(f"Inference error: {e}") | |
| continue | |
| # Convert inference results to a Supervision detections object | |
| detections = sv.Detections.from_inference(results) | |
| # Collect class IDs from the current frame | |
| if detections.class_id is not None: | |
| all_class_ids.extend(detections.class_id.tolist()) | |
| frame_count += 1 | |
| # Release the camera | |
| cap.release() | |
| # Extract unique detected class labels from all collected class IDs | |
| unique_class_ids = set(all_class_ids) | |
| class_labels = [model.class_names[class_id] for class_id in unique_class_ids] | |
| # Return the detected class labels as JSON | |
| return JSONResponse(content={"detected_classes": class_labels}) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |