harikrishna1985 commited on
Commit
e8417f4
·
verified ·
1 Parent(s): 7585e98

Upload src/03_evaluate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/03_evaluate.py +123 -0
src/03_evaluate.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import joblib
5
+ import pandas as pd
6
+ from huggingface_hub import hf_hub_download
7
+ from sklearn.metrics import (
8
+ accuracy_score,
9
+ f1_score,
10
+ classification_report,
11
+ confusion_matrix,
12
+ )
13
+
14
+
15
+ # =========================
16
+ # CONFIG
17
+ # =========================
18
+ DATASET_REPO_ID = "harikrishna1985/Engine_data"
19
+ MODEL_REPO_ID = "harikrishna1985/predictive-maintenance-model"
20
+
21
+ TEST_FILENAME = "processed/test.csv"
22
+ MODEL_FILENAME = "best_model.pkl"
23
+ MODEL_INFO_FILENAME = "best_model_info.json"
24
+
25
+ TARGET_COLUMN = "engine_condition"
26
+
27
+ LOCAL_EVAL_DIR = Path("artifacts")
28
+ LOCAL_EVAL_DIR.mkdir(parents=True, exist_ok=True)
29
+
30
+ EVAL_SUMMARY_FILE = LOCAL_EVAL_DIR / "evaluation_summary.json"
31
+ CLASSIFICATION_REPORT_FILE = LOCAL_EVAL_DIR / "classification_report.csv"
32
+ CONFUSION_MATRIX_FILE = LOCAL_EVAL_DIR / "confusion_matrix.csv"
33
+
34
+
35
+ def load_test_data():
36
+ test_path = hf_hub_download(
37
+ repo_id=DATASET_REPO_ID,
38
+ filename=TEST_FILENAME,
39
+ repo_type="dataset",
40
+ )
41
+ test_df = pd.read_csv(test_path)
42
+ test_df.columns = [c.strip().lower().replace(" ", "_") for c in test_df.columns]
43
+ return test_df
44
+
45
+
46
+ def load_model_and_info():
47
+ model_path = hf_hub_download(
48
+ repo_id=MODEL_REPO_ID,
49
+ filename=MODEL_FILENAME,
50
+ repo_type="model",
51
+ )
52
+ info_path = hf_hub_download(
53
+ repo_id=MODEL_REPO_ID,
54
+ filename=MODEL_INFO_FILENAME,
55
+ repo_type="model",
56
+ )
57
+
58
+ model = joblib.load(model_path)
59
+ with open(info_path, "r", encoding="utf-8") as f:
60
+ model_info = json.load(f)
61
+
62
+ return model, model_info
63
+
64
+
65
+ def prepare_test_features(test_df: pd.DataFrame, feature_columns: list[str]):
66
+ target_col_clean = TARGET_COLUMN.strip().lower().replace(" ", "_")
67
+
68
+ if target_col_clean not in test_df.columns:
69
+ raise ValueError(f"Target column '{target_col_clean}' missing in test data.")
70
+
71
+ X_test = test_df.drop(columns=[target_col_clean])
72
+ y_test = test_df[target_col_clean]
73
+
74
+ X_test = pd.get_dummies(X_test, drop_first=False)
75
+
76
+ # align to training features
77
+ X_test = X_test.reindex(columns=feature_columns, fill_value=0)
78
+
79
+ return X_test, y_test
80
+
81
+
82
+ def evaluate():
83
+ test_df = load_test_data()
84
+ model, model_info = load_model_and_info()
85
+
86
+ feature_columns = model_info["feature_columns"]
87
+
88
+ X_test, y_test = prepare_test_features(test_df, feature_columns)
89
+
90
+ preds = model.predict(X_test)
91
+
92
+ acc = accuracy_score(y_test, preds)
93
+ f1 = f1_score(y_test, preds, average="weighted")
94
+
95
+ report = classification_report(y_test, preds, output_dict=True)
96
+ report_df = pd.DataFrame(report).transpose()
97
+
98
+ labels = sorted(y_test.astype(str).unique().tolist())
99
+ cm = confusion_matrix(y_test.astype(str), pd.Series(preds).astype(str), labels=labels)
100
+ cm_df = pd.DataFrame(cm, index=labels, columns=labels)
101
+
102
+ summary = {
103
+ "model_name": model_info.get("model_name"),
104
+ "params": model_info.get("params"),
105
+ "accuracy": acc,
106
+ "f1_weighted": f1,
107
+ }
108
+
109
+ with open(EVAL_SUMMARY_FILE, "w", encoding="utf-8") as f:
110
+ json.dump(summary, f, indent=2)
111
+
112
+ report_df.to_csv(CLASSIFICATION_REPORT_FILE, index=True)
113
+ cm_df.to_csv(CONFUSION_MATRIX_FILE, index=True)
114
+
115
+ print("Evaluation completed.")
116
+ print(json.dumps(summary, indent=2))
117
+ print(f"Saved: {EVAL_SUMMARY_FILE}")
118
+ print(f"Saved: {CLASSIFICATION_REPORT_FILE}")
119
+ print(f"Saved: {CONFUSION_MATRIX_FILE}")
120
+
121
+
122
+ if __name__ == "__main__":
123
+ evaluate()