| from __future__ import annotations |
|
|
| from itertools import product |
| from pathlib import Path |
| from typing import Any |
|
|
| import joblib |
| import numpy as np |
| import pandas as pd |
| from sklearn.metrics import f1_score, roc_auc_score |
| from sklearn.pipeline import Pipeline |
| from sklearn.preprocessing import StandardScaler |
| from sklearn.svm import SVC |
|
|
| from .classical_features import extract_feature_matrix |
| from .data_discovery import CANONICAL_LABELS |
| from .paths import ensure_dir |
| from .utils import get_logger, now_stamp, save_json |
|
|
|
|
| LOGGER = get_logger(__name__) |
|
|
|
|
| def _as_list(value: Any) -> list[Any]: |
| return value if isinstance(value, list) else [value] |
|
|
|
|
| def candidate_svm_params(config: dict[str, Any]) -> list[dict[str, Any]]: |
| svm_cfg = config["classical"]["svm"] |
| candidates: list[dict[str, Any]] = [] |
| for kernel, c_value, gamma in product( |
| _as_list(svm_cfg.get("kernel", ["rbf"])), |
| _as_list(svm_cfg.get("C", [1.0])), |
| _as_list(svm_cfg.get("gamma", ["scale"])), |
| ): |
| candidates.append( |
| { |
| "kernel": kernel, |
| "C": float(c_value), |
| "gamma": gamma, |
| "class_weight": svm_cfg.get("class_weight", "balanced"), |
| } |
| ) |
| return candidates |
|
|
|
|
| def make_svm_pipeline(params: dict[str, Any]) -> Pipeline: |
| return Pipeline( |
| [ |
| ("scaler", StandardScaler()), |
| ( |
| "svm", |
| SVC( |
| kernel=params["kernel"], |
| C=params["C"], |
| gamma=params["gamma"], |
| class_weight=params.get("class_weight", "balanced"), |
| probability=True, |
| random_state=42, |
| ), |
| ), |
| ] |
| ) |
|
|
|
|
| def train_classical_model( |
| feature_type: str, |
| model_name: str, |
| splits_df: pd.DataFrame, |
| config: dict[str, Any], |
| ) -> Path: |
| model_dir = ensure_dir(config["paths"]["model_dir"]) |
| train_df = splits_df[splits_df["split"] == "train"].reset_index(drop=True) |
| val_df = splits_df[splits_df["split"] == "val"].reset_index(drop=True) |
| if train_df.empty or val_df.empty: |
| raise ValueError(f"{model_name} needs non-empty train and validation splits.") |
|
|
| LOGGER.info("Extracting %s training features.", feature_type.upper()) |
| x_train, y_train, expanded_train = extract_feature_matrix( |
| train_df, feature_type, config, balance_train=True |
| ) |
| LOGGER.info("Extracting %s validation features.", feature_type.upper()) |
| x_val, y_val, _ = extract_feature_matrix(val_df, feature_type, config, balance_train=False) |
|
|
| best_pipeline: Pipeline | None = None |
| best_result: dict[str, Any] | None = None |
| for params in candidate_svm_params(config): |
| LOGGER.info("Training %s candidate params: %s", model_name, params) |
| pipeline = make_svm_pipeline(params) |
| pipeline.fit(x_train, y_train) |
| pred = pipeline.predict(x_val) |
| prob = pipeline.predict_proba(x_val)[:, 1] |
| f1 = f1_score(y_val, pred, zero_division=0) |
| try: |
| auc = roc_auc_score(y_val, prob) |
| except ValueError: |
| auc = float("nan") |
| score = (f1, 0.0 if np.isnan(auc) else auc) |
| if best_result is None or score > (best_result["val_f1"], best_result["val_roc_auc_safe"]): |
| best_pipeline = pipeline |
| best_result = { |
| "params": params, |
| "val_f1": float(f1), |
| "val_roc_auc": None if np.isnan(auc) else float(auc), |
| "val_roc_auc_safe": 0.0 if np.isnan(auc) else float(auc), |
| } |
|
|
| if best_pipeline is None or best_result is None: |
| raise RuntimeError(f"No SVM candidate completed for {model_name}.") |
|
|
| metadata = { |
| "model_name": model_name, |
| "model_type": "classical", |
| "feature_type": feature_type, |
| "class_names": list(CANONICAL_LABELS), |
| "positive_class": "Damaged", |
| "created_at": now_stamp(), |
| "image_size": int(config["preprocessing"]["image_size"]), |
| "best_validation": {k: v for k, v in best_result.items() if k != "val_roc_auc_safe"}, |
| "train_rows": int(len(train_df)), |
| "train_rows_after_balancing": int(len(expanded_train)), |
| "validation_rows": int(len(val_df)), |
| "config": config, |
| } |
| bundle = {"pipeline": best_pipeline, "metadata": metadata} |
| path = model_dir / f"{model_name}.joblib" |
| joblib.dump(bundle, path) |
| save_json(metadata, model_dir / f"{model_name}_metadata.json") |
| LOGGER.info("Saved %s model: %s", model_name, path) |
| return path |
|
|
|
|