File size: 1,824 Bytes
29aaa24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | from __future__ import annotations
import argparse
from pathlib import Path
from typing import Any
import pandas as pd
from .config import load_config
from .data_discovery import prepare_data
from .utils import get_logger
LOGGER = get_logger(__name__)
CLASSICAL_REGISTRY = {
"hog_svm": {"feature_type": "hog", "model_name": "hog_svm"},
"lbp_svm": {"feature_type": "lbp", "model_name": "lbp_svm"},
}
def load_or_prepare_splits(config: dict[str, Any]) -> pd.DataFrame:
split_csv = Path(config["paths"]["split_csv"])
if split_csv.exists():
LOGGER.info("Loading split metadata from %s", split_csv)
return pd.read_csv(split_csv)
LOGGER.info("Split metadata not found; preparing data first.")
return prepare_data(config)
def train_classical_models(config: dict[str, Any], model_keys: list[str] | None = None) -> list[Path]:
from .classical_models import train_classical_model
splits_df = load_or_prepare_splits(config)
enabled = config["models"]["enabled"]
requested = model_keys or [key for key in CLASSICAL_REGISTRY if enabled.get(key, False)]
paths: list[Path] = []
for key in requested:
if key not in CLASSICAL_REGISTRY:
continue
spec = CLASSICAL_REGISTRY[key]
paths.append(train_classical_model(spec["feature_type"], spec["model_name"], splits_df, config))
return paths
def main() -> None:
parser = argparse.ArgumentParser(description="Train classical HOG/LBP SVM egg classifiers.")
parser.add_argument("--config", default="configs/default.yaml")
parser.add_argument("--models", nargs="*", default=None, choices=list(CLASSICAL_REGISTRY))
args = parser.parse_args()
config = load_config(args.config)
train_classical_models(config, args.models)
if __name__ == "__main__":
main()
|