hacksjce / app.py
mohammedafeef's picture
Update app.py
c35a0bd verified
import os
import cv2
import json
import torch
import numpy as np
from flask import Flask, Response, request, jsonify
from ultralytics import YOLO
# -------------- CONFIG --------------
MODEL_PATH = os.environ.get("MODEL_PATH", "best.pt")
FOCAL_LENGTH_PX = 615
KNOWN_WIDTHS_M = {
"person": 0.5, "car": 1.8, "truck": 2.3, "bus": 2.5,
"bicycle": 0.6, "motorcycle": 0.7, "door": 0.9,
"stairs": 1.2, "pole": 0.15, "bollard": 0.2
}
THRESHOLDS = {"CRITICAL": 1.0, "WARNING": 2.0, "CAUTION": 3.0}
# -------------- APP INIT --------------
app = Flask(__name__)
# Prefer GPU if available
device = 0 if torch.cuda.is_available() else "cpu"
# Load model
model = YOLO(MODEL_PATH)
model.to(device)
# Fuse for a small speed boost; ignore if unsupported by your build
try:
model.fuse()
except Exception:
pass
# -------------- UTILS --------------
def estimate_distance(bbox_width_px, class_name):
"""Approx distance using pinhole model D = (W * f) / w"""
if bbox_width_px is None or bbox_width_px <= 1:
return None
known_width = KNOWN_WIDTHS_M.get(class_name)
if not known_width:
return None
return (known_width * FOCAL_LENGTH_PX) / float(bbox_width_px)
def get_alert_level(distance_m):
if distance_m is None:
return "SAFE"
if distance_m <= THRESHOLDS["CRITICAL"]:
return "CRITICAL"
if distance_m <= THRESHOLDS["WARNING"]:
return "WARNING"
if distance_m <= THRESHOLDS["CAUTION"]:
return "CAUTION"
return "SAFE"
def annotate_frame(frame, detections):
"""Draw boxes and labels colored by alert level."""
for det in detections:
x1, y1, x2, y2 = det["bbox"]
level = det["alert_level"]
# Color by severity
if level == "CRITICAL":
color = (0, 0, 255) # Red
elif level == "WARNING":
color = (0, 165, 255) # Orange
elif level == "CAUTION":
color = (0, 255, 255) # Yellow
else:
color = (0, 255, 0) # Green
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
dist_str = f"{det['distance_m']}m" if det["distance_m"] is not None else "n/a"
label = f"{det['class']} {dist_str}"
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
y_top = max(y1 - th - 6, 0)
cv2.rectangle(frame, (x1, y_top), (x1 + tw + 8, y_top + th + 6), color, -1)
cv2.putText(frame, label, (x1 + 4, y_top + th),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2, cv2.LINE_AA)
return frame
# -------------- ROUTES --------------
@app.route("/ping")
def ping():
return jsonify({"ok": True}), 200
@app.route("/stream", methods=["POST"])
def process_frame():
"""
Accepts a single video frame via multipart/form-data field 'frame' (JPEG bytes),
returns annotated JPEG as body with alert metadata in headers.
"""
if "frame" not in request.files:
return jsonify({"error": "No frame uploaded"}), 400
# Decode image
file = request.files["frame"]
file_bytes = file.read()
img_bytes = np.frombuffer(file_bytes, np.uint8)
frame = cv2.imdecode(img_bytes, cv2.IMREAD_COLOR)
if frame is None:
return jsonify({"error": "Invalid image"}), 400
# Run inference
results = model(frame, conf=0.25, iou=0.5, verbose=False, device=device)
boxes = results[0].boxes
detections = []
max_level = "SAFE"
order = {"SAFE": 0, "CAUTION": 1, "WARNING": 2, "CRITICAL": 3}
if boxes is not None and len(boxes) > 0:
for b in boxes:
x1, y1, x2, y2 = b.xyxy[0].tolist()
cls_id = int(b.cls[0].item())
conf = float(b.conf[0].item())
class_name = model.names.get(cls_id, str(cls_id))
bbox_w = int(x2 - x1)
distance_m = estimate_distance(bbox_w, class_name)
level = get_alert_level(distance_m)
if order[level] > order[max_level]:
max_level = level
detections.append({
"class": class_name,
"confidence": round(conf, 3),
"distance_m": round(distance_m, 2) if distance_m else None,
"alert_level": level,
"bbox": [int(x1), int(y1), int(x2), int(y2)]
})
# Annotate and encode JPEG
annotated = annotate_frame(frame, detections)
ok, buffer = cv2.imencode(".jpg", annotated, [int(cv2.IMWRITE_JPEG_QUALITY), 80])
if not ok:
return jsonify({"error": "Encode failed"}), 500
encoded = buffer.tobytes()
# Headers with metadata
hdr_alert = max_level
hdr_count = str(len(detections))
headers = {
"Content-Length": str(len(encoded)), # some clients use it for streaming/decoding
"X-Alert-Level": hdr_alert,
"X-Detections-Count": hdr_count
}
return Response(encoded, mimetype="image/jpeg", headers=headers)
# -------------- MAIN --------------
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
# threaded=True allows concurrent requests from multiple clients
app.run(host="0.0.0.0", port=port, threaded=True)