SkyGuardAI commited on
Commit
940e55e
·
verified ·
1 Parent(s): 6976790

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -64
app.py CHANGED
@@ -11,35 +11,30 @@ import warnings
11
  warnings.filterwarnings('ignore')
12
 
13
  # ============================================================
14
- # 1. LOAD MODELS (مرة واحدة فقط)
15
  # ============================================================
16
  print("Loading YOLOv11 model from Hugging Face Hub...")
 
17
  model_path = hf_hub_download(
18
- repo_id="SkyGuardAI/drone-detection-yolov11", # <--- غيّر هذا
19
  filename="best.pt"
20
  )
21
  model = YOLO(model_path)
22
  print("YOLO model loaded successfully.")
23
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
- # تفعيل half precision إذا كان الجهاز يدعمه (GPU أو CPU حديث)
26
- use_fp16 = device == "cuda" or (hasattr(torch, 'has_mps') and torch.has_mps)
27
- print(f"Using device: {device}, FP16: {use_fp16}")
28
 
29
  print("Loading BLIP model...")
30
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
31
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
32
- if use_fp16:
33
- blip_model = blip_model.half() # تقليل الدقة للنصف
34
  blip_model.eval()
35
  print("BLIP model loaded successfully.")
36
 
37
  # ============================================================
38
- # 2. HEATMAP FUNCTIONS (محسّنة)
39
  # ============================================================
40
  layer_outputs = {}
41
- # تعريف الألوان مرة واحدة خارج الدالة
42
- COLORMAP_JET = cv2.COLORMAP_JET
43
 
44
  def hook_fn(module, input, output):
45
  layer_outputs['feature_map'] = output.detach()
@@ -59,8 +54,9 @@ def get_best_layer(model):
59
  def generate_heatmap(model, image):
60
  try:
61
  layer_outputs.clear()
62
- # تقليل حجم الصورة للخريطة الحرارية (تسريع كبير)
63
- img_resized = cv2.resize(image, (320, 320)) # كان 640×640
 
64
  pytorch_model = model.model if hasattr(model, 'model') else model
65
  target_layer = get_best_layer(pytorch_model)
66
  if target_layer is None:
@@ -82,39 +78,34 @@ def generate_heatmap(model, image):
82
  heatmap = (heatmap - min_val) / (max_val - min_val)
83
  else:
84
  heatmap = np.zeros_like(heatmap)
85
- # إعادة التكبير إلى 640×640 بعد المعالجة (للحفاظ على جودة العرض)
86
  heatmap = cv2.resize(heatmap, (640, 640))
87
  threshold = np.percentile(heatmap, 70)
88
  heatmap = np.where(heatmap > threshold, heatmap, 0)
89
  if heatmap.max() > 0:
90
  heatmap = heatmap / heatmap.max()
91
- heatmap_colored = cv2.applyColorMap((heatmap * 255).astype(np.uint8), COLORMAP_JET)
92
  heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
93
- # تكبير الخريطة النهائية لتتناسب مع الصورة الأصلية (تتم خارج الدالة)
94
- return heatmap_colored
95
  except Exception as e:
96
  print(f"Heatmap error: {e}")
97
  return None
98
 
99
  # ============================================================
100
- # 3. DYNAMIC CAPTION (BLIP) - مع FP16 و تقليل حجم الصورة
101
  # ============================================================
102
  def generate_dynamic_caption(image):
103
  try:
104
- # تقليل حجم الصورة قبل تمريرها إلى BLIP (تسريع كبير)
105
- if image.shape[0] > 480 or image.shape[1] > 480:
106
- image_small = cv2.resize(image, (480, 480))
107
- else:
108
- image_small = image
109
- pil_image = Image.fromarray(image_small).convert('RGB')
110
- inputs = blip_processor(pil_image, return_tensors="pt").to(device)
111
- if use_fp16:
112
- inputs = {k: v.half() for k, v in inputs.items()}
113
  with torch.no_grad():
114
  out = blip_model.generate(
115
  **inputs,
116
- max_length=40, # قللنا من 60 إلى 40
117
- num_beams=3, # قللنا من 5 إلى 3
118
  temperature=0.7,
119
  repetition_penalty=1.2
120
  )
@@ -125,10 +116,9 @@ def generate_dynamic_caption(image):
125
  return "AI model is analyzing the scene."
126
 
127
  # ============================================================
128
- # 4. XAI REPORT (بدون تغيير)
129
  # ============================================================
130
  def build_xai_report(is_drone, confidence, drone_count, processing_time, image_caption):
131
- # ... (نفس الكود السابق، لم يتغير) ...
132
  if is_drone:
133
  drone_text = "a drone" if drone_count == 1 else f"{drone_count} drones"
134
  if confidence >= 0.8:
