Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| import threading | |
| import http.server | |
| import socketserver | |
| from pathlib import Path | |
| import shutil | |
| import yaml | |
| from huggingface_hub import snapshot_download | |
| from tqdm import tqdm | |
| from PIL import Image | |
| # ========================================================================================= | |
| # 0. HEALTH CHECK SERVER (CRITICAL FOR DOCKER SPACES) | |
| # ========================================================================================= | |
| # Hugging Face Spaces expects the application to listen on port 7860. | |
| # If it doesn't listen within 30 minutes, the Space is flagged as "Runtime Error". | |
| # We start a minimal HTTP server in the background to satisfy this requirement immediately. | |
| PORT = 7860 | |
| # Global variable to track status | |
| TRAINING_STATUS = "Initializing..." | |
| def start_health_check_server(): | |
| class HealthCheckHandler(http.server.SimpleHTTPRequestHandler): | |
| def do_GET(self): | |
| self.send_response(200) | |
| self.send_header('Content-type', 'text/html') # Changed to HTML for auto-refresh | |
| self.end_headers() | |
| # Simple HTML with auto-refresh every 30 seconds | |
| html_content = f""" | |
| <html> | |
| <head> | |
| <meta http-equiv="refresh" content="30"> | |
| <title>YOLOv8-MPEB Training Status</title> | |
| <style> | |
| body {{ font-family: sans-serif; padding: 50px; text-align: center; }} | |
| h1 {{ color: #333; }} | |
| .status {{ font-size: 24px; padding: 20px; border: 1px solid #ddd; border-radius: 8px; background: #f9f9f9; display: inline-block; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>YOLOv8-MPEB Training Monitor</h1> | |
| <div class="status">{TRAINING_STATUS}</div> | |
| <p><em>Autorefreshing every 30 seconds...</em></p> | |
| </body> | |
| </html> | |
| """ | |
| self.wfile.write(html_content.encode()) | |
| def log_message(self, format, *args): | |
| pass | |
| try: | |
| with socketserver.TCPServer(("", PORT), HealthCheckHandler) as httpd: | |
| print(f"Health check server serving at port {PORT}") | |
| httpd.serve_forever() | |
| except Exception as e: | |
| print(f"Failed to start health check server: {e}") | |
| # Start the server in a daemon thread | |
| server_thread = threading.Thread(target=start_health_check_server, daemon=True) | |
| server_thread.start() | |
| # ========================================================================================= | |
| # 1. SETUP & CONFIGURATION | |
| # ========================================================================================= | |
| print("Starting App for YOLOv8-MPEB Training on CPU...") | |
| # Define paths | |
| CURRENT_DIR = Path(os.getcwd()) | |
| DATASET_REPO = "jeyanthangj2004/Visdrone-raw" | |
| DATASET_DIR = CURRENT_DIR / "visdrone_dataset" | |
| DATA_YAML_PATH = CURRENT_DIR / "data.yaml" | |
| # ========================================================================================= | |
| # 2. DOWNLOAD DATASET | |
| # ========================================================================================= | |
| print(f"Downloading dataset from {DATASET_REPO}...") | |
| try: | |
| snapshot_download(repo_id=DATASET_REPO, repo_type="dataset", local_dir=DATASET_DIR) | |
| print("Dataset download complete.") | |
| except Exception as e: | |
| print(f"Error downloading dataset: {e}") | |
| sys.exit(1) | |
| # ========================================================================================= | |
| # 3. DATASET CONVERSION | |
| # ========================================================================================= | |
| def visdrone2yolo(dir_path, split): | |
| """Convert VisDrone annotations to YOLO format.""" | |
| print(f"Checking/Converting {split} data in {dir_path}...") | |
| found_split_dir = None | |
| target_folder_name = f"VisDrone2019-DET-{split}" | |
| if (dir_path / target_folder_name).exists(): | |
| found_split_dir = dir_path / target_folder_name | |
| else: | |
| for p in dir_path.rglob(target_folder_name): | |
| if p.is_dir(): | |
| found_split_dir = p | |
| break | |
| if not found_split_dir: | |
| print(f"Warning: Could not find directory for split '{split}' ({target_folder_name}). Skipping.") | |
| return | |
| source_dir = found_split_dir | |
| images_dest_dir = dir_path / "images" / split | |
| labels_dest_dir = dir_path / "labels" / split | |
| if labels_dest_dir.exists() and any(labels_dest_dir.iterdir()): | |
| print(f"Labels for {split} seem to exist. Skipping conversion.") | |
| return | |
| labels_dest_dir.mkdir(parents=True, exist_ok=True) | |
| images_dest_dir.mkdir(parents=True, exist_ok=True) | |
| source_images_dir = source_dir / "images" | |
| if source_images_dir.exists(): | |
| print(f"Moving images from {source_images_dir} to {images_dest_dir}...") | |
| for img in source_images_dir.glob("*.jpg"): | |
| shutil.move(str(img), str(images_dest_dir / img.name)) | |
| source_annotations_dir = source_dir / "annotations" | |
| if source_annotations_dir.exists(): | |
| print(f"Converting annotations from {source_annotations_dir}...") | |
| for f in tqdm(list(source_annotations_dir.glob("*.txt")), desc=f"Converting {split}"): | |
| try: | |
| img_name = f.with_suffix(".jpg").name | |
| img_path = images_dest_dir / img_name | |
| if not img_path.exists(): | |
| continue | |
| img_size = Image.open(img_path).size | |
| dw, dh = 1.0 / img_size[0], 1.0 / img_size[1] | |
| lines = [] | |
| with open(f, encoding="utf-8") as file: | |
| for line in file: | |
| row = line.strip().split(",") | |
| if not row or len(row) < 6: continue | |
| if row[4] != "0": # Skip ignored regions | |
| x, y, w, h = map(int, row[:4]) | |
| cls = int(row[5]) - 1 | |
| if 0 <= cls <= 9: | |
| x_center, y_center = (x + w / 2) * dw, (y + h / 2) * dh | |
| w_norm, h_norm = w * dw, h * dh | |
| lines.append(f"{cls} {x_center:.6f} {y_center:.6f} {w_norm:.6f} {h_norm:.6f}\n") | |
| (labels_dest_dir / f.name).write_text("".join(lines), encoding="utf-8") | |
| except Exception as e: | |
| print(f"Error converting {f.name}: {e}") | |
| # Process datasets | |
| visdrone2yolo(DATASET_DIR, "train") | |
| visdrone2yolo(DATASET_DIR, "val") | |
| visdrone2yolo(DATASET_DIR, "test-dev") | |
| # ========================================================================================= | |
| # 4. CREATE DATA.YAML | |
| # ========================================================================================= | |
| data_yaml_content = { | |
| 'path': str(DATASET_DIR.absolute()), | |
| 'train': 'images/train', | |
| 'val': 'images/val', | |
| 'test': 'images/test-dev', | |
| 'names': {0: 'pedestrian', 1: 'people', 2: 'bicycle', 3: 'car', 4: 'van', 5: 'truck', 6: 'tricycle', 7: 'awning-tricycle', 8: 'bus', 9: 'motor'} | |
| } | |
| with open(DATA_YAML_PATH, 'w') as f: | |
| yaml.dump(data_yaml_content, f) | |
| print(f"Created data.yaml at {DATA_YAML_PATH}") | |
| # ========================================================================================= | |
| # 5. PATCH & LOAD MODEL | |
| # ========================================================================================= | |
| sys.path.insert(0, str(CURRENT_DIR)) | |
| try: | |
| from yolov8_mpeb_modules import MobileNetBlock, EMA, C2f_EMA, BiFPN_Fusion | |
| import ultralytics.nn.modules as modules | |
| import ultralytics.nn.modules.block as block | |
| import ultralytics.nn.tasks as tasks | |
| print("Patching Ultralytics modules...") | |
| block.GhostBottleneck = MobileNetBlock | |
| modules.GhostBottleneck = MobileNetBlock | |
| block.C3 = C2f_EMA | |
| modules.C3 = C2f_EMA | |
| if hasattr(tasks, 'GhostBottleneck'): tasks.GhostBottleneck = MobileNetBlock | |
| if hasattr(tasks, 'C3'): tasks.C3 = C2f_EMA | |
| if hasattr(tasks, 'block'): | |
| tasks.block.GhostBottleneck = MobileNetBlock | |
| tasks.block.C3 = C2f_EMA | |
| from ultralytics import YOLO | |
| except ImportError as e: | |
| print(f"Error importing modules: {e}") | |
| sys.exit(1) | |
| # ========================================================================================= | |
| # 6. TRAIN | |
| # ========================================================================================= | |
| print("Initializing Model...") | |
| model_yaml = CURRENT_DIR / "yolov8_mpeb.yaml" | |
| if not model_yaml.exists(): | |
| print(f"Error: {model_yaml} not found.") | |
| sys.exit(1) | |
| model = YOLO(str(model_yaml)) | |
| # Update Status | |
| TRAINING_STATUS = "Starting Training... (Check Logs for details)" | |
| print("Starting Training...") | |
| # Define callbacks to update status | |
| def on_train_epoch_end(trainer): | |
| global TRAINING_STATUS | |
| current_epoch = trainer.epoch + 1 | |
| total_epochs = trainer.epochs | |
| try: | |
| metrics = trainer.metrics | |
| map50 = metrics.get("metrics/mAP50(B)", 0) | |
| TRAINING_STATUS = f"Training in progress...<br>Epoch: {current_epoch}/{total_epochs}<br>mAP50: {map50:.4f}" | |
| except: | |
| TRAINING_STATUS = f"Training in progress...<br>Epoch: {current_epoch}/{total_epochs}" | |
| def on_train_batch_end(trainer): | |
| global TRAINING_STATUS | |
| # Update status every 5 batches to avoid over-refreshing | |
| if hasattr(trainer, 'pbar') and trainer.pbar and trainer.pbar.n % 5 == 0: | |
| current_epoch = trainer.epoch + 1 | |
| total_epochs = trainer.epochs | |
| current_batch = trainer.pbar.n | |
| total_batch = trainer.pbar.total or "?" | |
| TRAINING_STATUS = f"<b>Training Running...</b><br>Epoch: {current_epoch}/{total_epochs}<br>Batch: {current_batch}/{total_batch}<br><i>(CPU training is slow, please wait)</i>" | |
| model.add_callback("on_train_epoch_end", on_train_epoch_end) | |
| model.add_callback("on_train_batch_end", on_train_batch_end) | |
| # Train 200 epochs, CPU only | |
| results = model.train( | |
| data=str(DATA_YAML_PATH), | |
| epochs=200, | |
| device='cpu', | |
| project='runs/train', | |
| name='visdrone_mpeb', | |
| batch=4, | |
| workers=1, | |
| exist_ok=True | |
| ) | |
| # ========================================================================================= | |
| # 7. FINALIZE | |
| # ========================================================================================= | |
| TRAINING_STATUS = "Training Complete! Finalizing..." | |
| print("Training Complete.") | |
| best_weight_path = Path("runs/train/visdrone_mpeb/weights/best.pt") | |
| destination_path = CURRENT_DIR / "best.pt" | |
| if best_weight_path.exists(): | |
| shutil.copy(best_weight_path, destination_path) | |
| print(f"Successfully saved best.pt to {destination_path}") | |
| TRAINING_STATUS = f"Training Successfully Completed!<br>Model saved to: {destination_path}<br>You can now download the file or stop the space." | |
| else: | |
| print("Warning: best.pt not found in runs directory.") | |
| TRAINING_STATUS = "Training Finished, but best.pt was not found. Check logs." | |
| print("Exiting...") | |