mohammedafeef commited on
Commit
3748a91
·
verified ·
1 Parent(s): 636c470

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -48
app.py CHANGED
@@ -1,15 +1,14 @@
1
  import os
2
  import cv2
3
- import math
4
  import torch
5
  import numpy as np
6
- from flask import Flask, request, jsonify, render_template_string
7
  from ultralytics import YOLO
8
 
9
- # ------------------------------
10
- # CONFIG
11
- # ------------------------------
12
- MODEL_PATH = "best.pt" # your 5MB lightweight model
13
  FOCAL_LENGTH_PX = 615
14
  KNOWN_WIDTHS_M = {
15
  "person": 0.5, "car": 1.8, "truck": 2.3, "bus": 2.5,
@@ -18,17 +17,13 @@ KNOWN_WIDTHS_M = {
18
  }
19
  THRESHOLDS = {"CRITICAL": 1.0, "WARNING": 2.0, "CAUTION": 3.0}
20
 
21
- # ------------------------------
22
- # Initialize Flask & YOLO
23
- # ------------------------------
24
- app = Flask(__name__)
25
  device = 0 if torch.cuda.is_available() else "cpu"
26
  model = YOLO(MODEL_PATH)
27
  model.to(device)
 
28
 
29
- # ------------------------------
30
- # Utility Functions
31
- # ------------------------------
32
  def estimate_distance(bbox_width_px, class_name):
33
  if bbox_width_px <= 1:
34
  return None
@@ -48,19 +43,34 @@ def get_alert_level(distance_m):
48
  return "CAUTION"
49
  return "SAFE"
50
 
51
- # ------------------------------
52
- # API ENDPOINT: /detect
53
- # ------------------------------
54
- @app.route("/detect", methods=["POST"])
55
- def detect_objects():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  """
57
- Accepts an image file (from Unity or mobile camera)
58
- and returns detections + alert level JSON.
59
  """
60
- if "file" not in request.files:
61
- return jsonify({"error": "No file uploaded"}), 400
62
 
63
- file = request.files["file"]
64
  img_bytes = np.frombuffer(file.read(), np.uint8)
65
  frame = cv2.imdecode(img_bytes, cv2.IMREAD_COLOR)
66
 
@@ -72,49 +82,42 @@ def detect_objects():
72
  order = {"SAFE": 0, "CAUTION": 1, "WARNING": 2, "CRITICAL": 3}
73
 
74
  for b in boxes:
75
- x1, y1, x2, y2 = b.xyxy[0].tolist()
76
  cls_id = int(b.cls[0].item())
77
  conf = float(b.conf[0].item())
78
  class_name = model.names.get(cls_id, str(cls_id))
79
  bbox_w = x2 - x1
80
  distance_m = estimate_distance(bbox_w, class_name)
81
  level = get_alert_level(distance_m)
82
-
83
  if order[level] > order[max_level]:
84
  max_level = level
85
-
86
  detections.append({
87
  "class": class_name,
88
  "confidence": round(conf, 3),
89
  "distance_m": round(distance_m, 2) if distance_m else None,
90
  "alert_level": level,
91
- "bbox": [int(x1), int(y1), int(x2), int(y2)]
92
  })
93
 
94
- return jsonify({
95
- "alert_level": max_level,
96
- "detections": detections
97
- })
 
 
 
 
 
 
98
 
99
- # ------------------------------
100
- # SIMPLE FRONTEND PAGE
101
- # ------------------------------
102
  @app.route("/")
103
  def index():
104
- return render_template_string("""
105
- <html>
106
- <head><title>YOLOv8 Real-Time Alerts</title></head>
107
- <body style="font-family:sans-serif;text-align:center;">
108
- <h2>YOLOv8 Object Detection API</h2>
109
- <p>Send a POST /detect request with an image to get JSON alerts.</p>
110
- <p>Example (curl):</p>
111
- <code>
112
- curl -X POST -F "file=@image.jpg" https://YOUR_SPACE_URL.hf.space/detect
113
- </code>
114
- </body>
115
- </html>
116
- """)
117
 
118
  if __name__ == "__main__":
119
  port = int(os.environ.get("PORT", 7860))
120
- app.run(host="0.0.0.0", port=port)
 
1
  import os
2
  import cv2
 
3
  import torch
4
  import numpy as np
5
+ from flask import Flask, Response, request, jsonify
6
  from ultralytics import YOLO
7
 
8
+ app = Flask(__name__)
9
+
10
+ # ---------------- CONFIG ----------------
11
+ MODEL_PATH = "best.pt"
12
  FOCAL_LENGTH_PX = 615
13
  KNOWN_WIDTHS_M = {
14
  "person": 0.5, "car": 1.8, "truck": 2.3, "bus": 2.5,
 
17
  }
18
  THRESHOLDS = {"CRITICAL": 1.0, "WARNING": 2.0, "CAUTION": 3.0}
19
 
 
 
 
 
20
  device = 0 if torch.cuda.is_available() else "cpu"
21
  model = YOLO(MODEL_PATH)
22
  model.to(device)
23
+ model.fuse() # small speed boost
24
 
25
+
26
+ # ---------------- UTIL FUNCTIONS ----------------
 
27
  def estimate_distance(bbox_width_px, class_name):
28
  if bbox_width_px <= 1:
29
  return None
 
43
  return "CAUTION"
44
  return "SAFE"
45
 
46
+ def annotate_frame(frame, detections):
47
+ for det in detections:
48
+ x1, y1, x2, y2 = det["bbox"]
49
+ color = (0, 255, 0)
50
+ if det["alert_level"] == "CRITICAL":
51
+ color = (0, 0, 255)
52
+ elif det["alert_level"] == "WARNING":
53
+ color = (0, 165, 255)
54
+ elif det["alert_level"] == "CAUTION":
55
+ color = (0, 255, 255)
56
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
57
+ label = f"{det['class']} {det['distance_m']}m"
58
+ cv2.putText(frame, label, (x1, y1 - 10),
59
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
60
+ return frame
61
+
62
+
63
+ # ---------------- REALTIME VIDEO STREAM ----------------
64
+ @app.route("/stream", methods=["POST"])
65
+ def process_frame():
66
  """
67
+ Accepts a single video frame (JPEG bytes from Unity or camera),
68
+ returns detection data + optionally annotated frame.
69
  """
70
+ if "frame" not in request.files:
71
+ return jsonify({"error": "No frame uploaded"}), 400
72
 
73
+ file = request.files["frame"]
74
  img_bytes = np.frombuffer(file.read(), np.uint8)
75
  frame = cv2.imdecode(img_bytes, cv2.IMREAD_COLOR)
76
 
 
82
  order = {"SAFE": 0, "CAUTION": 1, "WARNING": 2, "CRITICAL": 3}
83
 
84
  for b in boxes:
85
+ x1, y1, x2, y2 = map(int, b.xyxy[0].tolist())
86
  cls_id = int(b.cls[0].item())
87
  conf = float(b.conf[0].item())
88
  class_name = model.names.get(cls_id, str(cls_id))
89
  bbox_w = x2 - x1
90
  distance_m = estimate_distance(bbox_w, class_name)
91
  level = get_alert_level(distance_m)
 
92
  if order[level] > order[max_level]:
93
  max_level = level
 
94
  detections.append({
95
  "class": class_name,
96
  "confidence": round(conf, 3),
97
  "distance_m": round(distance_m, 2) if distance_m else None,
98
  "alert_level": level,
99
+ "bbox": [x1, y1, x2, y2]
100
  })
101
 
102
+ annotated = annotate_frame(frame, detections)
103
+ _, buffer = cv2.imencode('.jpg', annotated)
104
+ encoded = buffer.tobytes()
105
+
106
+ return Response(encoded, mimetype='image/jpeg',
107
+ headers={
108
+ "X-Alert-Level": max_level,
109
+ "X-Detections": str(detections)
110
+ })
111
+
112
 
 
 
 
113
  @app.route("/")
114
  def index():
115
+ return """
116
+ <h2>YOLOv8 Real-Time Detection Stream</h2>
117
+ <p>POST /stream with 'frame' (JPEG) from Unity camera feed.</p>
118
+ """
119
+
 
 
 
 
 
 
 
 
120
 
121
  if __name__ == "__main__":
122
  port = int(os.environ.get("PORT", 7860))
123
+ app.run(host="0.0.0.0", port=port, threaded=True)