openfree commited on
Commit
dcf4020
Β·
verified Β·
1 Parent(s): 1c7e0b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +498 -428
app.py CHANGED
@@ -1,460 +1,530 @@
1
- # UVIS - Gradio App with Upload, URL & Video Support + HF Token Authentication
2
- """
3
- This script launches the UVIS (Unified Visual Intelligence System) as a Gradio Web App.
4
- Supports image, video, and URL-based media inputs for detection, segmentation, and depth estimation.
5
- Outputs include scene blueprint, structured JSON, and downloadable results.
6
- Now includes HuggingFace token authentication for private model access.
7
- """
8
-
9
  import os
10
- import time
11
- import logging
12
- import traceback
13
-
 
 
 
14
  import gradio as gr
 
 
15
  from PIL import Image
16
- import cv2
17
- import timeout_decorator
18
- import spaces
19
- import tempfile
 
 
 
20
  import shutil
 
 
 
 
21
 
22
- from registry import get_model
23
- from core.describe_scene import describe_scene
24
- from core.process import process_image, process_video
25
- from core.input_handler import resolve_input, validate_video, validate_image
26
- from utils.helpers import format_error, generate_session_id
27
- from huggingface_hub import hf_hub_download, login
28
-
29
- # HuggingFace Token Authentication
30
- HF_TOKEN = os.getenv("HF_TOKEN")
31
- if HF_TOKEN:
32
- try:
33
- login(token=HF_TOKEN)
34
- print("βœ… Successfully authenticated with HuggingFace using HF_TOKEN")
35
- except Exception as e:
36
- print(f"⚠️ Failed to authenticate with HuggingFace: {e}")
37
- else:
38
- print("⚠️ HF_TOKEN not found in environment variables. Some models may not be accessible.")
39
-
40
- # Clear HF cache if needed
41
  try:
