File size: 716 Bytes
9c2e807 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | """Phase 1 training entrypoint."""
from __future__ import annotations
import argparse
from dipauglib.utils.io import load_yaml
from dipaugnet.models.dipaugnet import DIPAugNet
from dipaugnet.training.engine import fit_phase1
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train DIPAug-Net phase 1 scaffold.")
parser.add_argument("--config", required=True)
return parser.parse_args()
def main() -> None:
args = parse_args()
config = load_yaml(args.config)
model = DIPAugNet(num_classes=config["dataset"]["num_classes"])
result = fit_phase1(model=model, optimizer=None, scheduler=None)
print(result)
if __name__ == "__main__":
main()
|