sol9x-sagar's picture
initial setup
2979822
"""Train MiniFASNet V2 SE for face anti-spoofing (2-class: Real, Spoof)."""
from src.minifasv2.config import TrainConfig
from src.minifasv2.main import Trainer
import argparse
import os
if __name__ == "__main__":
p = argparse.ArgumentParser(
description="Training Face-AntiSpoofing Model (2-class: Real, Spoof)"
)
p.add_argument(
"--crop_dir", type=str, default="data", help="Subdir with cropped images"
)
p.add_argument(
"--input_size",
type=int,
default=128,
help="Input size of images passed to model",
)
p.add_argument(
"--batch_size", type=int, default=256, help="Count of images in the batch"
)
p.add_argument(
"--resume",
type=str,
default=None,
help="Path to checkpoint file to resume training from",
)
p.add_argument(
"--transfer_learning",
action="store_true",
help="Use transfer learning mode (load only model weights, reset optimizer/scheduler)",
)
p.add_argument(
"--output_dir",
type=str,
default="./output",
help="Output directory for checkpoints and logs",
)
args = p.parse_args()
spoof_categories = [[0], [1, 2, 3, 7, 8, 9]]
config = TrainConfig(
crop_dir=args.crop_dir,
input_size=args.input_size,
batch_size=args.batch_size,
spoof_categories=spoof_categories,
output_dir=args.output_dir,
)
config.set_job("MINIFAS")
print("Device:", config.device)
resume_path = args.resume
if resume_path is None:
checkpoint_latest = os.path.join(config.model_path, "checkpoint_latest.pth")
if os.path.exists(checkpoint_latest):
resume_path = checkpoint_latest
print(f"Found existing checkpoint: {checkpoint_latest}")
print('Resuming training automatically. Use --resume "" to start fresh.')
trainer = Trainer(
config, resume_from=resume_path, transfer_learning=args.transfer_learning
)
trainer.train_model()
print("Finished")