openfree commited on
Commit
5bd2bc6
Β·
verified Β·
1 Parent(s): dcf4020

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -27
app.py CHANGED
@@ -62,10 +62,20 @@ class Visualization:
62
  self.get_bboxes()
63
 
64
  def get_cls_names(self):
65
- with open(f"{self.root}/data.yaml", 'r') as file:
 
 
 
 
 
 
66
  data = yaml.safe_load(file)
67
- class_names = data['names']
68
- self.class_dict = {index: name for index, name in enumerate(class_names)}
 
 
 
 
69
 
70
  def get_bboxes(self):
71
  self.vis_datas, self.analysis_datas, self.im_paths = {}, {}, {}
@@ -291,8 +301,28 @@ def train_model(epochs, batch_size, img_size, device_selection):
291
  # Save the model path
292
  model_path = os.path.join(results_path, "weights", "best.pt")
293
 
 
 
 
 
 
 
 
 
 
294
  training_in_progress = False
295
- return plots, f"Training completed! Model saved to {model_path}"
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  except Exception as e:
298
  training_in_progress = False
@@ -303,17 +333,32 @@ def run_inference(input_image, conf_threshold):
303
  """Run inference on a single image"""
304
  global model
305
 
 
306
  if model is None:
307
- # Try to load a default model
308
- try:
309
- model = YOLO("yolo11n.pt")
310
- except:
311
- return None, "Please train the model first or load a pre-trained model!"
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  if input_image is None:
314
  return None, "Please upload an image!"
315
 
316
  try:
 
 
 
317
  # Save the input image temporarily
318
  temp_path = "temp_inference.jpg"
319
  input_image.save(temp_path)
@@ -326,20 +371,26 @@ def run_inference(input_image, conf_threshold):
326
 
327
  # Get detection info
328
  detections = []
329
- if results[0].boxes is not None:
330
  for box in results[0].boxes:
331
  cls = int(box.cls)
332
  conf = float(box.conf)
333
  cls_name = model.names[cls]
334
  detections.append(f"{cls_name}: {conf:.2f}")
 
 
 
 
 
 
 
 
335
 
336
  # Clean up
337
  if os.path.exists(temp_path):
338
  os.remove(temp_path)
339
 
340
- detection_text = "\n".join(detections) if detections else "No objects detected"
341
-
342
- return Image.fromarray(annotated_image), f"Detections:\n{detection_text}"
343
 
344
  except Exception as e:
345
  return None, f"Error during inference: {str(e)}"
@@ -349,11 +400,21 @@ def batch_inference(data_type, num_images):
349
  """Run inference on multiple images from test set"""
350
  global model
351
 
 
352
  if model is None:
353
- try:
354
- model = YOLO("yolo11n.pt")
355
- except:
356
- return [], "Please train the model first!"
 
 
 
 
 
 
 
 
 
357
 
358
  if dataset_path is None:
359
  return [], "Please download the dataset first!"
@@ -375,15 +436,47 @@ def batch_inference(data_type, num_images):
375
  annotated = results[0].plot()
376
  results_images.append(Image.fromarray(annotated))
377
 
378
- return results_images, f"Processed {len(results_images)} images from {data_type} dataset"
 
 
 
379
 
380
  except Exception as e:
381
  return [], f"Error during batch inference: {str(e)}"
382
 
383
- def load_pretrained_model(model_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  """Load a pre-trained model"""
385
  global model
386
  try:
 
 
 
 
 
 
387
  if not os.path.exists(model_path):
388
  # Try default paths
389
  default_paths = [
@@ -396,8 +489,17 @@ def load_pretrained_model(model_path):
396
  model_path = path
397
  break
398
 
399
- model = YOLO(model_path)
400
- return f"Model loaded successfully from {model_path}"
 
 
 
 
 
 
 
 
 
401
  except Exception as e:
402
  return f"Error loading model: {str(e)}"
403
 
@@ -424,6 +526,12 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
424
  ```
425
  KDATA_API={"username":"your_username","key":"your_api_key"}
426
  ```
 
 
 
 
 
 
427
  """)
428
 
429
  with gr.Tab("πŸ“Š Dataset"):
@@ -433,6 +541,13 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
433
 
434
  download_btn.click(download_dataset, outputs=download_status)
435
 
 
 
 
 
 
 
 
436
  gr.Markdown("### Visualize Dataset Samples")
437
  with gr.Row():
438
  data_type_viz = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
@@ -465,6 +580,8 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
465
  - Use smaller batch sizes (4-8) to avoid GPU memory issues
466
  - Start with fewer epochs (5-10) for testing
467
  - Image size 480 provides good balance between quality and speed
 
 
468
  """)
469
 
470
  with gr.Row():
@@ -482,20 +599,68 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
482
  inputs=[epochs_input, batch_size_input, img_size_input, device_input],
483
  outputs=[training_gallery, training_status])
