| |
| """ |
| Train ViT-TRM on Something-Something V2. |
| |
| Examples: |
| # From scratch on local SSv2 data: |
| python train_ssv2.py --data_dir /path/to/ssv2 |
| |
| # Transfer from HMDB51 pretrained checkpoint: |
| python train_ssv2.py --data_dir /path/to/ssv2 --pretrained_ckpt ../vit-trm-hmdb51/vit-trm-epoch=29-val_acc=0.7113.ckpt |
| |
| # From HF Hub (if you have access): |
| python train_ssv2.py --from_hub |
| |
| # Quick smoke test (2 epochs, 1 batch): |
| python train_ssv2.py --data_dir /path/to/ssv2 --fast_dev_run |
| """ |
|
|
| import argparse |
| import pytorch_lightning as pl |
| from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
|
|
| from vit_trm_video import ViTTRMVideo |
| from ssv2_datamodule import SSv2DataModule |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train ViT-TRM on SSv2") |
|
|
| |
| parser.add_argument("--data_dir", type=str, default=None, help="Local SSv2 data directory") |
| parser.add_argument("--from_hub", action="store_true", help="Load SSv2 from HF Hub") |
| parser.add_argument("--hf_dataset_id", type=str, default="HuggingFaceM4/something-something-v2") |
| parser.add_argument("--num_frames", type=int, default=16) |
| parser.add_argument("--frame_stride", type=int, default=2, help="SSv2 videos are short, use stride=2") |
| parser.add_argument("--img_size", type=int, default=224) |
| parser.add_argument("--batch_size", type=int, default=8) |
| parser.add_argument("--num_workers", type=int, default=4) |
| parser.add_argument("--num_clips_val", type=int, default=4) |
|
|
| |
| parser.add_argument("--vit_name", type=str, default="vit_tiny_patch16_224") |
| parser.add_argument("--vit_pretrained", action="store_true", default=True) |
| parser.add_argument("--vit_freeze", action="store_true", default=False) |
| parser.add_argument("--trm_H_cycles", type=int, default=2) |
| parser.add_argument("--trm_L_layers", type=int, default=2) |
| parser.add_argument("--trm_num_heads", type=int, default=4) |
| parser.add_argument("--num_classes", type=int, default=174) |
| parser.add_argument("--pretrained_ckpt", type=str, default=None, |
| help="Path to HMDB51 checkpoint to transfer backbone+TRM weights from") |
|
|
| |
| parser.add_argument("--lr", type=float, default=3e-4) |
| parser.add_argument("--weight_decay", type=float, default=0.05) |
| parser.add_argument("--warmup_epochs", type=int, default=5) |
| parser.add_argument("--max_epochs", type=int, default=30) |
| parser.add_argument("--label_smoothing", type=float, default=0.1) |
| parser.add_argument("--iterative_refinement", action="store_true", default=False) |
|
|
| |
| parser.add_argument("--accelerator", type=str, default="auto") |
| parser.add_argument("--devices", type=int, default=1) |
| parser.add_argument("--precision", type=str, default="16-mixed") |
| parser.add_argument("--fast_dev_run", action="store_true", default=False) |
| parser.add_argument("--seed", type=int, default=42) |
|
|
| args = parser.parse_args() |
| pl.seed_everything(args.seed) |
|
|
| |
| data_dir = args.data_dir if not args.from_hub else None |
| dm = SSv2DataModule( |
| data_dir=data_dir, |
| hf_dataset_id=args.hf_dataset_id, |
| num_frames=args.num_frames, |
| frame_stride=args.frame_stride, |
| img_size=args.img_size, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers, |
| num_clips_val=args.num_clips_val, |
| ) |
|
|
| |
| model = ViTTRMVideo( |
| img_size=args.img_size, |
| vit_name=args.vit_name, |
| vit_pretrained=args.vit_pretrained, |
| vit_freeze=args.vit_freeze, |
| trm_H_cycles=args.trm_H_cycles, |
| trm_L_layers=args.trm_L_layers, |
| trm_num_heads=args.trm_num_heads, |
| num_classes=args.num_classes, |
| lr=args.lr, |
| weight_decay=args.weight_decay, |
| warmup_epochs=args.warmup_epochs, |
| max_epochs=args.max_epochs, |
| label_smoothing=args.label_smoothing, |
| iterative_refinement=args.iterative_refinement, |
| pretrained_ckpt=args.pretrained_ckpt, |
| ) |
|
|
| |
| ckpt_callback = ModelCheckpoint( |
| dirpath="checkpoints", |
| filename="vit-trm-ssv2-{epoch:02d}-{val_acc:.4f}", |
| monitor="val_acc", |
| mode="max", |
| save_top_k=3, |
| ) |
| lr_monitor = LearningRateMonitor(logging_interval="epoch") |
|
|
| |
| trainer = pl.Trainer( |
| accelerator=args.accelerator, |
| devices=args.devices, |
| precision=args.precision, |
| max_epochs=args.max_epochs, |
| callbacks=[ckpt_callback, lr_monitor], |
| fast_dev_run=args.fast_dev_run, |
| log_every_n_steps=50, |
| ) |
|
|
| trainer.fit(model, dm) |
|
|
| |
| if not args.fast_dev_run: |
| trainer.test(model, dm, ckpt_path="best") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|