File size: 4,627 Bytes
298b9fd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | from __future__ import annotations
from itertools import product
from pathlib import Path
from typing import Any
import joblib
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, roc_auc_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from .classical_features import extract_feature_matrix
from .data_discovery import CANONICAL_LABELS
from .paths import ensure_dir
from .utils import get_logger, now_stamp, save_json
LOGGER = get_logger(__name__)
def _as_list(value: Any) -> list[Any]:
return value if isinstance(value, list) else [value]
def candidate_svm_params(config: dict[str, Any]) -> list[dict[str, Any]]:
svm_cfg = config["classical"]["svm"]
candidates: list[dict[str, Any]] = []
for kernel, c_value, gamma in product(
_as_list(svm_cfg.get("kernel", ["rbf"])),
_as_list(svm_cfg.get("C", [1.0])),
_as_list(svm_cfg.get("gamma", ["scale"])),
):
candidates.append(
{
"kernel": kernel,
"C": float(c_value),
"gamma": gamma,
"class_weight": svm_cfg.get("class_weight", "balanced"),
}
)
return candidates
def make_svm_pipeline(params: dict[str, Any]) -> Pipeline:
return Pipeline(
[
("scaler", StandardScaler()),
(
"svm",
SVC(
kernel=params["kernel"],
C=params["C"],
gamma=params["gamma"],
class_weight=params.get("class_weight", "balanced"),
probability=True,
random_state=42,
),
),
]
)
def train_classical_model(
feature_type: str,
model_name: str,
splits_df: pd.DataFrame,
config: dict[str, Any],
) -> Path:
model_dir = ensure_dir(config["paths"]["model_dir"])
train_df = splits_df[splits_df["split"] == "train"].reset_index(drop=True)
val_df = splits_df[splits_df["split"] == "val"].reset_index(drop=True)
if train_df.empty or val_df.empty:
raise ValueError(f"{model_name} needs non-empty train and validation splits.")
LOGGER.info("Extracting %s training features.", feature_type.upper())
x_train, y_train, expanded_train = extract_feature_matrix(
train_df, feature_type, config, balance_train=True
)
LOGGER.info("Extracting %s validation features.", feature_type.upper())
x_val, y_val, _ = extract_feature_matrix(val_df, feature_type, config, balance_train=False)
best_pipeline: Pipeline | None = None
best_result: dict[str, Any] | None = None
for params in candidate_svm_params(config):
LOGGER.info("Training %s candidate params: %s", model_name, params)
pipeline = make_svm_pipeline(params)
pipeline.fit(x_train, y_train)
pred = pipeline.predict(x_val)
prob = pipeline.predict_proba(x_val)[:, 1]
f1 = f1_score(y_val, pred, zero_division=0)
try:
auc = roc_auc_score(y_val, prob)
except ValueError:
auc = float("nan")
score = (f1, 0.0 if np.isnan(auc) else auc)
if best_result is None or score > (best_result["val_f1"], best_result["val_roc_auc_safe"]):
best_pipeline = pipeline
best_result = {
"params": params,
"val_f1": float(f1),
"val_roc_auc": None if np.isnan(auc) else float(auc),
"val_roc_auc_safe": 0.0 if np.isnan(auc) else float(auc),
}
if best_pipeline is None or best_result is None:
raise RuntimeError(f"No SVM candidate completed for {model_name}.")
metadata = {
"model_name": model_name,
"model_type": "classical",
"feature_type": feature_type,
"class_names": list(CANONICAL_LABELS),
"positive_class": "Damaged",
"created_at": now_stamp(),
"image_size": int(config["preprocessing"]["image_size"]),
"best_validation": {k: v for k, v in best_result.items() if k != "val_roc_auc_safe"},
"train_rows": int(len(train_df)),
"train_rows_after_balancing": int(len(expanded_train)),
"validation_rows": int(len(val_df)),
"config": config,
}
bundle = {"pipeline": best_pipeline, "metadata": metadata}
path = model_dir / f"{model_name}.joblib"
joblib.dump(bundle, path)
save_json(metadata, model_dir / f"{model_name}_metadata.json")
LOGGER.info("Saved %s model: %s", model_name, path)
return path
|