openfree commited on
Commit
9162e76
Β·
verified Β·
1 Parent(s): ede7184

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -36
app.py CHANGED
@@ -15,13 +15,27 @@ from ultralytics import YOLO
15
  import shutil
16
  import tempfile
17
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Set Kaggle API credentials from environment variable
20
  if os.getenv("KDATA_API"):
21
  kaggle_key = os.getenv("KDATA_API")
22
  # Parse the key if it's in JSON format
23
  if "{" in kaggle_key:
24
- import json
25
  key_data = json.loads(kaggle_key)
26
  os.environ["KAGGLE_USERNAME"] = key_data.get("username", "")
27
  os.environ["KAGGLE_KEY"] = key_data.get("key", "")
@@ -95,6 +109,7 @@ class Visualization:
95
  ax.imshow(or_im)
96
  ax.axis("off")
97
  ax.set_title(f"Number of objects: {len(bboxes)}")
 
98
 
99
  return fig
100
 
@@ -147,23 +162,37 @@ def download_dataset():
147
  """Download the dataset using kagglehub"""
148
  global dataset_path
149
  try:
 
 
 
 
150
  dataset_path = kagglehub.dataset_download("orvile/x-ray-baggage-anomaly-detection")
 
 
 
 
 
 
 
 
151
  return f"Dataset downloaded successfully to: {dataset_path}"
152
  except Exception as e:
153
- return f"Error downloading dataset: {str(e)}"
154
 
155
  def visualize_data(data_type, num_samples):
156
  """Visualize sample images from the dataset"""
157
  if dataset_path is None:
158
- return None, "Please download the dataset first!"
159
 
160
  try:
161
  vis = Visualization(root=dataset_path, data_types=[data_type],
162
  n_ims=num_samples, rows=2, cmap="rgb")
163
  figs = vis.vis_samples(data_type, num_samples)
 
 
164
  return figs, f"Showing {len(figs)} samples from {data_type} dataset"
165
  except Exception as e:
166
- return None, f"Error visualizing data: {str(e)}"
167
 
168
  def analyze_class_distribution(data_type):
169
  """Analyze class distribution in the dataset"""
@@ -174,34 +203,43 @@ def analyze_class_distribution(data_type):
174
  vis = Visualization(root=dataset_path, data_types=[data_type],
175
  n_ims=20, rows=5, cmap="rgb")
176
  fig = vis.data_analysis(data_type)
 
 
177
  return fig, f"Class distribution for {data_type} dataset"
178
  except Exception as e:
179
  return None, f"Error analyzing data: {str(e)}"
180
 
 
181
  def train_model(epochs, batch_size, img_size, device_selection):
182
  """Train YOLOv11 model"""
183
  global model, training_in_progress
184
 
185
  if dataset_path is None:
186
- return None, "Please download the dataset first!"
187
 
188
  if training_in_progress:
189
- return None, "Training already in progress!"
190
 
191
  training_in_progress = True
192
 
193
  try:
194
- # Determine device
195
- if device_selection == "Auto":
 
 
196
  device = 0 if torch.cuda.is_available() else "cpu"
197
  elif device_selection == "CPU":
198
  device = "cpu"
199
  else:
200
- device = 0
201
 
202
  # Initialize model
203
  model = YOLO("yolo11n.pt")
204
 
 
 
 
 
205
  # Train model
206
  results = model.train(
207
  data=f"{dataset_path}/data.yaml",
@@ -209,34 +247,50 @@ def train_model(epochs, batch_size, img_size, device_selection):
209
  imgsz=img_size,
210
  batch=batch_size,
211
  device=device,
212
- project="xray_detection",
213
  name="train",
214
  exist_ok=True,
215
- verbose=True
 
 
216
  )
217
 
218
- # Plot training results
219
- results_path = "xray_detection/train"
220
  plots = []
221
 
222
- for plot_file in ["results.png", "confusion_matrix.png", "val_batch0_pred.jpg"]:
 
 
 
223
  plot_path = os.path.join(results_path, plot_file)
224
  if os.path.exists(plot_path):
225
  plots.append(Image.open(plot_path))
226
 
 
 
 
227
  training_in_progress = False
