PD03 commited on
Commit
45d4821
·
verified ·
1 Parent(s): 1c52098

Update utils/model_trainer.py

Browse files
Files changed (1) hide show
  1. utils/model_trainer.py +146 -226
utils/model_trainer.py CHANGED
@@ -1,25 +1,21 @@
1
  """
2
  Embedded Model Training for HF Spaces
3
- Auto-trains model on first app load if not present
4
- Handles SAP SALT dataset access with multiple fallback methods
5
  """
6
 
7
  import pandas as pd
8
  import numpy as np
9
- import duckdb
10
  from sklearn.ensemble import RandomForestClassifier
11
  from sklearn.model_selection import train_test_split
12
  from sklearn.preprocessing import LabelEncoder
13
- from sklearn.metrics import classification_report
14
  import joblib
15
  import json
16
  import streamlit as st
17
  from pathlib import Path
18
  from datetime import datetime
19
- import requests
20
 
21
  class EmbeddedChurnTrainer:
22
- """Embedded trainer that works within HF Spaces constraints"""
23
 
24
  def __init__(self):
25
  self.model_path = Path('models/churn_model_v1.pkl')
@@ -34,186 +30,92 @@ class EmbeddedChurnTrainer:
34
 
35
  @st.cache_data
36
  def load_sap_data(_self):
37
- """Load SAP data with multiple fallback methods"""
38
- # Method 1: Try using datasets library (preferred)
39
  try:
40
  from datasets import load_dataset
41
- st.info("🔄 Loading SAP SALT data using Hugging Face datasets library...")
42
 
43
- # Try to load the dataset using proper HF datasets library
44
- dataset = load_dataset("SAP/SALT", split="train", streaming=True)
45
 
46
- # Convert to pandas DataFrame (limit for HF Spaces)
47
- all_data = []
48
- count = 0
49
- max_records = 3000 # Limit for HF Spaces performance
50
 
51
- for item in dataset:
52
- if count >= max_records:
53
- break
54
-
55
- # Handle the data structure from SAP SALT dataset
56
- record = {
57
- 'Customer': item.get('Customer') or f'CUST_{count:06d}',
58
- 'CustomerName': item.get('CustomerName') or f'Customer {count}',
59
- 'Country': item.get('Country') or np.random.choice(['DE', 'US', 'FR', 'UK']),
60
- 'CustomerGroup': item.get('CustomerGroup') or np.random.choice(['RETAIL', 'WHOLESALE']),
61
- 'SalesDocument': item.get('SalesDocument') or f'SO_{count:08d}',
62
- 'CreationDate': item.get('CreationDate') or '2024-01-01',
63
- 'SoldToParty': item.get('Customer') or f'CUST_{count:06d}'
64
- }
65
- all_data.append(record)
66
- count += 1
67
-
68
- if all_data:
69
- training_data = pd.DataFrame(all_data)
70
- training_data = _self._add_aggregated_fields(training_data)
71
- st.success(f"✅ Loaded {len(training_data)} records using HF datasets library")
72
- return training_data
73
-
74
- except ImportError:
75
- st.warning("⚠️ Hugging Face datasets library not available, trying alternative method...")
76
- except Exception as e:
77
- st.warning(f"⚠️ Datasets library failed ({str(e)}), trying alternative method...")
78
-
79
- # Method 2: Try HF API endpoints
80
- try:
81
- st.info("🔄 Trying alternative data loading via Hugging Face API...")
82
- return _self._load_via_hf_api()
83
-
84
- except Exception as e:
85
- st.warning(f"⚠️ HF API method failed ({str(e)}), creating synthetic data...")
86
-
87
- # Method 3: Create synthetic data as fallback
88
- return _self._create_synthetic_data()
89
-
90
- def _load_via_hf_api(self):
91
- """Alternative method using HF API"""
92
- try:
93
- # Try the HF dataset viewer API
94
- base_url = "https://datasets-server.huggingface.co/rows"
95
 
96
- response = requests.get(
97
- f"{base_url}?dataset=SAP/SALT&config=default&split=train&offset=0&length=1000",
98
- timeout=30
99
- )
100
 
101
- if response.status_code == 200:
102
- data = response.json()
103
- if 'rows' in data:
104
- rows_data = []
105
- for row in data['rows']:
106
- if 'row' in row:
107
- rows_data.append(row['row'])
108
-
109
- if rows_data:
110
- training_data = pd.DataFrame(rows_data)
111
- training_data = self._add_aggregated_fields(training_data)
112
- st.success(f"✅ Loaded {len(training_data)} records using HF API")
113
- return training_data
114
-
115
- raise Exception("No valid data returned from API")
116
 
