PD03 commited on
Commit
a3796f2
·
verified ·
1 Parent(s): d131cca

Update utils/model_trainer.py

Browse files
Files changed (1) hide show
  1. utils/model_trainer.py +299 -153
utils/model_trainer.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  Embedded Model Training for HF Spaces
3
  Auto-trains model on first app load if not present
 
4
  """
5
 
6
  import pandas as pd
@@ -15,6 +16,7 @@ import json
15
  import streamlit as st
16
  from pathlib import Path
17
  from datetime import datetime
 
18
 
19
  class EmbeddedChurnTrainer:
20
  """Embedded trainer that works within HF Spaces constraints"""
@@ -32,46 +34,148 @@ class EmbeddedChurnTrainer:
32
 
33
  @st.cache_data
34
  def load_sap_data(_self):
35
- """Load SAP data with Streamlit caching"""
 
36
  try:
37
- conn = duckdb.connect(':memory:')
38
-
39
- # Load SAP datasets with limits for HF Spaces performance
40
- conn.execute("""
41
- CREATE TABLE customers AS
42
- SELECT * FROM 'hf://datasets/SAP/SALT/I_Customer.parquet'
43
- LIMIT 5000
44
- """)
45
-
46
- conn.execute("""
47
- CREATE TABLE sales_docs AS
48
- SELECT * FROM 'hf://datasets/SAP/SALT/I_SalesDocument.parquet'
49
- LIMIT 10000
50
- """)
51
-
52
- # Join data
53
- training_data = conn.execute("""
54
- SELECT
55
- c.Customer,
56
- c.CustomerName,
57
- c.Country,
58
- c.CustomerGroup,
59
- s.SalesDocument,
60
- s.CreationDate,
61
- s.SoldToParty,
62
- COUNT(s.SalesDocument) OVER (PARTITION BY c.Customer) as total_orders,
63
- MAX(s.CreationDate) OVER (PARTITION BY c.Customer) as last_order_date,
64
- MIN(s.CreationDate) OVER (PARTITION BY c.Customer) as first_order_date
65
- FROM customers c
66
- LEFT JOIN sales_docs s ON c.Customer = s.SoldToParty
67
- WHERE c.Customer IS NOT NULL
68
- """).df()
69
-
70
- return training_data
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  except Exception as e:
73
- st.error(f"Data loading failed: {str(e)}")
74
- return pd.DataFrame()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  def train_model_if_needed(self):
77
  """Train model if it doesn't exist, with progress updates"""
@@ -84,26 +188,34 @@ class EmbeddedChurnTrainer:
84
 
85
  try:
86
  # Step 1: Load data
87
- status_text.text("Loading SAP data...")
88
  progress_bar.progress(20)
89
  data = self.load_sap_data()
90
 
91
  if len(data) == 0:
92
- st.error("No training data available")
93
  return None
94
 
95
  # Step 2: Feature engineering
96
- status_text.text("Engineering features...")
97
  progress_bar.progress(40)
98
  features_data = self.engineer_features(data)
99
 
 
 
 
 
100
  # Step 3: Train model
101
- status_text.text("Training ML model...")
102
  progress_bar.progress(60)
103
  metrics = self.train_model(features_data)
104
 
 
 
 
 
105
  # Step 4: Save model
106
- status_text.text("Saving model...")
107
  progress_bar.progress(80)
108
  self.save_model_artifacts(metrics)
109
 
@@ -114,131 +226,165 @@ class EmbeddedChurnTrainer:
114
  return metrics
115
 
116
  except Exception as e:
117
- st.error(f"Training failed: {str(e)}")
118
  return None
119
 
120
  def engineer_features(self, data):
121
  """Feature engineering for churn prediction"""
