| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
| from typing import Any |
|
|
| import pandas as pd |
|
|
| from .config import load_config |
| from .data_discovery import prepare_data |
| from .utils import get_logger |
|
|
|
|
| LOGGER = get_logger(__name__) |
|
|
|
|
| CLASSICAL_REGISTRY = { |
| "hog_svm": {"feature_type": "hog", "model_name": "hog_svm"}, |
| "lbp_svm": {"feature_type": "lbp", "model_name": "lbp_svm"}, |
| } |
|
|
|
|
| def load_or_prepare_splits(config: dict[str, Any]) -> pd.DataFrame: |
| split_csv = Path(config["paths"]["split_csv"]) |
| if split_csv.exists(): |
| LOGGER.info("Loading split metadata from %s", split_csv) |
| return pd.read_csv(split_csv) |
| LOGGER.info("Split metadata not found; preparing data first.") |
| return prepare_data(config) |
|
|
|
|
| def train_classical_models(config: dict[str, Any], model_keys: list[str] | None = None) -> list[Path]: |
| from .classical_models import train_classical_model |
|
|
| splits_df = load_or_prepare_splits(config) |
| enabled = config["models"]["enabled"] |
| requested = model_keys or [key for key in CLASSICAL_REGISTRY if enabled.get(key, False)] |
| paths: list[Path] = [] |
| for key in requested: |
| if key not in CLASSICAL_REGISTRY: |
| continue |
| spec = CLASSICAL_REGISTRY[key] |
| paths.append(train_classical_model(spec["feature_type"], spec["model_name"], splits_df, config)) |
| return paths |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Train classical HOG/LBP SVM egg classifiers.") |
| parser.add_argument("--config", default="configs/default.yaml") |
| parser.add_argument("--models", nargs="*", default=None, choices=list(CLASSICAL_REGISTRY)) |
| args = parser.parse_args() |
| config = load_config(args.config) |
| train_classical_models(config, args.models) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|