from __future__ import annotations import argparse from typing import Any from .config import load_config from .data_discovery import prepare_data from .utils import environment_summary, get_logger, save_json LOGGER = get_logger(__name__) CLASSICAL_KEYS = ["hog_svm", "lbp_svm"] DL_KEYS = ["mobilenet_v3", "resnet50", "efficientnet_b0", "densenet121", "xception", "vit_small"] def train_all( config: dict[str, Any], model_keys: list[str] | None = None, skip_classical: bool = False, skip_dl: bool = False, ) -> None: from .train_classical import train_classical_models from .train_dl import train_dl_models prepare_data(config) selected = set(model_keys or []) enabled = config["models"]["enabled"] if not skip_classical: classical_keys = [ key for key in CLASSICAL_KEYS if (key in selected if selected else enabled.get(key, False)) ] if classical_keys: train_classical_models(config, classical_keys) if not skip_dl: dl_keys = [key for key in DL_KEYS if (key in selected if selected else enabled.get(key, False))] if dl_keys: train_dl_models(config, dl_keys) save_json(environment_summary(), f"{config['paths']['output_dir']}/environment.json") def main() -> None: all_keys = CLASSICAL_KEYS + DL_KEYS parser = argparse.ArgumentParser(description="Prepare data and train all enabled egg classifiers.") parser.add_argument("--config", default="configs/default.yaml") parser.add_argument("--models", nargs="*", default=None, choices=all_keys) parser.add_argument("--skip-classical", action="store_true") parser.add_argument("--skip-dl", action="store_true") args = parser.parse_args() config = load_config(args.config) train_all(config, args.models, skip_classical=args.skip_classical, skip_dl=args.skip_dl) if __name__ == "__main__": main()