122
- # Customer-level aggregation
123
- customer_features = data.groupby('Customer').agg({
124
- 'CustomerName': 'first',
125
- 'Country': 'first',
126
- 'CustomerGroup': 'first',
127
- 'total_orders': 'first',
128
- 'last_order_date': 'first',
129
- 'first_order_date': 'first'
130
- }).reset_index()
131
-
132
- # Handle missing dates
133
- reference_date = pd.to_datetime('2024-12-31')
134
- customer_features['last_order_date'] = pd.to_datetime(customer_features['last_order_date'])
135
- customer_features['first_order_date'] = pd.to_datetime(customer_features['first_order_date'])
136
-
137
- # RFM Features
138
- customer_features['Recency'] = (reference_date - customer_features['last_order_date']).dt.days
139
- customer_features['Recency'] = customer_features['Recency'].fillna(365)
140
- customer_features['Frequency'] = customer_features['total_orders'].fillna(0)
141
-
142
- # Simulated monetary value
143
- np.random.seed(42)
144
- customer_features['Monetary'] = customer_features['Frequency'] * np.random.exponential(500, len(customer_features))
145
-
146
- # Lifecycle features
147
- customer_features['Tenure'] = (reference_date - customer_features['first_order_date']).dt.days
148
- customer_features['Tenure'] = customer_features['Tenure'].fillna(0)
149
- customer_features['OrderVelocity'] = customer_features['Frequency'] / (customer_features['Tenure'] / 30 + 1)
150
-
151
- # Categorical encoding
152
- self.label_encoders = {}
153
- for col in ['Country', 'CustomerGroup']:
154
- if col in customer_features.columns:
155
- self.label_encoders[col] = LabelEncoder()
156
- customer_features[f'{col}_encoded'] = self.label_encoders[col].fit_transform(
157
- customer_features[col].fillna('Unknown')
158
- )
159
-
160
- # Target variable
161
- customer_features['IsChurned'] = (
162
- (customer_features['Recency'] > 90) &
163
- (customer_features['Frequency'] > 0)
164
- ).astype(int)
165
-
166
- # Select features
167
- self.feature_columns = [
168
- 'Recency', 'Frequency', 'Monetary', 'Tenure', 'OrderVelocity',
169
- 'Country_encoded', 'CustomerGroup_encoded'
170
- ]
171
-
172
- return customer_features[self.feature_columns + ['IsChurned', 'Customer', 'CustomerName']]
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  def train_model(self, data):
175
  """Train RandomForest model"""
176
- X = data[self.feature_columns]
177
- y = data['IsChurned']
178
-
179
- # Train-test split
180
- X_train, X_test, y_train, y_test = train_test_split(
181
- X, y, test_size=0.2, random_state=42, stratify=y
182
- )
183
-
184
- # Train model (optimized for HF Spaces)
185
- self.model = RandomForestClassifier(
186
- n_estimators=50, # Reduced for performance
187
- max_depth=8,
188
- min_samples_split=20,
189
- class_weight='balanced',
190
- random_state=42,
191
- n_jobs=1 # Single thread for HF Spaces
192
- )
193
-
194
- self.model.fit(X_train, y_train)
195
-
196
- # Evaluate
197
- test_score = self.model.score(X_test, y_test)
198
- y_pred = self.model.predict(X_test)
199
-
200
- metrics = {
201
- 'test_accuracy': test_score,
202
- 'feature_columns': self.feature_columns,
203
- 'training_samples': len(X_train),
204
- 'churn_rate': y.mean(),
205
- 'feature_importance': dict(zip(self.feature_columns, self.model.feature_importances_))
206
- }
207
-
208
- return metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  def save_model_artifacts(self, metrics):
211
  """Save model and metadata"""
