import sys import os import threading import http.server import socketserver from pathlib import Path import shutil import yaml from huggingface_hub import snapshot_download from tqdm import tqdm from PIL import Image # ========================================================================================= # 0. HEALTH CHECK SERVER (CRITICAL FOR DOCKER SPACES) # ========================================================================================= # Hugging Face Spaces expects the application to listen on port 7860. # If it doesn't listen within 30 minutes, the Space is flagged as "Runtime Error". # We start a minimal HTTP server in the background to satisfy this requirement immediately. PORT = 7860 # Global variable to track status TRAINING_STATUS = "Initializing..." def start_health_check_server(): class HealthCheckHandler(http.server.SimpleHTTPRequestHandler): def do_GET(self): self.send_response(200) self.send_header('Content-type', 'text/html') # Changed to HTML for auto-refresh self.end_headers() # Simple HTML with auto-refresh every 30 seconds html_content = f""" YOLOv8-MPEB Training Status

YOLOv8-MPEB Training Monitor

{TRAINING_STATUS}

Autorefreshing every 30 seconds...

""" self.wfile.write(html_content.encode()) def log_message(self, format, *args): pass try: with socketserver.TCPServer(("", PORT), HealthCheckHandler) as httpd: print(f"Health check server serving at port {PORT}") httpd.serve_forever() except Exception as e: print(f"Failed to start health check server: {e}") # Start the server in a daemon thread server_thread = threading.Thread(target=start_health_check_server, daemon=True) server_thread.start() # ========================================================================================= # 1. SETUP & CONFIGURATION # ========================================================================================= print("Starting App for YOLOv8-MPEB Training on CPU...") # Define paths CURRENT_DIR = Path(os.getcwd()) DATASET_REPO = "jeyanthangj2004/Visdrone-raw" DATASET_DIR = CURRENT_DIR / "visdrone_dataset" DATA_YAML_PATH = CURRENT_DIR / "data.yaml" # ========================================================================================= # 2. DOWNLOAD DATASET # ========================================================================================= print(f"Downloading dataset from {DATASET_REPO}...") try: snapshot_download(repo_id=DATASET_REPO, repo_type="dataset", local_dir=DATASET_DIR) print("Dataset download complete.") except Exception as e: print(f"Error downloading dataset: {e}") sys.exit(1) # ========================================================================================= # 3. DATASET CONVERSION # ========================================================================================= def visdrone2yolo(dir_path, split): """Convert VisDrone annotations to YOLO format.""" print(f"Checking/Converting {split} data in {dir_path}...") found_split_dir = None target_folder_name = f"VisDrone2019-DET-{split}" if (dir_path / target_folder_name).exists(): found_split_dir = dir_path / target_folder_name else: for p in dir_path.rglob(target_folder_name): if p.is_dir(): found_split_dir = p break if not found_split_dir: print(f"Warning: Could not find directory for split '{split}' ({target_folder_name}). Skipping.") return source_dir = found_split_dir images_dest_dir = dir_path / "images" / split labels_dest_dir = dir_path / "labels" / split if labels_dest_dir.exists() and any(labels_dest_dir.iterdir()): print(f"Labels for {split} seem to exist. Skipping conversion.") return labels_dest_dir.mkdir(parents=True, exist_ok=True) images_dest_dir.mkdir(parents=True, exist_ok=True) source_images_dir = source_dir / "images" if source_images_dir.exists(): print(f"Moving images from {source_images_dir} to {images_dest_dir}...") for img in source_images_dir.glob("*.jpg"): shutil.move(str(img), str(images_dest_dir / img.name)) source_annotations_dir = source_dir / "annotations" if source_annotations_dir.exists(): print(f"Converting annotations from {source_annotations_dir}...") for f in tqdm(list(source_annotations_dir.glob("*.txt")), desc=f"Converting {split}"): try: img_name = f.with_suffix(".jpg").name img_path = images_dest_dir / img_name if not img_path.exists(): continue img_size = Image.open(img_path).size dw, dh = 1.0 / img_size[0], 1.0 / img_size[1] lines = [] with open(f, encoding="utf-8") as file: for line in file: row = line.strip().split(",") if not row or len(row) < 6: continue if row[4] != "0": # Skip ignored regions x, y, w, h = map(int, row[:4]) cls = int(row[5]) - 1 if 0 <= cls <= 9: x_center, y_center = (x + w / 2) * dw, (y + h / 2) * dh w_norm, h_norm = w * dw, h * dh lines.append(f"{cls} {x_center:.6f} {y_center:.6f} {w_norm:.6f} {h_norm:.6f}\n") (labels_dest_dir / f.name).write_text("".join(lines), encoding="utf-8") except Exception as e: print(f"Error converting {f.name}: {e}") # Process datasets visdrone2yolo(DATASET_DIR, "train") visdrone2yolo(DATASET_DIR, "val") visdrone2yolo(DATASET_DIR, "test-dev") # ========================================================================================= # 4. CREATE DATA.YAML # ========================================================================================= data_yaml_content = { 'path': str(DATASET_DIR.absolute()), 'train': 'images/train', 'val': 'images/val', 'test': 'images/test-dev', 'names': {0: 'pedestrian', 1: 'people', 2: 'bicycle', 3: 'car', 4: 'van', 5: 'truck', 6: 'tricycle', 7: 'awning-tricycle', 8: 'bus', 9: 'motor'} } with open(DATA_YAML_PATH, 'w') as f: yaml.dump(data_yaml_content, f) print(f"Created data.yaml at {DATA_YAML_PATH}") # ========================================================================================= # 5. PATCH & LOAD MODEL # ========================================================================================= sys.path.insert(0, str(CURRENT_DIR)) try: from yolov8_mpeb_modules import MobileNetBlock, EMA, C2f_EMA, BiFPN_Fusion import ultralytics.nn.modules as modules import ultralytics.nn.modules.block as block import ultralytics.nn.tasks as tasks print("Patching Ultralytics modules...") block.GhostBottleneck = MobileNetBlock modules.GhostBottleneck = MobileNetBlock block.C3 = C2f_EMA modules.C3 = C2f_EMA if hasattr(tasks, 'GhostBottleneck'): tasks.GhostBottleneck = MobileNetBlock if hasattr(tasks, 'C3'): tasks.C3 = C2f_EMA if hasattr(tasks, 'block'): tasks.block.GhostBottleneck = MobileNetBlock tasks.block.C3 = C2f_EMA from ultralytics import YOLO except ImportError as e: print(f"Error importing modules: {e}") sys.exit(1) # ========================================================================================= # 6. TRAIN # ========================================================================================= print("Initializing Model...") model_yaml = CURRENT_DIR / "yolov8_mpeb.yaml" if not model_yaml.exists(): print(f"Error: {model_yaml} not found.") sys.exit(1) model = YOLO(str(model_yaml)) # Update Status TRAINING_STATUS = "Starting Training... (Check Logs for details)" print("Starting Training...") # Define callbacks to update status def on_train_epoch_end(trainer): global TRAINING_STATUS current_epoch = trainer.epoch + 1 total_epochs = trainer.epochs try: metrics = trainer.metrics map50 = metrics.get("metrics/mAP50(B)", 0) TRAINING_STATUS = f"Training in progress...
Epoch: {current_epoch}/{total_epochs}
mAP50: {map50:.4f}" except: TRAINING_STATUS = f"Training in progress...
Epoch: {current_epoch}/{total_epochs}" def on_train_batch_end(trainer): global TRAINING_STATUS # Update status every 5 batches to avoid over-refreshing if hasattr(trainer, 'pbar') and trainer.pbar and trainer.pbar.n % 5 == 0: current_epoch = trainer.epoch + 1 total_epochs = trainer.epochs current_batch = trainer.pbar.n total_batch = trainer.pbar.total or "?" TRAINING_STATUS = f"Training Running...
Epoch: {current_epoch}/{total_epochs}
Batch: {current_batch}/{total_batch}
(CPU training is slow, please wait)" model.add_callback("on_train_epoch_end", on_train_epoch_end) model.add_callback("on_train_batch_end", on_train_batch_end) # Train 200 epochs, CPU only results = model.train( data=str(DATA_YAML_PATH), epochs=200, device='cpu', project='runs/train', name='visdrone_mpeb', batch=4, workers=1, exist_ok=True ) # ========================================================================================= # 7. FINALIZE # ========================================================================================= TRAINING_STATUS = "Training Complete! Finalizing..." print("Training Complete.") best_weight_path = Path("runs/train/visdrone_mpeb/weights/best.pt") destination_path = CURRENT_DIR / "best.pt" if best_weight_path.exists(): shutil.copy(best_weight_path, destination_path) print(f"Successfully saved best.pt to {destination_path}") TRAINING_STATUS = f"Training Successfully Completed!
Model saved to: {destination_path}
You can now download the file or stop the space." else: print("Warning: best.pt not found in runs directory.") TRAINING_STATUS = "Training Finished, but best.pt was not found. Check logs." print("Exiting...")