Spaces:
Sleeping
Sleeping
| """ | |
| π― YOLO Trainer & Detector | |
| Train YOLOv8 on a custom dataset and run inference β all from a Gradio UI. | |
| """ | |
| import os | |
| import io | |
| import time | |
| import queue | |
| import threading | |
| import zipfile | |
| import yaml | |
| import cv2 | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from PIL import Image as PILImage | |
| from pathlib import Path | |
| import gradio as gr | |
| from ultralytics import YOLO | |
| from huggingface_hub import hf_hub_download | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Constants | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DATASET_REPO = "dharshanzeb/yolo-detection-dataset" | |
| DATASET_DIR = "/tmp/yolo_dataset" | |
| RUNS_DIR = "/tmp/yolo_runs" | |
| BEST_MODEL_PATH = os.path.join(RUNS_DIR, "gradio_train", "weights", "best.pt") | |
| LAST_MODEL_PATH = os.path.join(RUNS_DIR, "gradio_train", "weights", "last.pt") | |
| CLASS_NAMES = ["car", "person", "dog", "cat", "bicycle"] | |
| # Global state | |
| trained_model_path = None | |
| is_training = False | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Dataset download & preparation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def download_dataset(): | |
| """Download YOLO dataset from HF Hub and prepare data.yaml.""" | |
| if os.path.exists(os.path.join(DATASET_DIR, "images", "train")): | |
| yaml_path = os.path.join(DATASET_DIR, "data.yaml") | |
| if os.path.exists(yaml_path): | |
| return yaml_path, "β Dataset already downloaded." | |
| log = "π₯ Downloading dataset from HF Hub...\n" | |
| os.makedirs(DATASET_DIR, exist_ok=True) | |
| for split in ["train", "val", "test"]: | |
| log += f" Downloading {split}.zip...\n" | |
| zip_path = hf_hub_download( | |
| repo_id=DATASET_REPO, | |
| filename=f"yolo_format/{split}.zip", | |
| repo_type="dataset", | |
| ) | |
| with zipfile.ZipFile(zip_path) as zf: | |
| zf.extractall(DATASET_DIR) | |
| # Download and patch data.yaml | |
| yaml_remote = hf_hub_download( | |
| repo_id=DATASET_REPO, | |
| filename="yolo_format/data.yaml", | |
| repo_type="dataset", | |
| ) | |
| with open(yaml_remote) as f: | |
| cfg = yaml.safe_load(f) | |
| cfg["path"] = DATASET_DIR | |
| local_yaml = os.path.join(DATASET_DIR, "data.yaml") | |
| with open(local_yaml, "w") as f: | |
| yaml.dump(cfg, f) | |
| # Count images | |
| n_train = len(list(Path(DATASET_DIR, "images", "train").glob("*.jpg"))) | |
| n_val = len(list(Path(DATASET_DIR, "images", "val").glob("*.jpg"))) | |
| n_test = len(list(Path(DATASET_DIR, "images", "test").glob("*.jpg"))) | |
| log += f"\nβ Dataset ready!\n" | |
| log += f" Train: {n_train} images\n" | |
| log += f" Val: {n_val} images\n" | |
| log += f" Test: {n_test} images\n" | |
| log += f" Classes: {CLASS_NAMES}\n" | |
| return local_yaml, log | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Metrics chart | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_metrics_chart(history): | |
| """Create a loss + mAP chart from training history.""" | |
| if not history: | |
| return None | |
| epochs = [h["epoch"] for h in history] | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| fig.patch.set_facecolor("#1a1a2e") | |
| # ββ Loss subplot ββ | |
| ax1 = axes[0] | |
| ax1.set_facecolor("#16213e") | |
| loss_keys = list(history[0].get("loss", {}).keys()) | |
| colors_loss = ["#e94560", "#f5a623", "#50fa7b"] | |
| for i, k in enumerate(loss_keys): | |
| vals = [h["loss"].get(k, 0) for h in history] | |
| label = k.split("/")[-1] if "/" in k else k | |
| color = colors_loss[i % len(colors_loss)] | |
| ax1.plot(epochs, vals, marker="o", markersize=4, label=label, | |
| color=color, linewidth=2) | |
| ax1.set_title("Training Loss", color="white", fontsize=14, fontweight="bold") | |
| ax1.set_xlabel("Epoch", color="white") | |
| ax1.set_ylabel("Loss", color="white") | |
| ax1.legend(facecolor="#16213e", edgecolor="white", labelcolor="white") | |
| ax1.tick_params(colors="white") | |
| ax1.grid(True, alpha=0.2, color="white") | |
| for spine in ax1.spines.values(): | |
| spine.set_color("white") | |
| # ββ mAP subplot ββ | |
| ax2 = axes[1] | |
| ax2.set_facecolor("#16213e") | |
| map_keys = [ | |
| ("metrics/mAP50(B)", "mAP@50", "#00d2ff"), | |
| ("metrics/mAP50-95(B)", "mAP@50-95", "#7b2ff7"), | |
| ] | |
| for key, label, color in map_keys: | |
| vals = [h["metrics"].get(key, 0) for h in history] | |
| if any(v > 0 for v in vals): | |
| ax2.plot(epochs, vals, marker="s", markersize=4, label=label, | |
| color=color, linewidth=2) | |
| ax2.set_title("Validation mAP", color="white", fontsize=14, fontweight="bold") | |
| ax2.set_xlabel("Epoch", color="white") | |
| ax2.set_ylabel("mAP", color="white") | |
| ax2.set_ylim(0, 1) | |
| ax2.legend(facecolor="#16213e", edgecolor="white", labelcolor="white") | |
| ax2.tick_params(colors="white") | |
| ax2.grid(True, alpha=0.2, color="white") | |
| for spine in ax2.spines.values(): | |
| spine.set_color("white") | |
| plt.tight_layout(pad=2) | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png", dpi=120, bbox_inches="tight", | |
| facecolor=fig.get_facecolor()) | |
| buf.seek(0) | |
| chart = np.array(PILImage.open(buf).copy()) | |
| plt.close(fig) | |
| return chart | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Training function (generator for streaming logs) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train_yolo(model_size, epochs, batch_size, learning_rate, img_size): | |
| """ | |
| Generator function: trains YOLO and yields (log_text, metrics_chart, status) | |
| after each epoch. | |
| """ | |
| global trained_model_path, is_training | |
| if is_training: | |
| yield "β οΈ Training already in progress. Please wait.", None, "β οΈ Busy" | |
| return | |
| is_training = True | |
| log_queue_local = queue.Queue() | |
| history = [] | |
| accumulated_log = "" | |
| train_exception = [None] | |
| try: | |
| # Step 1: Download dataset | |
| yield "π₯ Preparing dataset...", None, "π₯ Downloading..." | |
| data_yaml, dl_log = download_dataset() | |
| accumulated_log += dl_log + "\n" | |
| yield accumulated_log, None, "π₯ Dataset ready" | |
| # Step 2: Load model | |
| model_variant = f"yolov8{model_size}.pt" | |
| accumulated_log += f"π Loading {model_variant}...\n" | |
| yield accumulated_log, None, f"π Loading {model_variant}" | |
| model = YOLO(model_variant) | |
| # Step 3: Attach callback | |
| def on_fit_epoch_end(trainer): | |
| try: | |
| loss_dict = {} | |
| if trainer.tloss is not None: | |
| loss_dict = trainer.label_loss_items(trainer.tloss) | |
| entry = { | |
| "epoch": trainer.epoch + 1, | |
| "epochs": trainer.epochs, | |
| "metrics": dict(trainer.metrics) if trainer.metrics else {}, | |
| "fitness": float(trainer.fitness) if trainer.fitness else 0.0, | |
| "loss": loss_dict, | |
| } | |
| log_queue_local.put(entry) | |
| except Exception as e: | |
| log_queue_local.put({"error": str(e)}) | |
| model.add_callback("on_fit_epoch_end", on_fit_epoch_end) | |
| # Step 4: Run training in background thread | |
| def run_training(): | |
| try: | |
| device = 0 if __import__("torch").cuda.is_available() else "cpu" | |
| model.train( | |
| data=data_yaml, | |
| epochs=int(epochs), | |
| batch=int(batch_size), | |
| lr0=float(learning_rate), | |
| imgsz=int(img_size), | |
| device=device, | |
| workers=0, | |
| project=RUNS_DIR, | |
| name="gradio_train", | |
| exist_ok=True, | |
| pretrained=True, | |
| mosaic=1.0, | |
| mixup=0.1, | |
| patience=50, | |
| verbose=False, | |
| ) | |
| except Exception as e: | |
| train_exception[0] = e | |
| finally: | |
| log_queue_local.put(None) # sentinel | |
| accumulated_log += f"\nπ Starting training: {model_variant} | {int(epochs)} epochs | batch={int(batch_size)} | lr={learning_rate}\n" | |
| accumulated_log += f"{'β' * 60}\n" | |
| yield accumulated_log, None, "π Training started..." | |
| t = threading.Thread(target=run_training, daemon=True) | |
| t.start() | |
| # Step 5: Stream logs from queue | |
| while True: | |
| try: | |
| item = log_queue_local.get(timeout=120) | |
| except queue.Empty: | |
| accumulated_log += "β³ Waiting for update...\n" | |
| yield accumulated_log, make_metrics_chart(history), "β³ Waiting..." | |
| continue | |
| if item is None: | |
| break | |
| if "error" in item: | |
| accumulated_log += f"β οΈ Callback error: {item['error']}\n" | |
| yield accumulated_log, make_metrics_chart(history), "β οΈ Error" | |
| continue | |
| history.append(item) | |
| e, E = item["epoch"], item["epochs"] | |
| # Format loss | |
| loss_parts = [] | |
| for k, v in item["loss"].items(): | |
| name = k.split("/")[-1] if "/" in k else k | |
| loss_parts.append(f"{name}={v:.4f}") | |
| loss_str = " | ".join(loss_parts) if loss_parts else "N/A" | |
| # Format mAP | |
| map50 = item["metrics"].get("metrics/mAP50(B)", 0) | |
| map50_95 = item["metrics"].get("metrics/mAP50-95(B)", 0) | |
| line = f"π Epoch {e:>3d}/{E} | {loss_str} | mAP50={map50:.4f} | mAP50-95={map50_95:.4f}\n" | |
| accumulated_log += line | |
| chart = make_metrics_chart(history) | |
| status = f"ποΈ Epoch {e}/{E} | mAP50={map50:.4f}" | |
| yield accumulated_log, chart, status | |
| t.join(timeout=10) | |
| # Step 6: Check results | |
| if train_exception[0]: | |
| accumulated_log += f"\nβ Training error: {train_exception[0]}\n" | |
| yield accumulated_log, make_metrics_chart(history), "β Failed" | |
| return | |
| # Find best model | |
| if os.path.exists(BEST_MODEL_PATH): | |
| trained_model_path = BEST_MODEL_PATH | |
| elif os.path.exists(LAST_MODEL_PATH): | |
| trained_model_path = LAST_MODEL_PATH | |
| accumulated_log += f"\n{'β' * 60}\n" | |
| accumulated_log += f"π TRAINING COMPLETE!\n" | |
| accumulated_log += f"{'β' * 60}\n" | |
| if trained_model_path: | |
| accumulated_log += f"π Model saved: {trained_model_path}\n" | |
| accumulated_log += f"π Switch to the Inference tab to test your model!\n" | |
| else: | |
| accumulated_log += f"β οΈ No model file found after training.\n" | |
| if history: | |
| final_map = history[-1]["metrics"].get("metrics/mAP50(B)", 0) | |
| accumulated_log += f"\nπ Final mAP@50: {final_map:.4f}\n" | |
| chart = make_metrics_chart(history) | |
| yield accumulated_log, chart, "β Training complete!" | |
| except Exception as e: | |
| accumulated_log += f"\nβ Error: {str(e)}\n" | |
| yield accumulated_log, make_metrics_chart(history) if history else None, "β Error" | |
| finally: | |
| is_training = False | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Inference function | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_inference(image, conf_threshold, use_pretrained): | |
| """Run YOLO detection on an uploaded image.""" | |
| if image is None: | |
| gr.Warning("Please upload an image first!") | |
| return None, "β οΈ No image uploaded" | |
| # Select model | |
| if use_pretrained: | |
| model_path = "yolov8n.pt" | |
| model_label = "YOLOv8n (COCO pretrained)" | |
| else: | |
| if trained_model_path and os.path.exists(trained_model_path): | |
| model_path = trained_model_path | |
| model_label = f"Custom trained ({os.path.basename(trained_model_path)})" | |
| else: | |
| gr.Warning("No trained model found! Train first or use pretrained COCO model.") | |
| return None, "β οΈ No trained model. Train first or check 'Use Pretrained'." | |
| try: | |
| device = 0 if __import__("torch").cuda.is_available() else "cpu" | |
| model = YOLO(model_path) | |
| results = model.predict( | |
| source=image, | |
| conf=float(conf_threshold), | |
| iou=0.45, | |
| device=device, | |
| verbose=False, | |
| ) | |
| result = results[0] | |
| # Draw bounding boxes | |
| plotted_bgr = result.plot(conf=True, labels=True, line_width=2) | |
| plotted_rgb = cv2.cvtColor(plotted_bgr, cv2.COLOR_BGR2RGB) | |
| # Build detection summary | |
| n_detections = len(result.boxes) | |
| if n_detections == 0: | |
| summary = f"π **{model_label}**\n\nNo objects detected (conf > {conf_threshold})" | |
| else: | |
| lines = [f"π **{model_label}** β Found **{n_detections}** objects:\n"] | |
| lines.append("| # | Class | Confidence | Bbox (x1,y1,x2,y2) |") | |
| lines.append("|---|-------|-----------|---------------------|") | |
| for i, box in enumerate(result.boxes): | |
| cls_id = int(box.cls[0].item()) | |
| cls_name = result.names[cls_id] | |
| conf = box.conf[0].item() | |
| coords = [round(v, 1) for v in box.xyxy[0].tolist()] | |
| lines.append(f"| {i+1} | **{cls_name}** | {conf:.2f} | {coords} |") | |
| summary = "\n".join(lines) | |
| return plotted_rgb, summary | |
| except Exception as e: | |
| return None, f"β Error: {str(e)}" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Sample images from dataset for inference demo | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_sample_image(): | |
| """Load a random sample from the test set.""" | |
| test_dir = Path(DATASET_DIR) / "images" / "test" | |
| if not test_dir.exists(): | |
| # Download dataset first | |
| download_dataset() | |
| if test_dir.exists(): | |
| images = list(test_dir.glob("*.jpg")) | |
| if images: | |
| import random | |
| img_path = random.choice(images) | |
| return np.array(PILImage.open(img_path)) | |
| return None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Build Gradio UI | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| css = """ | |
| .gradio-container { max-width: 1100px !important; margin: 0 auto !important; } | |
| .train-log { font-family: 'Courier New', monospace !important; font-size: 13px !important; } | |
| """ | |
| with gr.Blocks(css=css, title="π― YOLO Trainer & Detector", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π― YOLO Trainer & Detector | |
| **Train** a YOLOv8 model on the [yolo-detection-dataset](https://huggingface.co/datasets/dharshanzeb/yolo-detection-dataset) | |
| and **run inference** on any image β all from this UI. | |
| """) | |
| with gr.Tabs(): | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 1: TRAINING | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("ποΈ Train"): | |
| with gr.Row(): | |
| # Left column: controls | |
| with gr.Column(scale=1, min_width=280): | |
| gr.Markdown("### βοΈ Training Configuration") | |
| model_size = gr.Dropdown( | |
| choices=[ | |
| ("YOLOv8 Nano (fastest)", "n"), | |
| ("YOLOv8 Small", "s"), | |
| ("YOLOv8 Medium", "m"), | |
| ], | |
| value="n", | |
| label="Model Size", | |
| ) | |
| epochs = gr.Slider( | |
| minimum=1, maximum=100, value=20, step=1, | |
| label="Epochs", | |
| ) | |
| batch_size = gr.Slider( | |
| minimum=4, maximum=64, value=16, step=4, | |
| label="Batch Size", | |
| ) | |
| learning_rate = gr.Slider( | |
| minimum=0.0001, maximum=0.1, value=0.01, step=0.0001, | |
| label="Learning Rate", | |
| ) | |
| img_size = gr.Dropdown( | |
| choices=[320, 416, 512, 640], | |
| value=640, | |
| label="Image Size", | |
| ) | |
| train_btn = gr.Button( | |
| "π Start Training", variant="primary", size="lg" | |
| ) | |
| train_status = gr.Markdown("*Ready to train*") | |
| gr.Markdown(""" | |
| --- | |
| ### π Dataset Info | |
| - **5 classes**: car, person, dog, cat, bicycle | |
| - **500** train / **100** val / **50** test images | |
| - **640Γ640** resolution | |
| - 15% hard negatives included | |
| """) | |
| # Right column: logs + chart | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π Training Progress") | |
| metrics_chart = gr.Image( | |
| label="Loss & mAP Curves", | |
| interactive=False, | |
| height=320, | |
| ) | |
| train_log = gr.Textbox( | |
| label="Training Log", | |
| lines=18, | |
| max_lines=30, | |
| interactive=False, | |
| autoscroll=True, | |
| elem_classes=["train-log"], | |
| ) | |
| # Wire training button | |
| train_btn.click( | |
| fn=train_yolo, | |
| inputs=[model_size, epochs, batch_size, learning_rate, img_size], | |
| outputs=[train_log, metrics_chart, train_status], | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 2: INFERENCE | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Detect"): | |
| gr.Markdown("### Upload an image to detect objects") | |
| with gr.Row(): | |
| # Left: input | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| label="π€ Upload Image", | |
| type="numpy", | |
| sources=["upload", "clipboard"], | |
| height=400, | |
| ) | |
| with gr.Row(): | |
| conf_threshold = gr.Slider( | |
| minimum=0.05, maximum=0.95, value=0.25, step=0.05, | |
| label="Confidence Threshold", | |
| ) | |
| with gr.Row(): | |
| use_pretrained = gr.Checkbox( | |
| value=False, | |
| label="Use Pretrained COCO Model (YOLOv8n)", | |
| info="Check this if you haven't trained yet", | |
| ) | |
| with gr.Row(): | |
| detect_btn = gr.Button( | |
| "π Detect Objects", variant="primary", size="lg" | |
| ) | |
| sample_btn = gr.Button( | |
| "π² Load Sample", variant="secondary", size="lg" | |
| ) | |
| # Right: output | |
| with gr.Column(scale=1): | |
| output_image = gr.Image( | |
| label="πΈ Detection Result", | |
| type="numpy", | |
| interactive=False, | |
| height=400, | |
| ) | |
| detection_summary = gr.Markdown("*Upload an image and click Detect*") | |
| # Wire inference | |
| detect_btn.click( | |
| fn=run_inference, | |
| inputs=[input_image, conf_threshold, use_pretrained], | |
| outputs=[output_image, detection_summary], | |
| ) | |
| sample_btn.click( | |
| fn=load_sample_image, | |
| outputs=[input_image], | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TAB 3: ABOUT | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("βΉοΈ About"): | |
| gr.Markdown(""" | |
| ## How It Works | |
| ### ποΈ Training | |
| 1. The dataset ([dharshanzeb/yolo-detection-dataset](https://huggingface.co/datasets/dharshanzeb/yolo-detection-dataset)) | |
| is auto-downloaded from the HF Hub | |
| 2. YOLOv8 is initialized with COCO-pretrained weights (transfer learning) | |
| 3. Training runs with your configured hyperparameters | |
| 4. Real-time metrics (loss + mAP) are displayed after each epoch | |
| 5. The best model (by mAP) is saved automatically | |
| ### π Inference | |
| - **Custom model**: Uses the model you just trained | |
| - **Pretrained COCO**: Uses YOLOv8n trained on 80 COCO classes (good for real photos) | |
| ### π― Dataset Classes | |
| | ID | Class | Description | | |
| |---|---|---| | |
| | 0 | Car | Red car shapes with windows & wheels | | |
| | 1 | Person | Blue stick figures | | |
| | 2 | Dog | Brown dog shapes | | |
| | 3 | Cat | Orange cat shapes with ears | | |
| | 4 | Bicycle | Green bicycles with wheels | | |
| ### π‘ Tips | |
| - **First time?** Start with `YOLOv8 Nano` + `20 epochs` β trains in ~5 min on GPU | |
| - **Better accuracy?** Try `YOLOv8 Small` + `50 epochs` + `lr=0.01` | |
| - **No GPU?** Training works on CPU too (just slower). Use pretrained COCO for instant inference. | |
| - **Low mAP?** Increase epochs or try a larger model size | |
| ### π Links | |
| - [Dataset](https://huggingface.co/datasets/dharshanzeb/yolo-detection-dataset) | |
| - [Ultralytics YOLOv8](https://docs.ultralytics.com/) | |
| - [Gradio](https://gradio.app/) | |
| """) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=5).launch() | |