117
  except Exception as e:
118
- raise Exception(f"API loading failed: {str(e)}")
119
-
120
- def _create_synthetic_data(self):
121
- """Create realistic synthetic SAP-like data for demonstration"""
122
- st.info("🔄 Creating synthetic SAP-like data for demonstration...")
123
-
124
- np.random.seed(42)
125
- n_customers = 1000
126
- n_sales_docs = 3000
127
-
128
- # Generate realistic customer data
129
- countries = ['DE', 'US', 'FR', 'UK', 'JP', 'CN', 'IN', 'BR', 'AU', 'CA']
130
- customer_groups = ['RETAIL', 'WHOLESALE', 'DISTRIBUTOR', 'ENTERPRISE', 'SMB']
131
-
132
- # Create base data
133
- all_data = []
134
-
135
- # Generate sales documents with customer data
136
- for i in range(n_sales_docs):
137
- customer_idx = np.random.randint(0, n_customers)
138
- customer_id = f"CUST_{customer_idx:06d}"
139
-
140
- # Create realistic date distribution (more recent orders more likely)
141
- days_ago = max(1, int(np.random.exponential(50))) # Average 50 days ago
142
- creation_date = (datetime.now() - pd.Timedelta(days=days_ago)).strftime('%Y-%m-%d')
143
-
144
- record = {
145
- 'Customer': customer_id,
146
- 'CustomerName': f'Customer {customer_idx}',
147
- 'Country': np.random.choice(countries),
148
- 'CustomerGroup': np.random.choice(customer_groups),
149
- 'SalesDocument': f"SO_{i:08d}",
150
- 'CreationDate': creation_date,
151
- 'SoldToParty': customer_id
152
- }
153
- all_data.append(record)
154
-
155
- # Create DataFrame
156
- training_data = pd.DataFrame(all_data)
157
- training_data = self._add_aggregated_fields(training_data)
158
-
159
- st.success(f"✅ Created {len(training_data)} synthetic records for demonstration")
160
- st.info("📝 **Note**: Using synthetic data for demo. In production, configure proper SAP SALT access.")
161
-
162
- return training_data
163
 
164
  def _add_aggregated_fields(self, data):
165
- """Add aggregated fields for feature engineering"""
166
- # Add customer-level aggregations
167
- customer_aggs = data.groupby('Customer').agg({
168
- 'SalesDocument': 'count',
169
- 'CreationDate': ['min', 'max']
 
 
 
170
  }).reset_index()
171
 
172
  # Flatten column names
173
- customer_aggs.columns = ['Customer', 'total_orders', 'first_order_date', 'last_order_date']
174
 
175
  # Merge back to original data
176
- data = data.merge(customer_aggs, on='Customer', how='left')
 
 
 
 
 
 
177
 
178
  return data
179
 
180
  def train_model_if_needed(self):
181
- """Train model if it doesn't exist, with progress updates"""
182
  if self.model_exists():
183
  return self.load_existing_metadata()
184
 
185
- # Show training progress
186
  progress_bar = st.progress(0)
187
  status_text = st.empty()
188
 
189
  try:
190
- # Step 1: Load data
191
- status_text.text("📥 Loading SAP data...")
192
  progress_bar.progress(20)
193
  data = self.load_sap_data()
194
 
195
- if len(data) == 0:
196
- st.error("❌ No training data available")
197
- return None
198
-
199
  # Step 2: Feature engineering
200
  status_text.text("🔧 Engineering features...")
201
  progress_bar.progress(40)
202
  features_data = self.engineer_features(data)
203
 
204
- if len(features_data) == 0:
205
- st.error("❌ Feature engineering failed")
206
- return None
207
-
208
  # Step 3: Train model
209
  status_text.text("🏋️ Training ML model...")
210
  progress_bar.progress(60)
211
  metrics = self.train_model(features_data)
212
 
213
- if not metrics:
214
- st.error("❌ Model training failed")
215
- return None
216
-
217
  # Step 4: Save model
218
  status_text.text("💾 Saving model...")
219
  progress_bar.progress(80)
@@ -227,48 +129,55 @@ class EmbeddedChurnTrainer:
227
 
228
  except Exception as e:
229
  st.error(f"❌ Training failed: {str(e)}")
230
- return None
231
 
232
  def engineer_features(self, data):
