Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import json | |
| import torch | |
| import numpy as np | |
| from flask import Flask, Response, request, jsonify | |
| from ultralytics import YOLO | |
| # -------------- CONFIG -------------- | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "best.pt") | |
| FOCAL_LENGTH_PX = 615 | |
| KNOWN_WIDTHS_M = { | |
| "person": 0.5, "car": 1.8, "truck": 2.3, "bus": 2.5, | |
| "bicycle": 0.6, "motorcycle": 0.7, "door": 0.9, | |
| "stairs": 1.2, "pole": 0.15, "bollard": 0.2 | |
| } | |
| THRESHOLDS = {"CRITICAL": 1.0, "WARNING": 2.0, "CAUTION": 3.0} | |
| # -------------- APP INIT -------------- | |
| app = Flask(__name__) | |
| # Prefer GPU if available | |
| device = 0 if torch.cuda.is_available() else "cpu" | |
| # Load model | |
| model = YOLO(MODEL_PATH) | |
| model.to(device) | |
| # Fuse for a small speed boost; ignore if unsupported by your build | |
| try: | |
| model.fuse() | |
| except Exception: | |
| pass | |
| # -------------- UTILS -------------- | |
| def estimate_distance(bbox_width_px, class_name): | |
| """Approx distance using pinhole model D = (W * f) / w""" | |
| if bbox_width_px is None or bbox_width_px <= 1: | |
| return None | |
| known_width = KNOWN_WIDTHS_M.get(class_name) | |
| if not known_width: | |
| return None | |
| return (known_width * FOCAL_LENGTH_PX) / float(bbox_width_px) | |
| def get_alert_level(distance_m): | |
| if distance_m is None: | |
| return "SAFE" | |
| if distance_m <= THRESHOLDS["CRITICAL"]: | |
| return "CRITICAL" | |
| if distance_m <= THRESHOLDS["WARNING"]: | |
| return "WARNING" | |
| if distance_m <= THRESHOLDS["CAUTION"]: | |
| return "CAUTION" | |
| return "SAFE" | |
| def annotate_frame(frame, detections): | |
| """Draw boxes and labels colored by alert level.""" | |
| for det in detections: | |
| x1, y1, x2, y2 = det["bbox"] | |
| level = det["alert_level"] | |
| # Color by severity | |
| if level == "CRITICAL": | |
| color = (0, 0, 255) # Red | |
| elif level == "WARNING": | |
| color = (0, 165, 255) # Orange | |
| elif level == "CAUTION": | |
| color = (0, 255, 255) # Yellow | |
| else: | |
| color = (0, 255, 0) # Green | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
| dist_str = f"{det['distance_m']}m" if det["distance_m"] is not None else "n/a" | |
| label = f"{det['class']} {dist_str}" | |
| (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) | |
| y_top = max(y1 - th - 6, 0) | |
| cv2.rectangle(frame, (x1, y_top), (x1 + tw + 8, y_top + th + 6), color, -1) | |
| cv2.putText(frame, label, (x1 + 4, y_top + th), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2, cv2.LINE_AA) | |
| return frame | |
| # -------------- ROUTES -------------- | |
| def ping(): | |
| return jsonify({"ok": True}), 200 | |
| def process_frame(): | |
| """ | |
| Accepts a single video frame via multipart/form-data field 'frame' (JPEG bytes), | |
| returns annotated JPEG as body with alert metadata in headers. | |
| """ | |
| if "frame" not in request.files: | |
| return jsonify({"error": "No frame uploaded"}), 400 | |
| # Decode image | |
| file = request.files["frame"] | |
| file_bytes = file.read() | |
| img_bytes = np.frombuffer(file_bytes, np.uint8) | |
| frame = cv2.imdecode(img_bytes, cv2.IMREAD_COLOR) | |
| if frame is None: | |
| return jsonify({"error": "Invalid image"}), 400 | |
| # Run inference | |
| results = model(frame, conf=0.25, iou=0.5, verbose=False, device=device) | |
| boxes = results[0].boxes | |
| detections = [] | |
| max_level = "SAFE" | |
| order = {"SAFE": 0, "CAUTION": 1, "WARNING": 2, "CRITICAL": 3} | |
| if boxes is not None and len(boxes) > 0: | |
| for b in boxes: | |
| x1, y1, x2, y2 = b.xyxy[0].tolist() | |
| cls_id = int(b.cls[0].item()) | |
| conf = float(b.conf[0].item()) | |
| class_name = model.names.get(cls_id, str(cls_id)) | |
| bbox_w = int(x2 - x1) | |
| distance_m = estimate_distance(bbox_w, class_name) | |
| level = get_alert_level(distance_m) | |
| if order[level] > order[max_level]: | |
| max_level = level | |
| detections.append({ | |
| "class": class_name, | |
| "confidence": round(conf, 3), | |
| "distance_m": round(distance_m, 2) if distance_m else None, | |
| "alert_level": level, | |
| "bbox": [int(x1), int(y1), int(x2), int(y2)] | |
| }) | |
| # Annotate and encode JPEG | |
| annotated = annotate_frame(frame, detections) | |
| ok, buffer = cv2.imencode(".jpg", annotated, [int(cv2.IMWRITE_JPEG_QUALITY), 80]) | |
| if not ok: | |
| return jsonify({"error": "Encode failed"}), 500 | |
| encoded = buffer.tobytes() | |
| # Headers with metadata | |
| hdr_alert = max_level | |
| hdr_count = str(len(detections)) | |
| headers = { | |
| "Content-Length": str(len(encoded)), # some clients use it for streaming/decoding | |
| "X-Alert-Level": hdr_alert, | |
| "X-Detections-Count": hdr_count | |
| } | |
| return Response(encoded, mimetype="image/jpeg", headers=headers) | |
| # -------------- MAIN -------------- | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| # threaded=True allows concurrent requests from multiple clients | |
| app.run(host="0.0.0.0", port=port, threaded=True) | |