228
- return plots, f"Training completed! Model saved to {results_path}"
229
 
230
  except Exception as e:
231
  training_in_progress = False
232
- return None, f"Error during training: {str(e)}"
233
 
 
234
  def run_inference(input_image, conf_threshold):
235
  """Run inference on a single image"""
236
  global model
237
 
238
  if model is None:
239
- return None, "Please train the model first or load a pre-trained model!"
 
 
 
 
 
 
 
240
 
241
  try:
242
  # Save the input image temporarily
@@ -251,15 +305,16 @@ def run_inference(input_image, conf_threshold):
251
 
252
  # Get detection info
253
  detections = []
254
- for r in results:
255
- for box in r.boxes:
256
  cls = int(box.cls)
257
  conf = float(box.conf)
258
  cls_name = model.names[cls]
259
  detections.append(f"{cls_name}: {conf:.2f}")
260
 
261
  # Clean up
262
- os.remove(temp_path)
 
263
 
264
  detection_text = "\n".join(detections) if detections else "No objects detected"
265
 
@@ -268,20 +323,30 @@ def run_inference(input_image, conf_threshold):
268
  except Exception as e:
269
  return None, f"Error during inference: {str(e)}"
270
 
 
271
  def batch_inference(data_type, num_images):
272
  """Run inference on multiple images from test set"""
273
  global model
274
 
275
  if model is None:
276
- return None, "Please train the model first!"
 
 
 
277
 
278
  if dataset_path is None:
279
- return None, "Please download the dataset first!"
280
 
281
  try:
282
  image_dir = f"{dataset_path}/{data_type}/images"
 
 
 
283
  image_files = glob(f"{image_dir}/*")[:num_images]
284
 
 
 
 
285
  results_images = []
286
 
287
  for img_path in image_files:
@@ -292,19 +357,31 @@ def batch_inference(data_type, num_images):
292
  return results_images, f"Processed {len(results_images)} images from {data_type} dataset"
293
 
294
  except Exception as e:
295
- return None, f"Error during batch inference: {str(e)}"
296
 
297
  def load_pretrained_model(model_path):
298
  """Load a pre-trained model"""
299
  global model
300
  try:
 
 
 
 
 
 
 
 
 
 
 
 
301
  model = YOLO(model_path)
302
  return f"Model loaded successfully from {model_path}"
303
  except Exception as e:
304
  return f"Error loading model: {str(e)}"
305
 
306
  # Create Gradio interface
307
- with gr.Blocks(title="X-ray Baggage Anomaly Detection") as demo:
308
  gr.Markdown("""
309
  # 🎯 X-ray Baggage Anomaly Detection with YOLOv11
310
 
@@ -313,12 +390,25 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection") as demo:
313
  2. Analyze class distributions
314
  3. Train a YOLOv11 model for object detection
315
  4. Run inference on new images
 
 
316
  """)
317
 
 
 
 
 
 
 
 
 
 
 
 
318
  with gr.Tab("πŸ“Š Dataset"):
319
  with gr.Row():
320
- download_btn = gr.Button("Download Dataset", variant="primary")
321
- download_status = gr.Textbox(label="Status", interactive=False)
322
 
323
  download_btn.click(download_dataset, outputs=download_status)
324
 
@@ -347,6 +437,7 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection") as demo:
347
 
348
  with gr.Tab("πŸš€ Training"):
349
  gr.Markdown("### Train YOLOv11 Model")
 
350
 
351
  with gr.Row():
352
  epochs_input = gr.Slider(1, 50, 10, step=1, label="Epochs")
@@ -365,7 +456,7 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection") as demo:
365
 
366
  gr.Markdown("### Load Pre-trained Model")
367
  with gr.Row():
368
- model_path_input = gr.Textbox(label="Model Path", value="yolo11n.pt")
369
  load_model_btn = gr.Button("Load Model")
370
  load_status = gr.Textbox(label="Status", interactive=False)
371
 
@@ -375,14 +466,14 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection") as demo:
375
  gr.Markdown("### Single Image Inference")
376
 
377
  with gr.Row():
