openfree commited on
Commit
8cb877d
Β·
verified Β·
1 Parent(s): 5bd2bc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -53
app.py CHANGED
@@ -125,10 +125,8 @@ class Visualization:
125
  cv2.putText(or_im, f"Objects: {len(bboxes)}", (10, 30),
126
  cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
127
 
128
- # Convert BGR to RGB if needed
129
- if len(or_im.shape) == 3 and or_im.shape[2] == 3:
130
- or_im = cv2.cvtColor(or_im, cv2.COLOR_BGR2RGB)
131
-
132
  return Image.fromarray(or_im)
133
 
134
  def vis_samples(self, data_type, n_samples=4):
@@ -259,8 +257,12 @@ def train_model(epochs, batch_size, img_size, device_selection):
259
  else:
260
  device = 0 if torch.cuda.is_available() else "cpu"
261
 
262
- # Initialize model
263
- model = YOLO("yolo11n.pt")
 
 
 
 
264
 
265
  # Create project directory
266
  project_dir = "./xray_detection"
@@ -350,14 +352,23 @@ def run_inference(input_image, conf_threshold):
350
  model = YOLO("yolo11n.pt")
351
  print("Loaded default YOLOv11 model")
352
  except:
353
- return None, "Please train the model first or load a pre-trained model!"
 
 
 
 
 
354
 
355
  if input_image is None:
356
  return None, "Please upload an image!"
357
 
358
  try:
359
  # Check if model is trained on X-ray dataset
360
- model_info = f"Using model with {len(model.names)} classes\n"
 
 
 
 
361
 
362
  # Save the input image temporarily
363
  temp_path = "temp_inference.jpg"
@@ -375,16 +386,22 @@ def run_inference(input_image, conf_threshold):
375
  for box in results[0].boxes:
376
  cls = int(box.cls)
377
  conf = float(box.conf)
378
- cls_name = model.names[cls]
 
 
 
379
  detections.append(f"{cls_name}: {conf:.2f}")
380
 
381
  detection_text = model_info + "Detections:\n" + "\n".join(detections)
382
  else:
383
  # Check if it's because of wrong model
384
- if len(model.names) == 80: # COCO dataset has 80 classes
385
- detection_text = model_info + "No objects detected.\n\n⚠️ Note: This appears to be a general COCO model. For X-ray baggage detection, please train the model on the X-ray dataset first."
386
- else:
387
- detection_text = model_info + "No objects detected at this confidence threshold.\nTry lowering the confidence threshold."
 
 
 
388
 
389
  # Clean up
390
  if os.path.exists(temp_path):
@@ -412,7 +429,11 @@ def batch_inference(data_type, num_images):
412
  model = YOLO("yolo11n.pt")
413
  print("Loaded default model for batch inference")
414
  except:
415
- return [], "Please train the model first!"
 
 
 
 
416
  else:
417
  return [], "No trained model found. Please train the model first!"
418
 
@@ -437,7 +458,10 @@ def batch_inference(data_type, num_images):
437
  results_images.append(Image.fromarray(annotated))
438
 
439
  # Check model type
440
- model_type = "X-ray detection model" if len(model.names) != 80 else "General COCO model"
 
 
 
441
 
442
  return results_images, f"Processed {len(results_images)} images from {data_type} dataset using {model_type}"
443
 
@@ -482,7 +506,8 @@ def get_dataset_info():
482
  default_paths = [
483
  "./xray_detection/train/weights/best.pt",
484
  "./xray_detection/train/weights/last.pt",
485
- "yolo11n.pt"
 
486
  ]
487
  for path in default_paths:
488
  if os.path.exists(path):
@@ -492,26 +517,55 @@ def get_dataset_info():
492
  if os.path.exists(model_path):
493
  model = YOLO(model_path)
494
  # Check if it's a trained model by looking at class names
495
- if hasattr(model, 'names') and len(model.names) > 0:
496
- class_names = ", ".join([f"{i}: {name}" for i, name in model.names.items()][:5])
497
- if len(model.names) > 5:
498
- class_names += f"... (총 {len(model.names)} 클래슀)"
499
- return f"Model loaded successfully from {model_path}\n클래슀: {class_names}"
 
 
 
500
  return f"Model loaded successfully from {model_path}"
501
  else:
502
  return "Model file not found. Please train a model first or provide a valid path."
503
  except Exception as e:
504
  return f"Error loading model: {str(e)}"
505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  # Create Gradio interface
507
  with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft()) as demo:
