import os import cv2 import json import torch import numpy as np from flask import Flask, Response, request, jsonify from ultralytics import YOLO # -------------- CONFIG -------------- MODEL_PATH = os.environ.get("MODEL_PATH", "best.pt") FOCAL_LENGTH_PX = 615 KNOWN_WIDTHS_M = { "person": 0.5, "car": 1.8, "truck": 2.3, "bus": 2.5, "bicycle": 0.6, "motorcycle": 0.7, "door": 0.9, "stairs": 1.2, "pole": 0.15, "bollard": 0.2 } THRESHOLDS = {"CRITICAL": 1.0, "WARNING": 2.0, "CAUTION": 3.0} # -------------- APP INIT -------------- app = Flask(__name__) # Prefer GPU if available device = 0 if torch.cuda.is_available() else "cpu" # Load model model = YOLO(MODEL_PATH) model.to(device) # Fuse for a small speed boost; ignore if unsupported by your build try: model.fuse() except Exception: pass # -------------- UTILS -------------- def estimate_distance(bbox_width_px, class_name): """Approx distance using pinhole model D = (W * f) / w""" if bbox_width_px is None or bbox_width_px <= 1: return None known_width = KNOWN_WIDTHS_M.get(class_name) if not known_width: return None return (known_width * FOCAL_LENGTH_PX) / float(bbox_width_px) def get_alert_level(distance_m): if distance_m is None: return "SAFE" if distance_m <= THRESHOLDS["CRITICAL"]: return "CRITICAL" if distance_m <= THRESHOLDS["WARNING"]: return "WARNING" if distance_m <= THRESHOLDS["CAUTION"]: return "CAUTION" return "SAFE" def annotate_frame(frame, detections): """Draw boxes and labels colored by alert level.""" for det in detections: x1, y1, x2, y2 = det["bbox"] level = det["alert_level"] # Color by severity if level == "CRITICAL": color = (0, 0, 255) # Red elif level == "WARNING": color = (0, 165, 255) # Orange elif level == "CAUTION": color = (0, 255, 255) # Yellow else: color = (0, 255, 0) # Green cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) dist_str = f"{det['distance_m']}m" if det["distance_m"] is not None else "n/a" label = f"{det['class']} {dist_str}" (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) y_top = max(y1 - th - 6, 0) cv2.rectangle(frame, (x1, y_top), (x1 + tw + 8, y_top + th + 6), color, -1) cv2.putText(frame, label, (x1 + 4, y_top + th), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2, cv2.LINE_AA) return frame # -------------- ROUTES -------------- @app.route("/ping") def ping(): return jsonify({"ok": True}), 200 @app.route("/stream", methods=["POST"]) def process_frame(): """ Accepts a single video frame via multipart/form-data field 'frame' (JPEG bytes), returns annotated JPEG as body with alert metadata in headers. """ if "frame" not in request.files: return jsonify({"error": "No frame uploaded"}), 400 # Decode image file = request.files["frame"] file_bytes = file.read() img_bytes = np.frombuffer(file_bytes, np.uint8) frame = cv2.imdecode(img_bytes, cv2.IMREAD_COLOR) if frame is None: return jsonify({"error": "Invalid image"}), 400 # Run inference results = model(frame, conf=0.25, iou=0.5, verbose=False, device=device) boxes = results[0].boxes detections = [] max_level = "SAFE" order = {"SAFE": 0, "CAUTION": 1, "WARNING": 2, "CRITICAL": 3} if boxes is not None and len(boxes) > 0: for b in boxes: x1, y1, x2, y2 = b.xyxy[0].tolist() cls_id = int(b.cls[0].item()) conf = float(b.conf[0].item()) class_name = model.names.get(cls_id, str(cls_id)) bbox_w = int(x2 - x1) distance_m = estimate_distance(bbox_w, class_name) level = get_alert_level(distance_m) if order[level] > order[max_level]: max_level = level detections.append({ "class": class_name, "confidence": round(conf, 3), "distance_m": round(distance_m, 2) if distance_m else None, "alert_level": level, "bbox": [int(x1), int(y1), int(x2), int(y2)] }) # Annotate and encode JPEG annotated = annotate_frame(frame, detections) ok, buffer = cv2.imencode(".jpg", annotated, [int(cv2.IMWRITE_JPEG_QUALITY), 80]) if not ok: return jsonify({"error": "Encode failed"}), 500 encoded = buffer.tobytes() # Headers with metadata hdr_alert = max_level hdr_count = str(len(detections)) headers = { "Content-Length": str(len(encoded)), # some clients use it for streaming/decoding "X-Alert-Level": hdr_alert, "X-Detections-Count": hdr_count } return Response(encoded, mimetype="image/jpeg", headers=headers) # -------------- MAIN -------------- if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) # threaded=True allows concurrent requests from multiple clients app.run(host="0.0.0.0", port=port, threaded=True)