Update utils/model_trainer.py
Browse files- utils/model_trainer.py +10 -7
utils/model_trainer.py
CHANGED
|
@@ -22,6 +22,9 @@ class EmbeddedChurnTrainer:
|
|
| 22 |
def __init__(self):
|
| 23 |
self.model_path = Path('models/churn_model_v1.pkl')
|
| 24 |
self.metadata_path = Path('models/model_metadata.json')
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def model_exists(self):
|
| 27 |
"""Check if trained model exists"""
|
|
@@ -33,18 +36,18 @@ class EmbeddedChurnTrainer:
|
|
| 33 |
try:
|
| 34 |
conn = duckdb.connect(':memory:')
|
| 35 |
|
| 36 |
-
# Load SAP datasets
|
| 37 |
conn.execute("""
|
| 38 |
CREATE TABLE customers AS
|
| 39 |
SELECT * FROM 'hf://datasets/SAP/SALT/I_Customer.parquet'
|
| 40 |
LIMIT 5000
|
| 41 |
-
""")
|
| 42 |
|
| 43 |
conn.execute("""
|
| 44 |
CREATE TABLE sales_docs AS
|
| 45 |
SELECT * FROM 'hf://datasets/SAP/SALT/I_SalesDocument.parquet'
|
| 46 |
LIMIT 10000
|
| 47 |
-
""")
|
| 48 |
|
| 49 |
# Join data
|
| 50 |
training_data = conn.execute("""
|
|
@@ -71,7 +74,7 @@ class EmbeddedChurnTrainer:
|
|
| 71 |
return pd.DataFrame()
|
| 72 |
|
| 73 |
def train_model_if_needed(self):
|
| 74 |
-
"""Train model if it doesn't exist, with progress
|
| 75 |
if self.model_exists():
|
| 76 |
return self.load_existing_metadata()
|
| 77 |
|
|
@@ -115,7 +118,7 @@ class EmbeddedChurnTrainer:
|
|
| 115 |
return None
|
| 116 |
|
| 117 |
def engineer_features(self, data):
|
| 118 |
-
"""
|
| 119 |
# Customer-level aggregation
|
| 120 |
customer_features = data.groupby('Customer').agg({
|
| 121 |
'CustomerName': 'first',
|
|
@@ -178,9 +181,9 @@ class EmbeddedChurnTrainer:
|
|
| 178 |
X, y, test_size=0.2, random_state=42, stratify=y
|
| 179 |
)
|
| 180 |
|
| 181 |
-
# Train model
|
| 182 |
self.model = RandomForestClassifier(
|
| 183 |
-
n_estimators=50, # Reduced for
|
| 184 |
max_depth=8,
|
| 185 |
min_samples_split=20,
|
| 186 |
class_weight='balanced',
|
|
|
|
| 22 |
def __init__(self):
|
| 23 |
self.model_path = Path('models/churn_model_v1.pkl')
|
| 24 |
self.metadata_path = Path('models/model_metadata.json')
|
| 25 |
+
self.model = None
|
| 26 |
+
self.label_encoders = {}
|
| 27 |
+
self.feature_columns = []
|
| 28 |
|
| 29 |
def model_exists(self):
|
| 30 |
"""Check if trained model exists"""
|
|
|
|
| 36 |
try:
|
| 37 |
conn = duckdb.connect(':memory:')
|
| 38 |
|
| 39 |
+
# Load SAP datasets with limits for HF Spaces performance
|
| 40 |
conn.execute("""
|
| 41 |
CREATE TABLE customers AS
|
| 42 |
SELECT * FROM 'hf://datasets/SAP/SALT/I_Customer.parquet'
|
| 43 |
LIMIT 5000
|
| 44 |
+
""")
|
| 45 |
|
| 46 |
conn.execute("""
|
| 47 |
CREATE TABLE sales_docs AS
|
| 48 |
SELECT * FROM 'hf://datasets/SAP/SALT/I_SalesDocument.parquet'
|
| 49 |
LIMIT 10000
|
| 50 |
+
""")
|
| 51 |
|
| 52 |
# Join data
|
| 53 |
training_data = conn.execute("""
|
|
|
|
| 74 |
return pd.DataFrame()
|
| 75 |
|
| 76 |
def train_model_if_needed(self):
|
| 77 |
+
"""Train model if it doesn't exist, with progress updates"""
|
| 78 |
if self.model_exists():
|
| 79 |
return self.load_existing_metadata()
|
| 80 |
|
|
|
|
| 118 |
return None
|
| 119 |
|
| 120 |
def engineer_features(self, data):
|
| 121 |
+
"""Feature engineering for churn prediction"""
|
| 122 |
# Customer-level aggregation
|
| 123 |
customer_features = data.groupby('Customer').agg({
|
| 124 |
'CustomerName': 'first',
|
|
|
|
| 181 |
X, y, test_size=0.2, random_state=42, stratify=y
|
| 182 |
)
|
| 183 |
|
| 184 |
+
# Train model (optimized for HF Spaces)
|
| 185 |
self.model = RandomForestClassifier(
|
| 186 |
+
n_estimators=50, # Reduced for performance
|
| 187 |
max_depth=8,
|
| 188 |
min_samples_split=20,
|
| 189 |
class_weight='balanced',
|