vit-trm-ssv2 / train_ssv2.py
bcgxtberg's picture
Upload folder using huggingface_hub
77191d4 verified
#!/usr/bin/env python3
"""
Train ViT-TRM on Something-Something V2.
Examples:
# From scratch on local SSv2 data:
python train_ssv2.py --data_dir /path/to/ssv2
# Transfer from HMDB51 pretrained checkpoint:
python train_ssv2.py --data_dir /path/to/ssv2 --pretrained_ckpt ../vit-trm-hmdb51/vit-trm-epoch=29-val_acc=0.7113.ckpt
# From HF Hub (if you have access):
python train_ssv2.py --from_hub
# Quick smoke test (2 epochs, 1 batch):
python train_ssv2.py --data_dir /path/to/ssv2 --fast_dev_run
"""
import argparse
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from vit_trm_video import ViTTRMVideo
from ssv2_datamodule import SSv2DataModule
def main():
parser = argparse.ArgumentParser(description="Train ViT-TRM on SSv2")
# Data
parser.add_argument("--data_dir", type=str, default=None, help="Local SSv2 data directory")
parser.add_argument("--from_hub", action="store_true", help="Load SSv2 from HF Hub")
parser.add_argument("--hf_dataset_id", type=str, default="HuggingFaceM4/something-something-v2")
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--frame_stride", type=int, default=2, help="SSv2 videos are short, use stride=2")
parser.add_argument("--img_size", type=int, default=224)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--num_clips_val", type=int, default=4)
# Model
parser.add_argument("--vit_name", type=str, default="vit_tiny_patch16_224")
parser.add_argument("--vit_pretrained", action="store_true", default=True)
parser.add_argument("--vit_freeze", action="store_true", default=False)
parser.add_argument("--trm_H_cycles", type=int, default=2)
parser.add_argument("--trm_L_layers", type=int, default=2)
parser.add_argument("--trm_num_heads", type=int, default=4)
parser.add_argument("--num_classes", type=int, default=174)
parser.add_argument("--pretrained_ckpt", type=str, default=None,
help="Path to HMDB51 checkpoint to transfer backbone+TRM weights from")
# Training
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--weight_decay", type=float, default=0.05)
parser.add_argument("--warmup_epochs", type=int, default=5)
parser.add_argument("--max_epochs", type=int, default=30)
parser.add_argument("--label_smoothing", type=float, default=0.1)
parser.add_argument("--iterative_refinement", action="store_true", default=False)
# Trainer
parser.add_argument("--accelerator", type=str, default="auto")
parser.add_argument("--devices", type=int, default=1)
parser.add_argument("--precision", type=str, default="16-mixed")
parser.add_argument("--fast_dev_run", action="store_true", default=False)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
pl.seed_everything(args.seed)
# Data
data_dir = args.data_dir if not args.from_hub else None
dm = SSv2DataModule(
data_dir=data_dir,
hf_dataset_id=args.hf_dataset_id,
num_frames=args.num_frames,
frame_stride=args.frame_stride,
img_size=args.img_size,
batch_size=args.batch_size,
num_workers=args.num_workers,
num_clips_val=args.num_clips_val,
)
# Model
model = ViTTRMVideo(
img_size=args.img_size,
vit_name=args.vit_name,
vit_pretrained=args.vit_pretrained,
vit_freeze=args.vit_freeze,
trm_H_cycles=args.trm_H_cycles,
trm_L_layers=args.trm_L_layers,
trm_num_heads=args.trm_num_heads,
num_classes=args.num_classes,
lr=args.lr,
weight_decay=args.weight_decay,
warmup_epochs=args.warmup_epochs,
max_epochs=args.max_epochs,
label_smoothing=args.label_smoothing,
iterative_refinement=args.iterative_refinement,
pretrained_ckpt=args.pretrained_ckpt,
)
# Callbacks
ckpt_callback = ModelCheckpoint(
dirpath="checkpoints",
filename="vit-trm-ssv2-{epoch:02d}-{val_acc:.4f}",
monitor="val_acc",
mode="max",
save_top_k=3,
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
# Trainer
trainer = pl.Trainer(
accelerator=args.accelerator,
devices=args.devices,
precision=args.precision,
max_epochs=args.max_epochs,
callbacks=[ckpt_callback, lr_monitor],
fast_dev_run=args.fast_dev_run,
log_every_n_steps=50,
)
trainer.fit(model, dm)
# Test with best checkpoint
if not args.fast_dev_run:
trainer.test(model, dm, ckpt_path="best")
if __name__ == "__main__":
main()