from __future__ import annotations import argparse from pathlib import Path from typing import Any import pandas as pd from .config import load_config from .data_discovery import prepare_data from .utils import get_logger LOGGER = get_logger(__name__) CLASSICAL_REGISTRY = { "hog_svm": {"feature_type": "hog", "model_name": "hog_svm"}, "lbp_svm": {"feature_type": "lbp", "model_name": "lbp_svm"}, } def load_or_prepare_splits(config: dict[str, Any]) -> pd.DataFrame: split_csv = Path(config["paths"]["split_csv"]) if split_csv.exists(): LOGGER.info("Loading split metadata from %s", split_csv) return pd.read_csv(split_csv) LOGGER.info("Split metadata not found; preparing data first.") return prepare_data(config) def train_classical_models(config: dict[str, Any], model_keys: list[str] | None = None) -> list[Path]: from .classical_models import train_classical_model splits_df = load_or_prepare_splits(config) enabled = config["models"]["enabled"] requested = model_keys or [key for key in CLASSICAL_REGISTRY if enabled.get(key, False)] paths: list[Path] = [] for key in requested: if key not in CLASSICAL_REGISTRY: continue spec = CLASSICAL_REGISTRY[key] paths.append(train_classical_model(spec["feature_type"], spec["model_name"], splits_df, config)) return paths def main() -> None: parser = argparse.ArgumentParser(description="Train classical HOG/LBP SVM egg classifiers.") parser.add_argument("--config", default="configs/default.yaml") parser.add_argument("--models", nargs="*", default=None, choices=list(CLASSICAL_REGISTRY)) args = parser.parse_args() config = load_config(args.config) train_classical_models(config, args.models) if __name__ == "__main__": main()