484
 
485
- gr.Markdown("### Load Pre-trained Model")
486
- with gr.Row():
487
- model_path_input = gr.Textbox(label="Model Path", value="./xray_detection/train/weights/best.pt")
488
- load_model_btn = gr.Button("Load Model")
489
- load_status = gr.Textbox(label="Status", interactive=False)
490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  load_model_btn.click(load_pretrained_model, inputs=model_path_input, outputs=load_status)
492
 
493
  with gr.Tab("πŸ” Inference"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  gr.Markdown("### Single Image Inference")
 
495
 
496
  with gr.Row():
497
  with gr.Column():
498
- input_image = gr.Image(type="pil", label="Upload Image")
499
  conf_threshold = gr.Slider(0.1, 0.9, 0.5, step=0.05, label="Confidence Threshold")
500
  inference_btn = gr.Button("Run Detection", variant="primary")
501
 
@@ -508,6 +673,7 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
508
  outputs=[output_image, detection_info])
509
 
510
  gr.Markdown("### Batch Inference")
 
511
 
512
  with gr.Row():
513
  batch_data_type = gr.Dropdown(["test", "valid"], value="test", label="Dataset Type")
 
62
  self.get_bboxes()
63
 
64
  def get_cls_names(self):
65
+ yaml_path = f"{self.root}/data.yaml"
66
+ if not os.path.exists(yaml_path):
67
+ print(f"Warning: {yaml_path} not found")
68
+ self.class_dict = {}
69
+ return
70
+
71
+ with open(yaml_path, 'r') as file:
72
  data = yaml.safe_load(file)
73
+ class_names = data.get('names', [])
74
+ self.class_dict = {index: name for index, name in enumerate(class_names)}
75
+
76
+ # Print class names for debugging
77
+ if self.class_dict:
78
+ print(f"Dataset classes: {', '.join(class_names)}")
79
 
80
  def get_bboxes(self):
81
  self.vis_datas, self.analysis_datas, self.im_paths = {}, {}, {}
 
301
  # Save the model path
302
  model_path = os.path.join(results_path, "weights", "best.pt")
303
 
304
+ # Load the trained model to ensure it's ready for inference
305
+ if os.path.exists(model_path):
306
+ model = YOLO(model_path)
307
+ class_info = f"\nTrained on {len(model.names)} classes: {', '.join(list(model.names.values())[:5])}"
308
+ if len(model.names) > 5:
309
+ class_info += f"... (총 {len(model.names)} 클래슀)"
310
+ else:
311
+ class_info = ""
312
+
313
  training_in_progress = False
314
+
315
+ # Provide instructions for saving the model
316
+ save_instructions = """
317
+
318
+ ⚠️ **Important**: Models trained on Hugging Face Spaces are temporary!
319
+ To save your model permanently:
320
+ 1. Download the model file: `./xray_detection/train/weights/best.pt`
321
+ 2. Upload it to HuggingFace Hub or save locally
322
+ 3. Use 'Load Pre-trained Model' with the saved path
323
+ """
324
+
325
+ return plots, f"Training completed! Model saved to {model_path}{class_info}{save_instructions}"
326
 
327
  except Exception as e:
328
  training_in_progress = False
 
333
  """Run inference on a single image"""
334
  global model
335
 
336
+ # Try to load the trained model if not already loaded
337
  if model is None:
338
+ # First, try to load the trained model
339
+ trained_model_path = "./xray_detection/train/weights/best.pt"
340
+ if os.path.exists(trained_model_path):
341
+ try:
342
+ model = YOLO(trained_model_path)
343
+ print(f"Loaded trained model from {trained_model_path}")
344
+ except:
345
+ pass
346
+
347
+ # If still no model, try to load default
348
+ if model is None:
349
+ try:
350
+ model = YOLO("yolo11n.pt")
351
+ print("Loaded default YOLOv11 model")
352
+ except:
353
+ return None, "Please train the model first or load a pre-trained model!"
354
 
355
  if input_image is None:
356
  return None, "Please upload an image!"
357
 
358
  try:
359
+ # Check if model is trained on X-ray dataset
360
+ model_info = f"Using model with {len(model.names)} classes\n"
361
+
362
  # Save the input image temporarily
363
  temp_path = "temp_inference.jpg"
364
  input_image.save(temp_path)
 
371
 
372
  # Get detection info
373
  detections = []
374
+ if results[0].boxes is not None and len(results[0].boxes) > 0:
375
  for box in results[0].boxes:
376
  cls = int(box.cls)