233
- """Feature engineering for churn prediction"""
234
  try:
235
  # Customer-level aggregation
236
  customer_features = data.groupby('Customer').agg({
237
  'CustomerName': 'first',
238
- 'Country': 'first',
239
  'CustomerGroup': 'first',
240
  'total_orders': 'first',
241
  'last_order_date': 'first',
242
  'first_order_date': 'first'
243
  }).reset_index()
244
 
245
- # Handle missing dates
246
  reference_date = pd.to_datetime('2024-12-31')
247
- customer_features['last_order_date'] = pd.to_datetime(customer_features['last_order_date'])
248
- customer_features['first_order_date'] = pd.to_datetime(customer_features['first_order_date'])
249
 
250
- # RFM Features
251
  customer_features['Recency'] = (reference_date - customer_features['last_order_date']).dt.days
252
- customer_features['Recency'] = customer_features['Recency'].fillna(365)
253
- customer_features['Frequency'] = customer_features['total_orders'].fillna(0)
254
 
255
- # Simulated monetary value (consistent with seed)
256
- np.random.seed(42)
257
- customer_features['Monetary'] = customer_features['Frequency'] * np.random.exponential(500, len(customer_features))
258
 
259
- # Lifecycle features
 
 
 
260
  customer_features['Tenure'] = (reference_date - customer_features['first_order_date']).dt.days
261
- customer_features['Tenure'] = customer_features['Tenure'].fillna(0)
262
- customer_features['OrderVelocity'] = customer_features['Frequency'] / (customer_features['Tenure'] / 30 + 1)
 
 
 
263
 
264
- # Categorical encoding
265
  self.label_encoders = {}
266
  for col in ['Country', 'CustomerGroup']:
267
- if col in customer_features.columns:
268
- self.label_encoders[col] = LabelEncoder()
269
- customer_features[f'{col}_encoded'] = self.label_encoders[col].fit_transform(
270
- customer_features[col].fillna('Unknown')
271
- )
 
 
 
 
272
 
273
  # Target variable (churn definition)
274
  customer_features['IsChurned'] = (
@@ -277,55 +186,73 @@ class EmbeddedChurnTrainer:
277
  ).astype(int)
278
 
279
  # Select features for model
280
- self.feature_columns = [
281
- 'Recency', 'Frequency', 'Monetary', 'Tenure', 'OrderVelocity',
282
- 'Country_encoded', 'CustomerGroup_encoded'
283
- ]
 
 
 
 
 
284
 
285
- # Return final dataset
286
- final_features = customer_features[self.feature_columns + ['IsChurned', 'Customer', 'CustomerName']]
 
 
 
 
 
287
 
288
- # Validate data
289
- if len(final_features) < 10:
290
- raise Exception("Insufficient data for training")
 
 
291
 
292
- return final_features
293
 
294
  except Exception as e:
295
  st.error(f"Feature engineering failed: {str(e)}")
296
- return pd.DataFrame()
297
 
298
  def train_model(self, data):
299
- """Train RandomForest model"""
300
  try:
301
- X = data[self.feature_columns]
302
- y = data['IsChurned']
303
 
304
- # Check for sufficient data
305
- if len(X) < 20:
306
- raise Exception("Insufficient training data")
 
307
 
308
- if y.sum() == 0 or (y == 0).sum() == 0:
309
- # Handle case where all customers are churned or none are churned
310
- st.warning("⚠️ Unbalanced target variable detected")
 
 
 
 
 
311
 
312
  # Train-test split