378
- input_image = gr.Image(type="pil", label="Upload Image")
379
- conf_threshold = gr.Slider(0.1, 0.9, 0.5, step=0.05, label="Confidence Threshold")
380
-
381
- inference_btn = gr.Button("Run Detection", variant="primary")
382
-
383
- with gr.Row():
384
- output_image = gr.Image(type="pil", label="Detection Result")
385
- detection_info = gr.Textbox(label="Detection Info", lines=5)
386
 
387
  inference_btn.click(run_inference,
388
  inputs=[input_image, conf_threshold],
@@ -404,4 +495,8 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection") as demo:
404
 
405
  # Launch the app
406
  if __name__ == "__main__":
407
- demo.launch(share=True)
 
 
 
 
 
15
  import shutil
16
  import tempfile
17
  from pathlib import Path
18
+ import json
19
+
20
+ # Try to import spaces for Hugging Face Spaces GPU support
21
+ try:
22
+ import spaces
23
+ ON_SPACES = True
24
+ except ImportError:
25
+ ON_SPACES = False
26
+ # Create a dummy decorator if not on Spaces
27
+ class spaces:
28
+ @staticmethod
29
+ def GPU(duration=60):
30
+ def decorator(func):
31
+ return func
32
+ return decorator
33
 
34
  # Set Kaggle API credentials from environment variable
35
  if os.getenv("KDATA_API"):
36
  kaggle_key = os.getenv("KDATA_API")
37
  # Parse the key if it's in JSON format
38
  if "{" in kaggle_key:
 
39
  key_data = json.loads(kaggle_key)
40
  os.environ["KAGGLE_USERNAME"] = key_data.get("username", "")
41
  os.environ["KAGGLE_KEY"] = key_data.get("key", "")
 
109
  ax.imshow(or_im)
110
  ax.axis("off")
111
  ax.set_title(f"Number of objects: {len(bboxes)}")
112
+ plt.tight_layout()
113
 
114
  return fig
115
 
 
162
  """Download the dataset using kagglehub"""
163
  global dataset_path
164
  try:
165
+ # Create a local directory to store the dataset
166
+ local_dir = "./xray_dataset"
167
+
168
+ # Download dataset
169
  dataset_path = kagglehub.dataset_download("orvile/x-ray-baggage-anomaly-detection")
170
+
171
+ # If the dataset is downloaded to a temporary location, copy it to our local directory
172
+ if dataset_path != local_dir and os.path.exists(dataset_path):
173
+ if os.path.exists(local_dir):
174
+ shutil.rmtree(local_dir)
175
+ shutil.copytree(dataset_path, local_dir)
176
+ dataset_path = local_dir
177
+
178
  return f"Dataset downloaded successfully to: {dataset_path}"
179
  except Exception as e:
180
+ return f"Error downloading dataset: {str(e)}\n\nPlease ensure KDATA_API environment variable is set correctly."
181
 
182
  def visualize_data(data_type, num_samples):
183
  """Visualize sample images from the dataset"""
184
  if dataset_path is None:
185
+ return [], "Please download the dataset first!"
186
 
187
  try:
188
  vis = Visualization(root=dataset_path, data_types=[data_type],
189
  n_ims=num_samples, rows=2, cmap="rgb")
190
  figs = vis.vis_samples(data_type, num_samples)
191
+ if figs is None:
192
+ return [], f"No data found for {data_type} dataset"
193
  return figs, f"Showing {len(figs)} samples from {data_type} dataset"
194
  except Exception as e:
195
+ return [], f"Error visualizing data: {str(e)}"
196
 
197
  def analyze_class_distribution(data_type):
198
  """Analyze class distribution in the dataset"""
 
203
  vis = Visualization(root=dataset_path, data_types=[data_type],
204
  n_ims=20, rows=5, cmap="rgb")
205
  fig = vis.data_analysis(data_type)
206
+ if fig is None:
207
+ return None, f"No data found for {data_type} dataset"
208
  return fig, f"Class distribution for {data_type} dataset"
209
  except Exception as e:
210
  return None, f"Error analyzing data: {str(e)}"
211
 
212
+ @spaces.GPU(duration=300) # Request GPU for 5 minutes for training
213
  def train_model(epochs, batch_size, img_size, device_selection):
214
  """Train YOLOv11 model"""
215
  global model, training_in_progress
216
 
217
  if dataset_path is None:
218
+ return [], "Please download the dataset first!"
219
 
220
  if training_in_progress:
221
+ return [], "Training already in progress!"
222
 
223
  training_in_progress = True
224
 
225
  try:
226
+ # Determine device - on Spaces, always use GPU if available
227
+ if ON_SPACES and torch.cuda.is_available():
228
+ device = 0
229
+ elif device_selection == "Auto":
230
  device = 0 if torch.cuda.is_available() else "cpu"
231
  elif device_selection == "CPU":
232
  device = "cpu"
233
  else:
234
+ device = 0 if torch.cuda.is_available() else "cpu"
235
 
236
  # Initialize model
237
  model = YOLO("yolo11n.pt")
238
 
239
+ # Create project directory
240
+ project_dir = "./xray_detection"
241
+ os.makedirs(project_dir, exist_ok=True)
242
+
243
  # Train model
244
  results = model.train(
245
  data=f"{dataset_path}/data.yaml",
 
247
  imgsz=img_size,
248
  batch=batch_size,
249
  device=device,
250
+ project=project_dir,
251
  name="train",
252
  exist_ok=True,
253
+ verbose=True,
254
+ patience=5, # Reduce patience for faster training on Spaces
255
+ save_period=5 # Save checkpoints every 5 epochs
256
  )
257
 
258
+ # Collect training result plots
259
+ results_path = os.path.join(project_dir, "train")
260
  plots = []
261
 
262
+ plot_files = ["results.png", "confusion_matrix.png", "val_batch0_pred.jpg",
263
+ "train_batch0.jpg", "val_batch0_labels.jpg"]
264
+
265
+ for plot_file in plot_files:
266
  plot_path = os.path.join(results_path, plot_file)
267
  if os.path.exists(plot_path):
268
  plots.append(Image.open(plot_path))
269
 
270
+ # Save the model path
271
+ model_path = os.path.join(results_path, "weights", "best.pt")
272
+
273
  training_in_progress = False
274
+ return plots, f"Training completed! Model saved to {model_path}"
275
 
276
  except Exception as e:
277
  training_in_progress = False
278
+ return [], f"Error during training: {str(e)}"
279
 
280
+ @spaces.GPU(duration=60) # Request GPU for 1 minute for inference
281
  def run_inference(input_image, conf_threshold):
282
  """Run inference on a single image"""
283
  global model
284
 
285
  if model is None:
286
+ # Try to load a default model
287
+ try:
288
+ model = YOLO("yolo11n.pt")
289
+ except:
290
+ return None, "Please train the model first or load a pre-trained model!"
291
+
292
+ if input_image is None:
293
+ return None, "Please upload an image!"
294
 
295
  try:
296
  # Save the input image temporarily
 
305
 
306
  # Get detection info
307
  detections = []
308
+ if results[0].boxes is not None:
309
+ for box in results[0].boxes:
310
  cls = int(box.cls)
311
  conf = float(box.conf)
312
  cls_name = model.names[cls]
313
  detections.append(f"{cls_name}: {conf:.2f}")
314
 
315
  # Clean up
316
+ if os.path.exists(temp_path):
317
+ os.remove(temp_path)
318
 
319
  detection_text = "\n".join(detections) if detections else "No objects detected"
320
 
 
323
  except Exception as e:
324
  return None, f"Error during inference: {str(e)}"
325
 
326
+ @spaces.GPU(duration=60) # Request GPU for batch inference
327
  def batch_inference(data_type, num_images):
328
  """Run inference on multiple images from test set"""
329
  global model
330
 
331
  if model is None:
332
+ try:
333
+ model = YOLO("yolo11n.pt")
334
+ except:
335
+ return [], "Please train the model first!"
336
 
337
  if dataset_path is None:
338
+ return [], "Please download the dataset first!"
339
 
340
  try:
341
  image_dir = f"{dataset_path}/{data_type}/images"
342
+ if not os.path.exists(image_dir):
343
+ return [], f"Directory {image_dir} not found!"
344
+
345
  image_files = glob(f"{image_dir}/*")[:num_images]
346
 
347
+ if not image_files:
348
+ return [], f"No images found in {image_dir}"
349
+
350
  results_images = []
351
 
352
  for img_path in image_files:
 
357
  return results_images, f"Processed {len(results_images)} images from {data_type} dataset"
358
 
359
  except Exception as e:
360
+ return [], f"Error during batch inference: {str(e)}"
361
 
362
  def load_pretrained_model(model_path):
363
  """Load a pre-trained model"""
364
  global model
365
  try:
366
+ if not os.path.exists(model_path):
367
+ # Try default paths
368
+ default_paths = [
369
+ "./xray_detection/train/weights/best.pt",
370
+ "./xray_detection/train/weights/last.pt",
371
+ "yolo11n.pt"
372
+ ]
373
+ for path in default_paths:
374
+ if os.path.exists(path):
375
+ model_path = path
376
+ break
377
+
378
  model = YOLO(model_path)
379
  return f"Model loaded successfully from {model_path}"
380
  except Exception as e:
381
  return f"Error loading model: {str(e)}"
382
 
383
  # Create Gradio interface
384
+ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft()) as demo:
385
  gr.Markdown("""
386
  # 🎯 X-ray Baggage Anomaly Detection with YOLOv11
387
 
 
390
  2. Analyze class distributions
391
  3. Train a YOLOv11 model for object detection
392
  4. Run inference on new images
393
+
394
+ **Note:** GPU will be automatically allocated when needed for training and inference.
395
  """)
