Spaces:
Sleeping
Sleeping
File size: 5,154 Bytes
b5d1da8 c35a0bd b5d1da8 3748a91 b5d1da8 c35a0bd b5d1da8 c35a0bd b5d1da8 c35a0bd b5d1da8 c35a0bd 3748a91 c35a0bd b5d1da8 c35a0bd b5d1da8 3748a91 c35a0bd 3748a91 c35a0bd 3748a91 c35a0bd 3748a91 c35a0bd 3748a91 b5d1da8 c35a0bd b5d1da8 3748a91 b5d1da8 c35a0bd 3748a91 c35a0bd b5d1da8 c35a0bd b5d1da8 c35a0bd b5d1da8 c35a0bd 3748a91 c35a0bd 3748a91 c35a0bd 3748a91 c35a0bd b5d1da8 c35a0bd b5d1da8 c35a0bd 3748a91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import cv2
import json
import torch
import numpy as np
from flask import Flask, Response, request, jsonify
from ultralytics import YOLO
# -------------- CONFIG --------------
MODEL_PATH = os.environ.get("MODEL_PATH", "best.pt")
FOCAL_LENGTH_PX = 615
KNOWN_WIDTHS_M = {
"person": 0.5, "car": 1.8, "truck": 2.3, "bus": 2.5,
"bicycle": 0.6, "motorcycle": 0.7, "door": 0.9,
"stairs": 1.2, "pole": 0.15, "bollard": 0.2
}
THRESHOLDS = {"CRITICAL": 1.0, "WARNING": 2.0, "CAUTION": 3.0}
# -------------- APP INIT --------------
app = Flask(__name__)
# Prefer GPU if available
device = 0 if torch.cuda.is_available() else "cpu"
# Load model
model = YOLO(MODEL_PATH)
model.to(device)
# Fuse for a small speed boost; ignore if unsupported by your build
try:
model.fuse()
except Exception:
pass
# -------------- UTILS --------------
def estimate_distance(bbox_width_px, class_name):
"""Approx distance using pinhole model D = (W * f) / w"""
if bbox_width_px is None or bbox_width_px <= 1:
return None
known_width = KNOWN_WIDTHS_M.get(class_name)
if not known_width:
return None
return (known_width * FOCAL_LENGTH_PX) / float(bbox_width_px)
def get_alert_level(distance_m):
if distance_m is None:
return "SAFE"
if distance_m <= THRESHOLDS["CRITICAL"]:
return "CRITICAL"
if distance_m <= THRESHOLDS["WARNING"]:
return "WARNING"
if distance_m <= THRESHOLDS["CAUTION"]:
return "CAUTION"
return "SAFE"
def annotate_frame(frame, detections):
"""Draw boxes and labels colored by alert level."""
for det in detections:
x1, y1, x2, y2 = det["bbox"]
level = det["alert_level"]
# Color by severity
if level == "CRITICAL":
color = (0, 0, 255) # Red
elif level == "WARNING":
color = (0, 165, 255) # Orange
elif level == "CAUTION":
color = (0, 255, 255) # Yellow
else:
color = (0, 255, 0) # Green
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
dist_str = f"{det['distance_m']}m" if det["distance_m"] is not None else "n/a"
label = f"{det['class']} {dist_str}"
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
y_top = max(y1 - th - 6, 0)
cv2.rectangle(frame, (x1, y_top), (x1 + tw + 8, y_top + th + 6), color, -1)
cv2.putText(frame, label, (x1 + 4, y_top + th),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2, cv2.LINE_AA)
return frame
# -------------- ROUTES --------------
@app.route("/ping")
def ping():
return jsonify({"ok": True}), 200
@app.route("/stream", methods=["POST"])
def process_frame():
"""
Accepts a single video frame via multipart/form-data field 'frame' (JPEG bytes),
returns annotated JPEG as body with alert metadata in headers.
"""
if "frame" not in request.files:
return jsonify({"error": "No frame uploaded"}), 400
# Decode image
file = request.files["frame"]
file_bytes = file.read()
img_bytes = np.frombuffer(file_bytes, np.uint8)
frame = cv2.imdecode(img_bytes, cv2.IMREAD_COLOR)
if frame is None:
return jsonify({"error": "Invalid image"}), 400
# Run inference
results = model(frame, conf=0.25, iou=0.5, verbose=False, device=device)
boxes = results[0].boxes
detections = []
max_level = "SAFE"
order = {"SAFE": 0, "CAUTION": 1, "WARNING": 2, "CRITICAL": 3}
if boxes is not None and len(boxes) > 0:
for b in boxes:
x1, y1, x2, y2 = b.xyxy[0].tolist()
cls_id = int(b.cls[0].item())
conf = float(b.conf[0].item())
class_name = model.names.get(cls_id, str(cls_id))
bbox_w = int(x2 - x1)
distance_m = estimate_distance(bbox_w, class_name)
level = get_alert_level(distance_m)
if order[level] > order[max_level]:
max_level = level
detections.append({
"class": class_name,
"confidence": round(conf, 3),
"distance_m": round(distance_m, 2) if distance_m else None,
"alert_level": level,
"bbox": [int(x1), int(y1), int(x2), int(y2)]
})
# Annotate and encode JPEG
annotated = annotate_frame(frame, detections)
ok, buffer = cv2.imencode(".jpg", annotated, [int(cv2.IMWRITE_JPEG_QUALITY), 80])
if not ok:
return jsonify({"error": "Encode failed"}), 500
encoded = buffer.tobytes()
# Headers with metadata
hdr_alert = max_level
hdr_count = str(len(detections))
headers = {
"Content-Length": str(len(encoded)), # some clients use it for streaming/decoding
"X-Alert-Level": hdr_alert,
"X-Detections-Count": hdr_count
}
return Response(encoded, mimetype="image/jpeg", headers=headers)
# -------------- MAIN --------------
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
# threaded=True allows concurrent requests from multiple clients
app.run(host="0.0.0.0", port=port, threaded=True)
|