Spaces:
Configuration error
Configuration error
| import sys | |
| import os | |
| from pathlib import Path | |
| import shutil | |
| import yaml | |
| from huggingface_hub import snapshot_download | |
| from tqdm import tqdm | |
| from PIL import Image | |
| # ========================================================================================= | |
| # 1. SETUP & CONFIGURATION | |
| # ========================================================================================= | |
| print("Starting App for YOLOv8-MPEB Training on CPU...") | |
| # Define paths | |
| CURRENT_DIR = Path(os.getcwd()) | |
| DATASET_REPO = "jeyanthangj2004/Visdrone-raw" | |
| DATASET_DIR = CURRENT_DIR / "visdrone_dataset" | |
| DATA_YAML_PATH = CURRENT_DIR / "data.yaml" | |
| # ========================================================================================= | |
| # 2. DOWNLOAD DATASET | |
| # ========================================================================================= | |
| print(f"Downloading dataset from {DATASET_REPO}...") | |
| try: | |
| snapshot_download(repo_id=DATASET_REPO, repo_type="dataset", local_dir=DATASET_DIR) | |
| print("Dataset download complete.") | |
| except Exception as e: | |
| print(f"Error downloading dataset: {e}") | |
| sys.exit(1) | |
| # ========================================================================================= | |
| # 3. DATASET CONVERSION (If needed) | |
| # ========================================================================================= | |
| # Check if dataset is already in YOLO format (images/labels folders) or raw VisDrone format | |
| # Structure assumption based on user request: Visdrone-raw/VisDrone2019-DET-train/ | |
| # We will check and convert if we find the raw annotations. | |
| def visdrone2yolo(dir_path, split): | |
| """Convert VisDrone annotations to YOLO format.""" | |
| print(f"Checking/Converting {split} data in {dir_path}...") | |
| # Define source paths | |
| # Handle cases where folder might be named directly 'VisDrone2019-DET-train' or inside 'Visdrone' | |
| # The snapshot might create: ./visdrone_dataset/Visdrone/VisDrone2019-DET-train or similar | |
| # Search for the split folder recursively | |
| found_split_dir = None | |
| target_folder_name = f"VisDrone2019-DET-{split}" | |
| # First check explicitly in root logic | |
| if (dir_path / target_folder_name).exists(): | |
| found_split_dir = dir_path / target_folder_name | |
| else: | |
| # Recursive search | |
| for p in dir_path.rglob(target_folder_name): | |
| if p.is_dir(): | |
| found_split_dir = p | |
| break | |
| if not found_split_dir: | |
| print(f"Warning: Could not find directory for split '{split}' ({target_folder_name}). Skipping.") | |
| return | |
| source_dir = found_split_dir | |
| # Destination paths - strictly following YOLO structure | |
| images_dest_dir = dir_path / "images" / split | |
| labels_dest_dir = dir_path / "labels" / split | |
| # If labels already exist, assume done (unless force re-run, but for space we assume fresh or persist) | |
| if labels_dest_dir.exists() and any(labels_dest_dir.iterdir()): | |
| print(f"Labels for {split} seem to exist. Skipping conversion.") | |
| return | |
| labels_dest_dir.mkdir(parents=True, exist_ok=True) | |
| images_dest_dir.mkdir(parents=True, exist_ok=True) | |
| # Move/Copy images to new structure if not already there | |
| source_images_dir = source_dir / "images" | |
| if source_images_dir.exists(): | |
| print(f"Moving images from {source_images_dir} to {images_dest_dir}...") | |
| for img in source_images_dir.glob("*.jpg"): | |
| # We copy/move. Since we downloaded, we can move to save space. | |
| shutil.move(str(img), str(images_dest_dir / img.name)) | |
| # Process annotations | |
| source_annotations_dir = source_dir / "annotations" | |
| if source_annotations_dir.exists(): | |
| print(f"Converting annotations from {source_annotations_dir}...") | |
| for f in tqdm(list(source_annotations_dir.glob("*.txt")), desc=f"Converting {split}"): | |
| try: | |
| img_name = f.with_suffix(".jpg").name | |
| img_path = images_dest_dir / img_name | |
| if not img_path.exists(): | |
| continue | |
| img_size = Image.open(img_path).size | |
| dw, dh = 1.0 / img_size[0], 1.0 / img_size[1] | |
| lines = [] | |
| with open(f, encoding="utf-8") as file: | |
| for line in file: | |
| row = line.strip().split(",") | |
| if not row or len(row) < 6: continue | |
| if row[4] != "0": # Skip ignored regions | |
| x, y, w, h = map(int, row[:4]) | |
| cls = int(row[5]) - 1 | |
| # Clip cls to valid range 0-9 if needed, VisDrone usually 1-10 -> 0-9 | |
| if 0 <= cls <= 9: | |
| x_center, y_center = (x + w / 2) * dw, (y + h / 2) * dh | |
| w_norm, h_norm = w * dw, h * dh | |
| lines.append(f"{cls} {x_center:.6f} {y_center:.6f} {w_norm:.6f} {h_norm:.6f}\n") | |
| (labels_dest_dir / f.name).write_text("".join(lines), encoding="utf-8") | |
| except Exception as e: | |
| print(f"Error converting {f.name}: {e}") | |
| # Process datasets | |
| visdrone2yolo(DATASET_DIR, "train") | |
| visdrone2yolo(DATASET_DIR, "val") | |
| visdrone2yolo(DATASET_DIR, "test-dev") # Optional | |
| # ========================================================================================= | |
| # 4. CREATE DATA.YAML | |
| # ========================================================================================= | |
| data_yaml_content = { | |
| 'path': str(DATASET_DIR.absolute()), | |
| 'train': 'images/train', | |
| 'val': 'images/val', | |
| 'test': 'images/test-dev', | |
| 'names': { | |
| 0: 'pedestrian', | |
| 1: 'people', | |
| 2: 'bicycle', | |
| 3: 'car', | |
| 4: 'van', | |
| 5: 'truck', | |
| 6: 'tricycle', | |
| 7: 'awning-tricycle', | |
| 8: 'bus', | |
| 9: 'motor' | |
| } | |
| } | |
| with open(DATA_YAML_PATH, 'w') as f: | |
| yaml.dump(data_yaml_content, f) | |
| print(f"Created data.yaml at {DATA_YAML_PATH}") | |
| # ========================================================================================= | |
| # 5. PATCH & LOAD MODEL | |
| # ========================================================================================= | |
| # Ensure current directory is in python path | |
| sys.path.insert(0, str(CURRENT_DIR)) | |
| try: | |
| from yolov8_mpeb_modules import MobileNetBlock, EMA, C2f_EMA, BiFPN_Fusion | |
| import ultralytics.nn.modules as modules | |
| import ultralytics.nn.modules.block as block | |
| import ultralytics.nn.tasks as tasks | |
| print("Patching Ultralytics modules...") | |
| block.GhostBottleneck = MobileNetBlock | |
| modules.GhostBottleneck = MobileNetBlock | |
| block.C3 = C2f_EMA | |
| modules.C3 = C2f_EMA | |
| if hasattr(tasks, 'GhostBottleneck'): tasks.GhostBottleneck = MobileNetBlock | |
| if hasattr(tasks, 'C3'): tasks.C3 = C2f_EMA | |
| if hasattr(tasks, 'block'): | |
| tasks.block.GhostBottleneck = MobileNetBlock | |
| tasks.block.C3 = C2f_EMA | |
| from ultralytics import YOLO | |
| except ImportError as e: | |
| print(f"Error importing modules: {e}") | |
| print("Ensure 'yolov8_mpeb_modules.py' and 'yolov8_mpeb.yaml' are in the same directory.") | |
| sys.exit(1) | |
| # ========================================================================================= | |
| # 6. TRAIN | |
| # ========================================================================================= | |
| print("Initializing Model...") | |
| model_yaml = CURRENT_DIR / "yolov8_mpeb.yaml" | |
| if not model_yaml.exists(): | |
| print(f"Error: {model_yaml} not found.") | |
| sys.exit(1) | |
| model = YOLO(str(model_yaml)) | |
| print("Starting Training...") | |
| # Train 200 epochs, CPU only | |
| results = model.train( | |
| data=str(DATA_YAML_PATH), | |
| epochs=200, | |
| device='cpu', | |
| project='runs/train', | |
| name='visdrone_mpeb', | |
| batch=16, # Adjust batch size for CPU if needed (16 or 32 usually safe on modern CPUs) | |
| workers=4, | |
| exist_ok=True | |
| ) | |
| # ========================================================================================= | |
| # 7. FINALIZE | |
| # ========================================================================================= | |
| print("Training Complete.") | |
| best_weight_path = Path("runs/train/visdrone_mpeb/weights/best.pt") | |
| destination_path = CURRENT_DIR / "best.pt" | |
| if best_weight_path.exists(): | |
| shutil.copy(best_weight_path, destination_path) | |
| print(f"Successfully saved best.pt to {destination_path}") | |
| else: | |
| print("Warning: best.pt not found in runs directory.") | |
| print("Exiting...") | |