396
 
397
+ # Add instructions for Kaggle API setup
398
+ with gr.Accordion("πŸ“ Setup Instructions", open=False):
399
+ gr.Markdown("""
400
+ ### Kaggle API Setup
401
+ 1. Get your Kaggle API credentials from https://www.kaggle.com/settings
402
+ 2. Set the KDATA_API environment variable in Hugging Face Spaces settings:
403
+ ```
404
+ KDATA_API={"username":"your_username","key":"your_api_key"}
405
+ ```
406
+ """)
407
+
408
  with gr.Tab("πŸ“Š Dataset"):
409
  with gr.Row():
410
+ download_btn = gr.Button("Download Dataset", variant="primary", scale=1)
411
+ download_status = gr.Textbox(label="Status", interactive=False, scale=3)
412
 
413
  download_btn.click(download_dataset, outputs=download_status)
414
 
 
437
 
438
  with gr.Tab("πŸš€ Training"):
439
  gr.Markdown("### Train YOLOv11 Model")
440
+ gr.Markdown("**Note:** Training will automatically use GPU if available. This may take several minutes.")
441
 
442
  with gr.Row():
443
  epochs_input = gr.Slider(1, 50, 10, step=1, label="Epochs")
 
456
 
457
  gr.Markdown("### Load Pre-trained Model")
458
  with gr.Row():
459
+ model_path_input = gr.Textbox(label="Model Path", value="./xray_detection/train/weights/best.pt")
460
  load_model_btn = gr.Button("Load Model")
461
  load_status = gr.Textbox(label="Status", interactive=False)
462
 
 
466
  gr.Markdown("### Single Image Inference")
467
 
468
  with gr.Row():
469
+ with gr.Column():
470
+ input_image = gr.Image(type="pil", label="Upload Image")
471
+ conf_threshold = gr.Slider(0.1, 0.9, 0.5, step=0.05, label="Confidence Threshold")
472
+ inference_btn = gr.Button("Run Detection", variant="primary")
473
+
474
+ with gr.Column():
475
+ output_image = gr.Image(type="pil", label="Detection Result")
476
+ detection_info = gr.Textbox(label="Detection Info", lines=5)
477
 
478
  inference_btn.click(run_inference,
479
  inputs=[input_image, conf_threshold],
 
495
 
496
  # Launch the app
497
  if __name__ == "__main__":
498
+ # Check if running on Hugging Face Spaces
499
+ if ON_SPACES:
500
+ demo.launch(ssr_mode=False)
501
+ else:
502
+ demo.launch(share=True, ssr_mode=False)