import sys import os from pathlib import Path import shutil import yaml from huggingface_hub import snapshot_download from tqdm import tqdm from PIL import Image # ========================================================================================= # 1. SETUP & CONFIGURATION # ========================================================================================= print("Starting App for YOLOv8-MPEB Training on CPU...") # Define paths CURRENT_DIR = Path(os.getcwd()) DATASET_REPO = "jeyanthangj2004/Visdrone-raw" DATASET_DIR = CURRENT_DIR / "visdrone_dataset" DATA_YAML_PATH = CURRENT_DIR / "data.yaml" # ========================================================================================= # 2. DOWNLOAD DATASET # ========================================================================================= print(f"Downloading dataset from {DATASET_REPO}...") try: snapshot_download(repo_id=DATASET_REPO, repo_type="dataset", local_dir=DATASET_DIR) print("Dataset download complete.") except Exception as e: print(f"Error downloading dataset: {e}") sys.exit(1) # ========================================================================================= # 3. DATASET CONVERSION (If needed) # ========================================================================================= # Check if dataset is already in YOLO format (images/labels folders) or raw VisDrone format # Structure assumption based on user request: Visdrone-raw/VisDrone2019-DET-train/ # We will check and convert if we find the raw annotations. def visdrone2yolo(dir_path, split): """Convert VisDrone annotations to YOLO format.""" print(f"Checking/Converting {split} data in {dir_path}...") # Define source paths # Handle cases where folder might be named directly 'VisDrone2019-DET-train' or inside 'Visdrone' # The snapshot might create: ./visdrone_dataset/Visdrone/VisDrone2019-DET-train or similar # Search for the split folder recursively found_split_dir = None target_folder_name = f"VisDrone2019-DET-{split}" # First check explicitly in root logic if (dir_path / target_folder_name).exists(): found_split_dir = dir_path / target_folder_name else: # Recursive search for p in dir_path.rglob(target_folder_name): if p.is_dir(): found_split_dir = p break if not found_split_dir: print(f"Warning: Could not find directory for split '{split}' ({target_folder_name}). Skipping.") return source_dir = found_split_dir # Destination paths - strictly following YOLO structure images_dest_dir = dir_path / "images" / split labels_dest_dir = dir_path / "labels" / split # If labels already exist, assume done (unless force re-run, but for space we assume fresh or persist) if labels_dest_dir.exists() and any(labels_dest_dir.iterdir()): print(f"Labels for {split} seem to exist. Skipping conversion.") return labels_dest_dir.mkdir(parents=True, exist_ok=True) images_dest_dir.mkdir(parents=True, exist_ok=True) # Move/Copy images to new structure if not already there source_images_dir = source_dir / "images" if source_images_dir.exists(): print(f"Moving images from {source_images_dir} to {images_dest_dir}...") for img in source_images_dir.glob("*.jpg"): # We copy/move. Since we downloaded, we can move to save space. shutil.move(str(img), str(images_dest_dir / img.name)) # Process annotations source_annotations_dir = source_dir / "annotations" if source_annotations_dir.exists(): print(f"Converting annotations from {source_annotations_dir}...") for f in tqdm(list(source_annotations_dir.glob("*.txt")), desc=f"Converting {split}"): try: img_name = f.with_suffix(".jpg").name img_path = images_dest_dir / img_name if not img_path.exists(): continue img_size = Image.open(img_path).size dw, dh = 1.0 / img_size[0], 1.0 / img_size[1] lines = [] with open(f, encoding="utf-8") as file: for line in file: row = line.strip().split(",") if not row or len(row) < 6: continue if row[4] != "0": # Skip ignored regions x, y, w, h = map(int, row[:4]) cls = int(row[5]) - 1 # Clip cls to valid range 0-9 if needed, VisDrone usually 1-10 -> 0-9 if 0 <= cls <= 9: x_center, y_center = (x + w / 2) * dw, (y + h / 2) * dh w_norm, h_norm = w * dw, h * dh lines.append(f"{cls} {x_center:.6f} {y_center:.6f} {w_norm:.6f} {h_norm:.6f}\n") (labels_dest_dir / f.name).write_text("".join(lines), encoding="utf-8") except Exception as e: print(f"Error converting {f.name}: {e}") # Process datasets visdrone2yolo(DATASET_DIR, "train") visdrone2yolo(DATASET_DIR, "val") visdrone2yolo(DATASET_DIR, "test-dev") # Optional # ========================================================================================= # 4. CREATE DATA.YAML # ========================================================================================= data_yaml_content = { 'path': str(DATASET_DIR.absolute()), 'train': 'images/train', 'val': 'images/val', 'test': 'images/test-dev', 'names': { 0: 'pedestrian', 1: 'people', 2: 'bicycle', 3: 'car', 4: 'van', 5: 'truck', 6: 'tricycle', 7: 'awning-tricycle', 8: 'bus', 9: 'motor' } } with open(DATA_YAML_PATH, 'w') as f: yaml.dump(data_yaml_content, f) print(f"Created data.yaml at {DATA_YAML_PATH}") # ========================================================================================= # 5. PATCH & LOAD MODEL # ========================================================================================= # Ensure current directory is in python path sys.path.insert(0, str(CURRENT_DIR)) try: from yolov8_mpeb_modules import MobileNetBlock, EMA, C2f_EMA, BiFPN_Fusion import ultralytics.nn.modules as modules import ultralytics.nn.modules.block as block import ultralytics.nn.tasks as tasks print("Patching Ultralytics modules...") block.GhostBottleneck = MobileNetBlock modules.GhostBottleneck = MobileNetBlock block.C3 = C2f_EMA modules.C3 = C2f_EMA if hasattr(tasks, 'GhostBottleneck'): tasks.GhostBottleneck = MobileNetBlock if hasattr(tasks, 'C3'): tasks.C3 = C2f_EMA if hasattr(tasks, 'block'): tasks.block.GhostBottleneck = MobileNetBlock tasks.block.C3 = C2f_EMA from ultralytics import YOLO except ImportError as e: print(f"Error importing modules: {e}") print("Ensure 'yolov8_mpeb_modules.py' and 'yolov8_mpeb.yaml' are in the same directory.") sys.exit(1) # ========================================================================================= # 6. TRAIN # ========================================================================================= print("Initializing Model...") model_yaml = CURRENT_DIR / "yolov8_mpeb.yaml" if not model_yaml.exists(): print(f"Error: {model_yaml} not found.") sys.exit(1) model = YOLO(str(model_yaml)) print("Starting Training...") # Train 200 epochs, CPU only results = model.train( data=str(DATA_YAML_PATH), epochs=200, device='cpu', project='runs/train', name='visdrone_mpeb', batch=16, # Adjust batch size for CPU if needed (16 or 32 usually safe on modern CPUs) workers=4, exist_ok=True ) # ========================================================================================= # 7. FINALIZE # ========================================================================================= print("Training Complete.") best_weight_path = Path("runs/train/visdrone_mpeb/weights/best.pt") destination_path = CURRENT_DIR / "best.pt" if best_weight_path.exists(): shutil.copy(best_weight_path, destination_path) print(f"Successfully saved best.pt to {destination_path}") else: print("Warning: best.pt not found in runs directory.") print("Exiting...")