Subh775 commited on
Commit
c7ef42a
·
verified ·
1 Parent(s): 86e6366

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -217
app.py CHANGED
@@ -2,198 +2,115 @@ import os
2
  import sys
3
  import io
4
  import base64
5
- import time
6
  import threading
7
- import traceback
8
- import json
9
  import requests
10
- import numpy as np
11
- import torch
12
  from flask import Flask, request, jsonify, send_from_directory
13
  from PIL import Image
 
 
 
 
14
 
15
- # Ensure local modules take precedence (fixes issues if rfdetr is both local and installed)
16
  sys.path.insert(0, os.getcwd())
17
 
18
- # Libraries for Models
19
- from ultralytics import YOLO
20
- import supervision as sv
 
 
 
21
 
22
- # Import RF-DETR (Must be present in project folder or installed)
23
- try:
24
- from rfdetr import RFDETRSegPreview
25
- except ImportError:
26
- print("[WARN] rfdetr module not found. RF-DETR inference will fail.")
27
- RFDETRSegPreview = None
28
-
29
- # --- Configuration ---
30
- os.environ["CUDA_VISIBLE_DEVICES"] = "" # Force CPU
31
  app = Flask(__name__, static_folder="static")
32
 
33
- # Class Names mapping (Ensuring consistency)
34
- CLASS_NAMES = {0: 'Gun', 1: 'Explosive', 2: 'Grenade', 3: 'Knife'}
35
-
36
- # --- Weight URLs ---
37
- # RF-DETR
38
- RF_REPO = "Subh775/Threat-Detection-RFDETR"
39
- RF_WEIGHT_URL = f"https://huggingface.co/{RF_REPO}/resolve/main/checkpoint_best_total.pth"
40
- RF_WEIGHT_PATH = "/tmp/rfdetr_best.pth"
41
-
42
- # YOLOv8
43
- YOLO_REPO = "Subh775/Threat-Detection-YOLOv8n"
44
- YOLO_WEIGHT_URL = f"https://huggingface.co/{YOLO_REPO}/resolve/main/weights/best.pt"
45
- YOLO_WEIGHT_PATH = "/tmp/yolov8_best.pt"
46
-
47
- # Global Models
48
- MODEL_RF = None
49
- MODEL_YOLO = None
50
- LOCK = threading.Lock()
51
-
52
- # --- Helper Functions ---
53
-
54
- def download_file(url, dst):
55
- if os.path.exists(dst) and os.path.getsize(dst) > 0:
56
- return
57
- print(f"[INFO] Downloading {url} to {dst}...")
58
- try:
59
- r = requests.get(url, stream=True, timeout=180)
60
- r.raise_for_status()
61
- with open(dst, "wb") as f:
62
- for chunk in r.iter_content(chunk_size=8192):
63
- f.write(chunk)
64
- print(f"[INFO] Download finished: {dst}")
65
- except Exception as e:
66
- print(f"[ERROR] Download failed: {e}")
67
-
68
- def init_models():
69
- """Load both models into memory."""
70
- global MODEL_RF, MODEL_YOLO
71
- with LOCK:
72
- # 1. Load RF-DETR
73
- if MODEL_RF is None and RFDETRSegPreview is not None:
74
- try:
75
- download_file(RF_WEIGHT_URL, RF_WEIGHT_PATH)
76
- print("[INFO] Loading RF-DETR...")
77
- # Initialize with CPU params
78
- # Added try-except to catch architecture mismatches (e.g. Nano vs Base)
79
- try:
80
- MODEL_RF = RFDETRSegPreview(pretrain_weights=RF_WEIGHT_PATH)
81
- # Attempt optimization if method exists
82
- if hasattr(MODEL_RF, 'optimize_for_inference'):
83
- MODEL_RF.optimize_for_inference()
84
- print("[INFO] RF-DETR Ready.")
85
- except RuntimeError as re:
86
- print(f"[ERROR] RF-DETR Architecture Mismatch: {re}")
87
- print("[WARN] Skipping RF-DETR loading. App will run with YOLO only.")
88
- MODEL_RF = None
89
- except Exception as e:
90
- print(f"[ERROR] RF-DETR Load Failed: {e}")
91
- traceback.print_exc()
92
-
93
- # 2. Load YOLOv8
94
- if MODEL_YOLO is None:
95
- try:
96
- download_file(YOLO_WEIGHT_URL, YOLO_WEIGHT_PATH)
97
- print("[INFO] Loading YOLOv8...")
98
- MODEL_YOLO = YOLO(YOLO_WEIGHT_PATH)
99
- print("[INFO] YOLOv8 Ready.")
100
- except Exception as e:
101
- print(f"[ERROR] YOLOv8 Load Failed: {e}")
102
- traceback.print_exc()
103
-
104
- def encode_image(pil_img):
105
- try:
106
- buf = io.BytesIO()
107
- pil_img.save(buf, format="JPEG", quality=85)
108
- return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode('utf-8')
109
- except Exception as e:
110
- print(f"[ERROR] Encode failed: {e}")
111
- return ""
112
-
113
- def decode_image(data_url):
114
- try:
115
- if "," in data_url:
116
- header, encoded = data_url.split(",", 1)
117
- else:
118
- encoded = data_url
119
- data = base64.b64decode(encoded)
120
- return Image.open(io.BytesIO(data)).convert("RGB")
121
- except Exception:
122
- raise ValueError("Invalid Image Data")
123
-
124
- def annotate_common(image, detections, model_name):
125
  """
126
- Standardize annotation using Supervision for both models.
 
127
  """
