Update app.py
Browse files
app.py
CHANGED
|
@@ -125,10 +125,8 @@ class Visualization:
|
|
| 125 |
cv2.putText(or_im, f"Objects: {len(bboxes)}", (10, 30),
|
| 126 |
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
| 127 |
|
| 128 |
-
#
|
| 129 |
-
|
| 130 |
-
or_im = cv2.cvtColor(or_im, cv2.COLOR_BGR2RGB)
|
| 131 |
-
|
| 132 |
return Image.fromarray(or_im)
|
| 133 |
|
| 134 |
def vis_samples(self, data_type, n_samples=4):
|
|
@@ -259,8 +257,12 @@ def train_model(epochs, batch_size, img_size, device_selection):
|
|
| 259 |
else:
|
| 260 |
device = 0 if torch.cuda.is_available() else "cpu"
|
| 261 |
|
| 262 |
-
# Initialize model
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
# Create project directory
|
| 266 |
project_dir = "./xray_detection"
|
|
@@ -350,14 +352,23 @@ def run_inference(input_image, conf_threshold):
|
|
| 350 |
model = YOLO("yolo11n.pt")
|
| 351 |
print("Loaded default YOLOv11 model")
|
| 352 |
except:
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
if input_image is None:
|
| 356 |
return None, "Please upload an image!"
|
| 357 |
|
| 358 |
try:
|
| 359 |
# Check if model is trained on X-ray dataset
|
| 360 |
-
model_info =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
# Save the input image temporarily
|
| 363 |
temp_path = "temp_inference.jpg"
|
|
@@ -375,16 +386,22 @@ def run_inference(input_image, conf_threshold):
|
|
| 375 |
for box in results[0].boxes:
|
| 376 |
cls = int(box.cls)
|
| 377 |
conf = float(box.conf)
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
| 379 |
detections.append(f"{cls_name}: {conf:.2f}")
|
| 380 |
|
| 381 |
detection_text = model_info + "Detections:\n" + "\n".join(detections)
|
| 382 |
else:
|
| 383 |
# Check if it's because of wrong model
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
# Clean up
|
| 390 |
if os.path.exists(temp_path):
|
|
@@ -412,7 +429,11 @@ def batch_inference(data_type, num_images):
|
|
| 412 |
model = YOLO("yolo11n.pt")
|
| 413 |
print("Loaded default model for batch inference")
|
| 414 |
except:
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
else:
|
| 417 |
return [], "No trained model found. Please train the model first!"
|
| 418 |
|
|
@@ -437,7 +458,10 @@ def batch_inference(data_type, num_images):
|
|
| 437 |
results_images.append(Image.fromarray(annotated))
|
| 438 |
|
| 439 |
# Check model type
|
| 440 |
-
|
|
|
|
|
|
|
|
|
|
| 441 |
|
| 442 |
return results_images, f"Processed {len(results_images)} images from {data_type} dataset using {model_type}"
|
| 443 |
|
|
@@ -482,7 +506,8 @@ def get_dataset_info():
|
|
| 482 |
default_paths = [
|
| 483 |
"./xray_detection/train/weights/best.pt",
|
| 484 |
"./xray_detection/train/weights/last.pt",
|
| 485 |
-
"yolo11n.pt"
|
|
|
|
| 486 |
]
|
| 487 |
for path in default_paths:
|
| 488 |
if os.path.exists(path):
|
|
@@ -492,26 +517,55 @@ def get_dataset_info():
|
|
| 492 |
if os.path.exists(model_path):
|
| 493 |
model = YOLO(model_path)
|
| 494 |
# Check if it's a trained model by looking at class names
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
| 500 |
return f"Model loaded successfully from {model_path}"
|
| 501 |
else:
|
| 502 |
return "Model file not found. Please train a model first or provide a valid path."
|
| 503 |
except Exception as e:
|
| 504 |
return f"Error loading model: {str(e)}"
|
| 505 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
# Create Gradio interface
|
| 507 |
with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft()) as demo:
|
| 508 |
gr.Markdown("""
|
| 509 |
-
# π― X-ray Baggage Anomaly Detection with
|
| 510 |
|
| 511 |
This application allows you to:
|
| 512 |
1. Download and visualize the X-ray baggage dataset
|
| 513 |
2. Analyze class distributions
|
| 514 |
-
3. Train a
|
| 515 |
4. Run inference on new images
|
| 516 |
|
| 517 |
**Note:** GPU will be automatically allocated when needed for training and inference.
|
|
@@ -543,10 +597,36 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
|
|
| 543 |
|
| 544 |
# Dataset info section
|
| 545 |
with gr.Row():
|
| 546 |
-
dataset_info = gr.Markdown(value=
|
| 547 |
info_btn = gr.Button("π Refresh Dataset Info", scale=0)
|
| 548 |
|
| 549 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
|
| 551 |
gr.Markdown("### Visualize Dataset Samples")
|
| 552 |
with gr.Row():
|
|
@@ -572,7 +652,7 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
|
|
| 572 |
outputs=[distribution_plot, analysis_status])
|
| 573 |
|
| 574 |
with gr.Tab("π Training"):
|
| 575 |
-
gr.Markdown("### Train
|
| 576 |
gr.Markdown("""
|
| 577 |
**Note:** Training will automatically use GPU if available. This may take several minutes.
|
| 578 |
|
|
@@ -601,6 +681,14 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
|
|
| 601 |
|
| 602 |
gr.Markdown("### Model Management")
|
| 603 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
with gr.Row():
|
| 605 |
with gr.Column():
|
| 606 |
gr.Markdown("#### Download Trained Model")
|
|
@@ -617,38 +705,11 @@ with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft())
|
|
| 617 |
load_model_btn = gr.Button("Load Model")
|
| 618 |
load_status = gr.Textbox(label="Status", interactive=False)
|
| 619 |
|
| 620 |
-
# Add download functionality
|
| 621 |
-
def download_trained_model():
|
| 622 |
-
model_path = "./xray_detection/train/weights/best.pt"
|
| 623 |
-
if os.path.exists(model_path):
|
| 624 |
-
return gr.update(value=model_path, visible=True)
|
| 625 |
-
else:
|
| 626 |
-
return gr.update(value=None, visible=True)
|
| 627 |
-
|
| 628 |
download_model_btn.click(download_trained_model, outputs=download_file)
|
| 629 |
load_model_btn.click(load_pretrained_model, inputs=model_path_input, outputs=load_status)
|
| 630 |
|
| 631 |
with gr.Tab("π Inference"):
|
| 632 |
# Model status check
|
| 633 |
-
def check_model_status():
|
| 634 |
-
global model
|
| 635 |
-
if model is None:
|
| 636 |
-
# Try to load trained model
|
| 637 |
-
trained_path = "./xray_detection/train/weights/best.pt"
|
| 638 |
-
if os.path.exists(trained_path):
|
| 639 |
-
try:
|
| 640 |
-
model = YOLO(trained_path)
|
| 641 |
-
return f"β
Trained model loaded: {len(model.names)} classes"
|
| 642 |
-
except:
|
| 643 |
-
return "β No model loaded. Please train or load a model first."
|
| 644 |
-
return "β No model loaded. Please train or load a model first."
|
| 645 |
-
else:
|
| 646 |
-
num_classes = len(model.names)
|
| 647 |
-
if num_classes == 80:
|
| 648 |
-
return f"β οΈ Default COCO model loaded ({num_classes} classes). For X-ray detection, please train on the X-ray dataset."
|
| 649 |
-
else:
|
| 650 |
-
return f"β
Model loaded: {num_classes} classes - {', '.join(list(model.names.values())[:5])}..."
|
| 651 |
-
|
| 652 |
with gr.Row():
|
| 653 |
model_status = gr.Textbox(label="Model Status", value=check_model_status(), interactive=False)
|
| 654 |
refresh_status_btn = gr.Button("π Refresh Status", scale=0)
|
|
|
|
| 125 |
cv2.putText(or_im, f"Objects: {len(bboxes)}", (10, 30),
|
| 126 |
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
| 127 |
|
| 128 |
+
# OpenCV uses BGR, but PIL expects RGB, and we already loaded as RGB
|
| 129 |
+
# So no conversion needed
|
|
|
|
|
|
|
| 130 |
return Image.fromarray(or_im)
|
| 131 |
|
| 132 |
def vis_samples(self, data_type, n_samples=4):
|
|
|
|
| 257 |
else:
|
| 258 |
device = 0 if torch.cuda.is_available() else "cpu"
|
| 259 |
|
| 260 |
+
# Initialize model - use yolov8n if yolo11n not available
|
| 261 |
+
try:
|
| 262 |
+
model = YOLO("yolo11n.pt")
|
| 263 |
+
except Exception as e:
|
| 264 |
+
print(f"YOLOv11 not available: {e}, falling back to YOLOv8")
|
| 265 |
+
model = YOLO("yolov8n.pt") # Fallback to YOLOv8
|
| 266 |
|
| 267 |
# Create project directory
|
| 268 |
project_dir = "./xray_detection"
|
|
|
|
| 352 |
model = YOLO("yolo11n.pt")
|
| 353 |
print("Loaded default YOLOv11 model")
|
| 354 |
except:
|
| 355 |
+
try:
|
| 356 |
+
model = YOLO("yolov8n.pt")
|
| 357 |
+
print("Loaded YOLOv8 model as fallback")
|
| 358 |
+
except Exception as e:
|
| 359 |
+
print(f"Failed to load default model: {e}")
|
| 360 |
+
return None, "Please train the model first or load a pre-trained model!"
|
| 361 |
|
| 362 |
if input_image is None:
|
| 363 |
return None, "Please upload an image!"
|
| 364 |
|
| 365 |
try:
|
| 366 |
# Check if model is trained on X-ray dataset
|
| 367 |
+
model_info = ""
|
| 368 |
+
try:
|
| 369 |
+
model_info = f"Using model with {len(model.names)} classes\n"
|
| 370 |
+
except:
|
| 371 |
+
model_info = "Using loaded model\n"
|
| 372 |
|
| 373 |
# Save the input image temporarily
|
| 374 |
temp_path = "temp_inference.jpg"
|
|
|
|
| 386 |
for box in results[0].boxes:
|
| 387 |
cls = int(box.cls)
|
| 388 |
conf = float(box.conf)
|
| 389 |
+
try:
|
| 390 |
+
cls_name = model.names[cls]
|
| 391 |
+
except:
|
| 392 |
+
cls_name = f"Class {cls}"
|
| 393 |
detections.append(f"{cls_name}: {conf:.2f}")
|
| 394 |
|
| 395 |
detection_text = model_info + "Detections:\n" + "\n".join(detections)
|
| 396 |
else:
|
| 397 |
# Check if it's because of wrong model
|
| 398 |
+
try:
|
| 399 |
+
if len(model.names) == 80: # COCO dataset has 80 classes
|
| 400 |
+
detection_text = model_info + "No objects detected.\n\nβ οΈ Note: This appears to be a general COCO model. For X-ray baggage detection, please train the model on the X-ray dataset first."
|
| 401 |
+
else:
|
| 402 |
+
detection_text = model_info + "No objects detected at this confidence threshold.\nTry lowering the confidence threshold."
|
| 403 |
+
except:
|
| 404 |
+
detection_text = model_info + "No objects detected."
|
| 405 |
|
| 406 |
# Clean up
|
| 407 |
if os.path.exists(temp_path):
|
|
|
|
| 429 |
model = YOLO("yolo11n.pt")
|
| 430 |
print("Loaded default model for batch inference")
|
| 431 |
except:
|
| 432 |
+
try:
|
| 433 |
+
model = YOLO("yolov8n.pt")
|
| 434 |
+
print("Loaded YOLOv8 model as fallback for batch inference")
|
| 435 |
+
except:
|
| 436 |
+
return [], "Please train the model first!"
|
| 437 |
else:
|
| 438 |
return [], "No trained model found. Please train the model first!"
|
| 439 |
|
|
|
|
| 458 |
results_images.append(Image.fromarray(annotated))
|
| 459 |
|
| 460 |
# Check model type
|
| 461 |
+
try:
|
| 462 |
+
model_type = "X-ray detection model" if len(model.names) != 80 else "General COCO model"
|
| 463 |
+
except:
|
| 464 |
+
model_type = "Loaded model"
|
| 465 |
|
| 466 |
return results_images, f"Processed {len(results_images)} images from {data_type} dataset using {model_type}"
|
| 467 |
|
|
|
|
| 506 |
default_paths = [
|
| 507 |
"./xray_detection/train/weights/best.pt",
|
| 508 |
"./xray_detection/train/weights/last.pt",
|
| 509 |
+
"yolo11n.pt",
|
| 510 |
+
"yolov8n.pt"
|
| 511 |
]
|
| 512 |
for path in default_paths:
|
| 513 |
if os.path.exists(path):
|
|
|
|
| 517 |
if os.path.exists(model_path):
|
| 518 |
model = YOLO(model_path)
|
| 519 |
# Check if it's a trained model by looking at class names
|
| 520 |
+
try:
|
| 521 |
+
if hasattr(model, 'names') and len(model.names) > 0:
|
| 522 |
+
class_names = ", ".join([f"{i}: {name}" for i, name in model.names.items()][:5])
|
| 523 |
+
if len(model.names) > 5:
|
| 524 |
+
class_names += f"... (μ΄ {len(model.names)} ν΄λμ€)"
|
| 525 |
+
return f"Model loaded successfully from {model_path}\nν΄λμ€: {class_names}"
|
| 526 |
+
except:
|
| 527 |
+
pass
|
| 528 |
return f"Model loaded successfully from {model_path}"
|
| 529 |
else:
|
| 530 |
return "Model file not found. Please train a model first or provide a valid path."
|
| 531 |
except Exception as e:
|
| 532 |
return f"Error loading model: {str(e)}"
|
| 533 |
|
| 534 |
+
def check_model_status():
|
| 535 |
+
"""Check current model status"""
|
| 536 |
+
global model
|
| 537 |
+
if model is None:
|
| 538 |
+
# Try to load trained model
|
| 539 |
+
trained_path = "./xray_detection/train/weights/best.pt"
|
| 540 |
+
if os.path.exists(trained_path):
|
| 541 |
+
try:
|
| 542 |
+
model = YOLO(trained_path)
|
| 543 |
+
return f"β
Trained model loaded: {len(model.names)} classes"
|
| 544 |
+
except:
|
| 545 |
+
return "β No model loaded. Please train or load a model first."
|
| 546 |
+
return "β No model loaded. Please train or load a model first."
|
| 547 |
+
else:
|
| 548 |
+
try:
|
| 549 |
+
num_classes = len(model.names)
|
| 550 |
+
if num_classes == 80:
|
| 551 |
+
return f"β οΈ Default COCO model loaded ({num_classes} classes). For X-ray detection, please train on the X-ray dataset."
|
| 552 |
+
else:
|
| 553 |
+
class_names = ', '.join(list(model.names.values())[:5])
|
| 554 |
+
if len(model.names) > 5:
|
| 555 |
+
class_names += "..."
|
| 556 |
+
return f"β
Model loaded: {num_classes} classes - {class_names}"
|
| 557 |
+
except:
|
| 558 |
+
return "β
Model loaded"
|
| 559 |
+
|
| 560 |
# Create Gradio interface
|
| 561 |
with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft()) as demo:
|
| 562 |
gr.Markdown("""
|
| 563 |
+
# π― X-ray Baggage Anomaly Detection with YOLO
|
| 564 |
|
| 565 |
This application allows you to:
|
| 566 |
1. Download and visualize the X-ray baggage dataset
|
| 567 |
2. Analyze class distributions
|
| 568 |
+
3. Train a YOLO model for object detection
|
| 569 |
4. Run inference on new images
|
| 570 |
|
| 571 |
**Note:** GPU will be automatically allocated when needed for training and inference.
|
|
|
|
| 597 |
|
| 598 |
# Dataset info section
|
| 599 |
with gr.Row():
|
| 600 |
+
dataset_info = gr.Markdown(value="Dataset not downloaded yet.")
|
| 601 |
info_btn = gr.Button("π Refresh Dataset Info", scale=0)
|
| 602 |
|
| 603 |
+
def update_dataset_info():
|
| 604 |
+
return get_dataset_info()
|
| 605 |
+
|
| 606 |
+
info_btn.click(update_dataset_info, outputs=dataset_info)
|
| 607 |
+
|
| 608 |
+
gr.Markdown("### Visualize Dataset Samples")
|
| 609 |
+
with gr.Row():
|
| 610 |
+
data_type_viz = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
|
| 611 |
+
num_samples = gr.Slider(1, 8, 4, step=1, label="Number of Samples")
|
| 612 |
+
viz_btn = gr.Button("Visualize Samples")
|
| 613 |
+
|
| 614 |
+
viz_gallery = gr.Gallery(label="Sample Images", columns=2, height="auto")
|
| 615 |
+
viz_status = gr.Textbox(label="Status", interactive=False)
|
| 616 |
+
|
| 617 |
+
viz_btn.click(visualize_data, inputs=[data_type_viz, num_samples],
|
| 618 |
+
outputs=[viz_gallery, viz_status])
|
| 619 |
+
|
| 620 |
+
gr.Markdown("### Analyze Class Distribution")
|
| 621 |
+
with gr.Row():
|
| 622 |
+
data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
|
| 623 |
+
analyze_btn = gr.Button("Analyze Distribution")
|
| 624 |
+
|
| 625 |
+
distribution_plot = gr.Image(label="Class Distribution", type="pil")
|
| 626 |
+
analysis_status = gr.Textbox(label="Status", interactive=False)
|
| 627 |
+
|
| 628 |
+
analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis,
|
| 629 |
+
outputs=[distribution_plot, analysis_status])
|
| 630 |
|
| 631 |
gr.Markdown("### Visualize Dataset Samples")
|
| 632 |
with gr.Row():
|
|
|
|
| 652 |
outputs=[distribution_plot, analysis_status])
|
| 653 |
|
| 654 |
with gr.Tab("π Training"):
|
| 655 |
+
gr.Markdown("### Train YOLO Model")
|
| 656 |
gr.Markdown("""
|
| 657 |
**Note:** Training will automatically use GPU if available. This may take several minutes.
|
| 658 |
|
|
|
|
| 681 |
|
| 682 |
gr.Markdown("### Model Management")
|
| 683 |
|
| 684 |
+
# Define download_trained_model function here
|
| 685 |
+
def download_trained_model():
|
| 686 |
+
model_path = "./xray_detection/train/weights/best.pt"
|
| 687 |
+
if os.path.exists(model_path):
|
| 688 |
+
return gr.update(value=model_path, visible=True)
|
| 689 |
+
else:
|
| 690 |
+
return gr.update(value=None, visible=True)
|
| 691 |
+
|
| 692 |
with gr.Row():
|
| 693 |
with gr.Column():
|
| 694 |
gr.Markdown("#### Download Trained Model")
|
|
|
|
| 705 |
load_model_btn = gr.Button("Load Model")
|
| 706 |
load_status = gr.Textbox(label="Status", interactive=False)
|
| 707 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
download_model_btn.click(download_trained_model, outputs=download_file)
|
| 709 |
load_model_btn.click(load_pretrained_model, inputs=model_path_input, outputs=load_status)
|
| 710 |
|
| 711 |
with gr.Tab("π Inference"):
|
| 712 |
# Model status check
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
with gr.Row():
|
| 714 |
model_status = gr.Textbox(label="Model Status", value=check_model_status(), interactive=False)
|
| 715 |
refresh_status_btn = gr.Button("π Refresh Status", scale=0)
|