212
- # Ensure models directory exists
213
- Path('models').mkdir(exist_ok=True)
214
-
215
- # Save model with encoders
216
- model_data = {
217
- 'model': self.model,
218
- 'label_encoders': self.label_encoders,
219
- 'feature_columns': self.feature_columns,
220
- 'version': 'v1',
221
- 'training_date': datetime.now().isoformat()
222
- }
223
-
224
- joblib.dump(model_data, self.model_path)
225
-
226
- # Save metadata
227
- metadata = {
228
- 'model_name': 'churn_predictor',
229
- 'version': 'v1',
230
- 'training_date': datetime.now().isoformat(),
231
- 'metrics': metrics,
232
- 'status': 'trained'
233
- }
234
-
235
- with open(self.metadata_path, 'w') as f:
236
- json.dump(metadata, f, indent=2)
 
 
 
 
 
237
 
238
  def load_existing_metadata(self):
239
  """Load existing model metadata"""
240
  try:
241
  with open(self.metadata_path, 'r') as f:
242
  return json.load(f)
243
- except:
244
  return None
 
1
  """
2
  Embedded Model Training for HF Spaces
3
  Auto-trains model on first app load if not present
4
+ Handles SAP SALT dataset access with multiple fallback methods
5
  """
6
 
7
  import pandas as pd
 
16
  import streamlit as st
17
  from pathlib import Path
18
  from datetime import datetime
19
+ import requests
20
 
21
  class EmbeddedChurnTrainer:
22
  """Embedded trainer that works within HF Spaces constraints"""
 
34
 
35
  @st.cache_data
36
  def load_sap_data(_self):
37
+ """Load SAP data with multiple fallback methods"""
38
+ # Method 1: Try using datasets library (preferred)
39
  try:
