| from __future__ import annotations |
|
|
| import argparse |
| from typing import Any |
|
|
| from .config import load_config |
| from .data_discovery import prepare_data |
| from .utils import environment_summary, get_logger, save_json |
|
|
|
|
| LOGGER = get_logger(__name__) |
| CLASSICAL_KEYS = ["hog_svm", "lbp_svm"] |
| DL_KEYS = ["mobilenet_v3", "resnet50", "efficientnet_b0", "densenet121", "xception", "vit_small"] |
|
|
|
|
| def train_all( |
| config: dict[str, Any], |
| model_keys: list[str] | None = None, |
| skip_classical: bool = False, |
| skip_dl: bool = False, |
| ) -> None: |
| from .train_classical import train_classical_models |
| from .train_dl import train_dl_models |
|
|
| prepare_data(config) |
| selected = set(model_keys or []) |
| enabled = config["models"]["enabled"] |
| if not skip_classical: |
| classical_keys = [ |
| key for key in CLASSICAL_KEYS if (key in selected if selected else enabled.get(key, False)) |
| ] |
| if classical_keys: |
| train_classical_models(config, classical_keys) |
| if not skip_dl: |
| dl_keys = [key for key in DL_KEYS if (key in selected if selected else enabled.get(key, False))] |
| if dl_keys: |
| train_dl_models(config, dl_keys) |
| save_json(environment_summary(), f"{config['paths']['output_dir']}/environment.json") |
|
|
|
|
| def main() -> None: |
| all_keys = CLASSICAL_KEYS + DL_KEYS |
| parser = argparse.ArgumentParser(description="Prepare data and train all enabled egg classifiers.") |
| parser.add_argument("--config", default="configs/default.yaml") |
| parser.add_argument("--models", nargs="*", default=None, choices=all_keys) |
| parser.add_argument("--skip-classical", action="store_true") |
| parser.add_argument("--skip-dl", action="store_true") |
| args = parser.parse_args() |
| config = load_config(args.config) |
| train_all(config, args.models, skip_classical=args.skip_classical, skip_dl=args.skip_dl) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|