openfree commited on
Commit
1c7e0b2
Β·
verified Β·
1 Parent(s): c2e0b70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +428 -498
app.py CHANGED
@@ -1,530 +1,460 @@
 
 
 
 
 
 
 
 
1
  import os
2
- # Set environment variables for Spaces compatibility
3
- os.environ['OMP_NUM_THREADS'] = '1'
4
- os.environ['MKL_NUM_THREADS'] = '1'
5
- import cv2
6
- import yaml
7
- import torch
8
- import random
9
  import gradio as gr
10
- import numpy as np
11
- import kagglehub
12
  from PIL import Image
13
- from glob import glob
14
- import matplotlib
15
- matplotlib.use('Agg') # Use non-interactive backend
16
- import matplotlib.pyplot as plt
17
- from matplotlib import patches
18
- from torchvision import transforms as T
19
- from ultralytics import YOLO
20
- import shutil
21
  import tempfile
22
- from pathlib import Path
23
- import json
24
- from io import BytesIO
 
 
 
 
 
25
 
26
- # Try to import spaces for Hugging Face Spaces GPU support
 
 
 
 
 
 
 
 
 
 
 
27
  try:
28
- import spaces
29
- ON_SPACES = True
30
- except ImportError:
31
- ON_SPACES = False
32
- # Create a dummy decorator if not on Spaces
33
- class spaces:
34
- @staticmethod
35
- def GPU(duration=60):
36
- def decorator(func):
37
- return func
38
- return decorator
39
-
40
- # Set Kaggle API credentials from environment variable
41
- if os.getenv("KDATA_API"):
42
- kaggle_key = os.getenv("KDATA_API")
43
- # Parse the key if it's in JSON format
44
- if "{" in kaggle_key:
45
- key_data = json.loads(kaggle_key)
46
- os.environ["KAGGLE_USERNAME"] = key_data.get("username", "")
47
- os.environ["KAGGLE_KEY"] = key_data.get("key", "")
48
-
49
- # Global variables
50
- model = None
51
- dataset_path = None
52
- training_in_progress = False
53
-
54
- class Visualization:
55
- def __init__(self, root, data_types, n_ims, rows, cmap=None):
56
- self.n_ims, self.rows = n_ims, rows
57
- self.cmap, self.data_types = cmap, data_types
58
- self.colors = ["firebrick", "darkorange", "blueviolet"]
59
- self.root = root
60
-
61
- self.get_cls_names()
62
- self.get_bboxes()
63
-
64
- def get_cls_names(self):
65
- with open(f"{self.root}/data.yaml", 'r') as file:
66
- data = yaml.safe_load(file)
67
- class_names = data['names']
68
- self.class_dict = {index: name for index, name in enumerate(class_names)}
69
-
70
- def get_bboxes(self):
71
- self.vis_datas, self.analysis_datas, self.im_paths = {}, {}, {}
72
- for data_type in self.data_types:
73
- all_bboxes, all_analysis_datas = [], {}
74
- im_paths = glob(f"{self.root}/{data_type}/images/*")
75
-
76
- for idx, im_path in enumerate(im_paths):
77
- bboxes = []
78
- im_ext = os.path.splitext(im_path)[-1]
79
- lbl_path = im_path.replace(im_ext, ".txt")
80
- lbl_path = lbl_path.replace(f"{data_type}/images", f"{data_type}/labels")
81
- if not os.path.isfile(lbl_path):
82
- continue
83
- meta_data = open(lbl_path).readlines()
84
- for data in meta_data:
85
- parts = data.strip().split()[:5]
86
- cls_name = self.class_dict[int(parts[0])]
87
- bboxes.append([cls_name] + [float(x) for x in parts[1:]])
88
- if cls_name not in all_analysis_datas:
89
- all_analysis_datas[cls_name] = 1
90
- else:
91
- all_analysis_datas[cls_name] += 1
92
- all_bboxes.append(bboxes)
93
-
94
- self.vis_datas[data_type] = all_bboxes
95
- self.analysis_datas[data_type] = all_analysis_datas
96
- self.im_paths[data_type] = im_paths
97
-
98
- def plot_single(self, im_path, bboxes):
99
- or_im = np.array(Image.open(im_path).convert("RGB"))
100
- height, width, _ = or_im.shape
101
 
102
- for bbox in bboxes:
103
- class_id, x_center, y_center, w, h = bbox
 
104
 
