File size: 2,084 Bytes
2979822
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Train MiniFASNet V2 SE for face anti-spoofing (2-class: Real, Spoof)."""

from src.minifasv2.config import TrainConfig
from src.minifasv2.main import Trainer
import argparse
import os

if __name__ == "__main__":
    p = argparse.ArgumentParser(
        description="Training Face-AntiSpoofing Model (2-class: Real, Spoof)"
    )
    p.add_argument(
        "--crop_dir", type=str, default="data", help="Subdir with cropped images"
    )
    p.add_argument(
        "--input_size",
        type=int,
        default=128,
        help="Input size of images passed to model",
    )
    p.add_argument(
        "--batch_size", type=int, default=256, help="Count of images in the batch"
    )
    p.add_argument(
        "--resume",
        type=str,
        default=None,
        help="Path to checkpoint file to resume training from",
    )
    p.add_argument(
        "--transfer_learning",
        action="store_true",
        help="Use transfer learning mode (load only model weights, reset optimizer/scheduler)",
    )
    p.add_argument(
        "--output_dir",
        type=str,
        default="./output",
        help="Output directory for checkpoints and logs",
    )
    args = p.parse_args()

    spoof_categories = [[0], [1, 2, 3, 7, 8, 9]]

    config = TrainConfig(
        crop_dir=args.crop_dir,
        input_size=args.input_size,
        batch_size=args.batch_size,
        spoof_categories=spoof_categories,
        output_dir=args.output_dir,
    )
    config.set_job("MINIFAS")
    print("Device:", config.device)

    resume_path = args.resume
    if resume_path is None:
        checkpoint_latest = os.path.join(config.model_path, "checkpoint_latest.pth")
        if os.path.exists(checkpoint_latest):
            resume_path = checkpoint_latest
            print(f"Found existing checkpoint: {checkpoint_latest}")
            print('Resuming training automatically. Use --resume "" to start fresh.')

    trainer = Trainer(
        config, resume_from=resume_path, transfer_learning=args.transfer_learning
    )
    trainer.train_model()
    print("Finished")