Muhammad Ridzki Nugraha commited on
Commit
20005ea
·
verified ·
1 Parent(s): 214606f

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_from_files.py +43 -8
scripts/train_from_files.py CHANGED
@@ -24,6 +24,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24
  from training.train_regression import ResolutionTimePredictor
25
  from training.train_nlp import SeverityClassifier, IssueTypeClassifier, SimpleSummarizer
26
  from training.train_tfidf_classifier import train_tfidf_classifier
 
27
 
28
  def clean_header(header):
29
  """Normalize headers to match what training scripts expect"""
@@ -55,14 +56,34 @@ def load_local_csv(filepath, sheet_name):
55
  return data
56
 
57
  def main():
58
- # File paths
59
- cgo_path = "/Users/nrzngr/Desktop/ai-model/Acc Data 2 - Irregularity Report - Manual for Dashboard - CGO (1).csv"
60
- non_cargo_path = "/Users/nrzngr/Desktop/ai-model/Acc Data 2 - Irregularity Report - Manual for Dashboard - NON CARGO.csv"
61
-
62
- # Load and combine data
63
- cgo_data = load_local_csv(cgo_path, "CGO")
64
- non_cargo_data = load_local_csv(non_cargo_path, "NON CARGO")
65
- all_data = cgo_data + non_cargo_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  logger.info(f"Total records for training: {len(all_data)}")
68
 
@@ -80,6 +101,12 @@ def main():
80
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
81
  predictor.save(os.path.join(model_dir_reg, f"resolution_predictor_{timestamp}.pkl"))
82
  predictor.save(os.path.join(model_dir_reg, "resolution_predictor_latest.pkl"))
 
 
 
 
 
 
83
 
84
  # 2. Train NLP BERT Models
85
  # logger.info("\n" + "="*30 + " Training NLP BERT Models " + "="*30)
@@ -123,6 +150,14 @@ def main():
123
  "total_samples": len(all_data)
124
  }
125
 
 
 
 
 
 
 
 
 
126
  with open("models/training_summary.json", "w") as f:
127
  json.dump(summary_metrics, f, indent=2, default=str)
128
 
 
24
  from training.train_regression import ResolutionTimePredictor
25
  from training.train_nlp import SeverityClassifier, IssueTypeClassifier, SimpleSummarizer
26
  from training.train_tfidf_classifier import train_tfidf_classifier
27
+ from data.root_cause_service import RootCauseService
28
 
29
  def clean_header(header):
30
  """Normalize headers to match what training scripts expect"""
 
56
  return data
57
 
58
  def main():
59
+ # Preferred local cache path (single combined file)
60
+ cache_path = os.path.join(os.path.dirname(__file__), "..", "data", "training_data_cache.csv")
61
+ cache_path = os.path.abspath(cache_path)
62
+
63
+ all_data = []
64
+
65
+ if os.path.exists(cache_path):
66
+ logger.info(f"Using cached training data: {cache_path}")
67
+ df = pd.read_csv(cache_path)
68
+ df = df.fillna("")
69
+ all_data = df.to_dict(orient="records")
70
+ # Ensure _row_id and _sheet_name exist for downstream components
71
+ for i, r in enumerate(all_data):
72
+ if "_row_id" not in r:
73
+ r["_row_id"] = f"ALL_{i+2}"
74
+ if "_sheet_name" not in r:
75
+ r["_sheet_name"] = "ALL"
76
+ else:
77
+ # Fallback: separate CGO and NON CARGO CSVs (if provided locally)
78
+ cgo_path = "/Users/nrzngr/Desktop/ai-model/Acc Data 2 - Irregularity Report - Manual for Dashboard - CGO (1).csv"
79
+ non_cargo_path = "/Users/nrzngr/Desktop/ai-model/Acc Data 2 - Irregularity Report - Manual for Dashboard - NON CARGO.csv"
80
+ cgo_data = []
81
+ non_cargo_data = []
82
+ if os.path.exists(cgo_path):
83
+ cgo_data = load_local_csv(cgo_path, "CGO")
84
+ if os.path.exists(non_cargo_path):
85
+ non_cargo_data = load_local_csv(non_cargo_path, "NON CARGO")
86
+ all_data = cgo_data + non_cargo_data
87
 
88
  logger.info(f"Total records for training: {len(all_data)}")
89
 
 
101
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
102
  predictor.save(os.path.join(model_dir_reg, f"resolution_predictor_{timestamp}.pkl"))
103
  predictor.save(os.path.join(model_dir_reg, "resolution_predictor_latest.pkl"))
104
+ # Save metrics JSON alongside latest
105
+ try:
106
+ with open(os.path.join(model_dir_reg, "resolution_predictor_latest_metrics.json"), "w") as f:
107
+ json.dump(reg_metrics, f, indent=2, default=str)
108
+ except Exception as e:
109
+ logger.warning(f"Failed to save regression metrics JSON: {e}")
110
 
111
  # 2. Train NLP BERT Models
112
  # logger.info("\n" + "="*30 + " Training NLP BERT Models " + "="*30)
 
150
  "total_samples": len(all_data)
151
  }
152
 
153
+ # 5. Train Root Cause Classifier (TF-IDF + LogisticRegression)
154
+ try:
155
+ rc_service = RootCauseService()
156
+ rc_metrics = rc_service.train_from_data(all_data)
157
+ summary_metrics["root_cause"] = rc_metrics
158
+ except Exception as e:
159
+ logger.warning(f"Failed training root cause classifier: {e}")
160
+
161
  with open("models/training_summary.json", "w") as f:
162
  json.dump(summary_metrics, f, indent=2, default=str)
163