#!/usr/bin/env python3 """ Direct training without Gradio - forces module reload """ import sys import importlib # Force reload of modules to pick up bugfixes if 'IPAD.model.memory_module' in sys.modules: del sys.modules['IPAD.model.memory_module'] if 'IPAD.model.video_swin_transformer' in sys.modules: del sys.modules['IPAD.model.video_swin_transformer'] if 'train_hf' in sys.modules: del sys.modules['train_hf'] print("="*70) print("🚀 IPAD VAD Direct Training (with module reload)") print("="*70) print() # Now import fresh modules from train_hf import IPADTrainer import torch from datetime import datetime print(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print() # Configuration device_name = "S01" epochs = 10 batch_size = 4 lr = 1e-4 print("📋 Configuration:") print(f" Device: {device_name}") print(f" Epochs: {epochs}") print(f" Batch Size: {batch_size}") print(f" Learning Rate: {lr}") print() # Check GPU print("🔍 Hardware:") print(f" CUDA Available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f" GPU: {torch.cuda.get_device_name(0)}") print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") else: print(" Running on CPU (no @spaces.GPU decorator)") print() # Create trainer print("📦 Initializing trainer...") trainer = IPADTrainer( device_name=device_name, epochs=epochs, batch_size=batch_size, lr=lr, mem_dim=2000, checkpoint_dir="./checkpoints", wandb_project=None, hf_repo=None ) print("✅ Trainer initialized") print() # Train dataset_path = "/app/cache/IPAD_dataset" print(f"🏋️ Starting training...") print(f" Dataset: {dataset_path}") print() import time start_time = time.time() try: trainer.train(dataset_path) end_time = time.time() print() print("="*70) print(f"✅ Training completed in {(end_time - start_time) / 60:.1f} minutes!") print("="*70) # Check checkpoints from pathlib import Path checkpoint_dir = Path("./checkpoints") checkpoints = list(checkpoint_dir.glob(f"{device_name}_*.pth")) if checkpoints: print() print("💾 Checkpoints saved:") for ckpt in sorted(checkpoints): size_mb = ckpt.stat().st_size / (1024 * 1024) print(f" - {ckpt.name} ({size_mb:.1f} MB)") # Load and check checkpoint if ckpt.name.endswith("_010.pth"): # Final checkpoint checkpoint = torch.load(ckpt, map_location='cpu') print() print("📊 Final Metrics:") if 'metrics' in checkpoint: for key, value in checkpoint['metrics'].items(): print(f" {key}: {value:.6f}") except Exception as e: print(f"❌ Training failed: {e}") import traceback traceback.print_exc() print() print("="*70) print("🏁 Training script finished") print("="*70)