128
- try:
129
- # Create annotators
130
- box_annotator = sv.BoxAnnotator(thickness=2)
131
-
132
- labels = []
133
- # Handle different detection formats if necessary
134
- for class_id, confidence in zip(detections.class_id, detections.confidence):
135
- name = CLASS_NAMES.get(class_id, f"Class {class_id}")
136
- labels.append(f"{name} {confidence:.2f}")
137
-
138
- label_annotator = sv.LabelAnnotator(text_scale=0.5, text_padding=4)
139
-
140
- annotated_frame = image.copy()
141
- annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
142
- annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
143
-
144
- return annotated_frame
145
- except Exception as e:
146
- print(f"[WARN] Annotation failed for {model_name}: {e}")
147
- return image
148
-
149
- # --- Inference Logic ---
150
-
151
- def run_rfdetr_inference(image, conf):
152
- # FIX: If model is None, return original image, NOT a dict
153
- if MODEL_RF is None:
154
- return image, 0, 0
155
 
156
- start_time = time.perf_counter()
 
 
 
 
 
 
 
 
 
157
 
158
- try:
159
- # Run prediction
160
- detections = MODEL_RF.predict(image, threshold=conf)
161
-
162
- # Annotate
163
- annotated_img = annotate_common(image, detections, "RF-DETR")
164
- count = len(detections)
165
-
166
- latency = (time.perf_counter() - start_time) * 1000 # ms
167
- return annotated_img, count, latency
168
-
169
- except Exception as e:
170
- print(f"RF-DETR Inference Error: {e}")
171
- # Return original image on error
172
- return image, 0, 0
173
-
174
- def run_yolo_inference(image, conf):
175
- # FIX: If model is None, return original image, NOT a dict
176
- if MODEL_YOLO is None:
177
- return image, 0, 0
178
-
179
- start_time = time.perf_counter()
180
-
181
- try:
182
- # Run YOLO inference
183
- results = MODEL_YOLO(image, conf=conf, verbose=False)[0]
184
-
185
- # Convert to Supervision Detections
186
- detections = sv.Detections.from_ultralytics(results)
187
-
188
- annotated_img = annotate_common(image, detections, "YOLOv8")
189
- count = len(detections)
190
-
191
- latency = (time.perf_counter() - start_time) * 1000 # ms
192
- return annotated_img, count, latency
193
- except Exception as e:
194
- print(f"YOLO Inference Error: {e}")
195
- # Return original image on error
196
- return image, 0, 0
197
 
198
  # --- Routes ---
199
 
@@ -201,50 +118,54 @@ def run_yolo_inference(image, conf):
201
  def index():
202
  return send_from_directory('static', 'index.html')
203
 
204
- @app.route('/health', methods=['GET'])
205
- def health():
206
- return jsonify({"status": "running"})
207
-
208
  @app.route('/predict', methods=['POST'])
209
  def predict():
210
  try:
