PD03 commited on
Commit
c829fa9
·
verified ·
1 Parent(s): d92d528

Create utils/model_trainer.py

Browse files
Files changed (1) hide show
  1. utils/model_trainer.py +241 -0
utils/model_trainer.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Embedded Model Training for HF Spaces
3
+ Auto-trains model on first app load if not present
4
+ """
5
+
6
+ import pandas as pd
7
+ import numpy as np
8
+ import duckdb
9
+ from sklearn.ensemble import RandomForestClassifier
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn.preprocessing import LabelEncoder
12
+ from sklearn.metrics import classification_report
13
+ import joblib
14
+ import json
15
+ import streamlit as st
16
+ from pathlib import Path
17
+ from datetime import datetime
18
+
19
+ class EmbeddedChurnTrainer:
20
+ """Embedded trainer that works within HF Spaces constraints"""
21
+
22
+ def __init__(self):
23
+ self.model_path = Path('models/churn_model_v1.pkl')
24
+ self.metadata_path = Path('models/model_metadata.json')
25
+
26
+ def model_exists(self):
27
+ """Check if trained model exists"""
28
+ return self.model_path.exists() and self.metadata_path.exists()
29
+
30
+ @st.cache_data
31
+ def load_sap_data(_self):
32
+ """Load SAP data with Streamlit caching"""
33
+ try:
34
+ conn = duckdb.connect(':memory:')
35
+
36
+ # Load SAP datasets
37
+ conn.execute("""
38
+ CREATE TABLE customers AS
39
+ SELECT * FROM 'hf://datasets/SAP/SALT/I_Customer.parquet'
40
+ LIMIT 5000
41
+ """) # Limit for HF Spaces performance
42
+
43
+ conn.execute("""
44
+ CREATE TABLE sales_docs AS
45
+ SELECT * FROM 'hf://datasets/SAP/SALT/I_SalesDocument.parquet'
46
+ LIMIT 10000
47
+ """) # Limit for HF Spaces performance
48
+
49
+ # Join data
50
+ training_data = conn.execute("""
51
+ SELECT
52
+ c.Customer,
53
+ c.CustomerName,
54
+ c.Country,
55
+ c.CustomerGroup,
56
+ s.SalesDocument,
57
+ s.CreationDate,
58
+ s.SoldToParty,
59
+ COUNT(s.SalesDocument) OVER (PARTITION BY c.Customer) as total_orders,
60
+ MAX(s.CreationDate) OVER (PARTITION BY c.Customer) as last_order_date,
61
+ MIN(s.CreationDate) OVER (PARTITION BY c.Customer) as first_order_date
62
+ FROM customers c
63
+ LEFT JOIN sales_docs s ON c.Customer = s.SoldToParty
64
+ WHERE c.Customer IS NOT NULL
65
+ """).df()
66
+
67
+ return training_data
68
+
69
+ except Exception as e:
70
+ st.error(f"Data loading failed: {str(e)}")
71
+ return pd.DataFrame()
72
+
73
+ def train_model_if_needed(self):
74
+ """Train model if it doesn't exist, with progress bar"""
75
+ if self.model_exists():
76
+ return self.load_existing_metadata()
77
+
78
+ # Show training progress
79
+ progress_bar = st.progress(0)
80
+ status_text = st.empty()
81
+
82
+ try:
83
+ # Step 1: Load data
84
+ status_text.text("Loading SAP data...")
85
+ progress_bar.progress(20)
86
+ data = self.load_sap_data()
87
+
88
+ if len(data) == 0:
89
+ st.error("No training data available")
90
+ return None
91
+
92
+ # Step 2: Feature engineering
93
+ status_text.text("Engineering features...")
94
+ progress_bar.progress(40)
95
+ features_data = self.engineer_features(data)
96
+
97
+ # Step 3: Train model
98
+ status_text.text("Training ML model...")
99
+ progress_bar.progress(60)
100
+ metrics = self.train_model(features_data)
101
+
102
+ # Step 4: Save model
103
+ status_text.text("Saving model...")
104
+ progress_bar.progress(80)
105
+ self.save_model_artifacts(metrics)
106
+
107
+ # Complete
108
+ progress_bar.progress(100)
109
+ status_text.text("✅ Model training complete!")
110
+
111
+ return metrics
112
+
113
+ except Exception as e:
114
+ st.error(f"Training failed: {str(e)}")
115
+ return None
116
+
117
+ def engineer_features(self, data):
118
+ """Streamlined feature engineering for HF Spaces"""
119
+ # Customer-level aggregation
120
+ customer_features = data.groupby('Customer').agg({
121
+ 'CustomerName': 'first',
122
+ 'Country': 'first',
123
+ 'CustomerGroup': 'first',
124
+ 'total_orders': 'first',
125
+ 'last_order_date': 'first',
126
+ 'first_order_date': 'first'
127
+ }).reset_index()
128
+
129
+ # Handle missing dates
130
+ reference_date = pd.to_datetime('2024-12-31')
131
+ customer_features['last_order_date'] = pd.to_datetime(customer_features['last_order_date'])
132
+ customer_features['first_order_date'] = pd.to_datetime(customer_features['first_order_date'])
133
+
134
+ # RFM Features
135
+ customer_features['Recency'] = (reference_date - customer_features['last_order_date']).dt.days
136
+ customer_features['Recency'] = customer_features['Recency'].fillna(365)
137
+ customer_features['Frequency'] = customer_features['total_orders'].fillna(0)
138
+
139
+ # Simulated monetary value
140
+ np.random.seed(42)
141
+ customer_features['Monetary'] = customer_features['Frequency'] * np.random.exponential(500, len(customer_features))
142
+
143
+ # Lifecycle features
144
+ customer_features['Tenure'] = (reference_date - customer_features['first_order_date']).dt.days
145
+ customer_features['Tenure'] = customer_features['Tenure'].fillna(0)
146
+ customer_features['OrderVelocity'] = customer_features['Frequency'] / (customer_features['Tenure'] / 30 + 1)
147
+
148
+ # Categorical encoding
149
+ self.label_encoders = {}
150
+ for col in ['Country', 'CustomerGroup']:
151
+ if col in customer_features.columns:
152
+ self.label_encoders[col] = LabelEncoder()
153
+ customer_features[f'{col}_encoded'] = self.label_encoders[col].fit_transform(
154
+ customer_features[col].fillna('Unknown')
155
+ )
156
+
157
+ # Target variable
158
+ customer_features['IsChurned'] = (
159
+ (customer_features['Recency'] > 90) &
160
+ (customer_features['Frequency'] > 0)
161
+ ).astype(int)
162
+
163
+ # Select features
164
+ self.feature_columns = [
165
+ 'Recency', 'Frequency', 'Monetary', 'Tenure', 'OrderVelocity',
166
+ 'Country_encoded', 'CustomerGroup_encoded'
167
+ ]
168
+
169
+ return customer_features[self.feature_columns + ['IsChurned', 'Customer', 'CustomerName']]
170
+
171
+ def train_model(self, data):
172
+ """Train RandomForest model"""
173
+ X = data[self.feature_columns]
174
+ y = data['IsChurned']
175
+
176
+ # Train-test split
177
+ X_train, X_test, y_train, y_test = train_test_split(
178
+ X, y, test_size=0.2, random_state=42, stratify=y
179
+ )
180
+
181
+ # Train model
182
+ self.model = RandomForestClassifier(
183
+ n_estimators=50, # Reduced for HF Spaces performance
184
+ max_depth=8,
185
+ min_samples_split=20,
186
+ class_weight='balanced',
187
+ random_state=42,
188
+ n_jobs=1 # Single thread for HF Spaces
189
+ )
190
+
191
+ self.model.fit(X_train, y_train)
192
+
193
+ # Evaluate
194
+ test_score = self.model.score(X_test, y_test)
195
+ y_pred = self.model.predict(X_test)
196
+
197
+ metrics = {
198
+ 'test_accuracy': test_score,
199
+ 'feature_columns': self.feature_columns,
200
+ 'training_samples': len(X_train),
201
+ 'churn_rate': y.mean(),
202
+ 'feature_importance': dict(zip(self.feature_columns, self.model.feature_importances_))
203
+ }
204
+
205
+ return metrics
206
+
207
+ def save_model_artifacts(self, metrics):
208
+ """Save model and metadata"""
209
+ # Ensure models directory exists
210
+ Path('models').mkdir(exist_ok=True)
211
+
212
+ # Save model with encoders
213
+ model_data = {
214
+ 'model': self.model,
215
+ 'label_encoders': self.label_encoders,
216
+ 'feature_columns': self.feature_columns,
217
+ 'version': 'v1',
218
+ 'training_date': datetime.now().isoformat()
219
+ }
220
+
221
+ joblib.dump(model_data, self.model_path)
222
+
223
+ # Save metadata
224
+ metadata = {
225
+ 'model_name': 'churn_predictor',
226
+ 'version': 'v1',
227
+ 'training_date': datetime.now().isoformat(),
228
+ 'metrics': metrics,
229
+ 'status': 'trained'
230
+ }
231
+
232
+ with open(self.metadata_path, 'w') as f:
233
+ json.dump(metadata, f, indent=2)
234
+
235
+ def load_existing_metadata(self):
236
+ """Load existing model metadata"""
237
+ try:
238
+ with open(self.metadata_path, 'r') as f:
239
+ return json.load(f)
240
+ except:
241
+ return None