40
+ from datasets import load_dataset
41
+ st.info("🔄 Loading SAP SALT data using Hugging Face datasets library...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # Try to load the dataset using proper HF datasets library
44
+ dataset = load_dataset("SAP/SALT", split="train", streaming=True)
45
+
46
+ # Convert to pandas DataFrame (limit for HF Spaces)
47
+ all_data = []
48
+ count = 0
49
+ max_records = 3000 # Limit for HF Spaces performance
50
+
51
+ for item in dataset:
52
+ if count >= max_records:
53
+ break
54
+
55
+ # Handle the data structure from SAP SALT dataset
56
+ record = {
57
+ 'Customer': item.get('Customer') or f'CUST_{count:06d}',
58
+ 'CustomerName': item.get('CustomerName') or f'Customer {count}',
59
+ 'Country': item.get('Country') or np.random.choice(['DE', 'US', 'FR', 'UK']),
60
+ 'CustomerGroup': item.get('CustomerGroup') or np.random.choice(['RETAIL', 'WHOLESALE']),
61
+ 'SalesDocument': item.get('SalesDocument') or f'SO_{count:08d}',
62
+ 'CreationDate': item.get('CreationDate') or '2024-01-01',
63
+ 'SoldToParty': item.get('Customer') or f'CUST_{count:06d}'
64
+ }
65
+ all_data.append(record)
66
+ count += 1
67
+
68
+ if all_data:
69
+ training_data = pd.DataFrame(all_data)
70
+ training_data = _self._add_aggregated_fields(training_data)
71
+ st.success(f"✅ Loaded {len(training_data)} records using HF datasets library")
72
+ return training_data
73
+
74
+ except ImportError:
75
+ st.warning("⚠️ Hugging Face datasets library not available, trying alternative method...")
76
  except Exception as e:
77
+ st.warning(f"⚠️ Datasets library failed ({str(e)}), trying alternative method...")
78
+
79
+ # Method 2: Try HF API endpoints
80
+ try:
81
+ st.info("🔄 Trying alternative data loading via Hugging Face API...")
82
+ return _self._load_via_hf_api()
83
+
84
+ except Exception as e:
85
+ st.warning(f"⚠️ HF API method failed ({str(e)}), creating synthetic data...")
86
+
87
+ # Method 3: Create synthetic data as fallback
88
+ return _self._create_synthetic_data()
89
+
90
+ def _load_via_hf_api(self):
91
+ """Alternative method using HF API"""
92
+ try:
93
+ # Try the HF dataset viewer API
94
+ base_url = "https://datasets-server.huggingface.co/rows"
95
+
96
+ response = requests.get(
97
+ f"{base_url}?dataset=SAP/SALT&config=default&split=train&offset=0&length=1000",
98
+ timeout=30
99
+ )
100
+
101
+ if response.status_code == 200:
102
+ data = response.json()
103
+ if 'rows' in data:
104
+ rows_data = []
105
+ for row in data['rows']:
106
+ if 'row' in row:
107
+ rows_data.append(row['row'])
108
+
109
+ if rows_data:
110
+ training_data = pd.DataFrame(rows_data)
111
+ training_data = self._add_aggregated_fields(training_data)
112
+ st.success(f"✅ Loaded {len(training_data)} records using HF API")
113
+ return training_data
114
+
115
+ raise Exception("No valid data returned from API")
116
+
117
+ except Exception as e:
118
+ raise Exception(f"API loading failed: {str(e)}")
119
+
120
+ def _create_synthetic_data(self):
121
+ """Create realistic synthetic SAP-like data for demonstration"""
122
+ st.info("🔄 Creating synthetic SAP-like data for demonstration...")
123
+
124
+ np.random.seed(42)
125
+ n_customers = 1000
126
+ n_sales_docs = 3000
127
+
128
+ # Generate realistic customer data
129
+ countries = ['DE', 'US', 'FR', 'UK', 'JP', 'CN', 'IN', 'BR', 'AU', 'CA']
130
+ customer_groups = ['RETAIL', 'WHOLESALE', 'DISTRIBUTOR', 'ENTERPRISE', 'SMB']
131
+
132
+ # Create base data
133
+ all_data = []
134
+
135
+ # Generate sales documents with customer data
136
+ for i in range(n_sales_docs):
137
+ customer_idx = np.random.randint(0, n_customers)
138
+ customer_id = f"CUST_{customer_idx:06d}"
139
+
140
+ # Create realistic date distribution (more recent orders more likely)
141
+ days_ago = max(1, int(np.random.exponential(50))) # Average 50 days ago
142
+ creation_date = (datetime.now() - pd.Timedelta(days=days_ago)).strftime('%Y-%m-%d')
143
+
144
+ record = {
145
+ 'Customer': customer_id,
146
+ 'CustomerName': f'Customer {customer_idx}',
147
+ 'Country': np.random.choice(countries),
148
+ 'CustomerGroup': np.random.choice(customer_groups),
149
+ 'SalesDocument': f"SO_{i:08d}",
150
+ 'CreationDate': creation_date,
151
+ 'SoldToParty': customer_id
152
+ }
153
+ all_data.append(record)
154
+
155
+ # Create DataFrame
156
+ training_data = pd.DataFrame(all_data)
157
+ training_data = self._add_aggregated_fields(training_data)
158
+
159
+ st.success(f"✅ Created {len(training_data)} synthetic records for demonstration")
160
+ st.info("📝 **Note**: Using synthetic data for demo. In production, configure proper SAP SALT access.")
161
+
162
+ return training_data
163
+
164
+ def _add_aggregated_fields(self, data):
165
+ """Add aggregated fields for feature engineering"""
166
+ # Add customer-level aggregations
167
+ customer_aggs = data.groupby('Customer').agg({
168
+ 'SalesDocument': 'count',
169
+ 'CreationDate': ['min', 'max']
170
+ }).reset_index()
171
+
172
+ # Flatten column names
173
+ customer_aggs.columns = ['Customer', 'total_orders', 'first_order_date', 'last_order_date']
174
+
175
+ # Merge back to original data
176
+ data = data.merge(customer_aggs, on='Customer', how='left')
177
+
178
+ return data
179
 
180
  def train_model_if_needed(self):
181
  """Train model if it doesn't exist, with progress updates"""
 
188
 
189
  try:
190
  # Step 1: Load data
191
+ status_text.text("📥 Loading SAP data...")
192
  progress_bar.progress(20)
193
  data = self.load_sap_data()
194
 
195
  if len(data) == 0:
196
+ st.error("No training data available")
197
  return None
198
 
199
  # Step 2: Feature engineering
200
+ status_text.text("🔧 Engineering features...")
201
  progress_bar.progress(40)
202
  features_data = self.engineer_features(data)
203
 
204
+ if len(features_data) == 0:
205
+ st.error("❌ Feature engineering failed")
206
+ return None
207
+
208
  # Step 3: Train model
209
+ status_text.text("🏋️ Training ML model...")
210
  progress_bar.progress(60)
211
  metrics = self.train_model(features_data)
212
 
213
+ if not metrics:
214
+ st.error("❌ Model training failed")
215
+ return None
216
+
217
  # Step 4: Save model
218
+ status_text.text("💾 Saving model...")
219
  progress_bar.progress(80)
220
  self.save_model_artifacts(metrics)
221
 
 
226
  return metrics
227
 
228
  except Exception as e:
229
+ st.error(f"Training failed: {str(e)}")
230
  return None
231
 
232
  def engineer_features(self, data):
233
  """Feature engineering for churn prediction"""
234
+ try:
235
+ # Customer-level aggregation
236
+ customer_features = data.groupby('Customer').agg({
237
+ 'CustomerName': 'first',
238
+ 'Country': 'first',
239
+ 'CustomerGroup': 'first',
240
+ 'total_orders': 'first',
241
+ 'last_order_date': 'first',
242
+ 'first_order_date': 'first'
243
+ }).reset_index()
244
+
245
+ # Handle missing dates
246
+ reference_date = pd.to_datetime('2024-12-31')
247
+ customer_features['last_order_date'] = pd.to_datetime(customer_features['last_order_date'])
248
+ customer_features['first_order_date'] = pd.to_datetime(customer_features['first_order_date'])
249
+
250
+ # RFM Features
251
+ customer_features['Recency'] = (reference_date - customer_features['last_order_date']).dt.days
252
+ customer_features['Recency'] = customer_features['Recency'].fillna(365)
253
+ customer_features['Frequency'] = customer_features['total_orders'].fillna(0)
254
+
255
+ # Simulated monetary value (consistent with seed)
256
+ np.random.seed(42)
257
+ customer_features['Monetary'] = customer_features['Frequency'] * np.random.exponential(500, len(customer_features))
258
+
259
+ # Lifecycle features
260
+ customer_features['Tenure'] = (reference_date - customer_features['first_order_date']).dt.days
261
+ customer_features['Tenure'] = customer_features['Tenure'].fillna(0)
262
+ customer_features['OrderVelocity'] = customer_features['Frequency'] / (customer_features['Tenure'] / 30 + 1)
263
+
264
+ # Categorical encoding
265
+ self.label_encoders = {}
266
+ for col in ['Country', 'CustomerGroup']:
267
+ if col in customer_features.columns:
268
+ self.label_encoders[col] = LabelEncoder()
269
+ customer_features[f'{col}_encoded'] = self.label_encoders[col].fit_transform(
270
+ customer_features[col].fillna('Unknown')
271
+ )
272
+
273
+ # Target variable (churn definition)
274
+ customer_features['IsChurned'] = (
275
+ (customer_features['Recency'] > 90) &
276
+ (customer_features['Frequency'] > 0)
277
+ ).astype(int)
278
+
279
+ # Select features for model
280
+ self.feature_columns = [
281
+ 'Recency', 'Frequency', 'Monetary', 'Tenure', 'OrderVelocity',
282
+ 'Country_encoded', 'CustomerGroup_encoded'
283
+ ]
284
+
285
+ # Return final dataset
286
+ final_features = customer_features[self.feature_columns + ['IsChurned', 'Customer', 'CustomerName']]
287
+
288
+ # Validate data
289
+ if len(final_features) < 10:
290
+ raise Exception("Insufficient data for training")
291
+
292
+ return final_features
293
+
294
+ except Exception as e:
295
+ st.error(f"Feature engineering failed: {str(e)}")
296
+ return pd.DataFrame()
297
 
298
  def train_model(self, data):
299
  """Train RandomForest model"""
300
+ try:
301
+ X = data[self.feature_columns]
302
+ y = data['IsChurned']
303
+
304
+ # Check for sufficient data
305
+ if len(X) < 20:
306
+ raise Exception("Insufficient training data")
307
+
308
+ if y.sum() == 0 or (y == 0).sum() == 0:
309
+ # Handle case where all customers are churned or none are churned
310
+ st.warning("⚠️ Unbalanced target variable detected")
311
+
312
+ # Train-test split
313
+ test_size = min(0.2, max(0.1, len(X) // 10)) # Adaptive test size
314
+ X_train, X_test, y_train, y_test = train_test_split(
315
+ X, y, test_size=test_size, random_state=42, stratify=y if len(np.unique(y)) > 1 else None
316
+ )
317
+
318
+ # Train model (optimized for HF Spaces)
319
+ self.model = RandomForestClassifier(
320
+ n_estimators=50, # Reduced for performance
321
+ max_depth=8,
322
+ min_samples_split=max(2, len(X_train) // 50),
323
+ min_samples_leaf=max(1, len(X_train) // 100),
324
+ class_weight='balanced',
325
+ random_state=42,
326
+ n_jobs=1 # Single thread for HF Spaces
327
+ )
328
+
329
+ self.model.fit(X_train, y_train)
330
+
331
+ # Evaluate
332
+ train_score = self.model.score(X_train, y_train)
333
+ test_score = self.model.score(X_test, y_test)
334
+
335
+ metrics = {
336
+ 'train_accuracy': train_score,
337
+ 'test_accuracy': test_score,
338
+ 'feature_columns': self.feature_columns,
339
+ 'training_samples': len(X_train),
340
+ 'test_samples': len(X_test),
341
+ 'churn_rate': float(y.mean()),
342
+ 'feature_importance': dict(zip(self.feature_columns, self.model.feature_importances_))
343
+ }
344
+
345
+ return metrics
346
+
347
+ except Exception as e:
348
+ st.error(f"Model training failed: {str(e)}")
349
+ return None
350
 
351
  def save_model_artifacts(self, metrics):
352
  """Save model and metadata"""
353
+ try:
354
+ # Ensure models directory exists
355
+ Path('models').mkdir(exist_ok=True)
356
+
357
+ # Save model with encoders and metadata
358
+ model_data = {
359
+ 'model': self.model,
360
+ 'label_encoders': self.label_encoders,
361
+ 'feature_columns': self.feature_columns,
362
+ 'version': 'v1',
363
+ 'training_date': datetime.now().isoformat()
364
+ }
365
+
366
+ joblib.dump(model_data, self.model_path)
367
+
368
+ # Save metadata
369
+ metadata = {
370
+ 'model_name': 'churn_predictor',
371
+ 'version': 'v1',
372
+ 'training_date': datetime.now().isoformat(),
373
+ 'metrics': metrics,
374
+ 'status': 'trained'
375
+ }
376
+
377
+ with open(self.metadata_path, 'w') as f:
378
+ json.dump(metadata, f, indent=2)
379
+
380
+ except Exception as e:
381
+ st.error(f"Failed to save model: {str(e)}")
382
+ raise
383
 
384
  def load_existing_metadata(self):
385
  """Load existing model metadata"""
386
  try:
387
  with open(self.metadata_path, 'r') as f:
388
  return json.load(f)
389
+ except Exception:
390
  return None