budijuarto's picture
Upload src/egg_damage/train_all.py
c9fb532 verified
from __future__ import annotations
import argparse
from typing import Any
from .config import load_config
from .data_discovery import prepare_data
from .utils import environment_summary, get_logger, save_json
LOGGER = get_logger(__name__)
CLASSICAL_KEYS = ["hog_svm", "lbp_svm"]
DL_KEYS = ["mobilenet_v3", "resnet50", "efficientnet_b0", "densenet121", "xception", "vit_small"]
def train_all(
config: dict[str, Any],
model_keys: list[str] | None = None,
skip_classical: bool = False,
skip_dl: bool = False,
) -> None:
from .train_classical import train_classical_models
from .train_dl import train_dl_models
prepare_data(config)
selected = set(model_keys or [])
enabled = config["models"]["enabled"]
if not skip_classical:
classical_keys = [
key for key in CLASSICAL_KEYS if (key in selected if selected else enabled.get(key, False))
]
if classical_keys:
train_classical_models(config, classical_keys)
if not skip_dl:
dl_keys = [key for key in DL_KEYS if (key in selected if selected else enabled.get(key, False))]
if dl_keys:
train_dl_models(config, dl_keys)
save_json(environment_summary(), f"{config['paths']['output_dir']}/environment.json")
def main() -> None:
all_keys = CLASSICAL_KEYS + DL_KEYS
parser = argparse.ArgumentParser(description="Prepare data and train all enabled egg classifiers.")
parser.add_argument("--config", default="configs/default.yaml")
parser.add_argument("--models", nargs="*", default=None, choices=all_keys)
parser.add_argument("--skip-classical", action="store_true")
parser.add_argument("--skip-dl", action="store_true")
args = parser.parse_args()
config = load_config(args.config)
train_all(config, args.models, skip_classical=args.skip_classical, skip_dl=args.skip_dl)
if __name__ == "__main__":
main()