105
- x_min = int((x_center - w / 2) * width)
106
- y_min = int((y_center - h / 2) * height)
107
- x_max = int((x_center + w / 2) * width)
108
- y_max = int((y_center + h / 2) * height)
109
-
110
- color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
111
- cv2.rectangle(img=or_im, pt1=(x_min, y_min), pt2=(x_max, y_max),
112
- color=color, thickness=3)
113
-
114
- # Add text overlay
115
- cv2.putText(or_im, f"Objects: {len(bboxes)}", (10, 30),
116
- cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
117
-
118
- # Convert BGR to RGB if needed
119
- if len(or_im.shape) == 3 and or_im.shape[2] == 3:
120
- or_im = cv2.cvtColor(or_im, cv2.COLOR_BGR2RGB)
121
-
122
- return Image.fromarray(or_im)
123
 
124
- def vis_samples(self, data_type, n_samples=4):
125
- if data_type not in self.vis_datas:
126
- return None
127
-
128
- indices = [random.randint(0, len(self.vis_datas[data_type]) - 1)
129
- for _ in range(min(n_samples, len(self.vis_datas[data_type])))]
130
-
131
- figs = []
132
- for idx in indices:
133
- im_path = self.im_paths[data_type][idx]
134
- bboxes = self.vis_datas[data_type][idx]
135
- fig = self.plot_single(im_path, bboxes)
136
- figs.append(fig)
137
-
138
- return figs
139
 
140
- def data_analysis(self, data_type):
141
- if data_type not in self.analysis_datas:
142
- return None
143
-
144
- plt.style.use('default')
145
- fig, ax = plt.subplots(figsize=(12, 6))
146
-
147
- cls_names = list(self.analysis_datas[data_type].keys())
148
- counts = list(self.analysis_datas[data_type].values())
149
-
150
- color_map = {"train": "firebrick", "valid": "darkorange", "test": "blueviolet"}
151
- color = color_map.get(data_type, "steelblue")
152
-
153
- indices = np.arange(len(counts))
154
- bars = ax.bar(indices, counts, 0.7, color=color)
155
-
156
- ax.set_xlabel("Class Names", fontsize=12)
157
- ax.set_xticks(indices)
158
- ax.set_xticklabels(cls_names, rotation=45, ha='right')
159
- ax.set_ylabel("Data Counts", fontsize=12)
160
- ax.set_title(f"{data_type.upper()} Dataset Class Distribution", fontsize=14)
161
-
162
- for i, (bar, v) in enumerate(zip(bars, counts)):
163
- ax.text(bar.get_x() + bar.get_width()/2, v + 1, str(v),
164
- ha='center', va='bottom', fontsize=10, color='navy')
165
-
166
- plt.tight_layout()
167
-
168
- # Save to BytesIO and convert to PIL Image
169
- buf = BytesIO()
170
- fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
171
- buf.seek(0)
172
- img = Image.open(buf)
173
- plt.close(fig)
174
-
175
- return img
176
 
177
- def download_dataset():
178
- """Download the dataset using kagglehub"""
179
- global dataset_path
 
 
 
 
180
  try:
181
- # Create a local directory to store the dataset
182
- local_dir = "./xray_dataset"
183
-
184
- # Download dataset
185
- dataset_path = kagglehub.dataset_download("orvile/x-ray-baggage-anomaly-detection")
186
-
187
- # If the dataset is downloaded to a temporary location, copy it to our local directory
188
- if dataset_path != local_dir and os.path.exists(dataset_path):
189
- if os.path.exists(local_dir):
190
- shutil.rmtree(local_dir)
191
- shutil.copytree(dataset_path, local_dir)
192
- dataset_path = local_dir
193
-
194
- return f"Dataset downloaded successfully to: {dataset_path}"
195
  except Exception as e:
196
- return f"Error downloading dataset: {str(e)}\n\nPlease ensure KDATA_API environment variable is set correctly."
 
 
197
 
198
- def visualize_data(data_type, num_samples):
199
- """Visualize sample images from the dataset"""
200
- if dataset_path is None:
201
- return [], "Please download the dataset first!"
202
-
203
- try:
204
- vis = Visualization(root=dataset_path, data_types=[data_type],
205
- n_ims=num_samples, rows=2, cmap="rgb")
206
- figs = vis.vis_samples(data_type, num_samples)
207
- if figs is None:
208
- return [], f"No data found for {data_type} dataset"
209
- return figs, f"Showing {len(figs)} samples from {data_type} dataset"
210
- except Exception as e:
211
- return [], f"Error visualizing data: {str(e)}"
212
 
213
- def analyze_class_distribution(data_type):
214
- """Analyze class distribution in the dataset"""
215
- if dataset_path is None:
216
- return None, "Please download the dataset first!"
217
-
218
- try:
219
- vis = Visualization(root=dataset_path, data_types=[data_type],
220
- n_ims=20, rows=5, cmap="rgb")
221
- fig = vis.data_analysis(data_type)
222
- if fig is None:
223
- return None, f"No data found for {data_type} dataset"
224
- return fig, f"Class distribution for {data_type} dataset"
225
- except Exception as e:
226
- return None, f"Error analyzing data: {str(e)}"
227
 
228
- @spaces.GPU(duration=300) # Request GPU for 5 minutes for training
229
- def train_model(epochs, batch_size, img_size, device_selection):
230
- """Train YOLOv11 model"""
231
- global model, training_in_progress
232
-
233
- if dataset_path is None:
234
- return [], "Please download the dataset first!"
235
-
236
- if training_in_progress:
237
- return [], "Training already in progress!"
238
-
239
- training_in_progress = True
240
-
241
- try:
242
- # Determine device - on Spaces, always use GPU if available
243
- if ON_SPACES and torch.cuda.is_available():
244
- device = 0
245
- elif device_selection == "Auto":
246
- device = 0 if torch.cuda.is_available() else "cpu"
247
- elif device_selection == "CPU":
248
- device = "cpu"
249
- else:
250
- device = 0 if torch.cuda.is_available() else "cpu"
251
-
252
- # Initialize model
253
- model = YOLO("yolo11n.pt")
254
-
255
- # Create project directory
256
- project_dir = "./xray_detection"
257
- os.makedirs(project_dir, exist_ok=True)
258
-
259
- # Train model with workers=0 to avoid multiprocessing issues on Spaces
260
- results = model.train(
261
- data=f"{dataset_path}/data.yaml",
262
- epochs=epochs,
263
- imgsz=img_size,
264
- batch=batch_size,
265
- device=device,
266
- project=project_dir,
267
- name="train",
268
- exist_ok=True,
269
- verbose=True,
270
- patience=5, # Reduce patience for faster training on Spaces
271
- save_period=5, # Save checkpoints every 5 epochs
272
- workers=0, # Important: Set to 0 to avoid multiprocessing issues
273
- single_cls=False,
274
- rect=False,
275
- cache=False, # Disable caching to avoid memory issues
276
- amp=True # Use automatic mixed precision for faster training
277
  )
278
-
279
- # Collect training result plots
280
- results_path = os.path.join(project_dir, "train")
281
- plots = []
282
-
283
- plot_files = ["results.png", "confusion_matrix.png", "val_batch0_pred.jpg",
284
- "train_batch0.jpg", "val_batch0_labels.jpg"]
285
-
286
- for plot_file in plot_files:
287
- plot_path = os.path.join(results_path, plot_file)
288
- if os.path.exists(plot_path):
289
- plots.append(Image.open(plot_path))
290
-
291
- # Save the model path
292
- model_path = os.path.join(results_path, "weights", "best.pt")
293
-
294
- training_in_progress = False
295
- return plots, f"Training completed! Model saved to {model_path}"
296
-
297
- except Exception as e:
298
- training_in_progress = False
299
- return [], f"Error during training: {str(e)}"
300
 
301
- @spaces.GPU(duration=60) # Request GPU for 1 minute for inference
302
- def run_inference(input_image, conf_threshold):
303
- """Run inference on a single image"""
304
- global model
305
-
306
- if model is None:
307
- # Try to load a default model
 
 
 
 
 
 
 
 
 
 
308
  try:
309
- model = YOLO("yolo11n.pt")
310
- except:
311
- return None, "Please train the model first or load a pre-trained model!"
312
-
313
- if input_image is None:
314
- return None, "Please upload an image!"
315
-
316
- try:
317
- # Save the input image temporarily
318
- temp_path = "temp_inference.jpg"
319
- input_image.save(temp_path)
320
-
321
- # Run inference with workers=0
322
- results = model(temp_path, conf=conf_threshold, verbose=False, device=0 if torch.cuda.is_available() else 'cpu')
323
-
324
- # Draw results
325
- annotated_image = results[0].plot()
326
-
327
- # Get detection info
328
- detections = []
329
- if results[0].boxes is not None:
330
- for box in results[0].boxes:
331
- cls = int(box.cls)
332
- conf = float(box.conf)
333
- cls_name = model.names[cls]
334
- detections.append(f"{cls_name}: {conf:.2f}")
335
-
336
- # Clean up
337
- if os.path.exists(temp_path):
338
- os.remove(temp_path)
339
-
340
- detection_text = "\n".join(detections) if detections else "No objects detected"
341
-
342
- return Image.fromarray(annotated_image), f"Detections:\n{detection_text}"
343
-
344
- except Exception as e:
345
- return None, f"Error during inference: {str(e)}"
346
 
347
- @spaces.GPU(duration=60) # Request GPU for batch inference
348
- def batch_inference(data_type, num_images):
349
- """Run inference on multiple images from test set"""
350
- global model
351
-
352
- if model is None:
 
 
 
 
353
  try:
354
- model = YOLO("yolo11n.pt")
355
- except:
356
- return [], "Please train the model first!"
357
-
358
- if dataset_path is None:
359
- return [], "Please download the dataset first!"
360
-
361
- try:
362
- image_dir = f"{dataset_path}/{data_type}/images"
363
- if not os.path.exists(image_dir):
364
- return [], f"Directory {image_dir} not found!"
365
-
366
- image_files = glob(f"{image_dir}/*")[:num_images]
367
-
368
- if not image_files:
369
- return [], f"No images found in {image_dir}"
370
-
371
- results_images = []
372
-
373
- for img_path in image_files:
374
- results = model(img_path, verbose=False)
375
- annotated = results[0].plot()
376
- results_images.append(Image.fromarray(annotated))
377
-
378
- return results_images, f"Processed {len(results_images)} images from {data_type} dataset"
379
-
380
- except Exception as e:
381
- return [], f"Error during batch inference: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
- def load_pretrained_model(model_path):
384
- """Load a pre-trained model"""
385
- global model
386
- try:
387
- if not os.path.exists(model_path):
388
- # Try default paths
389
- default_paths = [
390
- "./xray_detection/train/weights/best.pt",
391
- "./xray_detection/train/weights/last.pt",
392
- "yolo11n.pt"
393
- ]
394
- for path in default_paths:
395
- if os.path.exists(path):
396
- model_path = path
397
- break
398
-
399
- model = YOLO(model_path)
400
- return f"Model loaded successfully from {model_path}"
401
- except Exception as e:
402
- return f"Error loading model: {str(e)}"
403
 
404
- # Create Gradio interface
405
- with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft()) as demo:
406
- gr.Markdown("""
407
- # 🎯 X-ray Baggage Anomaly Detection with YOLOv11
408
-
409
- This application allows you to:
410
- 1. Download and visualize the X-ray baggage dataset
411
- 2. Analyze class distributions
412
- 3. Train a YOLOv11 model for object detection
413
- 4. Run inference on new images
414
 
415
- **Note:** GPU will be automatically allocated when needed for training and inference.
416
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
- # Add instructions for Kaggle API setup
419
- with gr.Accordion("πŸ“ Setup Instructions", open=False):
420
- gr.Markdown("""
421
- ### Kaggle API Setup
422
- 1. Get your Kaggle API credentials from https://www.kaggle.com/settings
423
- 2. Set the KDATA_API environment variable in Hugging Face Spaces settings:
424
- ```
425
- KDATA_API={"username":"your_username","key":"your_api_key"}
426
- ```
427
- """)
 
 
 
428
 
429
- with gr.Tab("πŸ“Š Dataset"):
430
- with gr.Row():
431
- download_btn = gr.Button("Download Dataset", variant="primary", scale=1)
432
- download_status = gr.Textbox(label="Status", interactive=False, scale=3)
433
-
434
- download_btn.click(download_dataset, outputs=download_status)
435
-
436
- gr.Markdown("### Visualize Dataset Samples")
437
- with gr.Row():
438
- data_type_viz = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
439
- num_samples = gr.Slider(1, 8, 4, step=1, label="Number of Samples")
440
- viz_btn = gr.Button("Visualize Samples")
441
-
442
- viz_gallery = gr.Gallery(label="Sample Images", columns=2, height="auto")
443
- viz_status = gr.Textbox(label="Status", interactive=False)
444
-
445
- viz_btn.click(visualize_data, inputs=[data_type_viz, num_samples],
446
- outputs=[viz_gallery, viz_status])
447
-
448
- gr.Markdown("### Analyze Class Distribution")
449
- with gr.Row():
450
- data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
451
- analyze_btn = gr.Button("Analyze Distribution")
452
-
453
- distribution_plot = gr.Image(label="Class Distribution", type="pil")
454
- analysis_status = gr.Textbox(label="Status", interactive=False)
455
-
456
- analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis,
457
- outputs=[distribution_plot, analysis_status])
458
 
459
- with gr.Tab("πŸš€ Training"):
460
- gr.Markdown("### Train YOLOv11 Model")
461
- gr.Markdown("""
462
- **Note:** Training will automatically use GPU if available. This may take several minutes.
463
-
464
- **Tips for Hugging Face Spaces:**
465
- - Use smaller batch sizes (4-8) to avoid GPU memory issues
466
- - Start with fewer epochs (5-10) for testing
467
- - Image size 480 provides good balance between quality and speed
468
- """)
469
-
470
- with gr.Row():
471
- epochs_input = gr.Slider(1, 50, 10, step=1, label="Epochs")
472
- batch_size_input = gr.Slider(4, 32, 8, step=4, label="Batch Size (lower for limited GPU)")
473
- img_size_input = gr.Slider(320, 640, 480, step=32, label="Image Size")
474
- device_input = gr.Radio(["Auto", "GPU", "CPU"], value="Auto", label="Device")
475
-
476
- train_btn = gr.Button("Start Training", variant="primary")
477
-
478
- training_gallery = gr.Gallery(label="Training Results", columns=3, height="auto")
479
- training_status = gr.Textbox(label="Training Status", interactive=False)
480
-
481
- train_btn.click(train_model,
482
- inputs=[epochs_input, batch_size_input, img_size_input, device_input],
483
- outputs=[training_gallery, training_status])
484
-
485
- gr.Markdown("### Load Pre-trained Model")
486
- with gr.Row():
487
- model_path_input = gr.Textbox(label="Model Path", value="./xray_detection/train/weights/best.pt")
488
- load_model_btn = gr.Button("Load Model")
489
- load_status = gr.Textbox(label="Status", interactive=False)
490
-
491
- load_model_btn.click(load_pretrained_model, inputs=model_path_input, outputs=load_status)
492
 
493
- with gr.Tab("πŸ” Inference"):
494
- gr.Markdown("### Single Image Inference")
495
-
496
- with gr.Row():
497
- with gr.Column():
498
- input_image = gr.Image(type="pil", label="Upload Image")
499
- conf_threshold = gr.Slider(0.1, 0.9, 0.5, step=0.05, label="Confidence Threshold")
500
- inference_btn = gr.Button("Run Detection", variant="primary")
 
 
 
 
 
501
 
502
- with gr.Column():
503
- output_image = gr.Image(type="pil", label="Detection Result")
504
- detection_info = gr.Textbox(label="Detection Info", lines=5)
505
-
506
- inference_btn.click(run_inference,
507
- inputs=[input_image, conf_threshold],
508
- outputs=[output_image, detection_info])
509
-
510
- gr.Markdown("### Batch Inference")
511
-
512
- with gr.Row():
513
- batch_data_type = gr.Dropdown(["test", "valid"], value="test", label="Dataset Type")
514
- batch_num_images = gr.Slider(1, 10, 5, step=1, label="Number of Images")
515
- batch_btn = gr.Button("Run Batch Inference")
516
-
517
- batch_gallery = gr.Gallery(label="Batch Results", columns=3, height="auto")
518
- batch_status = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
 
520
- batch_btn.click(batch_inference,
521
- inputs=[batch_data_type, batch_num_images],
522
- outputs=[batch_gallery, batch_status])
523
 
524
- # Launch the app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  if __name__ == "__main__":
526
- # Check if running on Hugging Face Spaces
527
- if ON_SPACES:
528
- demo.launch(ssr_mode=False)
529
- else:
530
- demo.launch(share=True, ssr_mode=False)
 
1
+ # UVIS - Gradio App with Upload, URL & Video Support + HF Token Authentication
2
+ """
3
+ This script launches the UVIS (Unified Visual Intelligence System) as a Gradio Web App.
4
+ Supports image, video, and URL-based media inputs for detection, segmentation, and depth estimation.
5
+ Outputs include scene blueprint, structured JSON, and downloadable results.
6
+ Now includes HuggingFace token authentication for private model access.
7
+ """
8
+
9
  import os
10
+ import time
11
+ import logging
12
+ import traceback
13
+
 
 
 
14
  import gradio as gr
 
 
15
  from PIL import Image
16
+ import cv2
17
+ import timeout_decorator
18
+ import spaces
 
 
 
 
 
19
  import tempfile
20
+ import shutil
21
+
22
+ from registry import get_model
23
+ from core.describe_scene import describe_scene
24
+ from core.process import process_image, process_video
25
+ from core.input_handler import resolve_input, validate_video, validate_image
26
+ from utils.helpers import format_error, generate_session_id
27
+ from huggingface_hub import hf_hub_download, login
28
 
29
+ # HuggingFace Token Authentication
30
+ HF_TOKEN = os.getenv("HF_TOKEN")
31
+ if HF_TOKEN:
32
+ try:
33
+ login(token=HF_TOKEN)
34
+ print("βœ… Successfully authenticated with HuggingFace using HF_TOKEN")
35
+ except Exception as e:
36
+ print(f"⚠️ Failed to authenticate with HuggingFace: {e}")
37
+ else:
38
+ print("⚠️ HF_TOKEN not found in environment variables. Some models may not be accessible.")
39
+
40
+ # Clear HF cache if needed
41
  try:
42
+ cache_paths = [
43
+ os.path.expanduser("~/.cache/huggingface"),
44
+ "/home/user/.cache/huggingface"
45
+ ]
46
+ for path in cache_paths:
47
+ if os.path.exists(path):
48
+ shutil.rmtree(path, ignore_errors=True)
49
+ print("πŸ’₯ Nuked HF model cache from runtime.")
50
+ except Exception as e:
51
+ print("🚫 Failed to nuke cache:", e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ # Setup logging
54
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
55
+ logger = logging.getLogger(__name__)
56
 
57
+ # Model mappings
58
+ DETECTION_MODEL_MAP = {
59
+ "YOLOv8-Nano": "yolov8n",
60
+ "YOLOv8-Small": "yolov8s",
61
+ "YOLOv8-Large": "yolov8l",
62
+ "YOLOv11-Beta": "yolov11b"
63
+ }
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ SEGMENTATION_MODEL_MAP = {
66
+ "SegFormer-B0": "segformer_b0",
67
+ "SegFormer-B5": "segformer_b5",
68
+ "DeepLabV3-ResNet50": "deeplabv3_resnet50"
69
+ }
 
 
 
 
 
 
 
 
 
 
70
 
71
+ DEPTH_MODEL_MAP = {
72
+ "MiDaS v21 Small 256": "midas_v21_small_256",
73
+ "MiDaS v21 384": "midas_v21_384",
74
+ "DPT Hybrid 384": "dpt_hybrid_384",
75
+ "DPT Swin2 Large 384": "dpt_swin2_large_384",
76
+ "DPT Beit Large 512": "dpt_beit_large_512"
77
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ # Modified get_model wrapper to include HF token
80
+ def get_model_with_auth(model_type, model_name, device="cpu"):
81
+ """
82
+ Wrapper for get_model that includes HF token authentication.
83
+ """
84
+ # Pass HF_TOKEN to the registry get_model function if it exists
85
+ # This assumes the registry.get_model can accept a token parameter
86
  try:
87
+ if hasattr(get_model, '__code__') and 'token' in get_model.__code__.co_varnames:
88
+ return get_model(model_type, model_name, device=device, token=HF_TOKEN)
89
+ else:
90
+ # If get_model doesn't support token, use standard call
91
+ return get_model(model_type, model_name, device=device)
 
 
 
 
 
 
 
 
 
92
  except Exception as e:
93
+ logger.error(f"Failed to load model {model_type}/{model_name}: {e}")
94
+ # Fallback: try without token parameter
95
+ return get_model(model_type, model_name, device=device)
96
 
97
+ @spaces.GPU
98
+ def handle(mode, media_upload, url,
99
+ run_det, det_model, det_confidence,
100
+ run_seg, seg_model,
101
+ run_depth, depth_model,
102
+ blend):
103
+ """
104
+ Master handler for resolving input and processing.
105
+ Returns: (img_out, vid_out, json_out, zip_out)
106
+ """
107
+ session_id = generate_session_id()
108
+ logger.info(f"Session ID: {session_id} | Handler activated with mode: {mode}")
109
+ start_time = time.time()
 
110
 
111
+ # Check HF authentication status
112
+ if not HF_TOKEN:
113
+ logger.warning("Processing without HF authentication. Some models may not be available.")
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ media = resolve_input(mode, media_upload, url)
116
+ if not media:
117
+ return (
118
+ gr.update(visible=False),
119
+ gr.update(visible=False),
120
+ format_error("No valid input provided. Please check your upload or URL."),
121
+ None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ first_input = media[0]
125
+
126
+ # πŸ”§ Resolve dropdown label to model keys
127
+ resolved_det_model = DETECTION_MODEL_MAP.get(det_model, det_model)
128
+ resolved_seg_model = SEGMENTATION_MODEL_MAP.get(seg_model, seg_model)
129
+ resolved_depth_model = DEPTH_MODEL_MAP.get(depth_model, depth_model)
130
+
131
+ # --- VIDEO PATH ---
132
+ if isinstance(first_input, str) and first_input.lower().endswith((".mp4", ".mov", ".avi")):
133
+ valid, err = validate_video(first_input)
134
+ if not valid:
135
+ return (
136
+ gr.update(visible=False),
137
+ gr.update(visible=False),
138
+ format_error(err),
139
+ None
140
+ )
141
  try:
142
+ # Pass HF_TOKEN to process_video if needed
143
+ _, msg, output_video_path = process_video(
144
+ video_path=first_input,
145
+ run_det=run_det,
146
+ det_model=resolved_det_model,
147
+ det_confidence=det_confidence,
148
+ run_seg=run_seg,
149
+ seg_model=resolved_seg_model,
150
+ run_depth=run_depth,
151
+ depth_model=resolved_depth_model,
152
+ blend=blend,
153
+ hf_token=HF_TOKEN # Pass token if process_video supports it
154
+ )
155
+ return (
156
+ gr.update(visible=False), # hide image
157
+ gr.update(value=output_video_path, visible=True), # show video
158
+ msg,
159
+ output_video_path # for download
160
+ )
161
+ except Exception as e:
162
+ logger.error(f"Video processing failed: {e}")
163
+ # If it's an authentication error, provide specific message
164
+ if "401" in str(e) or "unauthorized" in str(e).lower():
165
+ error_msg = "Authentication failed. Please check HF_TOKEN environment variable."
166
+ else:
167
+ error_msg = str(e)
168
+ return (
169
+ gr.update(visible=False),
170
+ gr.update(visible=False),
171
+ format_error(error_msg),
172
+ None
173
+ )
 
 
 
 
 
174
 
175
+ # --- IMAGE PATH ---
176
+ elif isinstance(first_input, Image.Image):
177
+ valid, err = validate_image(first_input)
178
+ if not valid:
179
+ return (
180
+ gr.update(visible=False),
181
+ gr.update(visible=False),
182
+ format_error(err),
183
+ None
184
+ )
185
  try:
186
+ # Pass HF_TOKEN to process_image if needed
187
+ result_img, msg, output_zip = process_image(
188
+ image=first_input,
189
+ run_det=run_det,
190
+ det_model=resolved_det_model,
191
+ det_confidence=det_confidence,
192
+ run_seg=run_seg,
193
+ seg_model=resolved_seg_model,
194
+ run_depth=run_depth,
195
+ depth_model=resolved_depth_model,
196
+ blend=blend,
197
+ hf_token=HF_TOKEN # Pass token if process_image supports it
198
+ )
199
+ return (
200
+ gr.update(value=result_img, visible=True), # show image
201
+ gr.update(visible=False), # hide video
202
+ msg,
203
+ output_zip
204
+ )
205
+ except timeout_decorator.timeout_decorator.TimeoutError:
206
+ logger.error("Image processing timed out.")
207
+ return (
208
+ gr.update(visible=False),
209
+ gr.update(visible=False),
210
+ format_error("Processing timed out. Try a smaller image or simpler model."),
211
+ None
212
+ )
213
+ except Exception as e:
214
+ traceback.print_exc()
215
+ logger.error(f"Image processing failed: {e}")
216
+ # If it's an authentication error, provide specific message
217
+ if "401" in str(e) or "unauthorized" in str(e).lower():
218
+ error_msg = "Authentication failed. Please check HF_TOKEN environment variable."
219
+ else:
220
+ error_msg = str(e)
221
+ return (
222
+ gr.update(visible=False),
223
+ gr.update(visible=False),
224
+ format_error(error_msg),
225
+ None
226
+ )
227
 
228
+ logger.warning("Unsupported media type resolved.")
229
+ return (
230
+ gr.update(visible=False),
231
+ gr.update(visible=False),
232
+ format_error("Unsupported input type."),
233
+ None
234
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
+ def show_preview_from_upload(files):
237
+ if not files:
238
+ return gr.update(visible=False), gr.update(visible=False)
 
 
 
 
 
 
 
239
 
240
+ file = files[0]
241
+ filename = file.name.lower()
242
+
243
+ if filename.endswith((".png", ".jpg", ".jpeg", ".webp")):
244
+ img = Image.open(file).convert("RGB")
245
+ return gr.update(value=img, visible=True), gr.update(visible=False)
246
+
247
+ elif filename.endswith((".mp4", ".mov", ".avi")):
248
+ # Copy uploaded video to a known temp location
249
+ temp_dir = tempfile.mkdtemp()
250
+ ext = os.path.splitext(filename)[-1]
251
+ safe_path = os.path.join(temp_dir, f"uploaded_video{ext}")
252
+ with open(safe_path, "wb") as f:
253
+ f.write(file.read())
254
+
255
+ return gr.update(visible=False), gr.update(value=safe_path, visible=True)
256
+
257
+ return gr.update(visible=False), gr.update(visible=False)
258
+
259
+ def show_preview_from_url(url_input):
260
+ if not url_input:
261
+ return gr.update(visible=False), gr.update(visible=False)
262
+ path = url_input.strip().lower()
263
+ if path.endswith((".png", ".jpg", ".jpeg", ".webp")):
264
+ return gr.update(value=url_input, visible=True), gr.update(visible=False)
265
+ elif path.endswith((".mp4", ".mov", ".avi")):
266
+ return gr.update(visible=False), gr.update(value=url_input, visible=True)
267
+ return gr.update(visible=False), gr.update(visible=False)
268
+
269
+ def clear_model_cache():
270
+ """
271
+ Deletes all model weight folders so they are redownloaded fresh.
272
+ """
273
+ folders = [
274
+ "models/detection/weights",
275
+ "models/segmentation/weights",
276
+ "models/depth/weights"
277
+ ]
278
+ for folder in folders:
279
+ shutil.rmtree(folder, ignore_errors=True)
280
+ logger.info(f"πŸ—‘οΈ Cleared: {folder}")
281
 
282
+ # Also clear HF cache if token is available
283
+ if HF_TOKEN:
284
+ try:
285
+ cache_paths = [
286
+ os.path.expanduser("~/.cache/huggingface"),
287
+ "/home/user/.cache/huggingface"
288
+ ]
289
+ for path in cache_paths:
290
+ if os.path.exists(path):
291
+ shutil.rmtree(path, ignore_errors=True)
292
+ return "βœ… Model cache and HF cache cleared. Models will be reloaded on next run."
293
+ except Exception as e:
294
+ return f"⚠️ Model cache cleared, but failed to clear HF cache: {e}"
295
 
296
+ return "βœ… Model cache cleared. Models will be reloaded on next run."
297
+
298
+ def check_auth_status():
299
+ """
300
+ Check and display current authentication status.
301
+ """
302
+ if HF_TOKEN:
303
+ return f"βœ… Authenticated with HuggingFace (Token: {HF_TOKEN[:8]}...)"
304
+ else:
305
+ return "❌ Not authenticated. Set HF_TOKEN environment variable for private model access."
306
+
307
+ # Gradio Interface
308
+ with gr.Blocks(title="UVIS - Unified Visual Intelligence System") as demo:
309
+ gr.Markdown("## Unified Visual Intelligence System (UVIS)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
+ # Authentication Status
312
+ with gr.Row():
313
+ auth_status = gr.Textbox(
314
+ label="HF Authentication Status",
315
+ value=check_auth_status(),
316
+ interactive=False
317
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
+ with gr.Row():
320
+ # left panel
321
+ with gr.Column(scale=2):
322
+ # Input Mode Toggle
323
+ mode = gr.Radio(["Upload", "URL"], value="Upload", label="Input Mode")
324
+
325
+ # File upload: accepts multiple images or one video (user chooses wisely)
326
+ media_upload = gr.File(
327
+ label="Upload Images (1–5) or 1 Video",
328
+ file_types=["image", ".mp4", ".mov", ".avi"],
329
+ file_count="multiple",
330
+ visible=True
331
+ )
332
 
333
+ # URL input
334
+ url = gr.Textbox(label="URL (Image/Video)", visible=False)
335
+
336
+ # Toggle visibility
337
+ def toggle_inputs(selected_mode):
338
+ return [
339
+ gr.update(visible=(selected_mode == "Upload")), # media_upload
340
+ gr.update(visible=(selected_mode == "URL")), # url
341
+ gr.update(visible=False), # preview_image
342
+ gr.update(visible=False) # preview_video
343
+ ]
344
+
345
+ mode.change(toggle_inputs, inputs=mode, outputs=[media_upload, url])
346
+
347
+ # Visibility logic function
348
+ def toggle_visibility(checked):
349
+ return gr.update(visible=checked)
350
+
351
+ run_det = gr.Checkbox(label="Object Detection")
352
+ run_seg = gr.Checkbox(label="Semantic Segmentation")
353
+ run_depth = gr.Checkbox(label="Depth Estimation")
354
+
355
+ with gr.Row():
356
+ with gr.Column(visible=False) as OD_Settings:
357
+ with gr.Accordion("Object Detection Settings", open=True):
358
+ det_model = gr.Dropdown(
359
+ choices=list(DETECTION_MODEL_MAP.keys()),
360
+ label="Detection Model",
361
+ value="YOLOv8-Nano"
362
+ )
363
+ det_confidence = gr.Slider(0.1, 1.0, 0.5, label="Detection Confidence Threshold")
364
+ nms_thresh = gr.Slider(0.1, 1.0, 0.45, label="NMS Threshold")
365
+ max_det = gr.Slider(1, 100, 20, step=1, label="Max Detections")
366
+ iou_thresh = gr.Slider(0.1, 1.0, 0.5, label="IoU Threshold")
367
+ class_filter = gr.CheckboxGroup(["Person", "Car", "Dog"], label="Class Filter")
368
+
369
+ with gr.Column(visible=False) as SS_Settings:
370
+ with gr.Accordion("Semantic Segmentation Settings", open=True):
371
+ seg_model = gr.Dropdown(
372
+ choices=list(SEGMENTATION_MODEL_MAP.keys()),
373
+ label="Segmentation Model",
374
+ value="DeepLabV3-ResNet50"
375
+ )
376
+ resize_strategy = gr.Dropdown(["Crop", "Pad", "Scale"], label="Resize Strategy", value="Scale")
377
+ overlay_alpha = gr.Slider(0.0, 1.0, 0.5, label="Overlay Opacity")
378
+ seg_classes = gr.CheckboxGroup(["Road", "Sky", "Building"], label="Target Classes")
379
+ enable_crf = gr.Checkbox(label="Postprocessing (CRF)")
380
+
381
+ with gr.Column(visible=False) as DE_Settings:
382
+ with gr.Accordion("Depth Estimation Settings", open=True):
383
+ depth_model = gr.Dropdown(
384
+ choices=list(DEPTH_MODEL_MAP.keys()),
385
+ label="Depth Model",
386
+ value="MiDaS v21 Small 256"
387
+ )
388
+ output_type = gr.Dropdown(["Raw", "Disparity", "Scaled"], label="Output Type", value="Scaled")
389
+ colormap = gr.Dropdown(["Jet", "Viridis", "Plasma"], label="Colormap", value="Jet")
390
+ blend = gr.Slider(0.0, 1.0, 0.5, label="Overlay Blend")
391
+ normalize = gr.Checkbox(label="Normalize Depth", value=True)
392
+ max_depth = gr.Slider(0.1, 10.0, 5.0, label="Max Depth (meters)")
393
+
394
+ # Attach Visibility Logic
395
+ run_det.change(fn=toggle_visibility, inputs=[run_det], outputs=[OD_Settings])
396
+ run_seg.change(fn=toggle_visibility, inputs=[run_seg], outputs=[SS_Settings])
397
+ run_depth.change(fn=toggle_visibility, inputs=[run_depth], outputs=[DE_Settings])
398
+
399
+ blend = gr.Slider(0.0, 1.0, 0.5, label="Overlay Blend")
400
+
401
+ # Run Button
402
+ run = gr.Button("Run Analysis", variant="primary")
403
+
404
+ # Right panel
405
+ with gr.Column(scale=1):
406
+ # Only one is shown at a time β€” image or video
407
+ img_out = gr.Image(label="Preview / Processed Output", visible=False)
408
+ vid_out = gr.Video(label="Preview / Processed Video", visible=False, streaming=True, autoplay=True)
409
+ json_out = gr.JSON(label="Scene JSON")
410
+ zip_out = gr.File(label="Download Results")
411
+
412
+ with gr.Row():
413
+ clear_button = gr.Button("🧹 Clear Model Cache")
414
+ refresh_auth_button = gr.Button("πŸ”„ Refresh Auth Status")
415
+
416
+ status_box = gr.Textbox(label="Status", interactive=False)
417
 
418
+ clear_button.click(fn=clear_model_cache, inputs=[], outputs=[status_box])
419
+ refresh_auth_button.click(fn=check_auth_status, inputs=[], outputs=[auth_status])
 
420
 
421
+ media_upload.change(show_preview_from_upload, inputs=media_upload, outputs=[img_out, vid_out])
422
+ url.submit(show_preview_from_url, inputs=url, outputs=[img_out, vid_out])
423
+
424
+ # Button Click Event
425
+ run.click(
426
+ fn=handle,
427
+ inputs=[
428
+ mode, media_upload, url,
429
+ run_det, det_model, det_confidence,
430
+ run_seg, seg_model,
431
+ run_depth, depth_model,
432
+ blend
433
+ ],
434
+ outputs=[
435
+ img_out, # will be visible only if it's an image
436
+ vid_out, # will be visible only if it's a video
437
+ json_out,
438
+ zip_out
439
+ ]
440
+ )
441
+
442
+ # Footer Section
443
+ gr.Markdown("---")
444
+ gr.Markdown(
445
+ f"""
446
+ <div style='text-align: center; font-size: 14px;'>
447
+ Built by <b>Durga Deepak Valluri</b><br>
448
+ <a href="https://github.com/DurgaDeepakValluri" target="_blank">GitHub</a> |
449
+ <a href="https://deecoded.io" target="_blank">Website</a> |
450
+ <a href="https://www.linkedin.com/in/durga-deepak-valluri" target="_blank">LinkedIn</a><br>
451
+ <span style='font-size: 12px; color: #666;'>
452
+ {'πŸ” HF Authentication Active' if HF_TOKEN else 'πŸ”“ No HF Authentication'}
453
+ </span>
454
+ </div>
455
+ """,
456
+ )
457
+
458
+ # Launch the Gradio App
459
  if __name__ == "__main__":
460
+ demo.launch()