ipad-vad-training / train_hf.py
MSherbinii's picture
Add HF-adapted training script with Accelerate
c3981cb verified
"""
HuggingFace-adapted IPAD Training Script
Trains on HF infrastructure with ZeroGPU, Accelerate, and automatic checkpointing
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.cuda.amp import autocast, GradScaler
from pathlib import Path
import json
from datetime import datetime
from tqdm import tqdm
import wandb
from typing import Dict, Optional
import os
# HF infrastructure
from huggingface_hub import HfApi, create_repo
from accelerate import Accelerator
# Local imports
from IPAD.model.video_swin_transformer import VST
from IPAD.model.entropy_loss import EntropyLossEncap
from dataset import create_dataloaders, download_and_extract_dataset
class IPADTrainer:
"""
IPAD Model Trainer with HF Integration
"""
def __init__(
self,
device_name: str = "S01",
mem_dim: int = 2000,
shrink_thres: float = 0.0025,
lr: float = 1e-4,
batch_size: int = 4,
epochs: int = 200,
entropy_loss_weight: float = 0.0002,
period_loss_weight: float = 0.02,
checkpoint_dir: str = "./checkpoints",
wandb_project: Optional[str] = "ipad-vad",
hf_repo: Optional[str] = "MSherbinii/ipad-vad-checkpoints"
):
self.device_name = device_name
self.mem_dim = mem_dim
self.shrink_thres = shrink_thres
self.lr = lr
self.batch_size = batch_size
self.epochs = epochs
self.entropy_loss_weight = entropy_loss_weight
self.period_loss_weight = period_loss_weight
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(exist_ok=True, parents=True)
self.wandb_project = wandb_project
self.hf_repo = hf_repo
# Initialize Accelerator for distributed training
self.accelerator = Accelerator(
mixed_precision='fp16',
gradient_accumulation_steps=1,
log_with="wandb" if wandb_project else None
)
# Model
self.model = VST(mem_dim=mem_dim, shrink_thres=shrink_thres)
# Losses
self.recon_criterion = nn.MSELoss()
self.entropy_criterion = EntropyLossEncap()
self.period_criterion = nn.CrossEntropyLoss()
# Optimizer
self.optimizer = Adam(self.model.parameters(), lr=lr)
# HF API
self.hf_api = HfApi()
if hf_repo:
try:
create_repo(hf_repo, repo_type="model", private=False, exist_ok=True)
except:
pass
def setup_data(self, dataset_path: str):
"""Setup dataloaders"""
self.train_loader, self.test_loader = create_dataloaders(
dataset_path=dataset_path,
device_name=self.device_name,
batch_size=self.batch_size,
num_workers=4,
clip_length=16,
frame_size=(256, 256)
)
# Prepare with Accelerator
self.model, self.optimizer, self.train_loader, self.test_loader = \
self.accelerator.prepare(
self.model, self.optimizer, self.train_loader, self.test_loader
)
def train_epoch(self, epoch: int) -> Dict[str, float]:
"""Train for one epoch"""
self.model.train()
total_loss = 0.0
recon_loss_sum = 0.0
entropy_loss_sum = 0.0
period_loss_sum = 0.0
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.epochs}")
for batch_idx, clips in enumerate(pbar):
# clips shape: [B, C, T, H, W]
with self.accelerator.autocast():
# Forward pass
outputs = self.model(clips)
reconstructed = outputs['output']
att = outputs['att']
period_pred = outputs['recon_index']
# Reconstruction loss
recon_loss = self.recon_criterion(reconstructed, clips)
# Entropy loss on attention weights
entropy_loss = self.entropy_criterion(att)
# Period classification loss
# Create pseudo-labels (uniform distribution for now)
# In full implementation, this would use actual period annotations
period_labels = torch.randint(0, 200, (clips.size(0),)).to(clips.device)
period_loss = self.period_criterion(period_pred, period_labels)
# Combined loss
loss = (recon_loss +
self.entropy_loss_weight * entropy_loss +
self.period_loss_weight * period_loss)
# Backward pass
self.accelerator.backward(loss)
self.optimizer.step()
self.optimizer.zero_grad()
# Accumulate losses
total_loss += loss.item()
recon_loss_sum += recon_loss.item()
entropy_loss_sum += entropy_loss.item()
period_loss_sum += period_loss.item()
# Update progress bar
pbar.set_postfix({
'loss': f'{loss.item():.4f}',
'recon': f'{recon_loss.item():.4f}',
'entropy': f'{entropy_loss.item():.6f}',
'period': f'{period_loss.item():.4f}'
})
num_batches = len(self.train_loader)
return {
'train_loss': total_loss / num_batches,
'train_recon_loss': recon_loss_sum / num_batches,
'train_entropy_loss': entropy_loss_sum / num_batches,
'train_period_loss': period_loss_sum / num_batches
}
@torch.no_grad()
def validate(self) -> Dict[str, float]:
"""Validate on test set"""
self.model.eval()
total_loss = 0.0
recon_loss_sum = 0.0
for clips in tqdm(self.test_loader, desc="Validating"):
with self.accelerator.autocast():
outputs = self.model(clips)
reconstructed = outputs['output']
recon_loss = self.recon_criterion(reconstructed, clips)
total_loss += recon_loss.item()
recon_loss_sum += recon_loss.item()
num_batches = len(self.test_loader)
return {
'val_loss': total_loss / num_batches,
'val_recon_loss': recon_loss_sum / num_batches
}
def save_checkpoint(self, epoch: int, metrics: Dict[str, float]):
"""Save checkpoint locally and upload to HF Hub"""
checkpoint_name = f"{self.device_name}_epoch_{epoch:03d}.pth"
checkpoint_path = self.checkpoint_dir / checkpoint_name
# Save checkpoint
checkpoint = {
'epoch': epoch,
'model_state_dict': self.accelerator.unwrap_model(self.model).state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'metrics': metrics,
'config': {
'device_name': self.device_name,
'mem_dim': self.mem_dim,
'shrink_thres': self.shrink_thres,
'lr': self.lr,
'batch_size': self.batch_size
}
}
torch.save(checkpoint, checkpoint_path)
print(f"๐Ÿ’พ Checkpoint saved: {checkpoint_path}")
# Upload to HF Hub
if self.hf_repo:
try:
self.hf_api.upload_file(
path_or_fileobj=str(checkpoint_path),
path_in_repo=f"checkpoints/{checkpoint_name}",
repo_id=self.hf_repo,
repo_type="model",
commit_message=f"Epoch {epoch} - {self.device_name}"
)
print(f"โ˜๏ธ Uploaded to HF Hub: {self.hf_repo}")
except Exception as e:
print(f"โš ๏ธ Failed to upload to HF Hub: {e}")
def train(self, dataset_path: str):
"""Full training loop"""
print(f"\n๐Ÿš€ Starting training for {self.device_name}")
print(f"๐Ÿ“Š Epochs: {self.epochs}, Batch Size: {self.batch_size}, LR: {self.lr}")
# Setup data
self.setup_data(dataset_path)
# Initialize wandb
if self.wandb_project:
self.accelerator.init_trackers(
project_name=self.wandb_project,
config={
'device_name': self.device_name,
'mem_dim': self.mem_dim,
'lr': self.lr,
'batch_size': self.batch_size,
'epochs': self.epochs
}
)
# Training loop
best_val_loss = float('inf')
for epoch in range(1, self.epochs + 1):
# Train
train_metrics = self.train_epoch(epoch)
# Validate every 10 epochs
if epoch % 10 == 0:
val_metrics = self.validate()
metrics = {**train_metrics, **val_metrics}
# Save best model
if val_metrics['val_loss'] < best_val_loss:
best_val_loss = val_metrics['val_loss']
self.save_checkpoint(epoch, metrics)
# Log metrics
if self.wandb_project:
self.accelerator.log(metrics, step=epoch)
print(f"\n๐Ÿ“Š Epoch {epoch} - Train Loss: {train_metrics['train_loss']:.4f}, Val Loss: {val_metrics['val_loss']:.4f}")
# Save checkpoint every 50 epochs
if epoch % 50 == 0:
self.save_checkpoint(epoch, train_metrics)
print(f"\nโœ… Training complete for {self.device_name}!")
print(f"๐Ÿ“‚ Checkpoints saved to: {self.checkpoint_dir}")
if self.hf_repo:
print(f"โ˜๏ธ Model available at: https://huggingface.co/{self.hf_repo}")
def main():
"""Main training entry point"""
import argparse
parser = argparse.ArgumentParser(description="Train IPAD VAD model on HF infrastructure")
parser.add_argument("--device", type=str, default="S01", help="Device name (S01-S12, R01-R04)")
parser.add_argument("--epochs", type=int, default=200, help="Number of epochs")
parser.add_argument("--batch-size", type=int, default=4, help="Batch size")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--mem-dim", type=int, default=2000, help="Memory dimension")
parser.add_argument("--no-wandb", action="store_true", help="Disable wandb logging")
parser.add_argument("--dataset-path", type=str, default=None, help="Path to dataset (downloads if not provided)")
args = parser.parse_args()
# Download dataset if needed
if args.dataset_path is None:
dataset_path = download_and_extract_dataset()
else:
dataset_path = Path(args.dataset_path)
# Create trainer
trainer = IPADTrainer(
device_name=args.device,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
mem_dim=args.mem_dim,
wandb_project=None if args.no_wandb else "ipad-vad"
)
# Train
trainer.train(str(dataset_path))
if __name__ == "__main__":
main()