508
  gr.Markdown("""
509
- # 🎯 X-ray Baggage Anomaly Detection with YOLOv11
510
 
511
  This application allows you to:
512
  1. Download and visualize the X-ray baggage dataset
513
  2. Analyze class distributions
514
- 3. Train a YOLOv11 model for object detection
515
  4. Run inference on new images
516
 
517
  **Note:** GPU will be automatically allocated when needed for training and inference.
@@ -543,10 +597,36 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
543
 
544
  # Dataset info section
545
  with gr.Row():
546
- dataset_info = gr.Markdown(value=get_dataset_info())
547
  info_btn = gr.Button("πŸ”„ Refresh Dataset Info", scale=0)
548
 
549
- info_btn.click(get_dataset_info, outputs=dataset_info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
 
551
  gr.Markdown("### Visualize Dataset Samples")
552
  with gr.Row():
@@ -572,7 +652,7 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
572
  outputs=[distribution_plot, analysis_status])
573
 
574
  with gr.Tab("πŸš€ Training"):
575
- gr.Markdown("### Train YOLOv11 Model")
576
  gr.Markdown("""
577
  **Note:** Training will automatically use GPU if available. This may take several minutes.
578
 
@@ -601,6 +681,14 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
601
 
602
  gr.Markdown("### Model Management")
603
 
 
 
 
 
 
 
 
 
604
  with gr.Row():
605
  with gr.Column():
606
  gr.Markdown("#### Download Trained Model")
@@ -617,38 +705,11 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
617
  load_model_btn = gr.Button("Load Model")
618
  load_status = gr.Textbox(label="Status", interactive=False)
619
 
620
- # Add download functionality
621
- def download_trained_model():
622
- model_path = "./xray_detection/train/weights/best.pt"
623
- if os.path.exists(model_path):
624
- return gr.update(value=model_path, visible=True)
625
- else:
626
- return gr.update(value=None, visible=True)
627
-
628
  download_model_btn.click(download_trained_model, outputs=download_file)
629
  load_model_btn.click(load_pretrained_model, inputs=model_path_input, outputs=load_status)
630
 
631
  with gr.Tab("πŸ” Inference"):
632
  # Model status check
633
- def check_model_status():
634
- global model
635
- if model is None:
636
- # Try to load trained model
637
- trained_path = "./xray_detection/train/weights/best.pt"
638
- if os.path.exists(trained_path):
639
- try:
640
- model = YOLO(trained_path)
641
- return f"βœ… Trained model loaded: {len(model.names)} classes"
642
- except:
643
- return "❌ No model loaded. Please train or load a model first."
644
- return "❌ No model loaded. Please train or load a model first."
645
- else:
646
- num_classes = len(model.names)
647
- if num_classes == 80:
648
- return f"⚠️ Default COCO model loaded ({num_classes} classes). For X-ray detection, please train on the X-ray dataset."
649
- else:
650
- return f"βœ… Model loaded: {num_classes} classes - {', '.join(list(model.names.values())[:5])}..."
651
-
652
  with gr.Row():
653
  model_status = gr.Textbox(label="Model Status", value=check_model_status(), interactive=False)
654
  refresh_status_btn = gr.Button("πŸ”„ Refresh Status", scale=0)
 
125
  cv2.putText(or_im, f"Objects: {len(bboxes)}", (10, 30),
126
  cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
127
 
128
+ # OpenCV uses BGR, but PIL expects RGB, and we already loaded as RGB
129
+ # So no conversion needed
 
 
130
  return Image.fromarray(or_im)
131
 
132
  def vis_samples(self, data_type, n_samples=4):
 
257
  else:
258
  device = 0 if torch.cuda.is_available() else "cpu"
259
 
260
+ # Initialize model - use yolov8n if yolo11n not available
261
+ try:
262
+ model = YOLO("yolo11n.pt")
263
+ except Exception as e:
264
+ print(f"YOLOv11 not available: {e}, falling back to YOLOv8")
265
+ model = YOLO("yolov8n.pt") # Fallback to YOLOv8
266
 
267
  # Create project directory
268
  project_dir = "./xray_detection"
 
352
  model = YOLO("yolo11n.pt")
353
  print("Loaded default YOLOv11 model")
354
  except:
355
+ try:
356
+ model = YOLO("yolov8n.pt")
357
+ print("Loaded YOLOv8 model as fallback")
358
+ except Exception as e:
359
+ print(f"Failed to load default model: {e}")
360
+ return None, "Please train the model first or load a pre-trained model!"
361
 
362
  if input_image is None:
363
  return None, "Please upload an image!"
364
 
365
  try:
366
  # Check if model is trained on X-ray dataset
367
+ model_info = ""
368
+ try:
369
+ model_info = f"Using model with {len(model.names)} classes\n"
370
+ except:
371
+ model_info = "Using loaded model\n"
372
 
373
  # Save the input image temporarily
374
  temp_path = "temp_inference.jpg"
 
386
  for box in results[0].boxes:
387
  cls = int(box.cls)
388
  conf = float(box.conf)
389
+ try:
390
+ cls_name = model.names[cls]
391
+ except:
392
+ cls_name = f"Class {cls}"
393
  detections.append(f"{cls_name}: {conf:.2f}")
394
 
395
  detection_text = model_info + "Detections:\n" + "\n".join(detections)
396
  else:
397
  # Check if it's because of wrong model
398
+ try:
399
+ if len(model.names) == 80: # COCO dataset has 80 classes
400
+ detection_text = model_info + "No objects detected.\n\n⚠️ Note: This appears to be a general COCO model. For X-ray baggage detection, please train the model on the X-ray dataset first."
401
+ else:
402
+ detection_text = model_info + "No objects detected at this confidence threshold.\nTry lowering the confidence threshold."
403
+ except:
404
+ detection_text = model_info + "No objects detected."
405
 
406
  # Clean up
407
  if os.path.exists(temp_path):
 
429
  model = YOLO("yolo11n.pt")
430
  print("Loaded default model for batch inference")
431
  except:
432
+ try:
433
+ model = YOLO("yolov8n.pt")
434
+ print("Loaded YOLOv8 model as fallback for batch inference")
435
+ except:
436
+ return [], "Please train the model first!"
437
  else:
438
  return [], "No trained model found. Please train the model first!"
439
 
 
458
  results_images.append(Image.fromarray(annotated))
459
 
460
  # Check model type
461
+ try:
462
+ model_type = "X-ray detection model" if len(model.names) != 80 else "General COCO model"
463
+ except:
464
+ model_type = "Loaded model"
465
 
466
  return results_images, f"Processed {len(results_images)} images from {data_type} dataset using {model_type}"
467
 
 
506
  default_paths = [
507
  "./xray_detection/train/weights/best.pt",
508
  "./xray_detection/train/weights/last.pt",
509
+ "yolo11n.pt",
510
+ "yolov8n.pt"
511
  ]
512
  for path in default_paths:
513
  if os.path.exists(path):
 
517
  if os.path.exists(model_path):
518
  model = YOLO(model_path)
519
  # Check if it's a trained model by looking at class names
520
+ try:
521
+ if hasattr(model, 'names') and len(model.names) > 0:
522
+ class_names = ", ".join([f"{i}: {name}" for i, name in model.names.items()][:5])
523
+ if len(model.names) > 5:
524
+ class_names += f"... (총 {len(model.names)} 클래슀)"
525
+ return f"Model loaded successfully from {model_path}\n클래슀: {class_names}"
526
+ except:
527
+ pass
528
  return f"Model loaded successfully from {model_path}"
529
  else:
530
  return "Model file not found. Please train a model first or provide a valid path."
531
  except Exception as e:
532
  return f"Error loading model: {str(e)}"
533
 
534
+ def check_model_status():
535
+ """Check current model status"""
536
+ global model
537
+ if model is None:
538
+ # Try to load trained model
539
+ trained_path = "./xray_detection/train/weights/best.pt"
540
+ if os.path.exists(trained_path):
541
+ try:
542
+ model = YOLO(trained_path)
543
+ return f"βœ… Trained model loaded: {len(model.names)} classes"
544
+ except:
545
+ return "❌ No model loaded. Please train or load a model first."
546
+ return "❌ No model loaded. Please train or load a model first."
547
+ else:
548
+ try:
549
+ num_classes = len(model.names)
550
+ if num_classes == 80:
551
+ return f"⚠️ Default COCO model loaded ({num_classes} classes). For X-ray detection, please train on the X-ray dataset."
552
+ else:
553
+ class_names = ', '.join(list(model.names.values())[:5])
554
+ if len(model.names) > 5:
555
+ class_names += "..."
556
+ return f"βœ… Model loaded: {num_classes} classes - {class_names}"
557
+ except:
558
+ return "βœ… Model loaded"
559
+
560
  # Create Gradio interface
561
  with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft()) as demo:
562
  gr.Markdown("""
563
+ # 🎯 X-ray Baggage Anomaly Detection with YOLO
564
 
565
  This application allows you to:
566
  1. Download and visualize the X-ray baggage dataset
567
  2. Analyze class distributions
568
+ 3. Train a YOLO model for object detection
569
  4. Run inference on new images
570
 
571
  **Note:** GPU will be automatically allocated when needed for training and inference.
 
597
 
598
  # Dataset info section
599
  with gr.Row():
600
+ dataset_info = gr.Markdown(value="Dataset not downloaded yet.")
601
  info_btn = gr.Button("πŸ”„ Refresh Dataset Info", scale=0)
602
 
603
+ def update_dataset_info():
604
+ return get_dataset_info()
605
+
606
+ info_btn.click(update_dataset_info, outputs=dataset_info)
607
+
608
+ gr.Markdown("### Visualize Dataset Samples")
609
+ with gr.Row():
610
+ data_type_viz = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
611
+ num_samples = gr.Slider(1, 8, 4, step=1, label="Number of Samples")
612
+ viz_btn = gr.Button("Visualize Samples")
613
+
614
+ viz_gallery = gr.Gallery(label="Sample Images", columns=2, height="auto")
615
+ viz_status = gr.Textbox(label="Status", interactive=False)
616
+
617
+ viz_btn.click(visualize_data, inputs=[data_type_viz, num_samples],
618
+ outputs=[viz_gallery, viz_status])
619
+
620
+ gr.Markdown("### Analyze Class Distribution")
621
+ with gr.Row():
622
+ data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
623
+ analyze_btn = gr.Button("Analyze Distribution")
624
+
625
+ distribution_plot = gr.Image(label="Class Distribution", type="pil")
626
+ analysis_status = gr.Textbox(label="Status", interactive=False)
627
+
628
+ analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis,
629
+ outputs=[distribution_plot, analysis_status])
630
 
631
  gr.Markdown("### Visualize Dataset Samples")
632
  with gr.Row():
 
652
  outputs=[distribution_plot, analysis_status])
653
 
654
  with gr.Tab("πŸš€ Training"):
655
+ gr.Markdown("### Train YOLO Model")
656
  gr.Markdown("""
657
  **Note:** Training will automatically use GPU if available. This may take several minutes.
658
 
 
681
 
682
  gr.Markdown("### Model Management")
683
 
684
+ # Define download_trained_model function here
685
+ def download_trained_model():
686
+ model_path = "./xray_detection/train/weights/best.pt"
687
+ if os.path.exists(model_path):
688
+ return gr.update(value=model_path, visible=True)
689
+ else:
690
+ return gr.update(value=None, visible=True)
691
+
692
  with gr.Row():
693
  with gr.Column():
694
  gr.Markdown("#### Download Trained Model")
 
705
  load_model_btn = gr.Button("Load Model")
706
  load_status = gr.Textbox(label="Status", interactive=False)
707
 
 
 
 
 
 
 
 
 
708
  download_model_btn.click(download_trained_model, outputs=download_file)
709
  load_model_btn.click(load_pretrained_model, inputs=model_path_input, outputs=load_status)
710
 
711
  with gr.Tab("πŸ” Inference"):
712
  # Model status check
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  with gr.Row():
714
  model_status = gr.Textbox(label="Model Status", value=check_model_status(), interactive=False)
715
  refresh_status_btn = gr.Button("πŸ”„ Refresh Status", scale=0)