211
- # Ensure models are loaded (lazy loading)
212
- init_models()
213
-
214
- payload = request.json
215
- if not payload or 'image' not in payload:
216
- return jsonify({'error': 'No image provided'}), 400
217
-
218
- image = decode_image(payload['image'])
219
- conf = float(payload.get('conf', 0.25))
220
-
221
- # 1. Run RF-DETR
222
- rf_img, rf_count, rf_lat = run_rfdetr_inference(image.copy(), conf)
223
-
224
- # 2. Run YOLOv8
225
- yolo_img, yolo_count, yolo_lat = run_yolo_inference(image.copy(), conf)
226
-
227
- response = {
228
- "rfdetr": {
229
- "image": encode_image(rf_img),
230
- "count": rf_count,
231
- "latency": f"{rf_lat:.2f} ms",
232
- "model_name": "RF-DETR Nano"
233
- },
234
- "yolov8": {
235
- "image": encode_image(yolo_img),
236
- "count": yolo_count,
237
- "latency": f"{yolo_lat:.2f} ms",
238
- "model_name": "YOLOv8 Nano"
239
- }
240
- }
241
- return jsonify(response)
 
 
 
 
 
 
 
 
242
 
243
  except Exception as e:
244
- traceback.print_exc()
245
- return jsonify({'error': str(e)}), 500
246
 
247
  if __name__ == '__main__':
248
- # Initial warmup in background
249
- threading.Thread(target=init_models, daemon=True).start()
250
  app.run(host='0.0.0.0', port=7860)
 
2
  import sys
3
  import io
4
  import base64
 
5
  import threading
 
 
6
  import requests
 
 
7
  from flask import Flask, request, jsonify, send_from_directory
8
  from PIL import Image
9
+ import torch
10
+ import supervision as sv
11
+ from ultralytics import YOLO
12
+ from rfdetr import RFDETRNano
13
 
14
+ # Ensure local 'rfdetr' folder is found if present
15
  sys.path.insert(0, os.getcwd())
16
 
17
+ # Attempt to import the specific RF-DETR class
18
+ # try:
19
+ # from rfdetr import RFDETRNano
20
+ # except ImportError:
21
+ # print("[WARN] 'rfdetr' library not found. RF-DETR will be disabled.")
22
+ # RFDETRNano = None
23
 
 
 
 
 
 
 
 
 
 
24
  app = Flask(__name__, static_folder="static")
25
 
26
+ # --- Constants & Configuration ---
27
+ # Map Class IDs to Names (Common for both models if they share the dataset)
28
+ CLASS_MAP = {0: 'Gun', 1: 'Explosive', 2: 'Grenade', 3: 'Knife'}
29
+
30
+ # Weight Paths
31
+ RF_WEIGHTS_URL = "https://huggingface.co/Subh775/Threat-Detection-RFDETR/resolve/main/checkpoint_best_total.pth"
32
+ RF_WEIGHTS_PATH = "/tmp/rfdetr_best.pth"
33
+
34
+ YOLO_WEIGHTS_URL = "https://huggingface.co/Subh775/Threat-Detection-YOLOv8n/resolve/main/weights/best.pt"
35
+ YOLO_WEIGHTS_PATH = "/tmp/yolov8_best.pt"
36
+
37
+ # Global Model Instances
38
+ models = {
39
+ "rf": None,
40
+ "yolo": None
41
+ }
42
+
43
+ # --- Utilities ---
44
+
45
+ def download_if_missing(url, path):
46
+ """Downloads file from URL if it doesn't exist locally."""
47
+ if not os.path.exists(path):
48
+ print(f"[INFO] Downloading weights: {path}...")
49
+ try:
50
+ r = requests.get(url, stream=True)
51
+ r.raise_for_status()
52
+ with open(path, "wb") as f:
53
+ for chunk in r.iter_content(chunk_size=8192):
54
+ f.write(chunk)
55
+ print("[INFO] Download complete.")
56
+ except Exception as e:
57
+ print(f"[ERROR] Failed to download {url}: {e}")
58
+
59
+ def get_models():
60
+ """Lazy loader: initializes models only if they aren't ready."""
61
+ # 1. Load RF-DETR
62
+ if models["rf"] is None and RFDETRNano:
63
+ download_if_missing(RF_WEIGHTS_URL, RF_WEIGHTS_PATH)
64
+ try:
65
+ print("[INFO] Loading RF-DETR Nano...")
66
+ models["rf"] = RFDETRNano(pretrain_weights=RF_WEIGHTS_PATH)
67
+ except Exception as e:
68
+ print(f"[ERROR] RF-DETR Init Failed: {e}")
69
+
70
+ # 2. Load YOLOv8
71
+ if models["yolo"] is None:
72
+ download_if_missing(YOLO_WEIGHTS_URL, YOLO_WEIGHTS_PATH)
73
+ try:
74
+ print("[INFO] Loading YOLOv8...")
75
+ models["yolo"] = YOLO(YOLO_WEIGHTS_PATH)
76
+ except Exception as e:
77
+ print(f"[ERROR] YOLO Init Failed: {e}")
78
+
79
+ return models["rf"], models["yolo"]
80
+
81
+ def img_to_base64(img):
82
+ """Encodes PIL Image to Base64 string."""
83
+ buf = io.BytesIO()
84
+ img.save(buf, format="JPEG", quality=85)
85
+ return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode('utf-8')
86
+
87
+ def base64_to_img(data_str):
88
+ """Decodes Base64 string to PIL Image."""
89
+ if "base64," in data_str:
90
+ data_str = data_str.split("base64,")[1]
91
+ return Image.open(io.BytesIO(base64.b64decode(data_str))).convert("RGB")
92
+
93
+ def annotate_image(image, detections):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  """
95
+ Annotates an image with bounding boxes and labels using Supervision.
96
+ Expects detections to be a supervision.Detections object.
97
  """
