XRAY / app.py
openfree's picture
Update app.py
df57a34 verified
import os
# Set environment variables for Spaces compatibility
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
import cv2
import yaml
import torch
import random
import gradio as gr
import numpy as np
import kagglehub
from PIL import Image
from glob import glob
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
from matplotlib import patches
from torchvision import transforms as T
from ultralytics import YOLO
import shutil
import tempfile
from pathlib import Path
import json
from io import BytesIO
# Try to import spaces for Hugging Face Spaces GPU support
try:
import spaces
ON_SPACES = True
except ImportError:
ON_SPACES = False
# Create a dummy decorator if not on Spaces
class spaces:
@staticmethod
def GPU(duration=60):
def decorator(func):
return func
return decorator
# Set Kaggle API credentials from environment variable
if os.getenv("KDATA_API"):
kaggle_key = os.getenv("KDATA_API")
# Parse the key if it's in JSON format
if "{" in kaggle_key:
key_data = json.loads(kaggle_key)
os.environ["KAGGLE_USERNAME"] = key_data.get("username", "")
os.environ["KAGGLE_KEY"] = key_data.get("key", "")
# Global variables
model = None
dataset_path = None
training_in_progress = False
class Visualization:
def __init__(self, root, data_types, n_ims, rows, cmap=None):
self.n_ims, self.rows = n_ims, rows
self.cmap, self.data_types = cmap, data_types
self.colors = ["firebrick", "darkorange", "blueviolet"]
self.root = root
self.get_cls_names()
self.get_bboxes()
def get_cls_names(self):
yaml_path = f"{self.root}/data.yaml"
if not os.path.exists(yaml_path):
print(f"Warning: {yaml_path} not found")
self.class_dict = {}
return
with open(yaml_path, 'r') as file:
data = yaml.safe_load(file)
class_names = data.get('names', [])
self.class_dict = {index: name for index, name in enumerate(class_names)}
# Print class names for debugging
if self.class_dict:
print(f"Dataset classes: {', '.join(class_names)}")
def get_bboxes(self):
self.vis_datas, self.analysis_datas, self.im_paths = {}, {}, {}
for data_type in self.data_types:
all_bboxes, all_analysis_datas = [], {}
im_paths = glob(f"{self.root}/{data_type}/images/*")
for idx, im_path in enumerate(im_paths):
bboxes = []
im_ext = os.path.splitext(im_path)[-1]
lbl_path = im_path.replace(im_ext, ".txt")
lbl_path = lbl_path.replace(f"{data_type}/images", f"{data_type}/labels")
if not os.path.isfile(lbl_path):
continue
meta_data = open(lbl_path).readlines()
for data in meta_data:
parts = data.strip().split()[:5]
cls_name = self.class_dict[int(parts[0])]
bboxes.append([cls_name] + [float(x) for x in parts[1:]])
if cls_name not in all_analysis_datas:
all_analysis_datas[cls_name] = 1
else:
all_analysis_datas[cls_name] += 1
all_bboxes.append(bboxes)
self.vis_datas[data_type] = all_bboxes
self.analysis_datas[data_type] = all_analysis_datas
self.im_paths[data_type] = im_paths
def plot_single(self, im_path, bboxes):
or_im = np.array(Image.open(im_path).convert("RGB"))
height, width, _ = or_im.shape
for bbox in bboxes:
class_id, x_center, y_center, w, h = bbox
x_min = int((x_center - w / 2) * width)
y_min = int((y_center - h / 2) * height)
x_max = int((x_center + w / 2) * width)
y_max = int((y_center + h / 2) * height)
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
cv2.rectangle(img=or_im, pt1=(x_min, y_min), pt2=(x_max, y_max),
color=color, thickness=3)
# Add text overlay
cv2.putText(or_im, f"Objects: {len(bboxes)}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
# OpenCV uses BGR, but PIL expects RGB, and we already loaded as RGB
# So no conversion needed
return Image.fromarray(or_im)
def vis_samples(self, data_type, n_samples=4):
if data_type not in self.vis_datas:
return None
indices = [random.randint(0, len(self.vis_datas[data_type]) - 1)
for _ in range(min(n_samples, len(self.vis_datas[data_type])))]
figs = []
for idx in indices:
im_path = self.im_paths[data_type][idx]
bboxes = self.vis_datas[data_type][idx]
fig = self.plot_single(im_path, bboxes)
figs.append(fig)
return figs
def data_analysis(self, data_type):
if data_type not in self.analysis_datas:
return None
plt.style.use('default')
fig, ax = plt.subplots(figsize=(12, 6))
cls_names = list(self.analysis_datas[data_type].keys())
counts = list(self.analysis_datas[data_type].values())
color_map = {"train": "firebrick", "valid": "darkorange", "test": "blueviolet"}
color = color_map.get(data_type, "steelblue")
indices = np.arange(len(counts))
bars = ax.bar(indices, counts, 0.7, color=color)
ax.set_xlabel("Class Names", fontsize=12)
ax.set_xticks(indices)
ax.set_xticklabels(cls_names, rotation=45, ha='right')
ax.set_ylabel("Data Counts", fontsize=12)
ax.set_title(f"{data_type.upper()} Dataset Class Distribution", fontsize=14)
for i, (bar, v) in enumerate(zip(bars, counts)):
ax.text(bar.get_x() + bar.get_width()/2, v + 1, str(v),
ha='center', va='bottom', fontsize=10, color='navy')
plt.tight_layout()
# Save to BytesIO and convert to PIL Image
buf = BytesIO()
fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
buf.seek(0)
img = Image.open(buf)
plt.close(fig)
return img
def download_dataset():
"""Download the dataset using kagglehub"""
global dataset_path
try:
# Create a local directory to store the dataset
local_dir = "./xray_dataset"
# Download dataset
dataset_path = kagglehub.dataset_download("orvile/x-ray-baggage-anomaly-detection")
# If the dataset is downloaded to a temporary location, copy it to our local directory
if dataset_path != local_dir and os.path.exists(dataset_path):
if os.path.exists(local_dir):
shutil.rmtree(local_dir)
shutil.copytree(dataset_path, local_dir)
dataset_path = local_dir
return f"Dataset downloaded successfully to: {dataset_path}"
except Exception as e:
return f"Error downloading dataset: {str(e)}\n\nPlease ensure KDATA_API environment variable is set correctly."
def visualize_data(data_type, num_samples):
"""Visualize sample images from the dataset"""
if dataset_path is None:
return [], "Please download the dataset first!"
try:
vis = Visualization(root=dataset_path, data_types=[data_type],
n_ims=num_samples, rows=2, cmap="rgb")
figs = vis.vis_samples(data_type, num_samples)
if figs is None:
return [], f"No data found for {data_type} dataset"
return figs, f"Showing {len(figs)} samples from {data_type} dataset"
except Exception as e:
return [], f"Error visualizing data: {str(e)}"
def analyze_class_distribution(data_type):
"""Analyze class distribution in the dataset"""
if dataset_path is None:
return None, "Please download the dataset first!"
try:
vis = Visualization(root=dataset_path, data_types=[data_type],
n_ims=20, rows=5, cmap="rgb")
fig = vis.data_analysis(data_type)
if fig is None:
return None, f"No data found for {data_type} dataset"
return fig, f"Class distribution for {data_type} dataset"
except Exception as e:
return None, f"Error analyzing data: {str(e)}"
@spaces.GPU(duration=300) # Request GPU for 5 minutes for training
def train_model(epochs, batch_size, img_size, device_selection):
"""Train YOLOv11 model"""
global model, training_in_progress
if dataset_path is None:
return [], "Please download the dataset first!"
if training_in_progress:
return [], "Training already in progress!"
training_in_progress = True
try:
# Determine device - on Spaces, always use GPU if available
if ON_SPACES and torch.cuda.is_available():
device = 0
elif device_selection == "Auto":
device = 0 if torch.cuda.is_available() else "cpu"
elif device_selection == "CPU":
device = "cpu"
else:
device = 0 if torch.cuda.is_available() else "cpu"
# Read dataset info
yaml_path = f"{dataset_path}/data.yaml"
with open(yaml_path, 'r') as file:
data_config = yaml.safe_load(file)
class_names = data_config.get('names', [])
print(f"Training on {len(class_names)} classes: {class_names}")
# Initialize model - use yolov8n if yolo11n not available
try:
model = YOLO("yolo11n.pt")
except Exception as e:
print(f"YOLOv11 not available: {e}, falling back to YOLOv8")
model = YOLO("yolov8n.pt") # Fallback to YOLOv8
# Create project directory
project_dir = "./xray_detection"
os.makedirs(project_dir, exist_ok=True)
# Train model with optimized settings for X-ray detection
results = model.train(
data=yaml_path,
epochs=epochs,
imgsz=img_size,
batch=batch_size,
device=device,
project=project_dir,
name="train",
exist_ok=True,
verbose=True,
patience=5, # Reduce patience for faster training on Spaces
save_period=5, # Save checkpoints every 5 epochs
workers=0, # Important: Set to 0 to avoid multiprocessing issues
single_cls=False,
rect=False,
cache=False, # Disable caching to avoid memory issues
amp=True, # Use automatic mixed precision for faster training
# Optimization settings
optimizer='AdamW',
lr0=0.001, # Initial learning rate
lrf=0.01, # Final learning rate factor
momentum=0.937,
weight_decay=0.0005,
warmup_epochs=3.0,
warmup_momentum=0.8,
warmup_bias_lr=0.1,
# Loss weights
box=7.5,
cls=0.5,
dfl=1.5,
# Augmentation settings for X-ray images
hsv_h=0.0, # No hue augmentation for X-ray
hsv_s=0.0, # No saturation augmentation
hsv_v=0.1, # Slight value augmentation
degrees=0.0, # No rotation
translate=0.1,
scale=0.5,
shear=0.0,
perspective=0.0,
flipud=0.0, # No vertical flip for X-ray
fliplr=0.5, # Horizontal flip is okay
mosaic=1.0,
mixup=0.0,
copy_paste=0.0
)
# Collect training result plots
results_path = os.path.join(project_dir, "train")
plots = []
plot_files = ["results.png", "confusion_matrix.png", "val_batch0_pred.jpg",
"train_batch0.jpg", "val_batch0_labels.jpg"]
for plot_file in plot_files:
plot_path = os.path.join(results_path, plot_file)
if os.path.exists(plot_path):
plots.append(Image.open(plot_path))
# Save the model path
model_path = os.path.join(results_path, "weights", "best.pt")
# Load the trained model to ensure it's ready for inference
model_loaded = False
class_info = ""
if os.path.exists(model_path):
try:
model = YOLO(model_path)
model_loaded = True
class_info = f"\nโœ… Trained on {len(model.names)} classes: {', '.join(list(model.names.values()))}"
# Run a test inference to ensure model works
test_img = np.zeros((640, 640, 3), dtype=np.uint8)
test_results = model(test_img, verbose=False)
class_info += "\nโœ… Model test passed - ready for inference!"
except Exception as e:
class_info = f"\nโš ๏ธ Model loaded but test failed: {str(e)}"
else:
class_info = "\nโŒ Model file not found!"
training_in_progress = False
# Provide instructions for saving the model
save_instructions = """
โœ… **Training Complete!**
๐Ÿ“ฅ **Next Steps:**
1. Click "๐Ÿ“ฅ Download Model (.pt)" button below to save your model
2. Keep the downloaded file safe - you'll need it after Space restarts
3. To reuse: Upload the model file in the "Upload & Load Model" section
โš ๏ธ **Important**: This model will be lost when the Space restarts!
"""
return plots, f"Model saved to {model_path}{class_info}{save_instructions}"
except Exception as e:
training_in_progress = False
return [], f"Error during training: {str(e)}"
# ๐Ÿ” Inference (Modified to highlight bomb, pistol, spring, grenade, eod_gear, battery)
@spaces.GPU(duration=60) # Request GPU for 1 minute for inference
def run_inference(input_image, conf_threshold):
"""Run inference on a single image and print detected item names."""
global model
# Try to load the trained model if not already loaded
if model is None:
trained_model_path = "./xray_detection/train/weights/best.pt"
if os.path.exists(trained_model_path):
try:
model = YOLO(trained_model_path)
print(f"Loaded trained model from {trained_model_path}")
except Exception:
pass
# If still no model, try default
if model is None:
for fallback in ("yolo11n.pt", "yolov8n.pt"):
try:
model = YOLO(fallback)
print(f"Loaded fallback model: {fallback}")
break
except Exception:
continue
if model is None:
return None, "Please train the model first or load a pre-trained model!"
if input_image is None:
return None, "Please upload an image!"
try:
# Save the input image temporarily with proper format
temp_path = "temp_inference.jpg"
if input_image.mode != 'RGB':
input_image = input_image.convert('RGB')
input_image.save(temp_path, format='JPEG', quality=95)
# Run inference
imgsz = 640
results = model(
temp_path,
conf=conf_threshold,
verbose=False,
device=0 if torch.cuda.is_available() else 'cpu',
imgsz=imgsz,
augment=False,
agnostic_nms=False,
max_det=300
)
# Draw annotated image
annotated_image = results[0].plot(
conf=True,
labels=True,
boxes=True,
masks=False,
probs=False
)
# Prepare detection information
detections = []
detection_count = 0
danger_set = {'bomb', 'pistol', 'spring', 'grenade', 'eod_gear', 'battery'}
if results[0].boxes is not None:
detection_count = len(results[0].boxes)
for idx, box in enumerate(results[0].boxes):
cls = int(box.cls)
conf_val = float(box.conf)
xyxy = list(map(int, box.xyxy[0].tolist()))
cls_name = model.names.get(cls, f"Class {cls}")
# Highlight dangerous items
prefix = "โ€ผ๏ธ " if cls_name in danger_set else ""
detections.append(
f"{idx + 1}. {prefix}{cls_name}: {conf_val:.3f} "
f"| Box: [{xyxy[0]}, {xyxy[1]}, {xyxy[2]}, {xyxy[3]}]"
)
# Clean up temp file
if os.path.exists(temp_path):
os.remove(temp_path)
# Assemble detection text
det_text_header = (
f"Model classes ({len(model.names)}): {', '.join(list(model.names.values())[:10])}...\n"
f"Confidence threshold: {conf_threshold}\n\n"
)
if detections:
detection_text = (
det_text_header +
f"โœ… Found {detection_count} object(s):\n\n" +
"\n".join(detections)
)
else:
detection_text = det_text_header + "โŒ No objects detected."
return Image.fromarray(annotated_image), detection_text
except Exception as e:
import traceback
traceback.print_exc()
return None, f"Error during inference: {str(e)}"
@spaces.GPU(duration=60) # Request GPU for batch inference
def batch_inference(data_type, num_images):
"""Run inference on multiple images from test set"""
global model
# Try to load the trained model if not already loaded
if model is None:
trained_model_path = "./xray_detection/train/weights/best.pt"
if os.path.exists(trained_model_path):
try:
model = YOLO(trained_model_path)
print(f"Loaded trained model for batch inference")
except:
try:
model = YOLO("yolo11n.pt")
print("Loaded default model for batch inference")
except:
try:
model = YOLO("yolov8n.pt")
print("Loaded YOLOv8 model as fallback for batch inference")
except:
return [], "Please train the model first!"
else:
return [], "No trained model found. Please train the model first!"
if dataset_path is None:
return [], "Please download the dataset first!"
try:
image_dir = f"{dataset_path}/{data_type}/images"
if not os.path.exists(image_dir):
return [], f"Directory {image_dir} not found!"
image_files = glob(f"{image_dir}/*")[:num_images]
if not image_files:
return [], f"No images found in {image_dir}"
results_images = []
detection_counts = []
for img_path in image_files:
results = model(img_path, verbose=False, conf=0.25, imgsz=640)
annotated = results[0].plot()
results_images.append(Image.fromarray(annotated))
# Count detections
if results[0].boxes is not None:
detection_counts.append(len(results[0].boxes))
else:
detection_counts.append(0)
# Check model type
model_type = "X-ray detection model" if len(model.names) != 80 else "General COCO model"
avg_detections = sum(detection_counts) / len(detection_counts) if detection_counts else 0
return results_images, f"Processed {len(results_images)} images using {model_type}\nAverage detections per image: {avg_detections:.1f}"
except Exception as e:
return [], f"Error during batch inference: {str(e)}"
def get_dataset_info():
"""Get information about the X-ray dataset classes"""
if dataset_path is None:
return "Dataset not downloaded yet."
try:
yaml_path = f"{dataset_path}/data.yaml"
if not os.path.exists(yaml_path):
return "Dataset configuration file not found."
with open(yaml_path, 'r') as file:
data = yaml.safe_load(file)
class_names = data.get('names', [])
num_classes = len(class_names)
# Count images
train_images = len(glob(f"{dataset_path}/train/images/*")) if os.path.exists(f"{dataset_path}/train/images") else 0
valid_images = len(glob(f"{dataset_path}/valid/images/*")) if os.path.exists(f"{dataset_path}/valid/images") else 0
test_images = len(glob(f"{dataset_path}/test/images/*")) if os.path.exists(f"{dataset_path}/test/images") else 0
info = f"### ๐Ÿ“Š X-ray Baggage Dataset Info\n\n"
info += f"**Classes ({num_classes}):** {', '.join(class_names)}\n\n"
info += f"**Dataset Split:**\n"
info += f"- Training: {train_images} images\n"
info += f"- Validation: {valid_images} images\n"
info += f"- Test: {test_images} images\n"
info += f"- Total: {train_images + valid_images + test_images} images\n\n"
info += f"**What to expect:** The model will learn to detect these prohibited items in X-ray scans."
return info
except Exception as e:
return f"Error reading dataset info: {str(e)}"
"""Load a pre-trained model"""
global model
try:
# Check if it's a HuggingFace model path
if model_path.startswith("hf://") or "/" in model_path and not os.path.exists(model_path):
# Load from HuggingFace Hub
model = YOLO(model_path)
return f"Model loaded successfully from HuggingFace: {model_path}"
if not os.path.exists(model_path):
# Try default paths
default_paths = [
"./xray_detection/train/weights/best.pt",
"./xray_detection/train/weights/last.pt",
"yolo11n.pt",
"yolov8n.pt"
]
for path in default_paths:
if os.path.exists(path):
model_path = path
break
if os.path.exists(model_path):
model = YOLO(model_path)
# Check if it's a trained model by looking at class names
try:
if hasattr(model, 'names') and len(model.names) > 0:
class_names = ", ".join([f"{i}: {name}" for i, name in model.names.items()][:5])
if len(model.names) > 5:
class_names += f"... (์ด {len(model.names)} ํด๋ž˜์Šค)"
return f"Model loaded successfully from {model_path}\nํด๋ž˜์Šค: {class_names}"
except:
pass
return f"Model loaded successfully from {model_path}"
else:
return "Model file not found. Please train a model first or provide a valid path."
except Exception as e:
return f"Error loading model: {str(e)}"
def load_pretrained_model(model_file):
"""Load a pre-trained model from uploaded file"""
global model
if model_file is None:
return "Please upload a model file (.pt)"
try:
# model_file is already a filepath string when type="filepath"
temp_path = model_file
# Load the model
model = YOLO(temp_path)
# Check model info
try:
if hasattr(model, 'names') and len(model.names) > 0:
num_classes = len(model.names)
class_names = ", ".join([f"{name}" for name in list(model.names.values())[:5]])
if len(model.names) > 5:
class_names += f"... (์ด {num_classes} ํด๋ž˜์Šค)"
if num_classes == 80:
return f"โš ๏ธ Loaded COCO model with {num_classes} classes. This is not trained for X-ray detection.\nClasses: {class_names}"
else:
return f"โœ… Model loaded successfully!\nClasses ({num_classes}): {class_names}"
else:
return "โœ… Model loaded successfully!"
except:
return "โœ… Model loaded successfully!"
except Exception as e:
return f"Error loading model: {str(e)}"
def check_model_status():
"""Check current model status"""
global model
if model is None:
# Try to load trained model
trained_path = "./xray_detection/train/weights/best.pt"
if os.path.exists(trained_path):
try:
model = YOLO(trained_path)
num_classes = len(model.names)
class_names = ', '.join(list(model.names.values()))
return f"โœ… Trained model loaded: {num_classes} classes\n๐Ÿ“‹ Classes: {class_names}"
except:
return "โŒ No model loaded. Please train or load a model first."
return "โŒ No model loaded. Please train or load a model first."
else:
try:
num_classes = len(model.names)
class_names = ', '.join(list(model.names.values()))
if num_classes == 80:
return f"โš ๏ธ Default COCO model loaded ({num_classes} classes). For X-ray detection, please train on the X-ray dataset."
else:
return f"โœ… Model loaded: {num_classes} classes\n๐Ÿ“‹ Classes: {class_names}"
except:
return "โœ… Model loaded"
# Create Gradio interface
with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ๐ŸŽฏ X-ray Baggage Anomaly Detection with YOLO
This application allows you to:
1. Download and visualize the X-ray baggage dataset
2. Analyze class distributions
3. Train a YOLO model for object detection
4. Run inference on new images
**Note:** GPU will be automatically allocated when needed for training and inference.
""")
# Check if there's a pre-existing model
initial_model_status = "๐Ÿ” Checking for existing models..."
if os.path.exists("./xray_detection/train/weights/best.pt"):
try:
model = YOLO("./xray_detection/train/weights/best.pt")
initial_model_status = "โœ… Found previously trained model! Ready to use."
except:
initial_model_status = "โŒ No model loaded. Please train or upload a model."
else:
initial_model_status = "โŒ No model loaded. Please train or upload a model."
gr.Markdown(f"**Model Status:** {initial_model_status}")
# Add instructions for Kaggle API setup
with gr.Accordion("๐Ÿ“ Setup Instructions", open=False):
gr.Markdown("""
### Kaggle API Setup
1. Get your Kaggle API credentials from https://www.kaggle.com/settings
2. Set the KDATA_API environment variable in Hugging Face Spaces settings:
```
KDATA_API={"username":"your_username","key":"your_api_key"}
```
### Model Persistence on Hugging Face Spaces
- Models trained on Spaces are **temporary** and will be lost when the Space restarts
- After training, download your model using the "๐Ÿ“ฅ Download Model" button
- Upload the downloaded model file to reuse it after Space restarts
- No need for HuggingFace Hub or complex setups!
""")
with gr.Tab("๐Ÿ“Š Dataset"):
with gr.Row():
download_btn = gr.Button("Download Dataset", variant="primary", scale=1)
download_status = gr.Textbox(label="Status", interactive=False, scale=3)
download_btn.click(download_dataset, outputs=download_status)
# Dataset info section
with gr.Row():
dataset_info = gr.Markdown(value="Dataset not downloaded yet.")
info_btn = gr.Button("๐Ÿ”„ Refresh Dataset Info", scale=0)
def update_dataset_info():
return get_dataset_info()
info_btn.click(update_dataset_info, outputs=dataset_info)
gr.Markdown("### Visualize Dataset Samples")
with gr.Row():
data_type_viz = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
num_samples = gr.Slider(1, 8, 4, step=1, label="Number of Samples")
viz_btn = gr.Button("Visualize Samples")
viz_gallery = gr.Gallery(label="Sample Images", columns=2, height="auto")
viz_status = gr.Textbox(label="Status", interactive=False)
viz_btn.click(visualize_data, inputs=[data_type_viz, num_samples],
outputs=[viz_gallery, viz_status])
gr.Markdown("### Analyze Class Distribution")
with gr.Row():
data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
analyze_btn = gr.Button("Analyze Distribution")
distribution_plot = gr.Image(label="Class Distribution", type="pil")
analysis_status = gr.Textbox(label="Status", interactive=False)
analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis,
outputs=[distribution_plot, analysis_status])
gr.Markdown("### Visualize Dataset Samples")
with gr.Row():
data_type_viz = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
num_samples = gr.Slider(1, 8, 4, step=1, label="Number of Samples")
viz_btn = gr.Button("Visualize Samples")
viz_gallery = gr.Gallery(label="Sample Images", columns=2, height="auto")
viz_status = gr.Textbox(label="Status", interactive=False)
viz_btn.click(visualize_data, inputs=[data_type_viz, num_samples],
outputs=[viz_gallery, viz_status])
gr.Markdown("### Analyze Class Distribution")
with gr.Row():
data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type")
analyze_btn = gr.Button("Analyze Distribution")
distribution_plot = gr.Image(label="Class Distribution", type="pil")
analysis_status = gr.Textbox(label="Status", interactive=False)
analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis,
outputs=[distribution_plot, analysis_status])
with gr.Tab("๐Ÿš€ Training"):
gr.Markdown("### Train YOLO Model")
gr.Markdown("""
**Note:** Training will automatically use GPU if available. This may take several minutes.
**Recommended Settings for X-ray Detection:**
- **Epochs:** 20-30 for good results
- **Batch Size:** 2-4 for better convergence
- **Image Size:** 640 for best quality
- **Expected time:** ~2-5 minutes for 20 epochs
โš ๏ธ **Important**: Models are temporary on Spaces! Download your model after training.
""")
with gr.Row():
epochs_input = gr.Slider(1, 50, 20, step=1, label="Epochs (20+ recommended)")
batch_size_input = gr.Slider(2, 16, 4, step=2, label="Batch Size (lower for better results)")
img_size_input = gr.Slider(320, 640, 640, step=32, label="Image Size (640 recommended)")
device_input = gr.Radio(["Auto", "GPU", "CPU"], value="Auto", label="Device")
train_btn = gr.Button("Start Training", variant="primary")
training_gallery = gr.Gallery(label="Training Results", columns=3, height="auto")
training_status = gr.Textbox(label="Training Status", interactive=False)
train_btn.click(train_model,
inputs=[epochs_input, batch_size_input, img_size_input, device_input],
outputs=[training_gallery, training_status])
gr.Markdown("### ๐Ÿ“ฅ Model Management")
with gr.Row():
with gr.Column():
gr.Markdown("#### 1๏ธโƒฃ Download Trained Model")
gr.Markdown("After training, download your model to save it permanently.")
# Function to prepare model for download
def prepare_model_download():
model_path = "./xray_detection/train/weights/best.pt"
if os.path.exists(model_path):
return gr.update(value=model_path, visible=True), "โœ… Model ready for download!"
else:
return gr.update(value=None, visible=False), "โŒ No trained model found. Please train a model first."
download_btn = gr.Button("๐Ÿ“ฅ Download Model (.pt)", variant="secondary")
download_file = gr.File(label="Download Model File", visible=False)
download_status = gr.Textbox(label="Download Status", interactive=False)
download_btn.click(prepare_model_download, outputs=[download_file, download_status])
with gr.Column():
gr.Markdown("#### 2๏ธโƒฃ Upload & Load Model")
gr.Markdown("Upload a previously trained model file to continue using it.")
model_upload = gr.File(
label="Upload Model File (.pt)",
file_types=[".pt"],
type="filepath"
)
load_btn = gr.Button("๐Ÿ“ค Load Uploaded Model", variant="secondary")
load_status = gr.Textbox(label="Load Status", interactive=False)
load_btn.click(load_pretrained_model, inputs=model_upload, outputs=load_status)
# Auto-load when file is uploaded
model_upload.change(load_pretrained_model, inputs=model_upload, outputs=load_status)
with gr.Tab("๐Ÿ” Inference"):
# Model status check
with gr.Row():
model_status = gr.Textbox(label="Model Status", value=check_model_status(), interactive=False)
refresh_status_btn = gr.Button("๐Ÿ”„ Refresh Status", scale=0)
refresh_status_btn.click(check_model_status, outputs=model_status)
gr.Markdown("""
## ๐ŸŽฏ ๋ชจ๋ธ์ด ๊ฐ์ฒด๋ฅผ ๊ฐ์ง€ํ•˜์ง€ ๋ชปํ•˜๋‚˜์š”?
**๊ถŒ์žฅ ํ•™์Šต ์„ค์ •:**
- **Epochs: 30** (์ตœ์†Œ 20 ์ด์ƒ)
- **Batch Size: 2 ๋˜๋Š” 4**
- **Image Size: 640**
**์ฒดํฌ๋ฆฌ์ŠคํŠธ:**
1. โœ… X-ray ์ด๋ฏธ์ง€์ธ๊ฐ€? (์ผ๋ฐ˜ ์‚ฌ์ง„์€ ์ž‘๋™ ์•ˆ ํ•จ)
2. โœ… ์ถฉ๋ถ„ํžˆ ํ•™์Šตํ–ˆ๋‚˜? (20+ epochs)
3. โœ… Confidence threshold๋ฅผ 0.01๋กœ ๋‚ฎ์ถฐ๋ดค๋‚˜?
4. โœ… ๋ชจ๋ธ์ด ์ œ๋Œ€๋กœ ๋กœ๋“œ๋˜์—ˆ๋‚˜? (์ƒํƒœ ํ™•์ธ)
**์„ฑ๊ณต์ ์ธ ํ•™์Šต ํ›„ ์˜ˆ์ƒ ๊ฒฐ๊ณผ:**
- Firearm (์ด๊ธฐ๋ฅ˜) ๊ฐ์ง€
- Knife (์นผ) ๊ฐ์ง€
- Pliers (ํŽœ์น˜) ๊ฐ์ง€
- Scissors (๊ฐ€์œ„) ๊ฐ์ง€
- Wrench (๋ Œ์น˜) ๊ฐ์ง€
""")
gr.Markdown("### Single Image Inference")
gr.Markdown("Upload an X-ray baggage image to detect prohibited items.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload X-ray Image")
conf_threshold = gr.Slider(0.01, 0.9, 0.25, step=0.01, label="Confidence Threshold (๋‚ฎ์„์ˆ˜๋ก ๋” ๋งŽ์ด ๊ฐ์ง€)")
# Debug options
with gr.Row():
inference_btn = gr.Button("Run Detection", variant="primary")
test_btn = gr.Button("Test with 0.01 threshold", variant="secondary", scale=0)
# Add example images if dataset is available
example_images = []
if dataset_path and os.path.exists(f"{dataset_path}/test/images"):
test_images = glob(f"{dataset_path}/test/images/*")[:5]
example_images.extend(test_images)
if example_images:
gr.Examples(
examples=[[img] for img in example_images],
inputs=input_image,
label="Example X-ray Images (Click to load)"
)
with gr.Column():
output_image = gr.Image(type="pil", label="Detection Result")
detection_info = gr.Textbox(label="Detection Info", lines=8)
inference_btn.click(run_inference,
inputs=[input_image, conf_threshold],
outputs=[output_image, detection_info])
# Test with very low threshold
test_btn.click(
lambda img: run_inference(img, 0.01),
inputs=[input_image],
outputs=[output_image, detection_info]
)
# Auto-refresh model status after inference
inference_btn.click(check_model_status, outputs=model_status)
gr.Markdown("### Batch Inference")
gr.Markdown("Run detection on multiple images from the test dataset.")
with gr.Row():
batch_data_type = gr.Dropdown(["test", "valid"], value="test", label="Dataset Type")
batch_num_images = gr.Slider(1, 10, 5, step=1, label="Number of Images")
batch_btn = gr.Button("Run Batch Inference")
batch_gallery = gr.Gallery(label="Batch Results", columns=3, height="auto")
batch_status = gr.Textbox(label="Status", interactive=False)
batch_btn.click(batch_inference,
inputs=[batch_data_type, batch_num_images],
outputs=[batch_gallery, batch_status])
# Footer
gr.Markdown("---")
gr.Markdown("""
<div style='text-align: center; font-size: 14px; color: #666;'>
๐Ÿ’ก <b>Quick Start:</b> Download Dataset โ†’ Train Model (20+ epochs) โ†’ Run Inference<br>
๐Ÿ” <b>No detections?</b> Try lowering threshold to 0.01 or train for more epochs<br>
๐Ÿš€ Built with Gradio, YOLOv8, and โค๏ธ for X-ray security
</div>
""")
# Launch the app
if __name__ == "__main__":
# Check if running on Hugging Face Spaces
if ON_SPACES:
demo.launch(ssr_mode=False)
else:
demo.launch(share=True, ssr_mode=False)