mpeb / app.py
jeyanthangj2004's picture
Upload 21 files
558d0f4 verified
import sys
import os
from pathlib import Path
import shutil
import yaml
from huggingface_hub import snapshot_download
from tqdm import tqdm
from PIL import Image
# =========================================================================================
# 1. SETUP & CONFIGURATION
# =========================================================================================
print("Starting App for YOLOv8-MPEB Training on CPU...")
# Define paths
CURRENT_DIR = Path(os.getcwd())
DATASET_REPO = "jeyanthangj2004/Visdrone-raw"
DATASET_DIR = CURRENT_DIR / "visdrone_dataset"
DATA_YAML_PATH = CURRENT_DIR / "data.yaml"
# =========================================================================================
# 2. DOWNLOAD DATASET
# =========================================================================================
print(f"Downloading dataset from {DATASET_REPO}...")
try:
snapshot_download(repo_id=DATASET_REPO, repo_type="dataset", local_dir=DATASET_DIR)
print("Dataset download complete.")
except Exception as e:
print(f"Error downloading dataset: {e}")
sys.exit(1)
# =========================================================================================
# 3. DATASET CONVERSION (If needed)
# =========================================================================================
# Check if dataset is already in YOLO format (images/labels folders) or raw VisDrone format
# Structure assumption based on user request: Visdrone-raw/VisDrone2019-DET-train/
# We will check and convert if we find the raw annotations.
def visdrone2yolo(dir_path, split):
"""Convert VisDrone annotations to YOLO format."""
print(f"Checking/Converting {split} data in {dir_path}...")
# Define source paths
# Handle cases where folder might be named directly 'VisDrone2019-DET-train' or inside 'Visdrone'
# The snapshot might create: ./visdrone_dataset/Visdrone/VisDrone2019-DET-train or similar
# Search for the split folder recursively
found_split_dir = None
target_folder_name = f"VisDrone2019-DET-{split}"
# First check explicitly in root logic
if (dir_path / target_folder_name).exists():
found_split_dir = dir_path / target_folder_name
else:
# Recursive search
for p in dir_path.rglob(target_folder_name):
if p.is_dir():
found_split_dir = p
break
if not found_split_dir:
print(f"Warning: Could not find directory for split '{split}' ({target_folder_name}). Skipping.")
return
source_dir = found_split_dir
# Destination paths - strictly following YOLO structure
images_dest_dir = dir_path / "images" / split
labels_dest_dir = dir_path / "labels" / split
# If labels already exist, assume done (unless force re-run, but for space we assume fresh or persist)
if labels_dest_dir.exists() and any(labels_dest_dir.iterdir()):
print(f"Labels for {split} seem to exist. Skipping conversion.")
return
labels_dest_dir.mkdir(parents=True, exist_ok=True)
images_dest_dir.mkdir(parents=True, exist_ok=True)
# Move/Copy images to new structure if not already there
source_images_dir = source_dir / "images"
if source_images_dir.exists():
print(f"Moving images from {source_images_dir} to {images_dest_dir}...")
for img in source_images_dir.glob("*.jpg"):
# We copy/move. Since we downloaded, we can move to save space.
shutil.move(str(img), str(images_dest_dir / img.name))
# Process annotations
source_annotations_dir = source_dir / "annotations"
if source_annotations_dir.exists():
print(f"Converting annotations from {source_annotations_dir}...")
for f in tqdm(list(source_annotations_dir.glob("*.txt")), desc=f"Converting {split}"):
try:
img_name = f.with_suffix(".jpg").name
img_path = images_dest_dir / img_name
if not img_path.exists():
continue
img_size = Image.open(img_path).size
dw, dh = 1.0 / img_size[0], 1.0 / img_size[1]
lines = []
with open(f, encoding="utf-8") as file:
for line in file:
row = line.strip().split(",")
if not row or len(row) < 6: continue
if row[4] != "0": # Skip ignored regions
x, y, w, h = map(int, row[:4])
cls = int(row[5]) - 1
# Clip cls to valid range 0-9 if needed, VisDrone usually 1-10 -> 0-9
if 0 <= cls <= 9:
x_center, y_center = (x + w / 2) * dw, (y + h / 2) * dh
w_norm, h_norm = w * dw, h * dh
lines.append(f"{cls} {x_center:.6f} {y_center:.6f} {w_norm:.6f} {h_norm:.6f}\n")
(labels_dest_dir / f.name).write_text("".join(lines), encoding="utf-8")
except Exception as e:
print(f"Error converting {f.name}: {e}")
# Process datasets
visdrone2yolo(DATASET_DIR, "train")
visdrone2yolo(DATASET_DIR, "val")
visdrone2yolo(DATASET_DIR, "test-dev") # Optional
# =========================================================================================
# 4. CREATE DATA.YAML
# =========================================================================================
data_yaml_content = {
'path': str(DATASET_DIR.absolute()),
'train': 'images/train',
'val': 'images/val',
'test': 'images/test-dev',
'names': {
0: 'pedestrian',
1: 'people',
2: 'bicycle',
3: 'car',
4: 'van',
5: 'truck',
6: 'tricycle',
7: 'awning-tricycle',
8: 'bus',
9: 'motor'
}
}
with open(DATA_YAML_PATH, 'w') as f:
yaml.dump(data_yaml_content, f)
print(f"Created data.yaml at {DATA_YAML_PATH}")
# =========================================================================================
# 5. PATCH & LOAD MODEL
# =========================================================================================
# Ensure current directory is in python path
sys.path.insert(0, str(CURRENT_DIR))
try:
from yolov8_mpeb_modules import MobileNetBlock, EMA, C2f_EMA, BiFPN_Fusion
import ultralytics.nn.modules as modules
import ultralytics.nn.modules.block as block
import ultralytics.nn.tasks as tasks
print("Patching Ultralytics modules...")
block.GhostBottleneck = MobileNetBlock
modules.GhostBottleneck = MobileNetBlock
block.C3 = C2f_EMA
modules.C3 = C2f_EMA
if hasattr(tasks, 'GhostBottleneck'): tasks.GhostBottleneck = MobileNetBlock
if hasattr(tasks, 'C3'): tasks.C3 = C2f_EMA
if hasattr(tasks, 'block'):
tasks.block.GhostBottleneck = MobileNetBlock
tasks.block.C3 = C2f_EMA
from ultralytics import YOLO
except ImportError as e:
print(f"Error importing modules: {e}")
print("Ensure 'yolov8_mpeb_modules.py' and 'yolov8_mpeb.yaml' are in the same directory.")
sys.exit(1)
# =========================================================================================
# 6. TRAIN
# =========================================================================================
print("Initializing Model...")
model_yaml = CURRENT_DIR / "yolov8_mpeb.yaml"
if not model_yaml.exists():
print(f"Error: {model_yaml} not found.")
sys.exit(1)
model = YOLO(str(model_yaml))
print("Starting Training...")
# Train 200 epochs, CPU only
results = model.train(
data=str(DATA_YAML_PATH),
epochs=200,
device='cpu',
project='runs/train',
name='visdrone_mpeb',
batch=16, # Adjust batch size for CPU if needed (16 or 32 usually safe on modern CPUs)
workers=4,
exist_ok=True
)
# =========================================================================================
# 7. FINALIZE
# =========================================================================================
print("Training Complete.")
best_weight_path = Path("runs/train/visdrone_mpeb/weights/best.pt")
destination_path = CURRENT_DIR / "best.pt"
if best_weight_path.exists():
shutil.copy(best_weight_path, destination_path)
print(f"Successfully saved best.pt to {destination_path}")
else:
print("Warning: best.pt not found in runs directory.")
print("Exiting...")