Spaces:
Running
Running
Muhammad Ridzki Nugraha commited on
Upload folder using huggingface_hub
Browse files- scripts/train_from_files.py +43 -8
scripts/train_from_files.py
CHANGED
|
@@ -24,6 +24,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
| 24 |
from training.train_regression import ResolutionTimePredictor
|
| 25 |
from training.train_nlp import SeverityClassifier, IssueTypeClassifier, SimpleSummarizer
|
| 26 |
from training.train_tfidf_classifier import train_tfidf_classifier
|
|
|
|
| 27 |
|
| 28 |
def clean_header(header):
|
| 29 |
"""Normalize headers to match what training scripts expect"""
|
|
@@ -55,14 +56,34 @@ def load_local_csv(filepath, sheet_name):
|
|
| 55 |
return data
|
| 56 |
|
| 57 |
def main():
|
| 58 |
-
#
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
logger.info(f"Total records for training: {len(all_data)}")
|
| 68 |
|
|
@@ -80,6 +101,12 @@ def main():
|
|
| 80 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 81 |
predictor.save(os.path.join(model_dir_reg, f"resolution_predictor_{timestamp}.pkl"))
|
| 82 |
predictor.save(os.path.join(model_dir_reg, "resolution_predictor_latest.pkl"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
# 2. Train NLP BERT Models
|
| 85 |
# logger.info("\n" + "="*30 + " Training NLP BERT Models " + "="*30)
|
|
@@ -123,6 +150,14 @@ def main():
|
|
| 123 |
"total_samples": len(all_data)
|
| 124 |
}
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
with open("models/training_summary.json", "w") as f:
|
| 127 |
json.dump(summary_metrics, f, indent=2, default=str)
|
| 128 |
|
|
|
|
| 24 |
from training.train_regression import ResolutionTimePredictor
|
| 25 |
from training.train_nlp import SeverityClassifier, IssueTypeClassifier, SimpleSummarizer
|
| 26 |
from training.train_tfidf_classifier import train_tfidf_classifier
|
| 27 |
+
from data.root_cause_service import RootCauseService
|
| 28 |
|
| 29 |
def clean_header(header):
|
| 30 |
"""Normalize headers to match what training scripts expect"""
|
|
|
|
| 56 |
return data
|
| 57 |
|
| 58 |
def main():
|
| 59 |
+
# Preferred local cache path (single combined file)
|
| 60 |
+
cache_path = os.path.join(os.path.dirname(__file__), "..", "data", "training_data_cache.csv")
|
| 61 |
+
cache_path = os.path.abspath(cache_path)
|
| 62 |
+
|
| 63 |
+
all_data = []
|
| 64 |
+
|
| 65 |
+
if os.path.exists(cache_path):
|
| 66 |
+
logger.info(f"Using cached training data: {cache_path}")
|
| 67 |
+
df = pd.read_csv(cache_path)
|
| 68 |
+
df = df.fillna("")
|
| 69 |
+
all_data = df.to_dict(orient="records")
|
| 70 |
+
# Ensure _row_id and _sheet_name exist for downstream components
|
| 71 |
+
for i, r in enumerate(all_data):
|
| 72 |
+
if "_row_id" not in r:
|
| 73 |
+
r["_row_id"] = f"ALL_{i+2}"
|
| 74 |
+
if "_sheet_name" not in r:
|
| 75 |
+
r["_sheet_name"] = "ALL"
|
| 76 |
+
else:
|
| 77 |
+
# Fallback: separate CGO and NON CARGO CSVs (if provided locally)
|
| 78 |
+
cgo_path = "/Users/nrzngr/Desktop/ai-model/Acc Data 2 - Irregularity Report - Manual for Dashboard - CGO (1).csv"
|
| 79 |
+
non_cargo_path = "/Users/nrzngr/Desktop/ai-model/Acc Data 2 - Irregularity Report - Manual for Dashboard - NON CARGO.csv"
|
| 80 |
+
cgo_data = []
|
| 81 |
+
non_cargo_data = []
|
| 82 |
+
if os.path.exists(cgo_path):
|
| 83 |
+
cgo_data = load_local_csv(cgo_path, "CGO")
|
| 84 |
+
if os.path.exists(non_cargo_path):
|
| 85 |
+
non_cargo_data = load_local_csv(non_cargo_path, "NON CARGO")
|
| 86 |
+
all_data = cgo_data + non_cargo_data
|
| 87 |
|
| 88 |
logger.info(f"Total records for training: {len(all_data)}")
|
| 89 |
|
|
|
|
| 101 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 102 |
predictor.save(os.path.join(model_dir_reg, f"resolution_predictor_{timestamp}.pkl"))
|
| 103 |
predictor.save(os.path.join(model_dir_reg, "resolution_predictor_latest.pkl"))
|
| 104 |
+
# Save metrics JSON alongside latest
|
| 105 |
+
try:
|
| 106 |
+
with open(os.path.join(model_dir_reg, "resolution_predictor_latest_metrics.json"), "w") as f:
|
| 107 |
+
json.dump(reg_metrics, f, indent=2, default=str)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.warning(f"Failed to save regression metrics JSON: {e}")
|
| 110 |
|
| 111 |
# 2. Train NLP BERT Models
|
| 112 |
# logger.info("\n" + "="*30 + " Training NLP BERT Models " + "="*30)
|
|
|
|
| 150 |
"total_samples": len(all_data)
|
| 151 |
}
|
| 152 |
|
| 153 |
+
# 5. Train Root Cause Classifier (TF-IDF + LogisticRegression)
|
| 154 |
+
try:
|
| 155 |
+
rc_service = RootCauseService()
|
| 156 |
+
rc_metrics = rc_service.train_from_data(all_data)
|
| 157 |
+
summary_metrics["root_cause"] = rc_metrics
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.warning(f"Failed training root cause classifier: {e}")
|
| 160 |
+
|
| 161 |
with open("models/training_summary.json", "w") as f:
|
| 162 |
json.dump(summary_metrics, f, indent=2, default=str)
|
| 163 |
|