mpebtraining / app.py
jeyanthangj2004's picture
Update app.py
e839198 verified
import sys
import os
import threading
import http.server
import socketserver
from pathlib import Path
import shutil
import yaml
from huggingface_hub import snapshot_download
from tqdm import tqdm
from PIL import Image
# =========================================================================================
# 0. HEALTH CHECK SERVER (CRITICAL FOR DOCKER SPACES)
# =========================================================================================
# Hugging Face Spaces expects the application to listen on port 7860.
# If it doesn't listen within 30 minutes, the Space is flagged as "Runtime Error".
# We start a minimal HTTP server in the background to satisfy this requirement immediately.
PORT = 7860
# Global variable to track status
TRAINING_STATUS = "Initializing..."
def start_health_check_server():
class HealthCheckHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header('Content-type', 'text/html') # Changed to HTML for auto-refresh
self.end_headers()
# Simple HTML with auto-refresh every 30 seconds
html_content = f"""
<html>
<head>
<meta http-equiv="refresh" content="30">
<title>YOLOv8-MPEB Training Status</title>
<style>
body {{ font-family: sans-serif; padding: 50px; text-align: center; }}
h1 {{ color: #333; }}
.status {{ font-size: 24px; padding: 20px; border: 1px solid #ddd; border-radius: 8px; background: #f9f9f9; display: inline-block; }}
</style>
</head>
<body>
<h1>YOLOv8-MPEB Training Monitor</h1>
<div class="status">{TRAINING_STATUS}</div>
<p><em>Autorefreshing every 30 seconds...</em></p>
</body>
</html>
"""
self.wfile.write(html_content.encode())
def log_message(self, format, *args):
pass
try:
with socketserver.TCPServer(("", PORT), HealthCheckHandler) as httpd:
print(f"Health check server serving at port {PORT}")
httpd.serve_forever()
except Exception as e:
print(f"Failed to start health check server: {e}")
# Start the server in a daemon thread
server_thread = threading.Thread(target=start_health_check_server, daemon=True)
server_thread.start()
# =========================================================================================
# 1. SETUP & CONFIGURATION
# =========================================================================================
print("Starting App for YOLOv8-MPEB Training on CPU...")
# Define paths
CURRENT_DIR = Path(os.getcwd())
DATASET_REPO = "jeyanthangj2004/Visdrone-raw"
DATASET_DIR = CURRENT_DIR / "visdrone_dataset"
DATA_YAML_PATH = CURRENT_DIR / "data.yaml"
# =========================================================================================
# 2. DOWNLOAD DATASET
# =========================================================================================
print(f"Downloading dataset from {DATASET_REPO}...")
try:
snapshot_download(repo_id=DATASET_REPO, repo_type="dataset", local_dir=DATASET_DIR)
print("Dataset download complete.")
except Exception as e:
print(f"Error downloading dataset: {e}")
sys.exit(1)
# =========================================================================================
# 3. DATASET CONVERSION
# =========================================================================================
def visdrone2yolo(dir_path, split):
"""Convert VisDrone annotations to YOLO format."""
print(f"Checking/Converting {split} data in {dir_path}...")
found_split_dir = None
target_folder_name = f"VisDrone2019-DET-{split}"
if (dir_path / target_folder_name).exists():
found_split_dir = dir_path / target_folder_name
else:
for p in dir_path.rglob(target_folder_name):
if p.is_dir():
found_split_dir = p
break
if not found_split_dir:
print(f"Warning: Could not find directory for split '{split}' ({target_folder_name}). Skipping.")
return
source_dir = found_split_dir
images_dest_dir = dir_path / "images" / split
labels_dest_dir = dir_path / "labels" / split
if labels_dest_dir.exists() and any(labels_dest_dir.iterdir()):
print(f"Labels for {split} seem to exist. Skipping conversion.")
return
labels_dest_dir.mkdir(parents=True, exist_ok=True)
images_dest_dir.mkdir(parents=True, exist_ok=True)
source_images_dir = source_dir / "images"
if source_images_dir.exists():
print(f"Moving images from {source_images_dir} to {images_dest_dir}...")
for img in source_images_dir.glob("*.jpg"):
shutil.move(str(img), str(images_dest_dir / img.name))
source_annotations_dir = source_dir / "annotations"
if source_annotations_dir.exists():
print(f"Converting annotations from {source_annotations_dir}...")
for f in tqdm(list(source_annotations_dir.glob("*.txt")), desc=f"Converting {split}"):
try:
img_name = f.with_suffix(".jpg").name
img_path = images_dest_dir / img_name
if not img_path.exists():
continue
img_size = Image.open(img_path).size
dw, dh = 1.0 / img_size[0], 1.0 / img_size[1]
lines = []
with open(f, encoding="utf-8") as file:
for line in file:
row = line.strip().split(",")
if not row or len(row) < 6: continue
if row[4] != "0": # Skip ignored regions
x, y, w, h = map(int, row[:4])
cls = int(row[5]) - 1
if 0 <= cls <= 9:
x_center, y_center = (x + w / 2) * dw, (y + h / 2) * dh
w_norm, h_norm = w * dw, h * dh
lines.append(f"{cls} {x_center:.6f} {y_center:.6f} {w_norm:.6f} {h_norm:.6f}\n")
(labels_dest_dir / f.name).write_text("".join(lines), encoding="utf-8")
except Exception as e:
print(f"Error converting {f.name}: {e}")
# Process datasets
visdrone2yolo(DATASET_DIR, "train")
visdrone2yolo(DATASET_DIR, "val")
visdrone2yolo(DATASET_DIR, "test-dev")
# =========================================================================================
# 4. CREATE DATA.YAML
# =========================================================================================
data_yaml_content = {
'path': str(DATASET_DIR.absolute()),
'train': 'images/train',
'val': 'images/val',
'test': 'images/test-dev',
'names': {0: 'pedestrian', 1: 'people', 2: 'bicycle', 3: 'car', 4: 'van', 5: 'truck', 6: 'tricycle', 7: 'awning-tricycle', 8: 'bus', 9: 'motor'}
}
with open(DATA_YAML_PATH, 'w') as f:
yaml.dump(data_yaml_content, f)
print(f"Created data.yaml at {DATA_YAML_PATH}")
# =========================================================================================
# 5. PATCH & LOAD MODEL
# =========================================================================================
sys.path.insert(0, str(CURRENT_DIR))
try:
from yolov8_mpeb_modules import MobileNetBlock, EMA, C2f_EMA, BiFPN_Fusion
import ultralytics.nn.modules as modules
import ultralytics.nn.modules.block as block
import ultralytics.nn.tasks as tasks
print("Patching Ultralytics modules...")
block.GhostBottleneck = MobileNetBlock
modules.GhostBottleneck = MobileNetBlock
block.C3 = C2f_EMA
modules.C3 = C2f_EMA
if hasattr(tasks, 'GhostBottleneck'): tasks.GhostBottleneck = MobileNetBlock
if hasattr(tasks, 'C3'): tasks.C3 = C2f_EMA
if hasattr(tasks, 'block'):
tasks.block.GhostBottleneck = MobileNetBlock
tasks.block.C3 = C2f_EMA
from ultralytics import YOLO
except ImportError as e:
print(f"Error importing modules: {e}")
sys.exit(1)
# =========================================================================================
# 6. TRAIN
# =========================================================================================
print("Initializing Model...")
model_yaml = CURRENT_DIR / "yolov8_mpeb.yaml"
if not model_yaml.exists():
print(f"Error: {model_yaml} not found.")
sys.exit(1)
model = YOLO(str(model_yaml))
# Update Status
TRAINING_STATUS = "Starting Training... (Check Logs for details)"
print("Starting Training...")
# Define callbacks to update status
def on_train_epoch_end(trainer):
global TRAINING_STATUS
current_epoch = trainer.epoch + 1
total_epochs = trainer.epochs
try:
metrics = trainer.metrics
map50 = metrics.get("metrics/mAP50(B)", 0)
TRAINING_STATUS = f"Training in progress...<br>Epoch: {current_epoch}/{total_epochs}<br>mAP50: {map50:.4f}"
except:
TRAINING_STATUS = f"Training in progress...<br>Epoch: {current_epoch}/{total_epochs}"
def on_train_batch_end(trainer):
global TRAINING_STATUS
# Update status every 5 batches to avoid over-refreshing
if hasattr(trainer, 'pbar') and trainer.pbar and trainer.pbar.n % 5 == 0:
current_epoch = trainer.epoch + 1
total_epochs = trainer.epochs
current_batch = trainer.pbar.n
total_batch = trainer.pbar.total or "?"
TRAINING_STATUS = f"<b>Training Running...</b><br>Epoch: {current_epoch}/{total_epochs}<br>Batch: {current_batch}/{total_batch}<br><i>(CPU training is slow, please wait)</i>"
model.add_callback("on_train_epoch_end", on_train_epoch_end)
model.add_callback("on_train_batch_end", on_train_batch_end)
# Train 200 epochs, CPU only
results = model.train(
data=str(DATA_YAML_PATH),
epochs=200,
device='cpu',
project='runs/train',
name='visdrone_mpeb',
batch=4,
workers=1,
exist_ok=True
)
# =========================================================================================
# 7. FINALIZE
# =========================================================================================
TRAINING_STATUS = "Training Complete! Finalizing..."
print("Training Complete.")
best_weight_path = Path("runs/train/visdrone_mpeb/weights/best.pt")
destination_path = CURRENT_DIR / "best.pt"
if best_weight_path.exists():
shutil.copy(best_weight_path, destination_path)
print(f"Successfully saved best.pt to {destination_path}")
TRAINING_STATUS = f"Training Successfully Completed!<br>Model saved to: {destination_path}<br>You can now download the file or stop the space."
else:
print("Warning: best.pt not found in runs directory.")
TRAINING_STATUS = "Training Finished, but best.pt was not found. Check logs."
print("Exiting...")