File size: 716 Bytes
9c2e807
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""Phase 1 training entrypoint."""

from __future__ import annotations

import argparse

from dipauglib.utils.io import load_yaml
from dipaugnet.models.dipaugnet import DIPAugNet
from dipaugnet.training.engine import fit_phase1


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Train DIPAug-Net phase 1 scaffold.")
    parser.add_argument("--config", required=True)
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    config = load_yaml(args.config)
    model = DIPAugNet(num_classes=config["dataset"]["num_classes"])
    result = fit_phase1(model=model, optimizer=None, scheduler=None)
    print(result)


if __name__ == "__main__":
    main()