File size: 1,824 Bytes
29aaa24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from __future__ import annotations

import argparse
from pathlib import Path
from typing import Any

import pandas as pd

from .config import load_config
from .data_discovery import prepare_data
from .utils import get_logger


LOGGER = get_logger(__name__)


CLASSICAL_REGISTRY = {
    "hog_svm": {"feature_type": "hog", "model_name": "hog_svm"},
    "lbp_svm": {"feature_type": "lbp", "model_name": "lbp_svm"},
}


def load_or_prepare_splits(config: dict[str, Any]) -> pd.DataFrame:
    split_csv = Path(config["paths"]["split_csv"])
    if split_csv.exists():
        LOGGER.info("Loading split metadata from %s", split_csv)
        return pd.read_csv(split_csv)
    LOGGER.info("Split metadata not found; preparing data first.")
    return prepare_data(config)


def train_classical_models(config: dict[str, Any], model_keys: list[str] | None = None) -> list[Path]:
    from .classical_models import train_classical_model

    splits_df = load_or_prepare_splits(config)
    enabled = config["models"]["enabled"]
    requested = model_keys or [key for key in CLASSICAL_REGISTRY if enabled.get(key, False)]
    paths: list[Path] = []
    for key in requested:
        if key not in CLASSICAL_REGISTRY:
            continue
        spec = CLASSICAL_REGISTRY[key]
        paths.append(train_classical_model(spec["feature_type"], spec["model_name"], splits_df, config))
    return paths


def main() -> None:
    parser = argparse.ArgumentParser(description="Train classical HOG/LBP SVM egg classifiers.")
    parser.add_argument("--config", default="configs/default.yaml")
    parser.add_argument("--models", nargs="*", default=None, choices=list(CLASSICAL_REGISTRY))
    args = parser.parse_args()
    config = load_config(args.config)
    train_classical_models(config, args.models)


if __name__ == "__main__":
    main()