File size: 5,154 Bytes
b5d1da8
 
c35a0bd
b5d1da8
 
3748a91
b5d1da8
 
c35a0bd
 
b5d1da8
 
 
 
 
 
 
 
c35a0bd
 
 
 
b5d1da8
c35a0bd
 
b5d1da8
 
c35a0bd
 
 
 
 
3748a91
c35a0bd
b5d1da8
c35a0bd
 
b5d1da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3748a91
c35a0bd
3748a91
 
c35a0bd
 
 
 
 
 
 
 
 
 
 
3748a91
c35a0bd
 
 
 
 
 
 
 
3748a91
 
c35a0bd
 
 
 
3748a91
 
 
b5d1da8
c35a0bd
 
b5d1da8
3748a91
 
b5d1da8
c35a0bd
3748a91
c35a0bd
 
b5d1da8
c35a0bd
 
b5d1da8
c35a0bd
 
b5d1da8
 
 
 
 
 
c35a0bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3748a91
c35a0bd
 
 
3748a91
 
c35a0bd
 
 
 
 
 
 
 
3748a91
c35a0bd
b5d1da8
c35a0bd
b5d1da8
 
c35a0bd
3748a91
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import cv2
import json
import torch
import numpy as np
from flask import Flask, Response, request, jsonify
from ultralytics import YOLO

# -------------- CONFIG --------------
MODEL_PATH = os.environ.get("MODEL_PATH", "best.pt")
FOCAL_LENGTH_PX = 615
KNOWN_WIDTHS_M = {
    "person": 0.5, "car": 1.8, "truck": 2.3, "bus": 2.5,
    "bicycle": 0.6, "motorcycle": 0.7, "door": 0.9,
    "stairs": 1.2, "pole": 0.15, "bollard": 0.2
}
THRESHOLDS = {"CRITICAL": 1.0, "WARNING": 2.0, "CAUTION": 3.0}

# -------------- APP INIT --------------
app = Flask(__name__)

# Prefer GPU if available
device = 0 if torch.cuda.is_available() else "cpu"

# Load model
model = YOLO(MODEL_PATH)
model.to(device)
# Fuse for a small speed boost; ignore if unsupported by your build
try:
    model.fuse()
except Exception:
    pass

# -------------- UTILS --------------
def estimate_distance(bbox_width_px, class_name):
    """Approx distance using pinhole model D = (W * f) / w"""
    if bbox_width_px is None or bbox_width_px <= 1:
        return None
    known_width = KNOWN_WIDTHS_M.get(class_name)
    if not known_width:
        return None
    return (known_width * FOCAL_LENGTH_PX) / float(bbox_width_px)

def get_alert_level(distance_m):
    if distance_m is None:
        return "SAFE"
    if distance_m <= THRESHOLDS["CRITICAL"]:
        return "CRITICAL"
    if distance_m <= THRESHOLDS["WARNING"]:
        return "WARNING"
    if distance_m <= THRESHOLDS["CAUTION"]:
        return "CAUTION"
    return "SAFE"

def annotate_frame(frame, detections):
    """Draw boxes and labels colored by alert level."""
    for det in detections:
        x1, y1, x2, y2 = det["bbox"]
        level = det["alert_level"]
        # Color by severity
        if level == "CRITICAL":
            color = (0, 0, 255)    # Red
        elif level == "WARNING":
            color = (0, 165, 255)  # Orange
        elif level == "CAUTION":
            color = (0, 255, 255)  # Yellow
        else:
            color = (0, 255, 0)    # Green

        cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)

        dist_str = f"{det['distance_m']}m" if det["distance_m"] is not None else "n/a"
        label = f"{det['class']} {dist_str}"
        (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
        y_top = max(y1 - th - 6, 0)
        cv2.rectangle(frame, (x1, y_top), (x1 + tw + 8, y_top + th + 6), color, -1)
        cv2.putText(frame, label, (x1 + 4, y_top + th),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2, cv2.LINE_AA)
    return frame

# -------------- ROUTES --------------
@app.route("/ping")
def ping():
    return jsonify({"ok": True}), 200

@app.route("/stream", methods=["POST"])
def process_frame():
    """
    Accepts a single video frame via multipart/form-data field 'frame' (JPEG bytes),
    returns annotated JPEG as body with alert metadata in headers.
    """
    if "frame" not in request.files:
        return jsonify({"error": "No frame uploaded"}), 400

    # Decode image
    file = request.files["frame"]
    file_bytes = file.read()
    img_bytes = np.frombuffer(file_bytes, np.uint8)
    frame = cv2.imdecode(img_bytes, cv2.IMREAD_COLOR)
    if frame is None:
        return jsonify({"error": "Invalid image"}), 400

    # Run inference
    results = model(frame, conf=0.25, iou=0.5, verbose=False, device=device)
    boxes = results[0].boxes

    detections = []
    max_level = "SAFE"
    order = {"SAFE": 0, "CAUTION": 1, "WARNING": 2, "CRITICAL": 3}

    if boxes is not None and len(boxes) > 0:
        for b in boxes:
            x1, y1, x2, y2 = b.xyxy[0].tolist()
            cls_id = int(b.cls[0].item())
            conf = float(b.conf[0].item())
            class_name = model.names.get(cls_id, str(cls_id))
            bbox_w = int(x2 - x1)

            distance_m = estimate_distance(bbox_w, class_name)
            level = get_alert_level(distance_m)
            if order[level] > order[max_level]:
                max_level = level

            detections.append({
                "class": class_name,
                "confidence": round(conf, 3),
                "distance_m": round(distance_m, 2) if distance_m else None,
                "alert_level": level,
                "bbox": [int(x1), int(y1), int(x2), int(y2)]
            })

    # Annotate and encode JPEG
    annotated = annotate_frame(frame, detections)
    ok, buffer = cv2.imencode(".jpg", annotated, [int(cv2.IMWRITE_JPEG_QUALITY), 80])
    if not ok:
        return jsonify({"error": "Encode failed"}), 500
    encoded = buffer.tobytes()

    # Headers with metadata
    hdr_alert = max_level
    hdr_count = str(len(detections))
    headers = {
        "Content-Length": str(len(encoded)),  # some clients use it for streaming/decoding
        "X-Alert-Level": hdr_alert,
        "X-Detections-Count": hdr_count
    }

    return Response(encoded, mimetype="image/jpeg", headers=headers)

# -------------- MAIN --------------
if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    # threaded=True allows concurrent requests from multiple clients
    app.run(host="0.0.0.0", port=port, threaded=True)