@@ -220,7 +210,7 @@ Scattered activation pattern confirms absence of strong drone-like features.
220
  return report
221
 
222
  # ============================================================
223
- # 5. MAIN PIPELINE FUNCTION (محسّنة)
224
  # ============================================================
225
  def drone_detection_pipeline(input_image):
226
  try:
@@ -229,24 +219,10 @@ def drone_detection_pipeline(input_image):
229
  else:
230
  img = input_image.copy()
231
 
232
- # تقليل حجم الصورة قبل المعالجة (تسريع كبير)
233
  original_h, original_w = img.shape[:2]
234
- # تغيير الحجم إلى 640×640 إذا كانت الصورة أكبر
235
- if original_h > 640 or original_w > 640:
236
- scale = 640 / max(original_h, original_w)
237
- new_w = int(original_w * scale)
238
- new_h = int(original_h * scale)
239
- img_small = cv2.resize(img, (new_w, new_h))
240
- else:
241
- img_small = img
242
-
243
- # 1. الكشف باستخدام YOLO على الصورة المصغرة
244
- results = model(img_small)
245
-
246
- # 2. الخريطة الحرارية (تعمل على الصورة المصغرة أيضاً)
247
- heatmap_overlay = generate_heatmap(model, img_small)
248
 
249
- # 3. استخراج معلومات الكشف
250
  is_drone = False
251
  confidence = 0.0
252
  drone_count = 0
@@ -264,21 +240,16 @@ def drone_detection_pipeline(input_image):
264
  confidence = max(confidence, conf)
265
  is_drone = drone_count > 0
266
 
267
- # 4. إعادة رسم الصورة مع المربعات (نستخدم الصورة الأصلية للحفاظ على الجودة)
268
- result_img = results[0].plot() if len(results) > 0 else img_small
269
  result_img_rgb = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
 
 
 
270
 
271
- # 5. إنشاء الوصف النصي (يتم على الصورة المصغرة لتسريع BLIP)
272
- caption = generate_dynamic_caption(img_small)
273
-
274
- # 6. بناء التقرير
275
- report = build_xai_report(is_drone, confidence, drone_count, 0, caption)
276
-
277
- # 7. معالجة الخريطة الحرارية لتتناسب مع أبعاد الصورة الأصلية
278
  if heatmap_overlay is not None:
279
  heatmap_resized = cv2.resize(heatmap_overlay, (original_w, original_h))
280
  else:
281
- heatmap_resized = np.zeros((original_h, original_w, 3), dtype=np.uint8)
282
 
283
  return result_img_rgb, heatmap_resized, report
284
  except Exception as e:
@@ -288,14 +259,28 @@ def drone_detection_pipeline(input_image):
288
  return blank, blank, error_msg
289
 
290
  # ============================================================
291
- # 6. INTERFACE - مع إضافة queue() لتحسين الاستجابة
292
  # ============================================================
293
  custom_css = """
294
- /* (نفس CSS السابق، لم يتغير) */
295
- .gradio-container, body, .gradio-container .main {
296
  background: #ffffff !important;
297
  font-family: 'Segoe UI', 'Roboto', sans-serif;
298
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  .skyguard-header {
300
  text-align: center;
301
  padding: 1.5rem 0 0.5rem 0;
@@ -320,20 +305,25 @@ custom_css = """
320
  color: #2c5282 !important;
321
  margin-top: 0.5rem;
322
  }
323
- .gr-button-primary {
 
324
  background: #0a2b4e !important;
325
  border: none !important;
326
  color: white !important;
327
  font-weight: 700 !important;
328
  border-radius: 30px !important;
329
  padding: 10px 28px !important;
 
330
  text-transform: uppercase !important;
331
  letter-spacing: 1px;
 
332
  }
333
- .gr-button-primary:hover {
334
  background: #1e4a76 !important;
335
  transform: scale(1.02);
 
336
  }
 
337
  .gr-tabs .tab-nav button {
338
  background-color: #eef2f5 !important;
339
  color: #0a2b4e !important;
@@ -341,11 +331,28 @@ custom_css = """
341
  border-radius: 8px 8px 0 0 !important;
342
  padding: 8px 20px !important;
343
  margin-right: 4px;
 
344
  }
345
  .gr-tabs .tab-nav button.selected {
346
  background-color: #0a2b4e !important;
347
  color: white !important;
348
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  .skyguard-footer {
350
  text-align: center;
351
  margin-top: 30px;
@@ -354,6 +361,17 @@ custom_css = """
354
  border-top: 1px solid #e2e8f0;
355
  padding-top: 15px;
356
  }
 
 
 
 
 
 
 
 
 
 
 
