|
|
import os |
|
|
|
|
|
os.environ['OMP_NUM_THREADS'] = '1' |
|
|
os.environ['MKL_NUM_THREADS'] = '1' |
|
|
import cv2 |
|
|
import yaml |
|
|
import torch |
|
|
import random |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import kagglehub |
|
|
from PIL import Image |
|
|
from glob import glob |
|
|
import matplotlib |
|
|
matplotlib.use('Agg') |
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib import patches |
|
|
from torchvision import transforms as T |
|
|
from ultralytics import YOLO |
|
|
import shutil |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
import json |
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
try: |
|
|
import spaces |
|
|
ON_SPACES = True |
|
|
except ImportError: |
|
|
ON_SPACES = False |
|
|
|
|
|
class spaces: |
|
|
@staticmethod |
|
|
def GPU(duration=60): |
|
|
def decorator(func): |
|
|
return func |
|
|
return decorator |
|
|
|
|
|
|
|
|
if os.getenv("KDATA_API"): |
|
|
kaggle_key = os.getenv("KDATA_API") |
|
|
|
|
|
if "{" in kaggle_key: |
|
|
key_data = json.loads(kaggle_key) |
|
|
os.environ["KAGGLE_USERNAME"] = key_data.get("username", "") |
|
|
os.environ["KAGGLE_KEY"] = key_data.get("key", "") |
|
|
|
|
|
|
|
|
model = None |
|
|
dataset_path = None |
|
|
training_in_progress = False |
|
|
|
|
|
class Visualization: |
|
|
def __init__(self, root, data_types, n_ims, rows, cmap=None): |
|
|
self.n_ims, self.rows = n_ims, rows |
|
|
self.cmap, self.data_types = cmap, data_types |
|
|
self.colors = ["firebrick", "darkorange", "blueviolet"] |
|
|
self.root = root |
|
|
|
|
|
self.get_cls_names() |
|
|
self.get_bboxes() |
|
|
|
|
|
def get_cls_names(self): |
|
|
yaml_path = f"{self.root}/data.yaml" |
|
|
if not os.path.exists(yaml_path): |
|
|
print(f"Warning: {yaml_path} not found") |
|
|
self.class_dict = {} |
|
|
return |
|
|
|
|
|
with open(yaml_path, 'r') as file: |
|
|
data = yaml.safe_load(file) |
|
|
class_names = data.get('names', []) |
|
|
self.class_dict = {index: name for index, name in enumerate(class_names)} |
|
|
|
|
|
|
|
|
if self.class_dict: |
|
|
print(f"Dataset classes: {', '.join(class_names)}") |
|
|
|
|
|
def get_bboxes(self): |
|
|
self.vis_datas, self.analysis_datas, self.im_paths = {}, {}, {} |
|
|
for data_type in self.data_types: |
|
|
all_bboxes, all_analysis_datas = [], {} |
|
|
im_paths = glob(f"{self.root}/{data_type}/images/*") |
|
|
|
|
|
for idx, im_path in enumerate(im_paths): |
|
|
bboxes = [] |
|
|
im_ext = os.path.splitext(im_path)[-1] |
|
|
lbl_path = im_path.replace(im_ext, ".txt") |
|
|
lbl_path = lbl_path.replace(f"{data_type}/images", f"{data_type}/labels") |
|
|
if not os.path.isfile(lbl_path): |
|
|
continue |
|
|
meta_data = open(lbl_path).readlines() |
|
|
for data in meta_data: |
|
|
parts = data.strip().split()[:5] |
|
|
cls_name = self.class_dict[int(parts[0])] |
|
|
bboxes.append([cls_name] + [float(x) for x in parts[1:]]) |
|
|
if cls_name not in all_analysis_datas: |
|
|
all_analysis_datas[cls_name] = 1 |
|
|
else: |
|
|
all_analysis_datas[cls_name] += 1 |
|
|
all_bboxes.append(bboxes) |
|
|
|
|
|
self.vis_datas[data_type] = all_bboxes |
|
|
self.analysis_datas[data_type] = all_analysis_datas |
|
|
self.im_paths[data_type] = im_paths |
|
|
|
|
|
def plot_single(self, im_path, bboxes): |
|
|
or_im = np.array(Image.open(im_path).convert("RGB")) |
|
|
height, width, _ = or_im.shape |
|
|
|
|
|
for bbox in bboxes: |
|
|
class_id, x_center, y_center, w, h = bbox |
|
|
|
|
|
x_min = int((x_center - w / 2) * width) |
|
|
y_min = int((y_center - h / 2) * height) |
|
|
x_max = int((x_center + w / 2) * width) |
|
|
y_max = int((y_center + h / 2) * height) |
|
|
|
|
|
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) |
|
|
cv2.rectangle(img=or_im, pt1=(x_min, y_min), pt2=(x_max, y_max), |
|
|
color=color, thickness=3) |
|
|
|
|
|
|
|
|
cv2.putText(or_im, f"Objects: {len(bboxes)}", (10, 30), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) |
|
|
|
|
|
|
|
|
|
|
|
return Image.fromarray(or_im) |
|
|
|
|
|
def vis_samples(self, data_type, n_samples=4): |
|
|
if data_type not in self.vis_datas: |
|
|
return None |
|
|
|
|
|
indices = [random.randint(0, len(self.vis_datas[data_type]) - 1) |
|
|
for _ in range(min(n_samples, len(self.vis_datas[data_type])))] |
|
|
|
|
|
figs = [] |
|
|
for idx in indices: |
|
|
im_path = self.im_paths[data_type][idx] |
|
|
bboxes = self.vis_datas[data_type][idx] |
|
|
fig = self.plot_single(im_path, bboxes) |
|
|
figs.append(fig) |
|
|
|
|
|
return figs |
|
|
|
|
|
def data_analysis(self, data_type): |
|
|
if data_type not in self.analysis_datas: |
|
|
return None |
|
|
|
|
|
plt.style.use('default') |
|
|
fig, ax = plt.subplots(figsize=(12, 6)) |
|
|
|
|
|
cls_names = list(self.analysis_datas[data_type].keys()) |
|
|
counts = list(self.analysis_datas[data_type].values()) |
|
|
|
|
|
color_map = {"train": "firebrick", "valid": "darkorange", "test": "blueviolet"} |
|
|
color = color_map.get(data_type, "steelblue") |
|
|
|
|
|
indices = np.arange(len(counts)) |
|
|
bars = ax.bar(indices, counts, 0.7, color=color) |
|
|
|
|
|
ax.set_xlabel("Class Names", fontsize=12) |
|
|
ax.set_xticks(indices) |
|
|
ax.set_xticklabels(cls_names, rotation=45, ha='right') |
|
|
ax.set_ylabel("Data Counts", fontsize=12) |
|
|
ax.set_title(f"{data_type.upper()} Dataset Class Distribution", fontsize=14) |
|
|
|
|
|
for i, (bar, v) in enumerate(zip(bars, counts)): |
|
|
ax.text(bar.get_x() + bar.get_width()/2, v + 1, str(v), |
|
|
ha='center', va='bottom', fontsize=10, color='navy') |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
buf = BytesIO() |
|
|
fig.savefig(buf, format='png', dpi=100, bbox_inches='tight') |
|
|
buf.seek(0) |
|
|
img = Image.open(buf) |
|
|
plt.close(fig) |
|
|
|
|
|
return img |
|
|
|
|
|
def download_dataset(): |
|
|
"""Download the dataset using kagglehub""" |
|
|
global dataset_path |
|
|
try: |
|
|
|
|
|
local_dir = "./xray_dataset" |
|
|
|
|
|
|
|
|
dataset_path = kagglehub.dataset_download("orvile/x-ray-baggage-anomaly-detection") |
|
|
|
|
|
|
|
|
if dataset_path != local_dir and os.path.exists(dataset_path): |
|
|
if os.path.exists(local_dir): |
|
|
shutil.rmtree(local_dir) |
|
|
shutil.copytree(dataset_path, local_dir) |
|
|
dataset_path = local_dir |
|
|
|
|
|
return f"Dataset downloaded successfully to: {dataset_path}" |
|
|
except Exception as e: |
|
|
return f"Error downloading dataset: {str(e)}\n\nPlease ensure KDATA_API environment variable is set correctly." |
|
|
|
|
|
def visualize_data(data_type, num_samples): |
|
|
"""Visualize sample images from the dataset""" |
|
|
if dataset_path is None: |
|
|
return [], "Please download the dataset first!" |
|
|
|
|
|
try: |
|
|
vis = Visualization(root=dataset_path, data_types=[data_type], |
|
|
n_ims=num_samples, rows=2, cmap="rgb") |
|
|
figs = vis.vis_samples(data_type, num_samples) |
|
|
if figs is None: |
|
|
return [], f"No data found for {data_type} dataset" |
|
|
return figs, f"Showing {len(figs)} samples from {data_type} dataset" |
|
|
except Exception as e: |
|
|
return [], f"Error visualizing data: {str(e)}" |
|
|
|
|
|
def analyze_class_distribution(data_type): |
|
|
"""Analyze class distribution in the dataset""" |
|
|
if dataset_path is None: |
|
|
return None, "Please download the dataset first!" |
|
|
|
|
|
try: |
|
|
vis = Visualization(root=dataset_path, data_types=[data_type], |
|
|
n_ims=20, rows=5, cmap="rgb") |
|
|
fig = vis.data_analysis(data_type) |
|
|
if fig is None: |
|
|
return None, f"No data found for {data_type} dataset" |
|
|
return fig, f"Class distribution for {data_type} dataset" |
|
|
except Exception as e: |
|
|
return None, f"Error analyzing data: {str(e)}" |
|
|
|
|
|
@spaces.GPU(duration=300) |
|
|
def train_model(epochs, batch_size, img_size, device_selection): |
|
|
"""Train YOLOv11 model""" |
|
|
global model, training_in_progress |
|
|
|
|
|
if dataset_path is None: |
|
|
return [], "Please download the dataset first!" |
|
|
|
|
|
if training_in_progress: |
|
|
return [], "Training already in progress!" |
|
|
|
|
|
training_in_progress = True |
|
|
|
|
|
try: |
|
|
|
|
|
if ON_SPACES and torch.cuda.is_available(): |
|
|
device = 0 |
|
|
elif device_selection == "Auto": |
|
|
device = 0 if torch.cuda.is_available() else "cpu" |
|
|
elif device_selection == "CPU": |
|
|
device = "cpu" |
|
|
else: |
|
|
device = 0 if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
yaml_path = f"{dataset_path}/data.yaml" |
|
|
with open(yaml_path, 'r') as file: |
|
|
data_config = yaml.safe_load(file) |
|
|
|
|
|
class_names = data_config.get('names', []) |
|
|
print(f"Training on {len(class_names)} classes: {class_names}") |
|
|
|
|
|
|
|
|
try: |
|
|
model = YOLO("yolo11n.pt") |
|
|
except Exception as e: |
|
|
print(f"YOLOv11 not available: {e}, falling back to YOLOv8") |
|
|
model = YOLO("yolov8n.pt") |
|
|
|
|
|
|
|
|
project_dir = "./xray_detection" |
|
|
os.makedirs(project_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
results = model.train( |
|
|
data=yaml_path, |
|
|
epochs=epochs, |
|
|
imgsz=img_size, |
|
|
batch=batch_size, |
|
|
device=device, |
|
|
project=project_dir, |
|
|
name="train", |
|
|
exist_ok=True, |
|
|
verbose=True, |
|
|
patience=5, |
|
|
save_period=5, |
|
|
workers=0, |
|
|
single_cls=False, |
|
|
rect=False, |
|
|
cache=False, |
|
|
amp=True, |
|
|
|
|
|
optimizer='AdamW', |
|
|
lr0=0.001, |
|
|
lrf=0.01, |
|
|
momentum=0.937, |
|
|
weight_decay=0.0005, |
|
|
warmup_epochs=3.0, |
|
|
warmup_momentum=0.8, |
|
|
warmup_bias_lr=0.1, |
|
|
|
|
|
box=7.5, |
|
|
cls=0.5, |
|
|
dfl=1.5, |
|
|
|
|
|
hsv_h=0.0, |
|
|
hsv_s=0.0, |
|
|
hsv_v=0.1, |
|
|
degrees=0.0, |
|
|
translate=0.1, |
|
|
scale=0.5, |
|
|
shear=0.0, |
|
|
perspective=0.0, |
|
|
flipud=0.0, |
|
|
fliplr=0.5, |
|
|
mosaic=1.0, |
|
|
mixup=0.0, |
|
|
copy_paste=0.0 |
|
|
) |
|
|
|
|
|
|
|
|
results_path = os.path.join(project_dir, "train") |
|
|
plots = [] |
|
|
|
|
|
plot_files = ["results.png", "confusion_matrix.png", "val_batch0_pred.jpg", |
|
|
"train_batch0.jpg", "val_batch0_labels.jpg"] |
|
|
|
|
|
for plot_file in plot_files: |
|
|
plot_path = os.path.join(results_path, plot_file) |
|
|
if os.path.exists(plot_path): |
|
|
plots.append(Image.open(plot_path)) |
|
|
|
|
|
|
|
|
model_path = os.path.join(results_path, "weights", "best.pt") |
|
|
|
|
|
|
|
|
model_loaded = False |
|
|
class_info = "" |
|
|
if os.path.exists(model_path): |
|
|
try: |
|
|
model = YOLO(model_path) |
|
|
model_loaded = True |
|
|
class_info = f"\nโ
Trained on {len(model.names)} classes: {', '.join(list(model.names.values()))}" |
|
|
|
|
|
|
|
|
test_img = np.zeros((640, 640, 3), dtype=np.uint8) |
|
|
test_results = model(test_img, verbose=False) |
|
|
class_info += "\nโ
Model test passed - ready for inference!" |
|
|
except Exception as e: |
|
|
class_info = f"\nโ ๏ธ Model loaded but test failed: {str(e)}" |
|
|
else: |
|
|
class_info = "\nโ Model file not found!" |
|
|
|
|
|
training_in_progress = False |
|
|
|
|
|
|
|
|
save_instructions = """ |
|
|
|
|
|
โ
**Training Complete!** |
|
|
|
|
|
๐ฅ **Next Steps:** |
|
|
1. Click "๐ฅ Download Model (.pt)" button below to save your model |
|
|
2. Keep the downloaded file safe - you'll need it after Space restarts |
|
|
3. To reuse: Upload the model file in the "Upload & Load Model" section |
|
|
|
|
|
โ ๏ธ **Important**: This model will be lost when the Space restarts! |
|
|
""" |
|
|
|
|
|
return plots, f"Model saved to {model_path}{class_info}{save_instructions}" |
|
|
|
|
|
except Exception as e: |
|
|
training_in_progress = False |
|
|
return [], f"Error during training: {str(e)}" |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def run_inference(input_image, conf_threshold): |
|
|
"""Run inference on a single image and print detected item names.""" |
|
|
global model |
|
|
|
|
|
|
|
|
if model is None: |
|
|
trained_model_path = "./xray_detection/train/weights/best.pt" |
|
|
if os.path.exists(trained_model_path): |
|
|
try: |
|
|
model = YOLO(trained_model_path) |
|
|
print(f"Loaded trained model from {trained_model_path}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
if model is None: |
|
|
for fallback in ("yolo11n.pt", "yolov8n.pt"): |
|
|
try: |
|
|
model = YOLO(fallback) |
|
|
print(f"Loaded fallback model: {fallback}") |
|
|
break |
|
|
except Exception: |
|
|
continue |
|
|
if model is None: |
|
|
return None, "Please train the model first or load a pre-trained model!" |
|
|
|
|
|
if input_image is None: |
|
|
return None, "Please upload an image!" |
|
|
|
|
|
try: |
|
|
|
|
|
temp_path = "temp_inference.jpg" |
|
|
if input_image.mode != 'RGB': |
|
|
input_image = input_image.convert('RGB') |
|
|
input_image.save(temp_path, format='JPEG', quality=95) |
|
|
|
|
|
|
|
|
imgsz = 640 |
|
|
results = model( |
|
|
temp_path, |
|
|
conf=conf_threshold, |
|
|
verbose=False, |
|
|
device=0 if torch.cuda.is_available() else 'cpu', |
|
|
imgsz=imgsz, |
|
|
augment=False, |
|
|
agnostic_nms=False, |
|
|
max_det=300 |
|
|
) |
|
|
|
|
|
|
|
|
annotated_image = results[0].plot( |
|
|
conf=True, |
|
|
labels=True, |
|
|
boxes=True, |
|
|
masks=False, |
|
|
probs=False |
|
|
) |
|
|
|
|
|
|
|
|
detections = [] |
|
|
detection_count = 0 |
|
|
danger_set = {'bomb', 'pistol', 'spring', 'grenade', 'eod_gear', 'battery'} |
|
|
|
|
|
if results[0].boxes is not None: |
|
|
detection_count = len(results[0].boxes) |
|
|
for idx, box in enumerate(results[0].boxes): |
|
|
cls = int(box.cls) |
|
|
conf_val = float(box.conf) |
|
|
xyxy = list(map(int, box.xyxy[0].tolist())) |
|
|
cls_name = model.names.get(cls, f"Class {cls}") |
|
|
|
|
|
|
|
|
prefix = "โผ๏ธ " if cls_name in danger_set else "" |
|
|
detections.append( |
|
|
f"{idx + 1}. {prefix}{cls_name}: {conf_val:.3f} " |
|
|
f"| Box: [{xyxy[0]}, {xyxy[1]}, {xyxy[2]}, {xyxy[3]}]" |
|
|
) |
|
|
|
|
|
|
|
|
if os.path.exists(temp_path): |
|
|
os.remove(temp_path) |
|
|
|
|
|
|
|
|
det_text_header = ( |
|
|
f"Model classes ({len(model.names)}): {', '.join(list(model.names.values())[:10])}...\n" |
|
|
f"Confidence threshold: {conf_threshold}\n\n" |
|
|
) |
|
|
if detections: |
|
|
detection_text = ( |
|
|
det_text_header + |
|
|
f"โ
Found {detection_count} object(s):\n\n" + |
|
|
"\n".join(detections) |
|
|
) |
|
|
else: |
|
|
detection_text = det_text_header + "โ No objects detected." |
|
|
|
|
|
return Image.fromarray(annotated_image), detection_text |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, f"Error during inference: {str(e)}" |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def batch_inference(data_type, num_images): |
|
|
"""Run inference on multiple images from test set""" |
|
|
global model |
|
|
|
|
|
|
|
|
if model is None: |
|
|
trained_model_path = "./xray_detection/train/weights/best.pt" |
|
|
if os.path.exists(trained_model_path): |
|
|
try: |
|
|
model = YOLO(trained_model_path) |
|
|
print(f"Loaded trained model for batch inference") |
|
|
except: |
|
|
try: |
|
|
model = YOLO("yolo11n.pt") |
|
|
print("Loaded default model for batch inference") |
|
|
except: |
|
|
try: |
|
|
model = YOLO("yolov8n.pt") |
|
|
print("Loaded YOLOv8 model as fallback for batch inference") |
|
|
except: |
|
|
return [], "Please train the model first!" |
|
|
else: |
|
|
return [], "No trained model found. Please train the model first!" |
|
|
|
|
|
if dataset_path is None: |
|
|
return [], "Please download the dataset first!" |
|
|
|
|
|
try: |
|
|
image_dir = f"{dataset_path}/{data_type}/images" |
|
|
if not os.path.exists(image_dir): |
|
|
return [], f"Directory {image_dir} not found!" |
|
|
|
|
|
image_files = glob(f"{image_dir}/*")[:num_images] |
|
|
|
|
|
if not image_files: |
|
|
return [], f"No images found in {image_dir}" |
|
|
|
|
|
results_images = [] |
|
|
detection_counts = [] |
|
|
|
|
|
for img_path in image_files: |
|
|
results = model(img_path, verbose=False, conf=0.25, imgsz=640) |
|
|
annotated = results[0].plot() |
|
|
results_images.append(Image.fromarray(annotated)) |
|
|
|
|
|
|
|
|
if results[0].boxes is not None: |
|
|
detection_counts.append(len(results[0].boxes)) |
|
|
else: |
|
|
detection_counts.append(0) |
|
|
|
|
|
|
|
|
model_type = "X-ray detection model" if len(model.names) != 80 else "General COCO model" |
|
|
avg_detections = sum(detection_counts) / len(detection_counts) if detection_counts else 0 |
|
|
|
|
|
return results_images, f"Processed {len(results_images)} images using {model_type}\nAverage detections per image: {avg_detections:.1f}" |
|
|
|
|
|
except Exception as e: |
|
|
return [], f"Error during batch inference: {str(e)}" |
|
|
|
|
|
def get_dataset_info(): |
|
|
"""Get information about the X-ray dataset classes""" |
|
|
if dataset_path is None: |
|
|
return "Dataset not downloaded yet." |
|
|
|
|
|
try: |
|
|
yaml_path = f"{dataset_path}/data.yaml" |
|
|
if not os.path.exists(yaml_path): |
|
|
return "Dataset configuration file not found." |
|
|
|
|
|
with open(yaml_path, 'r') as file: |
|
|
data = yaml.safe_load(file) |
|
|
|
|
|
class_names = data.get('names', []) |
|
|
num_classes = len(class_names) |
|
|
|
|
|
|
|
|
train_images = len(glob(f"{dataset_path}/train/images/*")) if os.path.exists(f"{dataset_path}/train/images") else 0 |
|
|
valid_images = len(glob(f"{dataset_path}/valid/images/*")) if os.path.exists(f"{dataset_path}/valid/images") else 0 |
|
|
test_images = len(glob(f"{dataset_path}/test/images/*")) if os.path.exists(f"{dataset_path}/test/images") else 0 |
|
|
|
|
|
info = f"### ๐ X-ray Baggage Dataset Info\n\n" |
|
|
info += f"**Classes ({num_classes}):** {', '.join(class_names)}\n\n" |
|
|
info += f"**Dataset Split:**\n" |
|
|
info += f"- Training: {train_images} images\n" |
|
|
info += f"- Validation: {valid_images} images\n" |
|
|
info += f"- Test: {test_images} images\n" |
|
|
info += f"- Total: {train_images + valid_images + test_images} images\n\n" |
|
|
info += f"**What to expect:** The model will learn to detect these prohibited items in X-ray scans." |
|
|
|
|
|
return info |
|
|
except Exception as e: |
|
|
return f"Error reading dataset info: {str(e)}" |
|
|
"""Load a pre-trained model""" |
|
|
global model |
|
|
try: |
|
|
|
|
|
if model_path.startswith("hf://") or "/" in model_path and not os.path.exists(model_path): |
|
|
|
|
|
model = YOLO(model_path) |
|
|
return f"Model loaded successfully from HuggingFace: {model_path}" |
|
|
|
|
|
if not os.path.exists(model_path): |
|
|
|
|
|
default_paths = [ |
|
|
"./xray_detection/train/weights/best.pt", |
|
|
"./xray_detection/train/weights/last.pt", |
|
|
"yolo11n.pt", |
|
|
"yolov8n.pt" |
|
|
] |
|
|
for path in default_paths: |
|
|
if os.path.exists(path): |
|
|
model_path = path |
|
|
break |
|
|
|
|
|
if os.path.exists(model_path): |
|
|
model = YOLO(model_path) |
|
|
|
|
|
try: |
|
|
if hasattr(model, 'names') and len(model.names) > 0: |
|
|
class_names = ", ".join([f"{i}: {name}" for i, name in model.names.items()][:5]) |
|
|
if len(model.names) > 5: |
|
|
class_names += f"... (์ด {len(model.names)} ํด๋์ค)" |
|
|
return f"Model loaded successfully from {model_path}\nํด๋์ค: {class_names}" |
|
|
except: |
|
|
pass |
|
|
return f"Model loaded successfully from {model_path}" |
|
|
else: |
|
|
return "Model file not found. Please train a model first or provide a valid path." |
|
|
except Exception as e: |
|
|
return f"Error loading model: {str(e)}" |
|
|
|
|
|
def load_pretrained_model(model_file): |
|
|
"""Load a pre-trained model from uploaded file""" |
|
|
global model |
|
|
|
|
|
if model_file is None: |
|
|
return "Please upload a model file (.pt)" |
|
|
|
|
|
try: |
|
|
|
|
|
temp_path = model_file |
|
|
|
|
|
|
|
|
model = YOLO(temp_path) |
|
|
|
|
|
|
|
|
try: |
|
|
if hasattr(model, 'names') and len(model.names) > 0: |
|
|
num_classes = len(model.names) |
|
|
class_names = ", ".join([f"{name}" for name in list(model.names.values())[:5]]) |
|
|
if len(model.names) > 5: |
|
|
class_names += f"... (์ด {num_classes} ํด๋์ค)" |
|
|
|
|
|
if num_classes == 80: |
|
|
return f"โ ๏ธ Loaded COCO model with {num_classes} classes. This is not trained for X-ray detection.\nClasses: {class_names}" |
|
|
else: |
|
|
return f"โ
Model loaded successfully!\nClasses ({num_classes}): {class_names}" |
|
|
else: |
|
|
return "โ
Model loaded successfully!" |
|
|
except: |
|
|
return "โ
Model loaded successfully!" |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error loading model: {str(e)}" |
|
|
|
|
|
def check_model_status(): |
|
|
"""Check current model status""" |
|
|
global model |
|
|
if model is None: |
|
|
|
|
|
trained_path = "./xray_detection/train/weights/best.pt" |
|
|
if os.path.exists(trained_path): |
|
|
try: |
|
|
model = YOLO(trained_path) |
|
|
num_classes = len(model.names) |
|
|
class_names = ', '.join(list(model.names.values())) |
|
|
return f"โ
Trained model loaded: {num_classes} classes\n๐ Classes: {class_names}" |
|
|
except: |
|
|
return "โ No model loaded. Please train or load a model first." |
|
|
return "โ No model loaded. Please train or load a model first." |
|
|
else: |
|
|
try: |
|
|
num_classes = len(model.names) |
|
|
class_names = ', '.join(list(model.names.values())) |
|
|
|
|
|
if num_classes == 80: |
|
|
return f"โ ๏ธ Default COCO model loaded ({num_classes} classes). For X-ray detection, please train on the X-ray dataset." |
|
|
else: |
|
|
return f"โ
Model loaded: {num_classes} classes\n๐ Classes: {class_names}" |
|
|
except: |
|
|
return "โ
Model loaded" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="X-ray Baggage Anomaly Detection", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(""" |
|
|
# ๐ฏ X-ray Baggage Anomaly Detection with YOLO |
|
|
|
|
|
This application allows you to: |
|
|
1. Download and visualize the X-ray baggage dataset |
|
|
2. Analyze class distributions |
|
|
3. Train a YOLO model for object detection |
|
|
4. Run inference on new images |
|
|
|
|
|
**Note:** GPU will be automatically allocated when needed for training and inference. |
|
|
""") |
|
|
|
|
|
|
|
|
initial_model_status = "๐ Checking for existing models..." |
|
|
if os.path.exists("./xray_detection/train/weights/best.pt"): |
|
|
try: |
|
|
model = YOLO("./xray_detection/train/weights/best.pt") |
|
|
initial_model_status = "โ
Found previously trained model! Ready to use." |
|
|
except: |
|
|
initial_model_status = "โ No model loaded. Please train or upload a model." |
|
|
else: |
|
|
initial_model_status = "โ No model loaded. Please train or upload a model." |
|
|
|
|
|
gr.Markdown(f"**Model Status:** {initial_model_status}") |
|
|
|
|
|
|
|
|
with gr.Accordion("๐ Setup Instructions", open=False): |
|
|
gr.Markdown(""" |
|
|
### Kaggle API Setup |
|
|
1. Get your Kaggle API credentials from https://www.kaggle.com/settings |
|
|
2. Set the KDATA_API environment variable in Hugging Face Spaces settings: |
|
|
``` |
|
|
KDATA_API={"username":"your_username","key":"your_api_key"} |
|
|
``` |
|
|
|
|
|
### Model Persistence on Hugging Face Spaces |
|
|
- Models trained on Spaces are **temporary** and will be lost when the Space restarts |
|
|
- After training, download your model using the "๐ฅ Download Model" button |
|
|
- Upload the downloaded model file to reuse it after Space restarts |
|
|
- No need for HuggingFace Hub or complex setups! |
|
|
""") |
|
|
|
|
|
with gr.Tab("๐ Dataset"): |
|
|
with gr.Row(): |
|
|
download_btn = gr.Button("Download Dataset", variant="primary", scale=1) |
|
|
download_status = gr.Textbox(label="Status", interactive=False, scale=3) |
|
|
|
|
|
download_btn.click(download_dataset, outputs=download_status) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
dataset_info = gr.Markdown(value="Dataset not downloaded yet.") |
|
|
info_btn = gr.Button("๐ Refresh Dataset Info", scale=0) |
|
|
|
|
|
def update_dataset_info(): |
|
|
return get_dataset_info() |
|
|
|
|
|
info_btn.click(update_dataset_info, outputs=dataset_info) |
|
|
|
|
|
gr.Markdown("### Visualize Dataset Samples") |
|
|
with gr.Row(): |
|
|
data_type_viz = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type") |
|
|
num_samples = gr.Slider(1, 8, 4, step=1, label="Number of Samples") |
|
|
viz_btn = gr.Button("Visualize Samples") |
|
|
|
|
|
viz_gallery = gr.Gallery(label="Sample Images", columns=2, height="auto") |
|
|
viz_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
viz_btn.click(visualize_data, inputs=[data_type_viz, num_samples], |
|
|
outputs=[viz_gallery, viz_status]) |
|
|
|
|
|
gr.Markdown("### Analyze Class Distribution") |
|
|
with gr.Row(): |
|
|
data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type") |
|
|
analyze_btn = gr.Button("Analyze Distribution") |
|
|
|
|
|
distribution_plot = gr.Image(label="Class Distribution", type="pil") |
|
|
analysis_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis, |
|
|
outputs=[distribution_plot, analysis_status]) |
|
|
|
|
|
gr.Markdown("### Visualize Dataset Samples") |
|
|
with gr.Row(): |
|
|
data_type_viz = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type") |
|
|
num_samples = gr.Slider(1, 8, 4, step=1, label="Number of Samples") |
|
|
viz_btn = gr.Button("Visualize Samples") |
|
|
|
|
|
viz_gallery = gr.Gallery(label="Sample Images", columns=2, height="auto") |
|
|
viz_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
viz_btn.click(visualize_data, inputs=[data_type_viz, num_samples], |
|
|
outputs=[viz_gallery, viz_status]) |
|
|
|
|
|
gr.Markdown("### Analyze Class Distribution") |
|
|
with gr.Row(): |
|
|
data_type_analysis = gr.Dropdown(["train", "valid", "test"], value="train", label="Dataset Type") |
|
|
analyze_btn = gr.Button("Analyze Distribution") |
|
|
|
|
|
distribution_plot = gr.Image(label="Class Distribution", type="pil") |
|
|
analysis_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
analyze_btn.click(analyze_class_distribution, inputs=data_type_analysis, |
|
|
outputs=[distribution_plot, analysis_status]) |
|
|
|
|
|
with gr.Tab("๐ Training"): |
|
|
gr.Markdown("### Train YOLO Model") |
|
|
gr.Markdown(""" |
|
|
**Note:** Training will automatically use GPU if available. This may take several minutes. |
|
|
|
|
|
**Recommended Settings for X-ray Detection:** |
|
|
- **Epochs:** 20-30 for good results |
|
|
- **Batch Size:** 2-4 for better convergence |
|
|
- **Image Size:** 640 for best quality |
|
|
- **Expected time:** ~2-5 minutes for 20 epochs |
|
|
|
|
|
โ ๏ธ **Important**: Models are temporary on Spaces! Download your model after training. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
epochs_input = gr.Slider(1, 50, 20, step=1, label="Epochs (20+ recommended)") |
|
|
batch_size_input = gr.Slider(2, 16, 4, step=2, label="Batch Size (lower for better results)") |
|
|
img_size_input = gr.Slider(320, 640, 640, step=32, label="Image Size (640 recommended)") |
|
|
device_input = gr.Radio(["Auto", "GPU", "CPU"], value="Auto", label="Device") |
|
|
|
|
|
train_btn = gr.Button("Start Training", variant="primary") |
|
|
|
|
|
training_gallery = gr.Gallery(label="Training Results", columns=3, height="auto") |
|
|
training_status = gr.Textbox(label="Training Status", interactive=False) |
|
|
|
|
|
train_btn.click(train_model, |
|
|
inputs=[epochs_input, batch_size_input, img_size_input, device_input], |
|
|
outputs=[training_gallery, training_status]) |
|
|
|
|
|
gr.Markdown("### ๐ฅ Model Management") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### 1๏ธโฃ Download Trained Model") |
|
|
gr.Markdown("After training, download your model to save it permanently.") |
|
|
|
|
|
|
|
|
def prepare_model_download(): |
|
|
model_path = "./xray_detection/train/weights/best.pt" |
|
|
if os.path.exists(model_path): |
|
|
return gr.update(value=model_path, visible=True), "โ
Model ready for download!" |
|
|
else: |
|
|
return gr.update(value=None, visible=False), "โ No trained model found. Please train a model first." |
|
|
|
|
|
download_btn = gr.Button("๐ฅ Download Model (.pt)", variant="secondary") |
|
|
download_file = gr.File(label="Download Model File", visible=False) |
|
|
download_status = gr.Textbox(label="Download Status", interactive=False) |
|
|
|
|
|
download_btn.click(prepare_model_download, outputs=[download_file, download_status]) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("#### 2๏ธโฃ Upload & Load Model") |
|
|
gr.Markdown("Upload a previously trained model file to continue using it.") |
|
|
|
|
|
model_upload = gr.File( |
|
|
label="Upload Model File (.pt)", |
|
|
file_types=[".pt"], |
|
|
type="filepath" |
|
|
) |
|
|
load_btn = gr.Button("๐ค Load Uploaded Model", variant="secondary") |
|
|
load_status = gr.Textbox(label="Load Status", interactive=False) |
|
|
|
|
|
load_btn.click(load_pretrained_model, inputs=model_upload, outputs=load_status) |
|
|
|
|
|
|
|
|
model_upload.change(load_pretrained_model, inputs=model_upload, outputs=load_status) |
|
|
|
|
|
with gr.Tab("๐ Inference"): |
|
|
|
|
|
with gr.Row(): |
|
|
model_status = gr.Textbox(label="Model Status", value=check_model_status(), interactive=False) |
|
|
refresh_status_btn = gr.Button("๐ Refresh Status", scale=0) |
|
|
|
|
|
refresh_status_btn.click(check_model_status, outputs=model_status) |
|
|
|
|
|
gr.Markdown(""" |
|
|
## ๐ฏ ๋ชจ๋ธ์ด ๊ฐ์ฒด๋ฅผ ๊ฐ์งํ์ง ๋ชปํ๋์? |
|
|
|
|
|
**๊ถ์ฅ ํ์ต ์ค์ :** |
|
|
- **Epochs: 30** (์ต์ 20 ์ด์) |
|
|
- **Batch Size: 2 ๋๋ 4** |
|
|
- **Image Size: 640** |
|
|
|
|
|
**์ฒดํฌ๋ฆฌ์คํธ:** |
|
|
1. โ
X-ray ์ด๋ฏธ์ง์ธ๊ฐ? (์ผ๋ฐ ์ฌ์ง์ ์๋ ์ ํจ) |
|
|
2. โ
์ถฉ๋ถํ ํ์ตํ๋? (20+ epochs) |
|
|
3. โ
Confidence threshold๋ฅผ 0.01๋ก ๋ฎ์ถฐ๋ดค๋? |
|
|
4. โ
๋ชจ๋ธ์ด ์ ๋๋ก ๋ก๋๋์๋? (์ํ ํ์ธ) |
|
|
|
|
|
**์ฑ๊ณต์ ์ธ ํ์ต ํ ์์ ๊ฒฐ๊ณผ:** |
|
|
- Firearm (์ด๊ธฐ๋ฅ) ๊ฐ์ง |
|
|
- Knife (์นผ) ๊ฐ์ง |
|
|
- Pliers (ํ์น) ๊ฐ์ง |
|
|
- Scissors (๊ฐ์) ๊ฐ์ง |
|
|
- Wrench (๋ ์น) ๊ฐ์ง |
|
|
""") |
|
|
|
|
|
gr.Markdown("### Single Image Inference") |
|
|
gr.Markdown("Upload an X-ray baggage image to detect prohibited items.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image(type="pil", label="Upload X-ray Image") |
|
|
conf_threshold = gr.Slider(0.01, 0.9, 0.25, step=0.01, label="Confidence Threshold (๋ฎ์์๋ก ๋ ๋ง์ด ๊ฐ์ง)") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
inference_btn = gr.Button("Run Detection", variant="primary") |
|
|
test_btn = gr.Button("Test with 0.01 threshold", variant="secondary", scale=0) |
|
|
|
|
|
|
|
|
example_images = [] |
|
|
if dataset_path and os.path.exists(f"{dataset_path}/test/images"): |
|
|
test_images = glob(f"{dataset_path}/test/images/*")[:5] |
|
|
example_images.extend(test_images) |
|
|
|
|
|
if example_images: |
|
|
gr.Examples( |
|
|
examples=[[img] for img in example_images], |
|
|
inputs=input_image, |
|
|
label="Example X-ray Images (Click to load)" |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
output_image = gr.Image(type="pil", label="Detection Result") |
|
|
detection_info = gr.Textbox(label="Detection Info", lines=8) |
|
|
|
|
|
inference_btn.click(run_inference, |
|
|
inputs=[input_image, conf_threshold], |
|
|
outputs=[output_image, detection_info]) |
|
|
|
|
|
|
|
|
test_btn.click( |
|
|
lambda img: run_inference(img, 0.01), |
|
|
inputs=[input_image], |
|
|
outputs=[output_image, detection_info] |
|
|
) |
|
|
|
|
|
|
|
|
inference_btn.click(check_model_status, outputs=model_status) |
|
|
|
|
|
gr.Markdown("### Batch Inference") |
|
|
gr.Markdown("Run detection on multiple images from the test dataset.") |
|
|
|
|
|
with gr.Row(): |
|
|
batch_data_type = gr.Dropdown(["test", "valid"], value="test", label="Dataset Type") |
|
|
batch_num_images = gr.Slider(1, 10, 5, step=1, label="Number of Images") |
|
|
batch_btn = gr.Button("Run Batch Inference") |
|
|
|
|
|
batch_gallery = gr.Gallery(label="Batch Results", columns=3, height="auto") |
|
|
batch_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
batch_btn.click(batch_inference, |
|
|
inputs=[batch_data_type, batch_num_images], |
|
|
outputs=[batch_gallery, batch_status]) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown(""" |
|
|
<div style='text-align: center; font-size: 14px; color: #666;'> |
|
|
๐ก <b>Quick Start:</b> Download Dataset โ Train Model (20+ epochs) โ Run Inference<br> |
|
|
๐ <b>No detections?</b> Try lowering threshold to 0.01 or train for more epochs<br> |
|
|
๐ Built with Gradio, YOLOv8, and โค๏ธ for X-ray security |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
if ON_SPACES: |
|
|
demo.launch(ssr_mode=False) |
|
|
else: |
|
|
demo.launch(share=True, ssr_mode=False) |