File size: 1,802 Bytes
d406944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
import json
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

FEATURE_ORDER = ["overall", "p90", "high_ratio", "mid_ratio", "std"]


def clip01(x):
    return np.clip(x, 0.0, 1.0)


def main():
    parser = argparse.ArgumentParser(description="Train kn-like calibration model")
    parser.add_argument("--input", required=True, help="CSV with feature columns + target")
    parser.add_argument("--target", default="kn_rate", help="target column name")
    parser.add_argument("--out", default="calibration/model.json", help="output model json")
    args = parser.parse_args()

    df = pd.read_csv(args.input)

    missing = [c for c in FEATURE_ORDER + [args.target] if c not in df.columns]
    if missing:
        raise ValueError(f"Missing columns: {missing}")

    X = df[FEATURE_ORDER].astype(float).values
    y = df[args.target].astype(float).values

    model = LinearRegression()
    model.fit(X, y)

    pred = clip01(model.predict(X))

    metrics = {
        "mae": float(mean_absolute_error(y, pred)),
        "rmse": float(np.sqrt(mean_squared_error(y, pred))),
        "r2": float(r2_score(y, pred)),
        "n": int(len(df)),
    }

    payload = {
        "model_type": "linear",
        "feature_order": FEATURE_ORDER,
        "coef": [float(v) for v in model.coef_.tolist()],
        "intercept": float(model.intercept_),
        "train_metrics": metrics,
    }

    out = Path(args.out)
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")

    print("Saved:", out)
    print("Metrics:", metrics)


if __name__ == "__main__":
    main()