377
  conf = float(box.conf)
378
  cls_name = model.names[cls]
379
  detections.append(f"{cls_name}: {conf:.2f}")
380
+
381
+ detection_text = model_info + "Detections:\n" + "\n".join(detections)
382
+ else:
383
+ # Check if it's because of wrong model
384
+ if len(model.names) == 80: # COCO dataset has 80 classes
385
+ detection_text = model_info + "No objects detected.\n\n⚠️ Note: This appears to be a general COCO model. For X-ray baggage detection, please train the model on the X-ray dataset first."
386
+ else:
387
+ detection_text = model_info + "No objects detected at this confidence threshold.\nTry lowering the confidence threshold."
388
 
389
  # Clean up
390
  if os.path.exists(temp_path):
391
  os.remove(temp_path)
392
 
393
+ return Image.fromarray(annotated_image), detection_text
 
 
394
 
395
  except Exception as e:
396
  return None, f"Error during inference: {str(e)}"
 
400
  """Run inference on multiple images from test set"""
401
  global model
402
 
403
+ # Try to load the trained model if not already loaded
404
  if model is None:
405
+ trained_model_path = "./xray_detection/train/weights/best.pt"
406
+ if os.path.exists(trained_model_path):
407
+ try:
408
+ model = YOLO(trained_model_path)
409
+ print(f"Loaded trained model for batch inference")
410
+ except:
411
+ try:
412
+ model = YOLO("yolo11n.pt")
413
+ print("Loaded default model for batch inference")
414
+ except:
415
+ return [], "Please train the model first!"
416
+ else:
417
+ return [], "No trained model found. Please train the model first!"
418
 
419
  if dataset_path is None:
420
  return [], "Please download the dataset first!"
 
436
  annotated = results[0].plot()
437
  results_images.append(Image.fromarray(annotated))
438
 
439
+ # Check model type
440
+ model_type = "X-ray detection model" if len(model.names) != 80 else "General COCO model"
441
+
442
+ return results_images, f"Processed {len(results_images)} images from {data_type} dataset using {model_type}"
443
 
444
  except Exception as e:
445
  return [], f"Error during batch inference: {str(e)}"
446
 
447
+ def get_dataset_info():
448
+ """Get information about the X-ray dataset classes"""
449
+ if dataset_path is None:
450
+ return "Dataset not downloaded yet."
451
+
452
+ try:
453
+ yaml_path = f"{dataset_path}/data.yaml"
454
+ if not os.path.exists(yaml_path):
455
+ return "Dataset configuration file not found."
456
+
457
+ with open(yaml_path, 'r') as file:
458
+ data = yaml.safe_load(file)
459
+
460
+ class_names = data.get('names', [])
461
+ num_classes = len(class_names)
462
+
463
+ info = f"**X-ray Baggage Dataset Info:**\n"
464
+ info += f"- Number of classes: {num_classes}\n"
465
+ info += f"- Classes: {', '.join(class_names)}\n"
466
+ info += f"\nThese are the prohibited items that the model will learn to detect."
467
+
468
+ return info
469
+ except Exception as e:
470
+ return f"Error reading dataset info: {str(e)}"
471
  """Load a pre-trained model"""
472
  global model
473
  try:
474
+ # Check if it's a HuggingFace model path
475
+ if model_path.startswith("hf://") or "/" in model_path and not os.path.exists(model_path):
476
+ # Load from HuggingFace Hub
477
+ model = YOLO(model_path)
478
+ return f"Model loaded successfully from HuggingFace: {model_path}"
479
+
480
  if not os.path.exists(model_path):
481
  # Try default paths
