PD03 commited on
Commit
9794e0d
·
verified ·
1 Parent(s): e7df474

Update utils/model_trainer.py

Browse files
Files changed (1) hide show
  1. utils/model_trainer.py +10 -7
utils/model_trainer.py CHANGED
@@ -22,6 +22,9 @@ class EmbeddedChurnTrainer:
22
  def __init__(self):
23
  self.model_path = Path('models/churn_model_v1.pkl')
24
  self.metadata_path = Path('models/model_metadata.json')
 
 
 
25
 
26
  def model_exists(self):
27
  """Check if trained model exists"""
@@ -33,18 +36,18 @@ class EmbeddedChurnTrainer:
33
  try:
34
  conn = duckdb.connect(':memory:')
35
 
36
- # Load SAP datasets
37
  conn.execute("""
38
  CREATE TABLE customers AS
39
  SELECT * FROM 'hf://datasets/SAP/SALT/I_Customer.parquet'
40
  LIMIT 5000
41
- """) # Limit for HF Spaces performance
42
 
43
  conn.execute("""
44
  CREATE TABLE sales_docs AS
45
  SELECT * FROM 'hf://datasets/SAP/SALT/I_SalesDocument.parquet'
46
  LIMIT 10000
47
- """) # Limit for HF Spaces performance
48
 
49
  # Join data
50
  training_data = conn.execute("""
@@ -71,7 +74,7 @@ class EmbeddedChurnTrainer:
71
  return pd.DataFrame()
72
 
73
  def train_model_if_needed(self):
74
- """Train model if it doesn't exist, with progress bar"""
75
  if self.model_exists():
76
  return self.load_existing_metadata()
77
 
@@ -115,7 +118,7 @@ class EmbeddedChurnTrainer:
115
  return None
116
 
117
  def engineer_features(self, data):
118
- """Streamlined feature engineering for HF Spaces"""
119
  # Customer-level aggregation
120
  customer_features = data.groupby('Customer').agg({
121
  'CustomerName': 'first',
@@ -178,9 +181,9 @@ class EmbeddedChurnTrainer:
178
  X, y, test_size=0.2, random_state=42, stratify=y
179
  )
180
 
181
- # Train model
182
  self.model = RandomForestClassifier(
183
- n_estimators=50, # Reduced for HF Spaces performance
184
  max_depth=8,
185
  min_samples_split=20,
186
  class_weight='balanced',
 
22
  def __init__(self):
23
  self.model_path = Path('models/churn_model_v1.pkl')
24
  self.metadata_path = Path('models/model_metadata.json')
25
+ self.model = None
26
+ self.label_encoders = {}
27
+ self.feature_columns = []
28
 
29
  def model_exists(self):
30
  """Check if trained model exists"""
 
36
  try:
37
  conn = duckdb.connect(':memory:')
38
 
39
+ # Load SAP datasets with limits for HF Spaces performance
40
  conn.execute("""
41
  CREATE TABLE customers AS
42
  SELECT * FROM 'hf://datasets/SAP/SALT/I_Customer.parquet'
43
  LIMIT 5000
44
+ """)
45
 
46
  conn.execute("""
47
  CREATE TABLE sales_docs AS
48
  SELECT * FROM 'hf://datasets/SAP/SALT/I_SalesDocument.parquet'
49
  LIMIT 10000
50
+ """)
51
 
52
  # Join data
53
  training_data = conn.execute("""
 
74
  return pd.DataFrame()
75
 
76
  def train_model_if_needed(self):
77
+ """Train model if it doesn't exist, with progress updates"""
78
  if self.model_exists():
79
  return self.load_existing_metadata()
80
 
 
118
  return None
119
 
120
  def engineer_features(self, data):
121
+ """Feature engineering for churn prediction"""
122
  # Customer-level aggregation
123
  customer_features = data.groupby('Customer').agg({
124
  'CustomerName': 'first',
 
181
  X, y, test_size=0.2, random_state=42, stratify=y
182
  )
183
 
184
+ # Train model (optimized for HF Spaces)
185
  self.model = RandomForestClassifier(
186
+ n_estimators=50, # Reduced for performance
187
  max_depth=8,
188
  min_samples_split=20,
189
  class_weight='balanced',