Spaces:
Sleeping
Sleeping
| """ | |
| HuggingFace-adapted IPAD Training Script | |
| Trains on HF infrastructure with ZeroGPU, Accelerate, and automatic checkpointing | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.optim import Adam | |
| from torch.cuda.amp import autocast, GradScaler | |
| from pathlib import Path | |
| import json | |
| from datetime import datetime | |
| from tqdm import tqdm | |
| import wandb | |
| from typing import Dict, Optional | |
| import os | |
| # HF infrastructure | |
| from huggingface_hub import HfApi, create_repo | |
| from accelerate import Accelerator | |
| # Local imports | |
| from IPAD.model.video_swin_transformer import VST | |
| from IPAD.model.entropy_loss import EntropyLossEncap | |
| from dataset import create_dataloaders, download_and_extract_dataset | |
| class IPADTrainer: | |
| """ | |
| IPAD Model Trainer with HF Integration | |
| """ | |
| def __init__( | |
| self, | |
| device_name: str = "S01", | |
| mem_dim: int = 2000, | |
| shrink_thres: float = 0.0025, | |
| lr: float = 1e-4, | |
| batch_size: int = 4, | |
| epochs: int = 200, | |
| entropy_loss_weight: float = 0.0002, | |
| period_loss_weight: float = 0.02, | |
| checkpoint_dir: str = "./checkpoints", | |
| wandb_project: Optional[str] = "ipad-vad", | |
| hf_repo: Optional[str] = "MSherbinii/ipad-vad-checkpoints" | |
| ): | |
| self.device_name = device_name | |
| self.mem_dim = mem_dim | |
| self.shrink_thres = shrink_thres | |
| self.lr = lr | |
| self.batch_size = batch_size | |
| self.epochs = epochs | |
| self.entropy_loss_weight = entropy_loss_weight | |
| self.period_loss_weight = period_loss_weight | |
| self.checkpoint_dir = Path(checkpoint_dir) | |
| self.checkpoint_dir.mkdir(exist_ok=True, parents=True) | |
| self.wandb_project = wandb_project | |
| self.hf_repo = hf_repo | |
| # Initialize Accelerator for distributed training | |
| self.accelerator = Accelerator( | |
| mixed_precision='fp16', | |
| gradient_accumulation_steps=1, | |
| log_with="wandb" if wandb_project else None | |
| ) | |
| # Model | |
| self.model = VST(mem_dim=mem_dim, shrink_thres=shrink_thres) | |
| # Losses | |
| self.recon_criterion = nn.MSELoss() | |
| self.entropy_criterion = EntropyLossEncap() | |
| self.period_criterion = nn.CrossEntropyLoss() | |
| # Optimizer | |
| self.optimizer = Adam(self.model.parameters(), lr=lr) | |
| # HF API | |
| self.hf_api = HfApi() | |
| if hf_repo: | |
| try: | |
| create_repo(hf_repo, repo_type="model", private=False, exist_ok=True) | |
| except: | |
| pass | |
| def setup_data(self, dataset_path: str): | |
| """Setup dataloaders""" | |
| self.train_loader, self.test_loader = create_dataloaders( | |
| dataset_path=dataset_path, | |
| device_name=self.device_name, | |
| batch_size=self.batch_size, | |
| num_workers=4, | |
| clip_length=16, | |
| frame_size=(256, 256) | |
| ) | |
| # Prepare with Accelerator | |
| self.model, self.optimizer, self.train_loader, self.test_loader = \ | |
| self.accelerator.prepare( | |
| self.model, self.optimizer, self.train_loader, self.test_loader | |
| ) | |
| def train_epoch(self, epoch: int) -> Dict[str, float]: | |
| """Train for one epoch""" | |
| self.model.train() | |
| total_loss = 0.0 | |
| recon_loss_sum = 0.0 | |
| entropy_loss_sum = 0.0 | |
| period_loss_sum = 0.0 | |
| pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.epochs}") | |
| for batch_idx, clips in enumerate(pbar): | |
| # clips shape: [B, C, T, H, W] | |
| with self.accelerator.autocast(): | |
| # Forward pass | |
| outputs = self.model(clips) | |
| reconstructed = outputs['output'] | |
| att = outputs['att'] | |
| period_pred = outputs['recon_index'] | |
| # Reconstruction loss | |
| recon_loss = self.recon_criterion(reconstructed, clips) | |
| # Entropy loss on attention weights | |
| entropy_loss = self.entropy_criterion(att) | |
| # Period classification loss | |
| # Create pseudo-labels (uniform distribution for now) | |
| # In full implementation, this would use actual period annotations | |
| period_labels = torch.randint(0, 200, (clips.size(0),)).to(clips.device) | |
| period_loss = self.period_criterion(period_pred, period_labels) | |
| # Combined loss | |
| loss = (recon_loss + | |
| self.entropy_loss_weight * entropy_loss + | |
| self.period_loss_weight * period_loss) | |
| # Backward pass | |
| self.accelerator.backward(loss) | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| # Accumulate losses | |
| total_loss += loss.item() | |
| recon_loss_sum += recon_loss.item() | |
| entropy_loss_sum += entropy_loss.item() | |
| period_loss_sum += period_loss.item() | |
| # Update progress bar | |
| pbar.set_postfix({ | |
| 'loss': f'{loss.item():.4f}', | |
| 'recon': f'{recon_loss.item():.4f}', | |
| 'entropy': f'{entropy_loss.item():.6f}', | |
| 'period': f'{period_loss.item():.4f}' | |
| }) | |
| num_batches = len(self.train_loader) | |
| return { | |
| 'train_loss': total_loss / num_batches, | |
| 'train_recon_loss': recon_loss_sum / num_batches, | |
| 'train_entropy_loss': entropy_loss_sum / num_batches, | |
| 'train_period_loss': period_loss_sum / num_batches | |
| } | |
| def validate(self) -> Dict[str, float]: | |
| """Validate on test set""" | |
| self.model.eval() | |
| total_loss = 0.0 | |
| recon_loss_sum = 0.0 | |
| for clips in tqdm(self.test_loader, desc="Validating"): | |
| with self.accelerator.autocast(): | |
| outputs = self.model(clips) | |
| reconstructed = outputs['output'] | |
| recon_loss = self.recon_criterion(reconstructed, clips) | |
| total_loss += recon_loss.item() | |
| recon_loss_sum += recon_loss.item() | |
| num_batches = len(self.test_loader) | |
| return { | |
| 'val_loss': total_loss / num_batches, | |
| 'val_recon_loss': recon_loss_sum / num_batches | |
| } | |
| def save_checkpoint(self, epoch: int, metrics: Dict[str, float]): | |
| """Save checkpoint locally and upload to HF Hub""" | |
| checkpoint_name = f"{self.device_name}_epoch_{epoch:03d}.pth" | |
| checkpoint_path = self.checkpoint_dir / checkpoint_name | |
| # Save checkpoint | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state_dict': self.accelerator.unwrap_model(self.model).state_dict(), | |
| 'optimizer_state_dict': self.optimizer.state_dict(), | |
| 'metrics': metrics, | |
| 'config': { | |
| 'device_name': self.device_name, | |
| 'mem_dim': self.mem_dim, | |
| 'shrink_thres': self.shrink_thres, | |
| 'lr': self.lr, | |
| 'batch_size': self.batch_size | |
| } | |
| } | |
| torch.save(checkpoint, checkpoint_path) | |
| print(f"๐พ Checkpoint saved: {checkpoint_path}") | |
| # Upload to HF Hub | |
| if self.hf_repo: | |
| try: | |
| self.hf_api.upload_file( | |
| path_or_fileobj=str(checkpoint_path), | |
| path_in_repo=f"checkpoints/{checkpoint_name}", | |
| repo_id=self.hf_repo, | |
| repo_type="model", | |
| commit_message=f"Epoch {epoch} - {self.device_name}" | |
| ) | |
| print(f"โ๏ธ Uploaded to HF Hub: {self.hf_repo}") | |
| except Exception as e: | |
| print(f"โ ๏ธ Failed to upload to HF Hub: {e}") | |
| def train(self, dataset_path: str): | |
| """Full training loop""" | |
| print(f"\n๐ Starting training for {self.device_name}") | |
| print(f"๐ Epochs: {self.epochs}, Batch Size: {self.batch_size}, LR: {self.lr}") | |
| # Setup data | |
| self.setup_data(dataset_path) | |
| # Initialize wandb | |
| if self.wandb_project: | |
| self.accelerator.init_trackers( | |
| project_name=self.wandb_project, | |
| config={ | |
| 'device_name': self.device_name, | |
| 'mem_dim': self.mem_dim, | |
| 'lr': self.lr, | |
| 'batch_size': self.batch_size, | |
| 'epochs': self.epochs | |
| } | |
| ) | |
| # Training loop | |
| best_val_loss = float('inf') | |
| for epoch in range(1, self.epochs + 1): | |
| # Train | |
| train_metrics = self.train_epoch(epoch) | |
| # Validate every 10 epochs | |
| if epoch % 10 == 0: | |
| val_metrics = self.validate() | |
| metrics = {**train_metrics, **val_metrics} | |
| # Save best model | |
| if val_metrics['val_loss'] < best_val_loss: | |
| best_val_loss = val_metrics['val_loss'] | |
| self.save_checkpoint(epoch, metrics) | |
| # Log metrics | |
| if self.wandb_project: | |
| self.accelerator.log(metrics, step=epoch) | |
| print(f"\n๐ Epoch {epoch} - Train Loss: {train_metrics['train_loss']:.4f}, Val Loss: {val_metrics['val_loss']:.4f}") | |
| # Save checkpoint every 50 epochs | |
| if epoch % 50 == 0: | |
| self.save_checkpoint(epoch, train_metrics) | |
| print(f"\nโ Training complete for {self.device_name}!") | |
| print(f"๐ Checkpoints saved to: {self.checkpoint_dir}") | |
| if self.hf_repo: | |
| print(f"โ๏ธ Model available at: https://huggingface.co/{self.hf_repo}") | |
| def main(): | |
| """Main training entry point""" | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Train IPAD VAD model on HF infrastructure") | |
| parser.add_argument("--device", type=str, default="S01", help="Device name (S01-S12, R01-R04)") | |
| parser.add_argument("--epochs", type=int, default=200, help="Number of epochs") | |
| parser.add_argument("--batch-size", type=int, default=4, help="Batch size") | |
| parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") | |
| parser.add_argument("--mem-dim", type=int, default=2000, help="Memory dimension") | |
| parser.add_argument("--no-wandb", action="store_true", help="Disable wandb logging") | |
| parser.add_argument("--dataset-path", type=str, default=None, help="Path to dataset (downloads if not provided)") | |
| args = parser.parse_args() | |
| # Download dataset if needed | |
| if args.dataset_path is None: | |
| dataset_path = download_and_extract_dataset() | |
| else: | |
| dataset_path = Path(args.dataset_path) | |
| # Create trainer | |
| trainer = IPADTrainer( | |
| device_name=args.device, | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| lr=args.lr, | |
| mem_dim=args.mem_dim, | |
| wandb_project=None if args.no_wandb else "ipad-vad" | |
| ) | |
| # Train | |
| trainer.train(str(dataset_path)) | |
| if __name__ == "__main__": | |
| main() | |