482
  default_paths = [
 
489
  model_path = path
490
  break
491
 
492
+ if os.path.exists(model_path):
493
+ model = YOLO(model_path)
494
+ # Check if it's a trained model by looking at class names
495
+ if hasattr(model, 'names') and len(model.names) > 0:
496
+ class_names = ", ".join([f"{i}: {name}" for i, name in model.names.items()][:5])
497
+ if len(model.names) > 5:
498
+ class_names += f"... (총 {len(model.names)} 클래슀)"
499
+ return f"Model loaded successfully from {model_path}\n클래슀: {class_names}"
500
+ return f"Model loaded successfully from {model_path}"
501
+ else:
502
+ return "Model file not found. Please train a model first or provide a valid path."
503
  except Exception as e:
504
  return f"Error loading model: {str(e)}"
505
 
 
526
  ```
527
  KDATA_API={"username":"your_username","key":"your_api_key"}
528
  ```
529
+
530
+ ### Model Persistence on Hugging Face Spaces
531
+ - Models trained on Spaces are **temporary** and will be lost when the Space restarts
532
+ - After training, download your model using the "Download Model" button
533
+ - For permanent storage, upload to HuggingFace Hub or save locally
534
+ - You can load saved models using the "Load Pre-trained Model" feature
535
  """)
536
 
537
  with gr.Tab("πŸ“Š Dataset"):
 
541
 
542
  download_btn.click(download_dataset, outputs=download_status)
543
 
544
+ # Dataset info section
545
+ with gr.Row():
546
+ dataset_info = gr.Markdown(value=get_dataset_info())
547
+ info_btn = gr.Button("πŸ”„ Refresh Dataset Info", scale=0)
548
+
549
+ info_btn.click(get_dataset_info, outputs=dataset_info)
550
+
551
  gr.Markdown("### Visualize Dataset Samples")
552
  with gr.Row():
553
  data_type_viz = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
 
580
  - Use smaller batch sizes (4-8) to avoid GPU memory issues
581
  - Start with fewer epochs (5-10) for testing
582
  - Image size 480 provides good balance between quality and speed
583
+
584
+ ⚠️ **Important**: Models are temporary on Spaces! Download your model after training.
585
  """)
586
 
587
  with gr.Row():
 
599
  inputs=[epochs_input, batch_size_input, img_size_input, device_input],
600
  outputs=[training_gallery, training_status])
601
 
602
+ gr.Markdown("### Model Management")
 
 
 
 
603
 
604
+ with gr.Row():
605
+ with gr.Column():
606
+ gr.Markdown("#### Download Trained Model")
607
+ download_model_btn = gr.Button("Download Model (.pt)", variant="secondary")
608
+ download_file = gr.File(label="Download Model File", visible=False)
609
+
610
+ with gr.Column():
611
+ gr.Markdown("#### Load Pre-trained Model")
612
+ model_path_input = gr.Textbox(
613
+ label="Model Path",
614
+ value="./xray_detection/train/weights/best.pt",
615
+ placeholder="Local path or HuggingFace model ID"
616
+ )
617
+ load_model_btn = gr.Button("Load Model")
618
+ load_status = gr.Textbox(label="Status", interactive=False)
619
+
620
+ # Add download functionality
621
+ def download_trained_model():
622
+ model_path = "./xray_detection/train/weights/best.pt"
623
+ if os.path.exists(model_path):
624
+ return gr.update(value=model_path, visible=True)
625
+ else:
626
+ return gr.update(value=None, visible=True)
627
+
628
+ download_model_btn.click(download_trained_model, outputs=download_file)
629
  load_model_btn.click(load_pretrained_model, inputs=model_path_input, outputs=load_status)
630
 
631
  with gr.Tab("πŸ” Inference"):
632
+ # Model status check
633
+ def check_model_status():
634
+ global model
635
+ if model is None:
636
+ # Try to load trained model
637
+ trained_path = "./xray_detection/train/weights/best.pt"
638
+ if os.path.exists(trained_path):
639
+ try:
640
+ model = YOLO(trained_path)
641
+ return f"βœ… Trained model loaded: {len(model.names)} classes"
642
+ except:
643
+ return "❌ No model loaded. Please train or load a model first."
644
+ return "❌ No model loaded. Please train or load a model first."
645
+ else:
646
+ num_classes = len(model.names)
647
+ if num_classes == 80:
648
+ return f"⚠️ Default COCO model loaded ({num_classes} classes). For X-ray detection, please train on the X-ray dataset."
649
+ else:
650
+ return f"βœ… Model loaded: {num_classes} classes - {', '.join(list(model.names.values())[:5])}..."
651
+
652
+ with gr.Row():
653
+ model_status = gr.Textbox(label="Model Status", value=check_model_status(), interactive=False)
654
+ refresh_status_btn = gr.Button("πŸ”„ Refresh Status", scale=0)
655
+
656
+ refresh_status_btn.click(check_model_status, outputs=model_status)
657
+
658
  gr.Markdown("### Single Image Inference")
659
+ gr.Markdown("Upload an X-ray baggage image to detect prohibited items.")
660
 
661
  with gr.Row():
662
  with gr.Column():
663
+ input_image = gr.Image(type="pil", label="Upload X-ray Image")
664
  conf_threshold = gr.Slider(0.1, 0.9, 0.5, step=0.05, label="Confidence Threshold")
665
  inference_btn = gr.Button("Run Detection", variant="primary")
666
 
 
673
  outputs=[output_image, detection_info])
674
 
675
  gr.Markdown("### Batch Inference")
676
+ gr.Markdown("Run detection on multiple images from the test dataset.")
677
 
678
  with gr.Row():
679
  batch_data_type = gr.Dropdown(["test", "valid"], value="test", label="Dataset Type")