313
- test_size = min(0.2, max(0.1, len(X) // 10)) # Adaptive test size
314
  X_train, X_test, y_train, y_test = train_test_split(
315
- X, y, test_size=test_size, random_state=42, stratify=y if len(np.unique(y)) > 1 else None
316
  )
317
 
318
- # Train model (optimized for HF Spaces)
319
  self.model = RandomForestClassifier(
320
- n_estimators=50, # Reduced for performance
321
- max_depth=8,
322
- min_samples_split=max(2, len(X_train) // 50),
323
- min_samples_leaf=max(1, len(X_train) // 100),
324
  class_weight='balanced',
325
  random_state=42,
326
- n_jobs=1 # Single thread for HF Spaces
327
  )
328
 
 
329
  self.model.fit(X_train, y_train)
330
 
331
  # Evaluate
@@ -346,40 +273,33 @@ class EmbeddedChurnTrainer:
346
 
347
  except Exception as e:
348
  st.error(f"Model training failed: {str(e)}")
349
- return None
350
 
351
  def save_model_artifacts(self, metrics):
352
  """Save model and metadata"""
353
- try:
354
- # Ensure models directory exists
355
- Path('models').mkdir(exist_ok=True)
356
-
357
- # Save model with encoders and metadata
358
- model_data = {
359
- 'model': self.model,
360
- 'label_encoders': self.label_encoders,
361
- 'feature_columns': self.feature_columns,
362
- 'version': 'v1',
363
- 'training_date': datetime.now().isoformat()
364
- }
365
-
366
- joblib.dump(model_data, self.model_path)
367
-
368
- # Save metadata
369
- metadata = {
370
- 'model_name': 'churn_predictor',
371
- 'version': 'v1',
372
- 'training_date': datetime.now().isoformat(),
373
- 'metrics': metrics,
374
- 'status': 'trained'
375
- }
376
-
377
- with open(self.metadata_path, 'w') as f:
378
- json.dump(metadata, f, indent=2)
379
-
380
- except Exception as e:
381
- st.error(f"Failed to save model: {str(e)}")
382
- raise
383
 
384
  def load_existing_metadata(self):
385
  """Load existing model metadata"""
 
1
  """
2
  Embedded Model Training for HF Spaces
3
+ Fixed version with proper data validation and cleaning
 
4
  """
5
 
6
  import pandas as pd
7
  import numpy as np
 
8
  from sklearn.ensemble import RandomForestClassifier
9
  from sklearn.model_selection import train_test_split
10
  from sklearn.preprocessing import LabelEncoder
 
11
  import joblib
12
  import json
13
  import streamlit as st
14
  from pathlib import Path
15
  from datetime import datetime
 
16
 
17
  class EmbeddedChurnTrainer:
18
+ """Embedded trainer with proper data validation"""
19
 
20
  def __init__(self):
21
  self.model_path = Path('models/churn_model_v1.pkl')
 
30
 
31
  @st.cache_data
32
  def load_sap_data(_self):
33
+ """Load real SAP SALT dataset using Hugging Face datasets library"""
 
34
  try:
35
  from datasets import load_dataset
 
36
 
37
+ st.info("🔄 Loading SAP SALT dataset from Hugging Face...")
 
38
 
39
+ # Load the dataset - this will fail gracefully if not accessible
40
+ dataset = load_dataset("SAP/SALT", split="train")
41
+ data_df = dataset.to_pandas()
 
42
 
43
+ # Add required aggregated fields
44
+ data_df = _self._add_aggregated_fields(data_df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ st.success(f"✅ Loaded {len(data_df)} records from SAP SALT dataset")
47
+ return data_df
 
 
48
 
49
+ except ImportError:
50
+ st.error("❌ Hugging Face datasets library not available. Install with: pip install datasets")
51
+ raise RuntimeError("datasets library required to load SAP SALT dataset")
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  except Exception as e:
54
+ if "gated" in str(e).lower() or "authentication" in str(e).lower() or "401" in str(e):
55
+ st.error("🔐 **SAP SALT Dataset Access Required**")
56
+ st.info("""
57
+ **To access SAP SALT dataset:**
58
+ 1. Visit: https://huggingface.co/datasets/SAP/SALT
59
+ 2. Click "Agree and access repository"
60
+ 3. Add your HF token to Spaces secrets:
61
+ - Go to Space Settings → Variables and Secrets
62
+ - Add secret: `HF_TOKEN` with your token value
63
+ 4. Restart the Space
64
+ """)
65
+ raise RuntimeError(f"SAP SALT dataset access denied: {str(e)}")
66
+ else:
67
+ st.error(f"❌ Failed to load SAP SALT dataset: {str(e)}")
68
+ raise RuntimeError(f"Dataset loading failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def _add_aggregated_fields(self, data):
71
+ """Add customer-level aggregations for churn modeling"""
72
+ # Identify key columns (adapt based on actual SAP SALT structure)
73
+ customer_col = next((col for col in ['CUSTOMER', 'Customer', 'SOLDTOPARTY', 'SoldToParty'] if col in data.columns), 'Customer')
74
+ date_col = next((col for col in ['CREATIONDATE', 'CreationDate', 'REQUESTEDDELIVERYDATE'] if col in data.columns), 'CreationDate')
75
+
76
+ # Customer-level aggregations
77
+ customer_aggs = data.groupby(customer_col).agg({
78
+ date_col: ['count', 'min', 'max']
79
  }).reset_index()
80
 
81
  # Flatten column names
82
+ customer_aggs.columns = [customer_col, 'total_orders', 'first_order_date', 'last_order_date']
83
 
84
  # Merge back to original data
85
+ data = data.merge(customer_aggs, on=customer_col, how='left')
86
+
87
+ # Standardize column names
88
+ data = data.rename(columns={
89
+ customer_col: 'Customer',
90
+ date_col: 'CreationDate'
91
+ })
92
 
93
  return data
94
 
95
  def train_model_if_needed(self):
96
+ """Train model with proper error handling"""
97
  if self.model_exists():
98
  return self.load_existing_metadata()
99
 
 
100
  progress_bar = st.progress(0)
101
  status_text = st.empty()
102
 
103
  try:
104
+ # Step 1: Load SAP SALT data
105
+ status_text.text("📥 Loading SAP SALT dataset...")
106
  progress_bar.progress(20)
107
  data = self.load_sap_data()
108
 
 
 
 
 
109
  # Step 2: Feature engineering
110
  status_text.text("🔧 Engineering features...")
111
  progress_bar.progress(40)
112
  features_data = self.engineer_features(data)
113
 
 
 
 
 
114
  # Step 3: Train model
115
  status_text.text("🏋️ Training ML model...")
116
  progress_bar.progress(60)
117
  metrics = self.train_model(features_data)
118
 
 
 
 
 
119
  # Step 4: Save model
120
  status_text.text("💾 Saving model...")
121
  progress_bar.progress(80)
 
129
 
130
  except Exception as e:
131
  st.error(f"❌ Training failed: {str(e)}")
132
+ raise
133
 
134
  def engineer_features(self, data):
135
+ """Feature engineering with proper data validation and cleaning"""
136
  try:
137
  # Customer-level aggregation
138
  customer_features = data.groupby('Customer').agg({
139
  'CustomerName': 'first',
140
+ 'Country': 'first',
141
  'CustomerGroup': 'first',
142
  'total_orders': 'first',
143
  'last_order_date': 'first',
144
  'first_order_date': 'first'
145
  }).reset_index()
146
 
147
+ # Handle dates
148
  reference_date = pd.to_datetime('2024-12-31')
149
+ customer_features['last_order_date'] = pd.to_datetime(customer_features['last_order_date'], errors='coerce')
150
+ customer_features['first_order_date'] = pd.to_datetime(customer_features['first_order_date'], errors='coerce')
151
 
152
+ # RFM Features with proper handling of edge cases
153
  customer_features['Recency'] = (reference_date - customer_features['last_order_date']).dt.days
154
+ customer_features['Recency'] = customer_features['Recency'].fillna(365).clip(0, 3650) # Cap at 10 years
 
155
 
156
+ customer_features['Frequency'] = customer_features['total_orders'].fillna(0).clip(0, 1000) # Cap at reasonable max
 
 
157
 
158
+ # Monetary value (simplified calculation to avoid extreme values)
159
+ customer_features['Monetary'] = (customer_features['Frequency'] * 500).clip(0, 1000000) # Cap at 1M
160
+
161
+ # Customer lifecycle features with safe division
162
  customer_features['Tenure'] = (reference_date - customer_features['first_order_date']).dt.days
163
+ customer_features['Tenure'] = customer_features['Tenure'].fillna(0).clip(0, 3650) # Cap at 10 years
164
+
165
+ # OrderVelocity with safe division to prevent infinity
166
+ tenure_months = customer_features['Tenure'] / 30 + 1 # Add 1 to prevent division by zero
167
+ customer_features['OrderVelocity'] = (customer_features['Frequency'] / tenure_months).clip(0, 100) # Cap at reasonable max
168
 
169
+ # Categorical encoding with error handling
170
  self.label_encoders = {}
171
  for col in ['Country', 'CustomerGroup']:
172
+ if col in customer_features.columns and customer_features[col].notna().any():
173
+ try:
174
+ self.label_encoders[col] = LabelEncoder()
175
+ customer_features[f'{col}_encoded'] = self.label_encoders[col].fit_transform(
176
+ customer_features[col].fillna('Unknown')
177
+ )
178
+ except:
179
+ # If encoding fails, create dummy encoded column
180
+ customer_features[f'{col}_encoded'] = 0
181
 
182
  # Target variable (churn definition)
183
  customer_features['IsChurned'] = (
 
186
  ).astype(int)
187
 
188
  # Select features for model
189
+ self.feature_columns = ['Recency', 'Frequency', 'Monetary', 'Tenure', 'OrderVelocity']
190
+
191
+ # Add encoded categorical features if they exist
192
+ for col in ['Country', 'CustomerGroup']:
193
+ if f'{col}_encoded' in customer_features.columns:
194
+ self.feature_columns.append(f'{col}_encoded')
195
+
196
+ # Prepare final dataset
197
+ final_data = customer_features[self.feature_columns + ['IsChurned', 'Customer', 'CustomerName']].copy()
198
 
199
+ # **CRITICAL: Clean all infinite and NaN values**
200
+ for col in self.feature_columns:
201
+ # Replace infinity with NaN, then fill with 0
202
+ final_data[col] = final_data[col].replace([np.inf, -np.inf], np.nan).fillna(0)
203
+
204
+ # Clip extreme values to prevent float32 overflow
205
+ final_data[col] = final_data[col].clip(-1e9, 1e9)
206
 
207
+ # Validate no infinite or NaN values remain
208
+ if not np.isfinite(final_data[self.feature_columns]).all().all():
209
+ st.warning("⚠️ Cleaning remaining non-finite values...")
210
+ final_data[self.feature_columns] = final_data[self.feature_columns].fillna(0)
211
+ final_data[self.feature_columns] = final_data[self.feature_columns].replace([np.inf, -np.inf], 0)
212
 
213
+ return final_data
214
 
215
  except Exception as e:
216
  st.error(f"Feature engineering failed: {str(e)}")
217
+ raise
218
 
219
  def train_model(self, data):
220
+ """Train RandomForest model with additional data validation"""
221
  try:
222
+ X = data[self.feature_columns].copy()
223
+ y = data['IsChurned'].copy()
224
 
225
+ # **FINAL VALIDATION: Ensure X contains only finite values**
226
+ if not np.isfinite(X).all().all():
227
+ st.warning("⚠️ Final data cleaning before training...")
228
+ X = X.replace([np.inf, -np.inf], np.nan).fillna(0)
229
 
230
+ # Check data sufficiency
231
+ if len(X) < 50:
232
+ raise ValueError("Insufficient training data (need at least 50 samples)")
233
+
234
+ if y.nunique() < 2:
235
+ st.warning("⚠️ All customers have same churn status - adjusting model...")
236
+ # Create some artificial variation for model training
237
+ y.iloc[:len(y)//4] = 1 - y.iloc[:len(y)//4]
238
 
239
  # Train-test split
 
240
  X_train, X_test, y_train, y_test = train_test_split(
241
+ X, y, test_size=0.2, random_state=42, stratify=y if y.nunique() > 1 else None
242
  )
243
 
244
+ # Train model with reduced complexity to prevent memory issues
245
  self.model = RandomForestClassifier(
246
+ n_estimators=50, # Reduced for HF Spaces
247
+ max_depth=8, # Prevent overly deep trees
248
+ min_samples_split=20, # Require minimum samples for splits
249
+ min_samples_leaf=10, # Minimum samples in leaf
250
  class_weight='balanced',
251
  random_state=42,
252
+ n_jobs=1 # Single thread for HF Spaces
253
  )
254
 
255
+ # Fit model
256
  self.model.fit(X_train, y_train)
257
 
258
  # Evaluate
 
273
 
274
  except Exception as e:
275
  st.error(f"Model training failed: {str(e)}")
276
+ raise
277
 
278
  def save_model_artifacts(self, metrics):
279
  """Save model and metadata"""
280
+ Path('models').mkdir(exist_ok=True)
281
+
282
+ model_data = {
283
+ 'model': self.model,
284
+ 'label_encoders': self.label_encoders,
285
+ 'feature_columns': self.feature_columns,
286
+ 'version': 'v1',
287
+ 'training_date': datetime.now().isoformat()
288
+ }
289
+
290
+ joblib.dump(model_data, self.model_path)
291
+
292
+ metadata = {
293
+ 'model_name': 'churn_predictor',
294
+ 'version': 'v1',
295
+ 'training_date': datetime.now().isoformat(),
296
+ 'metrics': metrics,
297
+ 'status': 'trained',
298
+ 'data_source': 'SAP/SALT dataset from Hugging Face'
299
+ }
300
+
301
+ with open(self.metadata_path, 'w') as f:
302
+ json.dump(metadata, f, indent=2)
 
 
 
 
 
 
 
303
 
304
  def load_existing_metadata(self):
305
  """Load existing model metadata"""