fastapiyolo / app.py
omarash2016's picture
Update app.py
737f6a7 verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from io import BytesIO
from PIL import Image
import cv2
import supervision as sv
from inference import get_model
app = FastAPI()
import os
# Set writable directories for matplotlib and Hugging Face cache
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics"
# Load the pre-trained model once when the server starts
model = get_model(model_id="yolov8n-640")
def gen_frames():
# Open the default camera (index 0)
cap = cv2.VideoCapture(0)
if not cap.isOpened():
raise HTTPException(status_code=500, detail="Could not open video device")
while True:
success, frame = cap.read()
if not success:
break
# Convert the frame from BGR (OpenCV default) to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Convert the numpy array to a PIL Image
pil_img = Image.fromarray(frame_rgb)
# Run inference on the current frame (model.infer returns a list, so we take the first element)
try:
results = model.infer(pil_img)[0]
except Exception as e:
# If inference fails, skip this frame
print(f"Inference error: {e}")
continue
# Convert inference results to a Supervision detections object
detections = sv.Detections.from_inference(results)
# Create annotators for bounding boxes and labels
bounding_box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
# Annotate the frame with bounding boxes and labels
annotated_image = bounding_box_annotator.annotate(scene=pil_img, detections=detections)
annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
# Save the annotated image to an in-memory buffer in JPEG format
buf = BytesIO()
annotated_image.save(buf, format="JPEG")
buf.seek(0)
frame_bytes = buf.read()
# Yield the frame in MJPEG format
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n')
# Release the camera when the loop ends
cap.release()
@app.get("/live_feed")
async def live_feed():
"""
Streams a live feed from the camera with inference annotations.
Access via: http://localhost:8000/live_feed
"""
return StreamingResponse(
gen_frames(),
media_type="multipart/x-mixed-replace; boundary=frame"
)
@app.get("/detect_classes")
async def detect_classes():
"""
Detects and returns the labels of the detected classes in the current camera frame as JSON.
Updates every 30 frames.
Access via: http://localhost:8000/detect_classes
"""
# Open the default camera (index 0)
cap = cv2.VideoCapture(0)
if not cap.isOpened():
raise HTTPException(status_code=500, detail="Could not open video device")
frame_count = 0
all_class_ids = []
while frame_count < 3:
success, frame = cap.read()
if not success:
cap.release()
raise HTTPException(status_code=500, detail="Failed to read frame from camera")
# Convert the frame from BGR (OpenCV default) to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Convert the numpy array to a PIL Image
pil_img = Image.fromarray(frame_rgb)
# Run inference on the current frame (model.infer returns a list, so we take the first element)
try:
results = model.infer(pil_img)[0]
except Exception as e:
print(f"Inference error: {e}")
continue
# Convert inference results to a Supervision detections object
detections = sv.Detections.from_inference(results)
# Collect class IDs from the current frame
if detections.class_id is not None:
all_class_ids.extend(detections.class_id.tolist())
frame_count += 1
# Release the camera
cap.release()
# Extract unique detected class labels from all collected class IDs
unique_class_ids = set(all_class_ids)
class_labels = [model.class_names[class_id] for class_id in unique_class_ids]
# Return the detected class labels as JSON
return JSONResponse(content={"detected_classes": class_labels})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)