357
  """
358
 
359
  with gr.Blocks(title="SkyGuard - Drone Detection System", theme=gr.themes.Soft(), css=custom_css) as demo:
@@ -392,8 +410,7 @@ with gr.Blocks(title="SkyGuard - Drone Detection System", theme=gr.themes.Soft()
392
  """)
393
 
394
  # ============================================================
395
- # 7. RUN APP - مع إضافة queue() لتحسين تدفق الطلبات
396
  # ============================================================
397
  if __name__ == "__main__":
398
- demo.queue(default_concurrency_limit=1) # معالجة طلب واحد فقط في كل مرة
399
  demo.launch()
 
11
  warnings.filterwarnings('ignore')
12
 
13
  # ============================================================
14
+ # 1. LOAD MODELS (CHANGE USERNAME BELOW)
15
  # ============================================================
16
  print("Loading YOLOv11 model from Hugging Face Hub...")
17
+ # ⚠️ IMPORTANT: Replace "YOUR_USERNAME" with your actual Hugging Face username
18
  model_path = hf_hub_download(
19
+ repo_id="SkyGuardAI/drone-detection-yolov11", # <--- CHANGE THIS
20
  filename="best.pt"
21
  )
22
  model = YOLO(model_path)
23
  print("YOLO model loaded successfully.")
24
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ print(f"Using device: {device}")
 
 
27
 
28
  print("Loading BLIP model...")
29
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
30
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
 
 
31
  blip_model.eval()
32
  print("BLIP model loaded successfully.")
33
 
34
  # ============================================================
35
+ # 2. HEATMAP FUNCTIONS (unchanged)
36
  # ============================================================
37
  layer_outputs = {}
 
 
38
 
39
  def hook_fn(module, input, output):
40
  layer_outputs['feature_map'] = output.detach()
 
54
  def generate_heatmap(model, image):
55
  try:
56
  layer_outputs.clear()
57
+ if isinstance(image, Image.Image):
58
+ image = np.array(image)
59
+ img_resized = cv2.resize(image, (640, 640))
60
  pytorch_model = model.model if hasattr(model, 'model') else model
61
  target_layer = get_best_layer(pytorch_model)
62
  if target_layer is None:
 
78
  heatmap = (heatmap - min_val) / (max_val - min_val)
79
  else:
80
  heatmap = np.zeros_like(heatmap)
 
81
  heatmap = cv2.resize(heatmap, (640, 640))
82
  threshold = np.percentile(heatmap, 70)
83
  heatmap = np.where(heatmap > threshold, heatmap, 0)
84
  if heatmap.max() > 0:
85
  heatmap = heatmap / heatmap.max()
86
+ heatmap_colored = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
87
  heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
88
+ overlay = cv2.addWeighted(img_resized, 0.6, heatmap_colored, 0.4, 0)
89
+ return overlay
90
  except Exception as e:
91
  print(f"Heatmap error: {e}")
92
  return None
93
 
94
  # ============================================================
95
+ # 3. DYNAMIC CAPTION (BLIP)
96
  # ============================================================
97
  def generate_dynamic_caption(image):
98
  try:
99
+ if isinstance(image, np.ndarray):
100
+ image = Image.fromarray(image).convert('RGB')
101
+ elif isinstance(image, Image.Image):
102
+ image = image.convert('RGB')
103
+ inputs = blip_processor(image, return_tensors="pt").to(device)
 
 
 
 
104
  with torch.no_grad():
105
  out = blip_model.generate(
106
  **inputs,
107
+ max_length=60,
108
+ num_beams=5,
109
  temperature=0.7,
110
  repetition_penalty=1.2
111
  )
 
116
  return "AI model is analyzing the scene."
117
 
118
  # ============================================================
119
+ # 4. XAI REPORT
120
  # ============================================================
121
  def build_xai_report(is_drone, confidence, drone_count, processing_time, image_caption):
 
122
  if is_drone:
123
  drone_text = "a drone" if drone_count == 1 else f"{drone_count} drones"
124
  if confidence >= 0.8:
 
210
  return report
211
 
212
  # ============================================================
213
+ # 5. MAIN PIPELINE FUNCTION
214
  # ============================================================
215
  def drone_detection_pipeline(input_image):
216
  try:
 
219
  else:
220
  img = input_image.copy()
221
 
 
222
  original_h, original_w = img.shape[:2]
223
+ results = model(img)
224
+ heatmap_overlay = generate_heatmap(model, img)
 
 
 
 
 
 
 
 
 
 
 
 
225
 
 
226
  is_drone = False
227
  confidence = 0.0
228
  drone_count = 0
 
240
  confidence = max(confidence, conf)
241
  is_drone = drone_count > 0
242
 
243
+ result_img = results[0].plot() if len(results) > 0 else img
 
244
  result_img_rgb = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
