mohammedafeef commited on
Commit
c35a0bd
·
verified ·
1 Parent(s): d09d352

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -55
app.py CHANGED
@@ -1,14 +1,13 @@
1
  import os
2
  import cv2
 
3
  import torch
4
  import numpy as np
5
  from flask import Flask, Response, request, jsonify
6
  from ultralytics import YOLO
7
 
8
- app = Flask(__name__)
9
-
10
- # ---------------- CONFIG ----------------
11
- MODEL_PATH = "best.pt"
12
  FOCAL_LENGTH_PX = 615
13
  KNOWN_WIDTHS_M = {
14
  "person": 0.5, "car": 1.8, "truck": 2.3, "bus": 2.5,
@@ -17,15 +16,25 @@ KNOWN_WIDTHS_M = {
17
  }
18
  THRESHOLDS = {"CRITICAL": 1.0, "WARNING": 2.0, "CAUTION": 3.0}
19
 
 
 
 
 
20
  device = 0 if torch.cuda.is_available() else "cpu"
 
 
21
  model = YOLO(MODEL_PATH)
22
  model.to(device)
23
- model.fuse() # small speed boost
24
-
 
 
 
25
 
26
- # ---------------- UTIL FUNCTIONS ----------------
27
  def estimate_distance(bbox_width_px, class_name):
28
- if bbox_width_px <= 1:
 
29
  return None
30
  known_width = KNOWN_WIDTHS_M.get(class_name)
31
  if not known_width:
@@ -44,80 +53,102 @@ def get_alert_level(distance_m):
44
  return "SAFE"
45
 
46
  def annotate_frame(frame, detections):
 
47
  for det in detections:
48
  x1, y1, x2, y2 = det["bbox"]
49
- color = (0, 255, 0)
50
- if det["alert_level"] == "CRITICAL":
51
- color = (0, 0, 255)
52
- elif det["alert_level"] == "WARNING":
53
- color = (0, 165, 255)
54
- elif det["alert_level"] == "CAUTION":
55
- color = (0, 255, 255)
 
 
 
 
56
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
57
- label = f"{det['class']} {det['distance_m']}m"
58
- cv2.putText(frame, label, (x1, y1 - 10),
59
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
 
 
 
 
 
60
  return frame
61
 
 
 
 
 
62
 
63
- # ---------------- REALTIME VIDEO STREAM ----------------
64
  @app.route("/stream", methods=["POST"])
65
  def process_frame():
66
  """
67
- Accepts a single video frame (JPEG bytes from Unity or camera),
68
- returns detection data + optionally annotated frame.
69
  """
70
  if "frame" not in request.files:
71
  return jsonify({"error": "No frame uploaded"}), 400
72
 
 
73
  file = request.files["frame"]
74
- img_bytes = np.frombuffer(file.read(), np.uint8)
 
75
  frame = cv2.imdecode(img_bytes, cv2.IMREAD_COLOR)
 
 
76
 
77
- results = model(frame, conf=0.25, iou=0.5, verbose=False)
 
78
  boxes = results[0].boxes
79
 
80
  detections = []
81
  max_level = "SAFE"
82
  order = {"SAFE": 0, "CAUTION": 1, "WARNING": 2, "CRITICAL": 3}
83
 
84
- for b in boxes:
85
- x1, y1, x2, y2 = map(int, b.xyxy[0].tolist())
86
- cls_id = int(b.cls[0].item())
87
- conf = float(b.conf[0].item())
88
- class_name = model.names.get(cls_id, str(cls_id))
89
- bbox_w = x2 - x1
90
- distance_m = estimate_distance(bbox_w, class_name)
91
- level = get_alert_level(distance_m)
92
- if order[level] > order[max_level]:
93
- max_level = level
94
- detections.append({
95
- "class": class_name,
96
- "confidence": round(conf, 3),
97
- "distance_m": round(distance_m, 2) if distance_m else None,
98
- "alert_level": level,
99
- "bbox": [x1, y1, x2, y2]
100
- })
101
-
 
 
 
 
102
  annotated = annotate_frame(frame, detections)
103
- _, buffer = cv2.imencode('.jpg', annotated)
 
 
104
  encoded = buffer.tobytes()
105
 
106
- return Response(encoded, mimetype='image/jpeg',
107
- headers={
108
- "X-Alert-Level": max_level,
109
- "X-Detections": str(detections)
110
- })
111
-
112
-
113
- @app.route("/")
114
- def index():
115
- return """
116
- <h2>YOLOv8 Real-Time Detection Stream</h2>
117
- <p>POST /stream with 'frame' (JPEG) from Unity camera feed.</p>
118
- """
119
 
 
120
 
 
121
  if __name__ == "__main__":
122
  port = int(os.environ.get("PORT", 7860))
 
123
  app.run(host="0.0.0.0", port=port, threaded=True)
 
1
  import os
2
  import cv2
3
+ import json
4
  import torch
5
  import numpy as np
6
  from flask import Flask, Response, request, jsonify
7
  from ultralytics import YOLO
8
 
9
+ # -------------- CONFIG --------------
10
+ MODEL_PATH = os.environ.get("MODEL_PATH", "best.pt")
 
 
11
  FOCAL_LENGTH_PX = 615
12
  KNOWN_WIDTHS_M = {
13
  "person": 0.5, "car": 1.8, "truck": 2.3, "bus": 2.5,
 
16
  }
17
  THRESHOLDS = {"CRITICAL": 1.0, "WARNING": 2.0, "CAUTION": 3.0}
18
 
19
+ # -------------- APP INIT --------------
20
+ app = Flask(__name__)
21
+
22
+ # Prefer GPU if available
23
  device = 0 if torch.cuda.is_available() else "cpu"
24
+
25
+ # Load model
26
  model = YOLO(MODEL_PATH)
27
  model.to(device)
28
+ # Fuse for a small speed boost; ignore if unsupported by your build
29
+ try:
30
+ model.fuse()
31
+ except Exception:
32
+ pass
33
 
34
+ # -------------- UTILS --------------
35
  def estimate_distance(bbox_width_px, class_name):
36
+ """Approx distance using pinhole model D = (W * f) / w"""
37
+ if bbox_width_px is None or bbox_width_px <= 1:
38
  return None
39
  known_width = KNOWN_WIDTHS_M.get(class_name)
40
  if not known_width:
 
53
  return "SAFE"
54
 
55
  def annotate_frame(frame, detections):
56
+ """Draw boxes and labels colored by alert level."""
57
  for det in detections:
58
  x1, y1, x2, y2 = det["bbox"]
59
+ level = det["alert_level"]
60
+ # Color by severity
61
+ if level == "CRITICAL":
62
+ color = (0, 0, 255) # Red
63
+ elif level == "WARNING":
64
+ color = (0, 165, 255) # Orange
65
+ elif level == "CAUTION":
66
+ color = (0, 255, 255) # Yellow
67
+ else:
68
+ color = (0, 255, 0) # Green
69
+
70
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
71
+
72
+ dist_str = f"{det['distance_m']}m" if det["distance_m"] is not None else "n/a"
73
+ label = f"{det['class']} {dist_str}"
74
+ (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
75
+ y_top = max(y1 - th - 6, 0)
76
+ cv2.rectangle(frame, (x1, y_top), (x1 + tw + 8, y_top + th + 6), color, -1)
77
+ cv2.putText(frame, label, (x1 + 4, y_top + th),
78
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2, cv2.LINE_AA)
79
  return frame
80
 
81
+ # -------------- ROUTES --------------
82
+ @app.route("/ping")
83
+ def ping():
84
+ return jsonify({"ok": True}), 200
85
 
 
86
  @app.route("/stream", methods=["POST"])
87
  def process_frame():
88
  """
89
+ Accepts a single video frame via multipart/form-data field 'frame' (JPEG bytes),
90
+ returns annotated JPEG as body with alert metadata in headers.
91
  """
92
  if "frame" not in request.files:
93
  return jsonify({"error": "No frame uploaded"}), 400
94
 
95
+ # Decode image
96
  file = request.files["frame"]
97
+ file_bytes = file.read()
98
+ img_bytes = np.frombuffer(file_bytes, np.uint8)
99
  frame = cv2.imdecode(img_bytes, cv2.IMREAD_COLOR)
100
+ if frame is None:
101
+ return jsonify({"error": "Invalid image"}), 400
102
 
103
+ # Run inference
104
+ results = model(frame, conf=0.25, iou=0.5, verbose=False, device=device)
105
  boxes = results[0].boxes
106
 
107
  detections = []
108
  max_level = "SAFE"
109
  order = {"SAFE": 0, "CAUTION": 1, "WARNING": 2, "CRITICAL": 3}
110
 
111
+ if boxes is not None and len(boxes) > 0:
112
+ for b in boxes:
113
+ x1, y1, x2, y2 = b.xyxy[0].tolist()
114
+ cls_id = int(b.cls[0].item())
115
+ conf = float(b.conf[0].item())
116
+ class_name = model.names.get(cls_id, str(cls_id))
117
+ bbox_w = int(x2 - x1)
118
+
119
+ distance_m = estimate_distance(bbox_w, class_name)
120
+ level = get_alert_level(distance_m)
121
+ if order[level] > order[max_level]:
122
+ max_level = level
123
+
124
+ detections.append({
125
+ "class": class_name,
126
+ "confidence": round(conf, 3),
127
+ "distance_m": round(distance_m, 2) if distance_m else None,
128
+ "alert_level": level,
129
+ "bbox": [int(x1), int(y1), int(x2), int(y2)]
130
+ })
131
+
132
+ # Annotate and encode JPEG
133
  annotated = annotate_frame(frame, detections)
134
+ ok, buffer = cv2.imencode(".jpg", annotated, [int(cv2.IMWRITE_JPEG_QUALITY), 80])
135
+ if not ok:
136
+ return jsonify({"error": "Encode failed"}), 500
137
  encoded = buffer.tobytes()
138
 
139
+ # Headers with metadata
140
+ hdr_alert = max_level
141
+ hdr_count = str(len(detections))
142
+ headers = {
143
+ "Content-Length": str(len(encoded)), # some clients use it for streaming/decoding
144
+ "X-Alert-Level": hdr_alert,
145
+ "X-Detections-Count": hdr_count
146
+ }
 
 
 
 
 
147
 
148
+ return Response(encoded, mimetype="image/jpeg", headers=headers)
149
 
150
+ # -------------- MAIN --------------
151
  if __name__ == "__main__":
152
  port = int(os.environ.get("PORT", 7860))
153
+ # threaded=True allows concurrent requests from multiple clients
154
  app.run(host="0.0.0.0", port=port, threaded=True)