""" HuggingFace-adapted IPAD Training Script Trains on HF infrastructure with ZeroGPU, Accelerate, and automatic checkpointing """ import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import Adam from torch.cuda.amp import autocast, GradScaler from pathlib import Path import json from datetime import datetime from tqdm import tqdm import wandb from typing import Dict, Optional import os # HF infrastructure from huggingface_hub import HfApi, create_repo from accelerate import Accelerator # Local imports from IPAD.model.video_swin_transformer import VST from IPAD.model.entropy_loss import EntropyLossEncap from dataset import create_dataloaders, download_and_extract_dataset class IPADTrainer: """ IPAD Model Trainer with HF Integration """ def __init__( self, device_name: str = "S01", mem_dim: int = 2000, shrink_thres: float = 0.0025, lr: float = 1e-4, batch_size: int = 4, epochs: int = 200, entropy_loss_weight: float = 0.0002, period_loss_weight: float = 0.02, checkpoint_dir: str = "./checkpoints", wandb_project: Optional[str] = "ipad-vad", hf_repo: Optional[str] = "MSherbinii/ipad-vad-checkpoints" ): self.device_name = device_name self.mem_dim = mem_dim self.shrink_thres = shrink_thres self.lr = lr self.batch_size = batch_size self.epochs = epochs self.entropy_loss_weight = entropy_loss_weight self.period_loss_weight = period_loss_weight self.checkpoint_dir = Path(checkpoint_dir) self.checkpoint_dir.mkdir(exist_ok=True, parents=True) self.wandb_project = wandb_project self.hf_repo = hf_repo # Initialize Accelerator for distributed training self.accelerator = Accelerator( mixed_precision='fp16', gradient_accumulation_steps=1, log_with="wandb" if wandb_project else None ) # Model self.model = VST(mem_dim=mem_dim, shrink_thres=shrink_thres) # Losses self.recon_criterion = nn.MSELoss() self.entropy_criterion = EntropyLossEncap() self.period_criterion = nn.CrossEntropyLoss() # Optimizer self.optimizer = Adam(self.model.parameters(), lr=lr) # HF API self.hf_api = HfApi() if hf_repo: try: create_repo(hf_repo, repo_type="model", private=False, exist_ok=True) except: pass def setup_data(self, dataset_path: str): """Setup dataloaders""" self.train_loader, self.test_loader = create_dataloaders( dataset_path=dataset_path, device_name=self.device_name, batch_size=self.batch_size, num_workers=4, clip_length=16, frame_size=(256, 256) ) # Prepare with Accelerator self.model, self.optimizer, self.train_loader, self.test_loader = \ self.accelerator.prepare( self.model, self.optimizer, self.train_loader, self.test_loader ) def train_epoch(self, epoch: int) -> Dict[str, float]: """Train for one epoch""" self.model.train() total_loss = 0.0 recon_loss_sum = 0.0 entropy_loss_sum = 0.0 period_loss_sum = 0.0 pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.epochs}") for batch_idx, clips in enumerate(pbar): # clips shape: [B, C, T, H, W] with self.accelerator.autocast(): # Forward pass outputs = self.model(clips) reconstructed = outputs['output'] att = outputs['att'] period_pred = outputs['recon_index'] # Reconstruction loss recon_loss = self.recon_criterion(reconstructed, clips) # Entropy loss on attention weights entropy_loss = self.entropy_criterion(att) # Period classification loss # Create pseudo-labels (uniform distribution for now) # In full implementation, this would use actual period annotations period_labels = torch.randint(0, 200, (clips.size(0),)).to(clips.device) period_loss = self.period_criterion(period_pred, period_labels) # Combined loss loss = (recon_loss + self.entropy_loss_weight * entropy_loss + self.period_loss_weight * period_loss) # Backward pass self.accelerator.backward(loss) self.optimizer.step() self.optimizer.zero_grad() # Accumulate losses total_loss += loss.item() recon_loss_sum += recon_loss.item() entropy_loss_sum += entropy_loss.item() period_loss_sum += period_loss.item() # Update progress bar pbar.set_postfix({ 'loss': f'{loss.item():.4f}', 'recon': f'{recon_loss.item():.4f}', 'entropy': f'{entropy_loss.item():.6f}', 'period': f'{period_loss.item():.4f}' }) num_batches = len(self.train_loader) return { 'train_loss': total_loss / num_batches, 'train_recon_loss': recon_loss_sum / num_batches, 'train_entropy_loss': entropy_loss_sum / num_batches, 'train_period_loss': period_loss_sum / num_batches } @torch.no_grad() def validate(self) -> Dict[str, float]: """Validate on test set""" self.model.eval() total_loss = 0.0 recon_loss_sum = 0.0 for clips in tqdm(self.test_loader, desc="Validating"): with self.accelerator.autocast(): outputs = self.model(clips) reconstructed = outputs['output'] recon_loss = self.recon_criterion(reconstructed, clips) total_loss += recon_loss.item() recon_loss_sum += recon_loss.item() num_batches = len(self.test_loader) return { 'val_loss': total_loss / num_batches, 'val_recon_loss': recon_loss_sum / num_batches } def save_checkpoint(self, epoch: int, metrics: Dict[str, float]): """Save checkpoint locally and upload to HF Hub""" checkpoint_name = f"{self.device_name}_epoch_{epoch:03d}.pth" checkpoint_path = self.checkpoint_dir / checkpoint_name # Save checkpoint checkpoint = { 'epoch': epoch, 'model_state_dict': self.accelerator.unwrap_model(self.model).state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'metrics': metrics, 'config': { 'device_name': self.device_name, 'mem_dim': self.mem_dim, 'shrink_thres': self.shrink_thres, 'lr': self.lr, 'batch_size': self.batch_size } } torch.save(checkpoint, checkpoint_path) print(f"šŸ’¾ Checkpoint saved: {checkpoint_path}") # Upload to HF Hub if self.hf_repo: try: self.hf_api.upload_file( path_or_fileobj=str(checkpoint_path), path_in_repo=f"checkpoints/{checkpoint_name}", repo_id=self.hf_repo, repo_type="model", commit_message=f"Epoch {epoch} - {self.device_name}" ) print(f"ā˜ļø Uploaded to HF Hub: {self.hf_repo}") except Exception as e: print(f"āš ļø Failed to upload to HF Hub: {e}") def train(self, dataset_path: str): """Full training loop""" print(f"\nšŸš€ Starting training for {self.device_name}") print(f"šŸ“Š Epochs: {self.epochs}, Batch Size: {self.batch_size}, LR: {self.lr}") # Setup data self.setup_data(dataset_path) # Initialize wandb if self.wandb_project: self.accelerator.init_trackers( project_name=self.wandb_project, config={ 'device_name': self.device_name, 'mem_dim': self.mem_dim, 'lr': self.lr, 'batch_size': self.batch_size, 'epochs': self.epochs } ) # Training loop best_val_loss = float('inf') for epoch in range(1, self.epochs + 1): # Train train_metrics = self.train_epoch(epoch) # Validate every 10 epochs if epoch % 10 == 0: val_metrics = self.validate() metrics = {**train_metrics, **val_metrics} # Save best model if val_metrics['val_loss'] < best_val_loss: best_val_loss = val_metrics['val_loss'] self.save_checkpoint(epoch, metrics) # Log metrics if self.wandb_project: self.accelerator.log(metrics, step=epoch) print(f"\nšŸ“Š Epoch {epoch} - Train Loss: {train_metrics['train_loss']:.4f}, Val Loss: {val_metrics['val_loss']:.4f}") # Save checkpoint every 50 epochs if epoch % 50 == 0: self.save_checkpoint(epoch, train_metrics) print(f"\nāœ… Training complete for {self.device_name}!") print(f"šŸ“‚ Checkpoints saved to: {self.checkpoint_dir}") if self.hf_repo: print(f"ā˜ļø Model available at: https://huggingface.co/{self.hf_repo}") def main(): """Main training entry point""" import argparse parser = argparse.ArgumentParser(description="Train IPAD VAD model on HF infrastructure") parser.add_argument("--device", type=str, default="S01", help="Device name (S01-S12, R01-R04)") parser.add_argument("--epochs", type=int, default=200, help="Number of epochs") parser.add_argument("--batch-size", type=int, default=4, help="Batch size") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") parser.add_argument("--mem-dim", type=int, default=2000, help="Memory dimension") parser.add_argument("--no-wandb", action="store_true", help="Disable wandb logging") parser.add_argument("--dataset-path", type=str, default=None, help="Path to dataset (downloads if not provided)") args = parser.parse_args() # Download dataset if needed if args.dataset_path is None: dataset_path = download_and_extract_dataset() else: dataset_path = Path(args.dataset_path) # Create trainer trainer = IPADTrainer( device_name=args.device, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, mem_dim=args.mem_dim, wandb_project=None if args.no_wandb else "ipad-vad" ) # Train trainer.train(str(dataset_path)) if __name__ == "__main__": main()