mohammedafeef commited on
Commit
b5d1da8
·
verified ·
1 Parent(s): 322d453

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import math
4
+ import torch
5
+ import numpy as np
6
+ from flask import Flask, request, jsonify, render_template_string
7
+ from ultralytics import YOLO
8
+
9
+ # ------------------------------
10
+ # CONFIG
11
+ # ------------------------------
12
+ MODEL_PATH = "best.pt" # your 5MB lightweight model
13
+ FOCAL_LENGTH_PX = 615
14
+ KNOWN_WIDTHS_M = {
15
+ "person": 0.5, "car": 1.8, "truck": 2.3, "bus": 2.5,
16
+ "bicycle": 0.6, "motorcycle": 0.7, "door": 0.9,
17
+ "stairs": 1.2, "pole": 0.15, "bollard": 0.2
18
+ }
19
+ THRESHOLDS = {"CRITICAL": 1.0, "WARNING": 2.0, "CAUTION": 3.0}
20
+
21
+ # ------------------------------
22
+ # Initialize Flask & YOLO
23
+ # ------------------------------
24
+ app = Flask(__name__)
25
+ device = 0 if torch.cuda.is_available() else "cpu"
26
+ model = YOLO(MODEL_PATH)
27
+ model.to(device)
28
+
29
+ # ------------------------------
30
+ # Utility Functions
31
+ # ------------------------------
32
+ def estimate_distance(bbox_width_px, class_name):
33
+ if bbox_width_px <= 1:
34
+ return None
35
+ known_width = KNOWN_WIDTHS_M.get(class_name)
36
+ if not known_width:
37
+ return None
38
+ return (known_width * FOCAL_LENGTH_PX) / float(bbox_width_px)
39
+
40
+ def get_alert_level(distance_m):
41
+ if distance_m is None:
42
+ return "SAFE"
43
+ if distance_m <= THRESHOLDS["CRITICAL"]:
44
+ return "CRITICAL"
45
+ if distance_m <= THRESHOLDS["WARNING"]:
46
+ return "WARNING"
47
+ if distance_m <= THRESHOLDS["CAUTION"]:
48
+ return "CAUTION"
49
+ return "SAFE"
50
+
51
+ # ------------------------------
52
+ # API ENDPOINT: /detect
53
+ # ------------------------------
54
+ @app.route("/detect", methods=["POST"])
55
+ def detect_objects():
56
+ """
57
+ Accepts an image file (from Unity or mobile camera)
58
+ and returns detections + alert level JSON.
59
+ """
60
+ if "file" not in request.files:
61
+ return jsonify({"error": "No file uploaded"}), 400
62
+
63
+ file = request.files["file"]
64
+ img_bytes = np.frombuffer(file.read(), np.uint8)
65
+ frame = cv2.imdecode(img_bytes, cv2.IMREAD_COLOR)
66
+
67
+ results = model(frame, conf=0.25, iou=0.5, verbose=False)
68
+ boxes = results[0].boxes
69
+
70
+ detections = []
71
+ max_level = "SAFE"
72
+ order = {"SAFE": 0, "CAUTION": 1, "WARNING": 2, "CRITICAL": 3}
73
+
74
+ for b in boxes:
75
+ x1, y1, x2, y2 = b.xyxy[0].tolist()
76
+ cls_id = int(b.cls[0].item())
77
+ conf = float(b.conf[0].item())
78
+ class_name = model.names.get(cls_id, str(cls_id))
79
+ bbox_w = x2 - x1
80
+ distance_m = estimate_distance(bbox_w, class_name)
81
+ level = get_alert_level(distance_m)
82
+
83
+ if order[level] > order[max_level]:
84
+ max_level = level
85
+
86
+ detections.append({
87
+ "class": class_name,
88
+ "confidence": round(conf, 3),
89
+ "distance_m": round(distance_m, 2) if distance_m else None,
90
+ "alert_level": level,
91
+ "bbox": [int(x1), int(y1), int(x2), int(y2)]
92
+ })
93
+
94
+ return jsonify({
95
+ "alert_level": max_level,
96
+ "detections": detections
97
+ })
98
+
99
+ # ------------------------------
100
+ # SIMPLE FRONTEND PAGE
101
+ # ------------------------------
102
+ @app.route("/")
103
+ def index():
104
+ return render_template_string("""
105
+ <html>
106
+ <head><title>YOLOv8 Real-Time Alerts</title></head>
107
+ <body style="font-family:sans-serif;text-align:center;">
108
+ <h2>YOLOv8 Object Detection API</h2>
109
+ <p>Send a POST /detect request with an image to get JSON alerts.</p>
110
+ <p>Example (curl):</p>
111
+ <code>
112
+ curl -X POST -F "file=@image.jpg" https://YOUR_SPACE_URL.hf.space/detect
113
+ </code>
114
+ </body>
115
+ </html>
116
+ """)
117
+
118
+ if __name__ == "__main__":
119
+ port = int(os.environ.get("PORT", 7860))
120
+ app.run(host="0.0.0.0", port=port)