egg-damage-top3-classifier / src /egg_damage /train_classical.py
budijuarto's picture
Upload src/egg_damage/train_classical.py
29aaa24 verified
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Any
import pandas as pd
from .config import load_config
from .data_discovery import prepare_data
from .utils import get_logger
LOGGER = get_logger(__name__)
CLASSICAL_REGISTRY = {
"hog_svm": {"feature_type": "hog", "model_name": "hog_svm"},
"lbp_svm": {"feature_type": "lbp", "model_name": "lbp_svm"},
}
def load_or_prepare_splits(config: dict[str, Any]) -> pd.DataFrame:
split_csv = Path(config["paths"]["split_csv"])
if split_csv.exists():
LOGGER.info("Loading split metadata from %s", split_csv)
return pd.read_csv(split_csv)
LOGGER.info("Split metadata not found; preparing data first.")
return prepare_data(config)
def train_classical_models(config: dict[str, Any], model_keys: list[str] | None = None) -> list[Path]:
from .classical_models import train_classical_model
splits_df = load_or_prepare_splits(config)
enabled = config["models"]["enabled"]
requested = model_keys or [key for key in CLASSICAL_REGISTRY if enabled.get(key, False)]
paths: list[Path] = []
for key in requested:
if key not in CLASSICAL_REGISTRY:
continue
spec = CLASSICAL_REGISTRY[key]
paths.append(train_classical_model(spec["feature_type"], spec["model_name"], splits_df, config))
return paths
def main() -> None:
parser = argparse.ArgumentParser(description="Train classical HOG/LBP SVM egg classifiers.")
parser.add_argument("--config", default="configs/default.yaml")
parser.add_argument("--models", nargs="*", default=None, choices=list(CLASSICAL_REGISTRY))
args = parser.parse_args()
config = load_config(args.config)
train_classical_models(config, args.models)
if __name__ == "__main__":
main()