98
+ # Initialize annotators
99
+ box_annotator = sv.BoxAnnotator(thickness=2)
100
+ label_annotator = sv.LabelAnnotator(text_scale=0.5, text_padding=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ # Generate labels: "ClassName Confidence"
103
+ labels = []
104
+ for class_id, conf in zip(detections.class_id, detections.confidence):
105
+ name = CLASS_MAP.get(class_id, str(class_id))
106
+ labels.append(f"{name} {conf:.2f}")
107
+
108
+ # Apply annotations
109
+ annotated = image.copy()
110
+ annotated = box_annotator.annotate(scene=annotated, detections=detections)
111
+ annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels)
112
 
113
+ return annotated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  # --- Routes ---
116
 
 
118
  def index():
119
  return send_from_directory('static', 'index.html')
120
 
 
 
 
 
121
  @app.route('/predict', methods=['POST'])
122
  def predict():
123
  try:
124
+ data = request.json
125
+ if not data or 'image' not in data:
126
+ return jsonify({"error": "No image data provided"}), 400
127
+
128
+ # Parse inputs
129
+ raw_image = base64_to_img(data['image'])
130
+ conf_threshold = float(data.get('conf', 0.25))
131
+
132
+ # Ensure models are loaded
133
+ rf_model, yolo_model = get_models()
134
+
135
+ # --- Run RF-DETR ---
136
+ rf_result_b64 = data['image'] # Fallback to original
137
+ if rf_model:
138
+ try:
139
+ # Predict -> Returns Supervision Detections
140
+ detections = rf_model.predict(raw_image, threshold=conf_threshold)
141
+ annotated_rf = annotate_image(raw_image, detections)
142
+ rf_result_b64 = img_to_base64(annotated_rf)
143
+ except Exception as e:
144
+ print(f"RF-DETR Inference Error: {e}")
145
+
146
+ # --- Run YOLOv8 ---
147
+ yolo_result_b64 = data['image'] # Fallback to original
148
+ if yolo_model:
149
+ try:
150
+ # Predict -> Returns Ultralytics Results -> Convert to Supervision
151
+ results = yolo_model(raw_image, conf=conf_threshold, verbose=False)[0]
152
+ detections = sv.Detections.from_ultralytics(results)
153
+ annotated_yolo = annotate_image(raw_image, detections)
154
+ yolo_result_b64 = img_to_base64(annotated_yolo)
155
+ except Exception as e:
156
+ print(f"YOLO Inference Error: {e}")
157
+
158
+ # Return JSON
159
+ return jsonify({
160
+ "rfdetr": {"image": rf_result_b64},
161
+ "yolov8": {"image": yolo_result_b64}
162
+ })
163
 
164
  except Exception as e:
165
+ print(f"Server Error: {e}")
166
+ return jsonify({"error": str(e)}), 500
167
 
168
  if __name__ == '__main__':
169
+ # Pre-load models in background to speed up first request
170
+ threading.Thread(target=get_models).start()
171
  app.run(host='0.0.0.0', port=7860)