245
+ caption = generate_dynamic_caption(img)
246
+ processing_time_ms = 0
247
+ report = build_xai_report(is_drone, confidence, drone_count, processing_time_ms, caption)
248
 
 
 
 
 
 
 
 
249
  if heatmap_overlay is not None:
250
  heatmap_resized = cv2.resize(heatmap_overlay, (original_w, original_h))
251
  else:
252
+ heatmap_resized = np.zeros_like(result_img_rgb)
253
 
254
  return result_img_rgb, heatmap_resized, report
255
  except Exception as e:
 
259
  return blank, blank, error_msg
260
 
261
  # ============================================================
262
+ # 6. INTERFACE - DARK BLUE ONLY, NO PURPLE/VIOLET
263
  # ============================================================
264
  custom_css = """
265
+ /* Force white background */
266
+ .gradio-container, body, .gradio-container .main, .gradio-container .gradio-container {
267
  background: #ffffff !important;
268
  font-family: 'Segoe UI', 'Roboto', sans-serif;
269
  }
270
+ /* Override any purple/violet colors from Gradio default theme */
271
+ :root {
272
+ --primary-50: #0a2b4e !important;
273
+ --primary-100: #1e4a76 !important;
274
+ --primary-200: #2c5282 !important;
275
+ --primary-300: #1e3a8a !important;
276
+ --primary-400: #0a2b4e !important;
277
+ --primary-500: #0a2b4e !important;
278
+ --primary-600: #0a2b4e !important;
279
+ --primary-700: #0a2b4e !important;
280
+ --primary-800: #0a2b4e !important;
281
+ --primary-900: #0a2b4e !important;
282
+ }
283
+ /* Header - dark blue */
284
  .skyguard-header {
285
  text-align: center;
286
  padding: 1.5rem 0 0.5rem 0;
 
305
  color: #2c5282 !important;
306
  margin-top: 0.5rem;
307
  }
308
+ /* Primary button - solid dark blue, no gradient, no purple */
309
+ .gr-button-primary, button.gr-button-primary, .gr-button.primary {
310
  background: #0a2b4e !important;
311
  border: none !important;
312
  color: white !important;
313
  font-weight: 700 !important;
314
  border-radius: 30px !important;
315
  padding: 10px 28px !important;
316
+ transition: all 0.2s ease !important;
317
  text-transform: uppercase !important;
318
  letter-spacing: 1px;
319
+ box-shadow: none !important;
320
  }
321
+ .gr-button-primary:hover, button.gr-button-primary:hover, .gr-button.primary:hover {
322
  background: #1e4a76 !important;
323
  transform: scale(1.02);
324
+ box-shadow: 0 2px 8px rgba(10,43,78,0.2) !important;
325
  }
326
+ /* Tabs - remove purple */
327
  .gr-tabs .tab-nav button {
328
  background-color: #eef2f5 !important;
329
  color: #0a2b4e !important;
 
331
  border-radius: 8px 8px 0 0 !important;
332
  padding: 8px 20px !important;
333
  margin-right: 4px;
334
+ border: none !important;
335
  }
336
  .gr-tabs .tab-nav button.selected {
337
  background-color: #0a2b4e !important;
338
  color: white !important;
339
  }
340
+ /* Input and output boxes */
341
+ .gr-box, .gr-form, .gr-input, .gr-panel {
342
+ background-color: #f9fafb !important;
343
+ border: 1px solid #d1d5db !important;
344
+ border-radius: 12px !important;
345
+ }
346
+ /* Labels */
347
+ label, .gr-form label {
348
+ color: #1e293b !important;
349
+ font-weight: 500 !important;
350
+ }
351
+ /* Markdown text */
352
+ .gr-markdown p, .gr-markdown {
353
+ color: #1e293b !important;
354
+ }
355
+ /* Footer */
356
  .skyguard-footer {
357
  text-align: center;
358
  margin-top: 30px;
 
361
  border-top: 1px solid #e2e8f0;
362
  padding-top: 15px;
363
  }
364
+ /* Additional overrides for any purple elements */
365
+ .gr-button, .gr-button-lg, .gr-button-sm, button {
366
+ --tw-ring-color: #0a2b4e !important;
367
+ }
368
+ .prose a, a {
369
+ color: #1e4a76 !important;
370
+ }
371
+ input:focus, textarea:focus, select:focus {
372
+ border-color: #0a2b4e !important;
373
+ ring-color: #0a2b4e !important;
374
+ }
375
  """
376
 
377
  with gr.Blocks(title="SkyGuard - Drone Detection System", theme=gr.themes.Soft(), css=custom_css) as demo:
 
410
  """)
411
 
412
  # ============================================================
413
+ # 7. RUN APP
414
  # ============================================================
415
  if __name__ == "__main__":
 
416
  demo.launch()