egg-damage-top3-classifier / src /egg_damage /classical_models.py
budijuarto's picture
Upload src/egg_damage/classical_models.py
298b9fd verified
from __future__ import annotations
from itertools import product
from pathlib import Path
from typing import Any
import joblib
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, roc_auc_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from .classical_features import extract_feature_matrix
from .data_discovery import CANONICAL_LABELS
from .paths import ensure_dir
from .utils import get_logger, now_stamp, save_json
LOGGER = get_logger(__name__)
def _as_list(value: Any) -> list[Any]:
return value if isinstance(value, list) else [value]
def candidate_svm_params(config: dict[str, Any]) -> list[dict[str, Any]]:
svm_cfg = config["classical"]["svm"]
candidates: list[dict[str, Any]] = []
for kernel, c_value, gamma in product(
_as_list(svm_cfg.get("kernel", ["rbf"])),
_as_list(svm_cfg.get("C", [1.0])),
_as_list(svm_cfg.get("gamma", ["scale"])),
):
candidates.append(
{
"kernel": kernel,
"C": float(c_value),
"gamma": gamma,
"class_weight": svm_cfg.get("class_weight", "balanced"),
}
)
return candidates
def make_svm_pipeline(params: dict[str, Any]) -> Pipeline:
return Pipeline(
[
("scaler", StandardScaler()),
(
"svm",
SVC(
kernel=params["kernel"],
C=params["C"],
gamma=params["gamma"],
class_weight=params.get("class_weight", "balanced"),
probability=True,
random_state=42,
),
),
]
)
def train_classical_model(
feature_type: str,
model_name: str,
splits_df: pd.DataFrame,
config: dict[str, Any],
) -> Path:
model_dir = ensure_dir(config["paths"]["model_dir"])
train_df = splits_df[splits_df["split"] == "train"].reset_index(drop=True)
val_df = splits_df[splits_df["split"] == "val"].reset_index(drop=True)
if train_df.empty or val_df.empty:
raise ValueError(f"{model_name} needs non-empty train and validation splits.")
LOGGER.info("Extracting %s training features.", feature_type.upper())
x_train, y_train, expanded_train = extract_feature_matrix(
train_df, feature_type, config, balance_train=True
)
LOGGER.info("Extracting %s validation features.", feature_type.upper())
x_val, y_val, _ = extract_feature_matrix(val_df, feature_type, config, balance_train=False)
best_pipeline: Pipeline | None = None
best_result: dict[str, Any] | None = None
for params in candidate_svm_params(config):
LOGGER.info("Training %s candidate params: %s", model_name, params)
pipeline = make_svm_pipeline(params)
pipeline.fit(x_train, y_train)
pred = pipeline.predict(x_val)
prob = pipeline.predict_proba(x_val)[:, 1]
f1 = f1_score(y_val, pred, zero_division=0)
try:
auc = roc_auc_score(y_val, prob)
except ValueError:
auc = float("nan")
score = (f1, 0.0 if np.isnan(auc) else auc)
if best_result is None or score > (best_result["val_f1"], best_result["val_roc_auc_safe"]):
best_pipeline = pipeline
best_result = {
"params": params,
"val_f1": float(f1),
"val_roc_auc": None if np.isnan(auc) else float(auc),
"val_roc_auc_safe": 0.0 if np.isnan(auc) else float(auc),
}
if best_pipeline is None or best_result is None:
raise RuntimeError(f"No SVM candidate completed for {model_name}.")
metadata = {
"model_name": model_name,
"model_type": "classical",
"feature_type": feature_type,
"class_names": list(CANONICAL_LABELS),
"positive_class": "Damaged",
"created_at": now_stamp(),
"image_size": int(config["preprocessing"]["image_size"]),
"best_validation": {k: v for k, v in best_result.items() if k != "val_roc_auc_safe"},
"train_rows": int(len(train_df)),
"train_rows_after_balancing": int(len(expanded_train)),
"validation_rows": int(len(val_df)),
"config": config,
}
bundle = {"pipeline": best_pipeline, "metadata": metadata}
path = model_dir / f"{model_name}.joblib"
joblib.dump(bundle, path)
save_json(metadata, model_dir / f"{model_name}_metadata.json")
LOGGER.info("Saved %s model: %s", model_name, path)
return path