dipaug-project-hub / scripts /train_phase1.py
abersbail's picture
Deploy fixed DIPAug project hub
b5c1055 verified
raw
history blame contribute delete
716 Bytes
"""Phase 1 training entrypoint."""
from __future__ import annotations
import argparse
from dipauglib.utils.io import load_yaml
from dipaugnet.models.dipaugnet import DIPAugNet
from dipaugnet.training.engine import fit_phase1
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train DIPAug-Net phase 1 scaffold.")
parser.add_argument("--config", required=True)
return parser.parse_args()
def main() -> None:
args = parse_args()
config = load_yaml(args.config)
model = DIPAugNet(num_classes=config["dataset"]["num_classes"])
result = fit_phase1(model=model, optimizer=None, scheduler=None)
print(result)
if __name__ == "__main__":
main()