Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Direct training without Gradio - forces module reload | |
| """ | |
| import sys | |
| import importlib | |
| # Force reload of modules to pick up bugfixes | |
| if 'IPAD.model.memory_module' in sys.modules: | |
| del sys.modules['IPAD.model.memory_module'] | |
| if 'IPAD.model.video_swin_transformer' in sys.modules: | |
| del sys.modules['IPAD.model.video_swin_transformer'] | |
| if 'train_hf' in sys.modules: | |
| del sys.modules['train_hf'] | |
| print("="*70) | |
| print("π IPAD VAD Direct Training (with module reload)") | |
| print("="*70) | |
| print() | |
| # Now import fresh modules | |
| from train_hf import IPADTrainer | |
| import torch | |
| from datetime import datetime | |
| print(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| print() | |
| # Configuration | |
| device_name = "S01" | |
| epochs = 10 | |
| batch_size = 4 | |
| lr = 1e-4 | |
| print("π Configuration:") | |
| print(f" Device: {device_name}") | |
| print(f" Epochs: {epochs}") | |
| print(f" Batch Size: {batch_size}") | |
| print(f" Learning Rate: {lr}") | |
| print() | |
| # Check GPU | |
| print("π Hardware:") | |
| print(f" CUDA Available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f" GPU: {torch.cuda.get_device_name(0)}") | |
| print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") | |
| else: | |
| print(" Running on CPU (no @spaces.GPU decorator)") | |
| print() | |
| # Create trainer | |
| print("π¦ Initializing trainer...") | |
| trainer = IPADTrainer( | |
| device_name=device_name, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| lr=lr, | |
| mem_dim=2000, | |
| checkpoint_dir="./checkpoints", | |
| wandb_project=None, | |
| hf_repo=None | |
| ) | |
| print("β Trainer initialized") | |
| print() | |
| # Train | |
| dataset_path = "/app/cache/IPAD_dataset" | |
| print(f"ποΈ Starting training...") | |
| print(f" Dataset: {dataset_path}") | |
| print() | |
| import time | |
| start_time = time.time() | |
| try: | |
| trainer.train(dataset_path) | |
| end_time = time.time() | |
| print() | |
| print("="*70) | |
| print(f"β Training completed in {(end_time - start_time) / 60:.1f} minutes!") | |
| print("="*70) | |
| # Check checkpoints | |
| from pathlib import Path | |
| checkpoint_dir = Path("./checkpoints") | |
| checkpoints = list(checkpoint_dir.glob(f"{device_name}_*.pth")) | |
| if checkpoints: | |
| print() | |
| print("πΎ Checkpoints saved:") | |
| for ckpt in sorted(checkpoints): | |
| size_mb = ckpt.stat().st_size / (1024 * 1024) | |
| print(f" - {ckpt.name} ({size_mb:.1f} MB)") | |
| # Load and check checkpoint | |
| if ckpt.name.endswith("_010.pth"): # Final checkpoint | |
| checkpoint = torch.load(ckpt, map_location='cpu') | |
| print() | |
| print("π Final Metrics:") | |
| if 'metrics' in checkpoint: | |
| for key, value in checkpoint['metrics'].items(): | |
| print(f" {key}: {value:.6f}") | |
| except Exception as e: | |
| print(f"β Training failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print() | |
| print("="*70) | |
| print("π Training script finished") | |
| print("="*70) | |