42
- cache_paths = [
43
- os.path.expanduser("~/.cache/huggingface"),
44
- "/home/user/.cache/huggingface"
45
- ]
46
- for path in cache_paths:
47
- if os.path.exists(path):
48
- shutil.rmtree(path, ignore_errors=True)
49
- print("πŸ’₯ Nuked HF model cache from runtime.")
50
- except Exception as e:
51
- print("🚫 Failed to nuke cache:", e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # Setup logging
54
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
55
- logger = logging.getLogger(__name__)
56
 
57
- # Model mappings
58
- DETECTION_MODEL_MAP = {
59
- "YOLOv8-Nano": "yolov8n",
60
- "YOLOv8-Small": "yolov8s",
61
- "YOLOv8-Large": "yolov8l",
62
- "YOLOv11-Beta": "yolov11b"
63
- }
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- SEGMENTATION_MODEL_MAP = {
66
- "SegFormer-B0": "segformer_b0",
67
- "SegFormer-B5": "segformer_b5",
68
- "DeepLabV3-ResNet50": "deeplabv3_resnet50"
69
- }
 
 
 
 
 
 
 
 
 
 
70
 
71
- DEPTH_MODEL_MAP = {
72
- "MiDaS v21 Small 256": "midas_v21_small_256",
73
- "MiDaS v21 384": "midas_v21_384",
74
- "DPT Hybrid 384": "dpt_hybrid_384",
75
- "DPT Swin2 Large 384": "dpt_swin2_large_384",
76
- "DPT Beit Large 512": "dpt_beit_large_512"
77
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- # Modified get_model wrapper to include HF token
80
- def get_model_with_auth(model_type, model_name, device="cpu"):
81
- """
82
- Wrapper for get_model that includes HF token authentication.
83
- """
84
- # Pass HF_TOKEN to the registry get_model function if it exists
85
- # This assumes the registry.get_model can accept a token parameter
86
  try:
87
- if hasattr(get_model, '__code__') and 'token' in get_model.__code__.co_varnames:
88
- return get_model(model_type, model_name, device=device, token=HF_TOKEN)
89
- else:
90
- # If get_model doesn't support token, use standard call
91
- return get_model(model_type, model_name, device=device)
 
 
 
 
 
 
 
 
 
92
  except Exception as e:
93
- logger.error(f"Failed to load model {model_type}/{model_name}: {e}")
94
- # Fallback: try without token parameter
95
- return get_model(model_type, model_name, device=device)
96
 
97
- @spaces.GPU
98
- def handle(mode, media_upload, url,
99
- run_det, det_model, det_confidence,
100
- run_seg, seg_model,
101
- run_depth, depth_model,
102
- blend):
103
- """
104
- Master handler for resolving input and processing.
105
- Returns: (img_out, vid_out, json_out, zip_out)
106
- """
107
- session_id = generate_session_id()
108
- logger.info(f"Session ID: {session_id} | Handler activated with mode: {mode}")
109
- start_time = time.time()
 
110
 
111
- # Check HF authentication status
112
- if not HF_TOKEN:
113
- logger.warning("Processing without HF authentication. Some models may not be available.")
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- media = resolve_input(mode, media_upload, url)
116
- if not media:
117
- return (
118
- gr.update(visible=False),
119
- gr.update(visible=False),
120
- format_error("No valid input provided. Please check your upload or URL."),
121
- None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- first_input = media[0]
125
-
126
- # πŸ”§ Resolve dropdown label to model keys
127
- resolved_det_model = DETECTION_MODEL_MAP.get(det_model, det_model)
128
- resolved_seg_model = SEGMENTATION_MODEL_MAP.get(seg_model, seg_model)
129
- resolved_depth_model = DEPTH_MODEL_MAP.get(depth_model, depth_model)
130
-
131
- # --- VIDEO PATH ---
132
- if isinstance(first_input, str) and first_input.lower().endswith((".mp4", ".mov", ".avi")):
133
- valid, err = validate_video(first_input)
134
- if not valid:
135
- return (
136
- gr.update(visible=False),
137
- gr.update(visible=False),
138
- format_error(err),
139
- None
140
- )
141
- try:
142
- # Pass HF_TOKEN to process_video if needed
143
- _, msg, output_video_path = process_video(
144
- video_path=first_input,
145
- run_det=run_det,
146
- det_model=resolved_det_model,
147
- det_confidence=det_confidence,
148
- run_seg=run_seg,
149
- seg_model=resolved_seg_model,
150
- run_depth=run_depth,
151
- depth_model=resolved_depth_model,
152
- blend=blend,
153
- hf_token=HF_TOKEN # Pass token if process_video supports it
154
- )
155
- return (
156
- gr.update(visible=False), # hide image
157
- gr.update(value=output_video_path, visible=True), # show video
158
- msg,
159
- output_video_path # for download
160
- )
161
- except Exception as e:
162
- logger.error(f"Video processing failed: {e}")
163
- # If it's an authentication error, provide specific message
164
- if "401" in str(e) or "unauthorized" in str(e).lower():
165
- error_msg = "Authentication failed. Please check HF_TOKEN environment variable."
166
- else:
167
- error_msg = str(e)
168
- return (
169
- gr.update(visible=False),
170
- gr.update(visible=False),
171
- format_error(error_msg),
172
- None
173
- )
174
-
175
- # --- IMAGE PATH ---
176
- elif isinstance(first_input, Image.Image):
177
- valid, err = validate_image(first_input)
178
- if not valid:
179
- return (
180
- gr.update(visible=False),
181
- gr.update(visible=False),
182
- format_error(err),
183
- None
184
- )
185
  try:
186
- # Pass HF_TOKEN to process_image if needed
187
- result_img, msg, output_zip = process_image(
188
- image=first_input,
189
- run_det=run_det,
190
- det_model=resolved_det_model,
191
- det_confidence=det_confidence,
192
- run_seg=run_seg,
193
- seg_model=resolved_seg_model,
194
- run_depth=run_depth,
195
- depth_model=resolved_depth_model,
196
- blend=blend,
197
- hf_token=HF_TOKEN # Pass token if process_image supports it
198
- )
199
- return (
200
- gr.update(value=result_img, visible=True), # show image
201
- gr.update(visible=False), # hide video
202
- msg,
203
- output_zip
204
- )
205
- except timeout_decorator.timeout_decorator.TimeoutError:
206
- logger.error("Image processing timed out.")
207
- return (
208
- gr.update(visible=False),
209
- gr.update(visible=False),
210
- format_error("Processing timed out. Try a smaller image or simpler model."),
211
- None
212
- )
213
- except Exception as e:
214
- traceback.print_exc()
215
- logger.error(f"Image processing failed: {e}")
216
- # If it's an authentication error, provide specific message
217
- if "401" in str(e) or "unauthorized" in str(e).lower():
218
- error_msg = "Authentication failed. Please check HF_TOKEN environment variable."
219
- else:
220
- error_msg = str(e)
221
- return (
222
- gr.update(visible=False),
223
- gr.update(visible=False),
224
- format_error(error_msg),
225
- None
226
- )
227
-
228
- logger.warning("Unsupported media type resolved.")
229
- return (
230
- gr.update(visible=False),
231
- gr.update(visible=False),
232
- format_error("Unsupported input type."),
233
- None
234
- )
235
-
236
- def show_preview_from_upload(files):
237
- if not files:
238
- return gr.update(visible=False), gr.update(visible=False)
239
 
240
- file = files[0]
241
- filename = file.name.lower()
242
-
243
- if filename.endswith((".png", ".jpg", ".jpeg", ".webp")):
244
- img = Image.open(file).convert("RGB")
245
- return gr.update(value=img, visible=True), gr.update(visible=False)
246
-
247
- elif filename.endswith((".mp4", ".mov", ".avi")):
248
- # Copy uploaded video to a known temp location
249
- temp_dir = tempfile.mkdtemp()
250
- ext = os.path.splitext(filename)[-1]
251
- safe_path = os.path.join(temp_dir, f"uploaded_video{ext}")
252
- with open(safe_path, "wb") as f:
253
- f.write(file.read())
254
-
255
- return gr.update(visible=False), gr.update(value=safe_path, visible=True)
256
-
257
- return gr.update(visible=False), gr.update(visible=False)
258
-
259
- def show_preview_from_url(url_input):
260
- if not url_input:
261
- return gr.update(visible=False), gr.update(visible=False)
262
- path = url_input.strip().lower()
263
- if path.endswith((".png", ".jpg", ".jpeg", ".webp")):
264
- return gr.update(value=url_input, visible=True), gr.update(visible=False)
265
- elif path.endswith((".mp4", ".mov", ".avi")):
266
- return gr.update(visible=False), gr.update(value=url_input, visible=True)
267
- return gr.update(visible=False), gr.update(visible=False)
 
 
 
 
 
268
 
269
- def clear_model_cache():
270
- """
271
- Deletes all model weight folders so they are redownloaded fresh.
272
- """
273
- folders = [
274
- "models/detection/weights",
275
- "models/segmentation/weights",
276
- "models/depth/weights"
277
- ]
278
- for folder in folders:
279
- shutil.rmtree(folder, ignore_errors=True)
280
- logger.info(f"πŸ—‘οΈ Cleared: {folder}")
281
 
282
- # Also clear HF cache if token is available
283
- if HF_TOKEN:
284
  try:
285
- cache_paths = [
286
- os.path.expanduser("~/.cache/huggingface"),
287
- "/home/user/.cache/huggingface"
288
- ]
289
- for path in cache_paths:
290
- if os.path.exists(path):
291
- shutil.rmtree(path, ignore_errors=True)
292
- return "βœ… Model cache and HF cache cleared. Models will be reloaded on next run."
293
- except Exception as e:
294
- return f"⚠️ Model cache cleared, but failed to clear HF cache: {e}"
295
-
296
- return "βœ… Model cache cleared. Models will be reloaded on next run."
297
-
298
- def check_auth_status():
299
- """
300
- Check and display current authentication status.
301
- """
302
- if HF_TOKEN:
303
- return f"βœ… Authenticated with HuggingFace (Token: {HF_TOKEN[:8]}...)"
304
- else:
305
- return "❌ Not authenticated. Set HF_TOKEN environment variable for private model access."
306
-
307
- # Gradio Interface
308
- with gr.Blocks(title="UVIS - Unified Visual Intelligence System") as demo:
309
- gr.Markdown("## Unified Visual Intelligence System (UVIS)")
310
 
311
- # Authentication Status
312
- with gr.Row():
313
- auth_status = gr.Textbox(
314
- label="HF Authentication Status",
315
- value=check_auth_status(),
316
- interactive=False
317
- )
318
 
319
- with gr.Row():
320
- # left panel
321
- with gr.Column(scale=2):
322
- # Input Mode Toggle
323
- mode = gr.Radio(["Upload", "URL"], value="Upload", label="Input Mode")
324
-
325
- # File upload: accepts multiple images or one video (user chooses wisely)
326
- media_upload = gr.File(
327
- label="Upload Images (1–5) or 1 Video",
328
- file_types=["image", ".mp4", ".mov", ".avi"],
329
- file_count="multiple",
330
- visible=True
331
- )
332
-
333
- # URL input
334
- url = gr.Textbox(label="URL (Image/Video)", visible=False)
335
-
336
- # Toggle visibility
337
- def toggle_inputs(selected_mode):
338
- return [
339
- gr.update(visible=(selected_mode == "Upload")), # media_upload
340
- gr.update(visible=(selected_mode == "URL")), # url
341
- gr.update(visible=False), # preview_image
342
- gr.update(visible=False) # preview_video
343
- ]
344
-
345
- mode.change(toggle_inputs, inputs=mode, outputs=[media_upload, url])
346
-
347
- # Visibility logic function
348
- def toggle_visibility(checked):
349
- return gr.update(visible=checked)
350
-
351
- run_det = gr.Checkbox(label="Object Detection")
352
- run_seg = gr.Checkbox(label="Semantic Segmentation")
353
- run_depth = gr.Checkbox(label="Depth Estimation")
354
-
355
- with gr.Row():
356
- with gr.Column(visible=False) as OD_Settings:
357
- with gr.Accordion("Object Detection Settings", open=True):
358
- det_model = gr.Dropdown(
359
- choices=list(DETECTION_MODEL_MAP.keys()),
360
- label="Detection Model",
361
- value="YOLOv8-Nano"
362
- )
363
- det_confidence = gr.Slider(0.1, 1.0, 0.5, label="Detection Confidence Threshold")
364
- nms_thresh = gr.Slider(0.1, 1.0, 0.45, label="NMS Threshold")
365
- max_det = gr.Slider(1, 100, 20, step=1, label="Max Detections")
366
- iou_thresh = gr.Slider(0.1, 1.0, 0.5, label="IoU Threshold")
367
- class_filter = gr.CheckboxGroup(["Person", "Car", "Dog"], label="Class Filter")
368
-
369
- with gr.Column(visible=False) as SS_Settings:
370
- with gr.Accordion("Semantic Segmentation Settings", open=True):
371
- seg_model = gr.Dropdown(
372
- choices=list(SEGMENTATION_MODEL_MAP.keys()),
373
- label="Segmentation Model",
374
- value="DeepLabV3-ResNet50"
375
- )
376
- resize_strategy = gr.Dropdown(["Crop", "Pad", "Scale"], label="Resize Strategy", value="Scale")
377
- overlay_alpha = gr.Slider(0.0, 1.0, 0.5, label="Overlay Opacity")
378
- seg_classes = gr.CheckboxGroup(["Road", "Sky", "Building"], label="Target Classes")
379
- enable_crf = gr.Checkbox(label="Postprocessing (CRF)")
380
 
381
- with gr.Column(visible=False) as DE_Settings:
382
- with gr.Accordion("Depth Estimation Settings", open=True):
383
- depth_model = gr.Dropdown(
384
- choices=list(DEPTH_MODEL_MAP.keys()),
385
- label="Depth Model",
386
- value="MiDaS v21 Small 256"
387
- )
388
- output_type = gr.Dropdown(["Raw", "Disparity", "Scaled"], label="Output Type", value="Scaled")
389
- colormap = gr.Dropdown(["Jet", "Viridis", "Plasma"], label="Colormap", value="Jet")
390
- blend = gr.Slider(0.0, 1.0, 0.5, label="Overlay Blend")
391
- normalize = gr.Checkbox(label="Normalize Depth", value=True)
392
- max_depth = gr.Slider(0.1, 10.0, 5.0, label="Max Depth (meters)")
393
-
394
- # Attach Visibility Logic
395
- run_det.change(fn=toggle_visibility, inputs=[run_det], outputs=[OD_Settings])
396
- run_seg.change(fn=toggle_visibility, inputs=[run_seg], outputs=[SS_Settings])
397
- run_depth.change(fn=toggle_visibility, inputs=[run_depth], outputs=[DE_Settings])
398
-
399
- blend = gr.Slider(0.0, 1.0, 0.5, label="Overlay Blend")
400
 
401
- # Run Button
402
- run = gr.Button("Run Analysis", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
- # Right panel
405
- with gr.Column(scale=1):
406
- # Only one is shown at a time β€” image or video
407
- img_out = gr.Image(label="Preview / Processed Output", visible=False)
408
- vid_out = gr.Video(label="Preview / Processed Video", visible=False, streaming=True, autoplay=True)
409
- json_out = gr.JSON(label="Scene JSON")
410
- zip_out = gr.File(label="Download Results")
411
-
412
- with gr.Row():
413
- clear_button = gr.Button("🧹 Clear Model Cache")
414
- refresh_auth_button = gr.Button("πŸ”„ Refresh Auth Status")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
- status_box = gr.Textbox(label="Status", interactive=False)
 
 
417
 
418
- clear_button.click(fn=clear_model_cache, inputs=[], outputs=[status_box])
419
- refresh_auth_button.click(fn=check_auth_status, inputs=[], outputs=[auth_status])
420
-
421
- media_upload.change(show_preview_from_upload, inputs=media_upload, outputs=[img_out, vid_out])
422
- url.submit(show_preview_from_url, inputs=url, outputs=[img_out, vid_out])
423
-
424
- # Button Click Event
425
- run.click(
426
- fn=handle,
427
- inputs=[
428
- mode, media_upload, url,
429
- run_det, det_model, det_confidence,
430
- run_seg, seg_model,
431
- run_depth, depth_model,
432
- blend
433
- ],
434
- outputs=[
435
- img_out, # will be visible only if it's an image
436
- vid_out, # will be visible only if it's a video
437
- json_out,
438
- zip_out
439
- ]
440
- )
441
-
442
- # Footer Section
443
- gr.Markdown("---")
444
- gr.Markdown(
445
- f"""
446
- <div style='text-align: center; font-size: 14px;'>
447
- Built by <b>Durga Deepak Valluri</b><br>
448
- <a href="https://github.com/DurgaDeepakValluri" target="_blank">GitHub</a> |
449
- <a href="https://deecoded.io" target="_blank">Website</a> |
450
- <a href="https://www.linkedin.com/in/durga-deepak-valluri" target="_blank">LinkedIn</a><br>
451
- <span style='font-size: 12px; color: #666;'>
452
- {'πŸ” HF Authentication Active' if HF_TOKEN else 'πŸ”“ No HF Authentication'}
453
- </span>
454
- </div>
455
- """,
456
- )
457
 
458
- # Launch the Gradio App
459
  if __name__ == "__main__":
460
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ # Set environment variables for Spaces compatibility
3
+ os.environ['OMP_NUM_THREADS'] = '1'
4
+ os.environ['MKL_NUM_THREADS'] = '1'
5
+ import cv2
6
+ import yaml
7
+ import torch
8
+ import random
9
  import gradio as gr
10
+ import numpy as np
11
+ import kagglehub
12
  from PIL import Image
13
+ from glob import glob
14
+ import matplotlib
15
+ matplotlib.use('Agg') # Use non-interactive backend
16
+ import matplotlib.pyplot as plt
17
+ from matplotlib import patches
18
+ from torchvision import transforms as T
19
+ from ultralytics import YOLO
20
  import shutil
21
+ import tempfile
22
+ from pathlib import Path
23
+ import json
24
+ from io import BytesIO
25
 
26
+ # Try to import spaces for Hugging Face Spaces GPU support
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  try:
28
+ import spaces
29
+ ON_SPACES = True
30
+ except ImportError:
31
+ ON_SPACES = False
32
+ # Create a dummy decorator if not on Spaces
33
+ class spaces:
34
+ @staticmethod
35
+ def GPU(duration=60):
36
+ def decorator(func):
37
+ return func
38
+ return decorator
39
+
40
+ # Set Kaggle API credentials from environment variable
41
+ if os.getenv("KDATA_API"):
42
+ kaggle_key = os.getenv("KDATA_API")
43
+ # Parse the key if it's in JSON format
44
+ if "{" in kaggle_key:
45
+ key_data = json.loads(kaggle_key)
46
+ os.environ["KAGGLE_USERNAME"] = key_data.get("username", "")
47
+ os.environ["KAGGLE_KEY"] = key_data.get("key", "")
48
+
49
+ # Global variables
50
+ model = None
51
+ dataset_path = None
52
+ training_in_progress = False
53
+
54
+ class Visualization:
55
+ def __init__(self, root, data_types, n_ims, rows, cmap=None):
56
+ self.n_ims, self.rows = n_ims, rows
57
+ self.cmap, self.data_types = cmap, data_types
58
+ self.colors = ["firebrick", "darkorange", "blueviolet"]
59
+ self.root = root
60
+
61
+ self.get_cls_names()
62
+ self.get_bboxes()
63
+
64
+ def get_cls_names(self):
65
+ with open(f"{self.root}/data.yaml", 'r') as file:
66
+ data = yaml.safe_load(file)
67
+ class_names = data['names']
68
+ self.class_dict = {index: name for index, name in enumerate(class_names)}
69
+
70
+ def get_bboxes(self):
71
+ self.vis_datas, self.analysis_datas, self.im_paths = {}, {}, {}
72
+ for data_type in self.data_types:
73
+ all_bboxes, all_analysis_datas = [], {}
74
+ im_paths = glob(f"{self.root}/{data_type}/images/*")
75
+
76
+ for idx, im_path in enumerate(im_paths):
77
+ bboxes = []
78
+ im_ext = os.path.splitext(im_path)[-1]
79
+ lbl_path = im_path.replace(im_ext, ".txt")
80
+ lbl_path = lbl_path.replace(f"{data_type}/images", f"{data_type}/labels")
81
+ if not os.path.isfile(lbl_path):
82
+ continue
83
+ meta_data = open(lbl_path).readlines()
84
+ for data in meta_data:
85
+ parts = data.strip().split()[:5]
86
+ cls_name = self.class_dict[int(parts[0])]
87
+ bboxes.append([cls_name] + [float(x) for x in parts[1:]])
88
+ if cls_name not in all_analysis_datas:
89
+ all_analysis_datas[cls_name] = 1
90
+ else:
91
+ all_analysis_datas[cls_name] += 1
92
+ all_bboxes.append(bboxes)
93
+
94
+ self.vis_datas[data_type] = all_bboxes
95
+ self.analysis_datas[data_type] = all_analysis_datas
96
+ self.im_paths[data_type] = im_paths
97
+
98
+ def plot_single(self, im_path, bboxes):
99
+ or_im = np.array(Image.open(im_path).convert("RGB"))
100
+ height, width, _ = or_im.shape
101
 
102
+ for bbox in bboxes:
103
+ class_id, x_center, y_center, w, h = bbox
 
104
 
105
+ x_min = int((x_center - w / 2) * width)
106
+ y_min = int((y_center - h / 2) * height)
107
+ x_max = int((x_center + w / 2) * width)
108
+ y_max = int((y_center + h / 2) * height)
109
+
110
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
111
+ cv2.rectangle(img=or_im, pt1=(x_min, y_min), pt2=(x_max, y_max),
112
+ color=color, thickness=3)
113
+
114
+ # Add text overlay
115
+ cv2.putText(or_im, f"Objects: {len(bboxes)}", (10, 30),
116
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
117
+
118
+ # Convert BGR to RGB if needed
119
+ if len(or_im.shape) == 3 and or_im.shape[2] == 3:
120
+ or_im = cv2.cvtColor(or_im, cv2.COLOR_BGR2RGB)
121
+
122
+ return Image.fromarray(or_im)
123
 
124
+ def vis_samples(self, data_type, n_samples=4):
125
+ if data_type not in self.vis_datas:
126
+ return None
127
+
128
+ indices = [random.randint(0, len(self.vis_datas[data_type]) - 1)
129
+ for _ in range(min(n_samples, len(self.vis_datas[data_type])))]
130
+
131
+ figs = []
132
+ for idx in indices:
133
+ im_path = self.im_paths[data_type][idx]
134
+ bboxes = self.vis_datas[data_type][idx]
135
+ fig = self.plot_single(im_path, bboxes)
136
+ figs.append(fig)
137
+
138
+ return figs
139
 
140
+ def data_analysis(self, data_type):
141
+ if data_type not in self.analysis_datas:
142
+ return None
143
+
144
+ plt.style.use('default')
145
+ fig, ax = plt.subplots(figsize=(12, 6))
146
+
147
+ cls_names = list(self.analysis_datas[data_type].keys())
148
+ counts = list(self.analysis_datas[data_type].values())
149
+
150
+ color_map = {"train": "firebrick", "valid": "darkorange", "test": "blueviolet"}
151
+ color = color_map.get(data_type, "steelblue")
152
+
153
+ indices = np.arange(len(counts))
154
+ bars = ax.bar(indices, counts, 0.7, color=color)
155
+
156
+ ax.set_xlabel("Class Names", fontsize=12)
157
+ ax.set_xticks(indices)
158
+ ax.set_xticklabels(cls_names, rotation=45, ha='right')
159
+ ax.set_ylabel("Data Counts", fontsize=12)
160
+ ax.set_title(f"{data_type.upper()} Dataset Class Distribution", fontsize=14)
161
+
162
+ for i, (bar, v) in enumerate(zip(bars, counts)):
163
+ ax.text(bar.get_x() + bar.get_width()/2, v + 1, str(v),
164
+ ha='center', va='bottom', fontsize=10, color='navy')
165
+
166
+ plt.tight_layout()
167
+
168
+ # Save to BytesIO and convert to PIL Image
169
+ buf = BytesIO()
170
+ fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
171
+ buf.seek(0)
172
+ img = Image.open(buf)
173
+ plt.close(fig)
174
+
175
+ return img
176
 
177
+ def download_dataset():
178
+ """Download the dataset using kagglehub"""
179
+ global dataset_path
 
 
 
 
180
  try:
181
+ # Create a local directory to store the dataset
182
+ local_dir = "./xray_dataset"
183
+
184
+ # Download dataset
185
+ dataset_path = kagglehub.dataset_download("orvile/x-ray-baggage-anomaly-detection")
186
+
187
+ # If the dataset is downloaded to a temporary location, copy it to our local directory
188
+ if dataset_path != local_dir and os.path.exists(dataset_path):
189
+ if os.path.exists(local_dir):
190
+ shutil.rmtree(local_dir)
191
+ shutil.copytree(dataset_path, local_dir)
192
+ dataset_path = local_dir
193
+
194
+ return f"Dataset downloaded successfully to: {dataset_path}"
195
  except Exception as e:
196
+ return f"Error downloading dataset: {str(e)}\n\nPlease ensure KDATA_API environment variable is set correctly."
 
 
197
 
198
+ def visualize_data(data_type, num_samples):
199
+ """Visualize sample images from the dataset"""
200
+ if dataset_path is None:
201
+ return [], "Please download the dataset first!"
202
+
203
+ try:
204
+ vis = Visualization(root=dataset_path, data_types=[data_type],
205
+ n_ims=num_samples, rows=2, cmap="rgb")
206
+ figs = vis.vis_samples(data_type, num_samples)
207
+ if figs is None:
208
+ return [], f"No data found for {data_type} dataset"
209
+ return figs, f"Showing {len(figs)} samples from {data_type} dataset"
210
+ except Exception as e:
211
+ return [], f"Error visualizing data: {str(e)}"
212
 
213
+ def analyze_class_distribution(data_type):
214
+ """Analyze class distribution in the dataset"""
215
+ if dataset_path is None:
216
+ return None, "Please download the dataset first!"
217
+
218
+ try:
219
+ vis = Visualization(root=dataset_path, data_types=[data_type],
220
+ n_ims=20, rows=5, cmap="rgb")
221
+ fig = vis.data_analysis(data_type)
222
+ if fig is None:
223
+ return None, f"No data found for {data_type} dataset"
224
+ return fig, f"Class distribution for {data_type} dataset"
225
+ except Exception as e:
226
+ return None, f"Error analyzing data: {str(e)}"
227
 
228
+ @spaces.GPU(duration=300) # Request GPU for 5 minutes for training
229
+ def train_model(epochs, batch_size, img_size, device_selection):
230
+ """Train YOLOv11 model"""
231
+ global model, training_in_progress
232
+
233
+ if dataset_path is None:
234
+ return [], "Please download the dataset first!"
235
+
236
+ if training_in_progress:
237
+ return [], "Training already in progress!"
238
+
239
+ training_in_progress = True
240
+
241
+ try:
242
+ # Determine device - on Spaces, always use GPU if available
243
+ if ON_SPACES and torch.cuda.is_available():
244
+ device = 0
245
+ elif device_selection == "Auto":
246
+ device = 0 if torch.cuda.is_available() else "cpu"
247
+ elif device_selection == "CPU":
248
+ device = "cpu"
249
+ else:
250
+ device = 0 if torch.cuda.is_available() else "cpu"
251
+
252
+ # Initialize model
253
+ model = YOLO("yolo11n.pt")
254
+
255
+ # Create project directory
256
+ project_dir = "./xray_detection"
257
+ os.makedirs(project_dir, exist_ok=True)
258
+
259
+ # Train model with workers=0 to avoid multiprocessing issues on Spaces
260
+ results = model.train(
261
+ data=f"{dataset_path}/data.yaml",
262
+ epochs=epochs,
263
+ imgsz=img_size,
264
+ batch=batch_size,
265
+ device=device,
266
+ project=project_dir,
267
+ name="train",
268
+ exist_ok=True,
269
+ verbose=True,
270
+ patience=5, # Reduce patience for faster training on Spaces
271
+ save_period=5, # Save checkpoints every 5 epochs
272
+ workers=0, # Important: Set to 0 to avoid multiprocessing issues
273
+ single_cls=False,
274
+ rect=False,
275
+ cache=False, # Disable caching to avoid memory issues
276
+ amp=True # Use automatic mixed precision for faster training
277
  )
278
+
279
+ # Collect training result plots
280
+ results_path = os.path.join(project_dir, "train")
281
+ plots = []
282
+
283
+ plot_files = ["results.png", "confusion_matrix.png", "val_batch0_pred.jpg",
284
+ "train_batch0.jpg", "val_batch0_labels.jpg"]
285
+
286
+ for plot_file in plot_files:
287
+ plot_path = os.path.join(results_path, plot_file)
288
+ if os.path.exists(plot_path):
289
+ plots.append(Image.open(plot_path))
290
+
291
+ # Save the model path
292
+ model_path = os.path.join(results_path, "weights", "best.pt")
293
+
294
+ training_in_progress = False
295
+ return plots, f"Training completed! Model saved to {model_path}"
296
+
297
+ except Exception as e:
298
+ training_in_progress = False
299
+ return [], f"Error during training: {str(e)}"
300
 
301
+ @spaces.GPU(duration=60) # Request GPU for 1 minute for inference
302
+ def run_inference(input_image, conf_threshold):
303
+ """Run inference on a single image"""
304
+ global model
305
+
306
+ if model is None:
307
+ # Try to load a default model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  try:
309
+ model = YOLO("yolo11n.pt")
310
+ except:
311
+ return None, "Please train the model first or load a pre-trained model!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
+ if input_image is None:
314
+ return None, "Please upload an image!"
315
+
316
+ try:
317
+ # Save the input image temporarily
318
+ temp_path = "temp_inference.jpg"
319
+ input_image.save(temp_path)
320
+
321
+ # Run inference with workers=0
322
+ results = model(temp_path, conf=conf_threshold, verbose=False, device=0 if torch.cuda.is_available() else 'cpu')
323
+
324
+ # Draw results
325
+ annotated_image = results[0].plot()
326
+
327
+ # Get detection info
328
+ detections = []
329
+ if results[0].boxes is not None:
330
+ for box in results[0].boxes:
331
+ cls = int(box.cls)
332
+ conf = float(box.conf)
333
+ cls_name = model.names[cls]
334
+ detections.append(f"{cls_name}: {conf:.2f}")
335
+
336
+ # Clean up
337
+ if os.path.exists(temp_path):
338
+ os.remove(temp_path)
339
+
340
+ detection_text = "\n".join(detections) if detections else "No objects detected"
341
+
342
+ return Image.fromarray(annotated_image), f"Detections:\n{detection_text}"
343
+
344
+ except Exception as e:
345
+ return None, f"Error during inference: {str(e)}"
346
 
347
+ @spaces.GPU(duration=60) # Request GPU for batch inference
348
+ def batch_inference(data_type, num_images):
349
+ """Run inference on multiple images from test set"""
350
+ global model
 
 
 
 
 
 
 
 
351
 
352
+ if model is None:
 
353
  try:
354
+ model = YOLO("yolo11n.pt")
355
+ except:
356
+ return [], "Please train the model first!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
+ if dataset_path is None:
359
+ return [], "Please download the dataset first!"
 
 
 
 
 
360
 
361
+ try:
362
+ image_dir = f"{dataset_path}/{data_type}/images"
363
+ if not os.path.exists(image_dir):
364
+ return [], f"Directory {image_dir} not found!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
+ image_files = glob(f"{image_dir}/*")[:num_images]
367
+
368
+ if not image_files:
369
+ return [], f"No images found in {image_dir}"
370
+
371
+ results_images = []
372
+
373
+ for img_path in image_files:
374
+ results = model(img_path, verbose=False)
375
+ annotated = results[0].plot()
376
+ results_images.append(Image.fromarray(annotated))
377
+
378
+ return results_images, f"Processed {len(results_images)} images from {data_type} dataset"
379
+
380
+ except Exception as e:
381
+ return [], f"Error during batch inference: {str(e)}"
 
 
 
382
 
383
+ def load_pretrained_model(model_path):
384
+ """Load a pre-trained model"""
385
+ global model
386
+ try:
387
+ if not os.path.exists(model_path):
388
+ # Try default paths
389
+ default_paths = [
390
+ "./xray_detection/train/weights/best.pt",
391
+ "./xray_detection/train/weights/last.pt",
392
+ "yolo11n.pt"
393
+ ]
394
+ for path in default_paths:
395
+ if os.path.exists(path):
396
+ model_path = path
397
+ break
398
+
399
+ model = YOLO(model_path)
400
+ return f"Model loaded successfully from {model_path}"
401
+ except Exception as e:
402
+ return f"Error loading model: {str(e)}"
403
 
404
+ # Create Gradio interface
405
+ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft()) as demo:
406
+ gr.Markdown("""
407
+ # 🎯 X-ray Baggage Anomaly Detection with YOLOv11
408
+
409
+ This application allows you to:
410
+ 1. Download and visualize the X-ray baggage dataset
411
+ 2. Analyze class distributions
412
+ 3. Train a YOLOv11 model for object detection
413
+ 4. Run inference on new images
414
+
415
+ **Note:** GPU will be automatically allocated when needed for training and inference.
416
+ """)
417
+
418
+ # Add instructions for Kaggle API setup
419
+ with gr.Accordion("πŸ“ Setup Instructions", open=False):
420
+ gr.Markdown("""
421
+ ### Kaggle API Setup
422
+ 1. Get your Kaggle API credentials from https://www.kaggle.com/settings
423
+ 2. Set the KDATA_API environment variable in Hugging Face Spaces settings:
424
+ ```
425
+ KDATA_API={"username":"your_username","key":"your_api_key"}
426
+ ```
427
+ """)
428
+
429
+ with gr.Tab("πŸ“Š Dataset"):
430
+ with gr.Row():
431
+ download_btn = gr.Button("Download Dataset", variant="primary", scale=1)
432
+ download_status = gr.Textbox(label="Status", interactive=False, scale=3)
433
+
434
+ download_btn.click(download_dataset, outputs=download_status)
435
+
436
+ gr.Markdown("### Visualize Dataset Samples")
437
+ with gr.Row():
438
+ data_type_viz = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
439
+ num_samples = gr.Slider(1, 8, 4, step=1, label="Number of Samples")
440
+ viz_btn = gr.Button("Visualize Samples")
441
+
442
+ viz_gallery = gr.Gallery(label="Sample Images", columns=2, height="auto")
443
+ viz_status = gr.Textbox(label="Status", interactive=False)
444
+
445
+ viz_btn.click(visualize_data, inputs=[data_type_viz, num_samples],
446
+ outputs=[viz_gallery, viz_status])
447
+
448
+ gr.Markdown("### Analyze Class Distribution")
449
+ with gr.Row():
450
+ data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
451
+ analyze_btn = gr.Button("Analyze Distribution")
452
+
453
+ distribution_plot = gr.Image(label="Class Distribution", type="pil")
454
+ analysis_status = gr.Textbox(label="Status", interactive=False)
455
+
456
+ analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis,
457
+ outputs=[distribution_plot, analysis_status])
458
+
459
+ with gr.Tab("πŸš€ Training"):
460
+ gr.Markdown("### Train YOLOv11 Model")
461
+ gr.Markdown("""
462
+ **Note:** Training will automatically use GPU if available. This may take several minutes.
463
+
464
+ **Tips for Hugging Face Spaces:**
465
+ - Use smaller batch sizes (4-8) to avoid GPU memory issues
466
+ - Start with fewer epochs (5-10) for testing
467
+ - Image size 480 provides good balance between quality and speed
468
+ """)
469
+
470
+ with gr.Row():
471
+ epochs_input = gr.Slider(1, 50, 10, step=1, label="Epochs")
472
+ batch_size_input = gr.Slider(4, 32, 8, step=4, label="Batch Size (lower for limited GPU)")
473
+ img_size_input = gr.Slider(320, 640, 480, step=32, label="Image Size")
474
+ device_input = gr.Radio(["Auto", "GPU", "CPU"], value="Auto", label="Device")
475
+
476
+ train_btn = gr.Button("Start Training", variant="primary")
477
+
478
+ training_gallery = gr.Gallery(label="Training Results", columns=3, height="auto")
479
+ training_status = gr.Textbox(label="Training Status", interactive=False)
480
+
481
+ train_btn.click(train_model,
482
+ inputs=[epochs_input, batch_size_input, img_size_input, device_input],
483
+ outputs=[training_gallery, training_status])
484
+
485
+ gr.Markdown("### Load Pre-trained Model")
486
+ with gr.Row():
487
+ model_path_input = gr.Textbox(label="Model Path", value="./xray_detection/train/weights/best.pt")
488
+ load_model_btn = gr.Button("Load Model")
489
+ load_status = gr.Textbox(label="Status", interactive=False)
490
+
491
+ load_model_btn.click(load_pretrained_model, inputs=model_path_input, outputs=load_status)
492
+
493
+ with gr.Tab("πŸ” Inference"):
494
+ gr.Markdown("### Single Image Inference")
495
+
496
+ with gr.Row():
497
+ with gr.Column():
498
+ input_image = gr.Image(type="pil", label="Upload Image")
499
+ conf_threshold = gr.Slider(0.1, 0.9, 0.5, step=0.05, label="Confidence Threshold")
500
+ inference_btn = gr.Button("Run Detection", variant="primary")
501
 
502
+ with gr.Column():
503
+ output_image = gr.Image(type="pil", label="Detection Result")
504
+ detection_info = gr.Textbox(label="Detection Info", lines=5)
505
 
506
+ inference_btn.click(run_inference,
507
+ inputs=[input_image, conf_threshold],
508
+ outputs=[output_image, detection_info])
509
+
510
+ gr.Markdown("### Batch Inference")
511
+
512
+ with gr.Row():
513
+ batch_data_type = gr.Dropdown(["test", "valid"], value="test", label="Dataset Type")
514
+ batch_num_images = gr.Slider(1, 10, 5, step=1, label="Number of Images")
515
+ batch_btn = gr.Button("Run Batch Inference")
516
+
517
+ batch_gallery = gr.Gallery(label="Batch Results", columns=3, height="auto")
518
+ batch_status = gr.Textbox(label="Status", interactive=False)
519
+
520
+ batch_btn.click(batch_inference,
521
+ inputs=[batch_data_type, batch_num_images],
522
+ outputs=[batch_gallery, batch_status])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
+ # Launch the app
525
  if __name__ == "__main__":
526
+ # Check if running on Hugging Face Spaces
527
+ if ON_SPACES:
528
+ demo.launch(ssr_mode=False)
529
+ else:
530
+ demo.launch(share=True, ssr_mode=False)