Update agent_tools/ml_tools.py
Browse files- agent_tools/ml_tools.py +13 -13
agent_tools/ml_tools.py
CHANGED
|
@@ -10,7 +10,6 @@ import json
|
|
| 10 |
from pathlib import Path
|
| 11 |
from datetime import datetime
|
| 12 |
import duckdb
|
| 13 |
-
import streamlit as st
|
| 14 |
|
| 15 |
# Global model cache for HF Spaces
|
| 16 |
_model_cache = {}
|
|
@@ -41,7 +40,7 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
|
|
| 41 |
# Load model
|
| 42 |
model_data = load_model_with_cache()
|
| 43 |
if model_data is None:
|
| 44 |
-
return json.dumps({"error": "Model not found. Please
|
| 45 |
|
| 46 |
model = model_data['model']
|
| 47 |
label_encoders = model_data['label_encoders']
|
|
@@ -54,20 +53,22 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
|
|
| 54 |
CREATE TABLE customers AS
|
| 55 |
SELECT * FROM 'hf://datasets/SAP/SALT/I_Customer.parquet'
|
| 56 |
LIMIT 2000
|
| 57 |
-
""")
|
| 58 |
|
| 59 |
conn.execute("""
|
| 60 |
CREATE TABLE sales_docs AS
|
| 61 |
SELECT * FROM 'hf://datasets/SAP/SALT/I_SalesDocument.parquet'
|
| 62 |
LIMIT 5000
|
| 63 |
-
""")
|
| 64 |
|
| 65 |
# Filter customers if specified
|
| 66 |
if customer_ids:
|
| 67 |
customer_list = [f"'{cid.strip()}'" for cid in customer_ids.split(',')]
|
| 68 |
where_clause = f"WHERE c.Customer IN ({','.join(customer_list)})"
|
|
|
|
| 69 |
else:
|
| 70 |
-
where_clause = "
|
|
|
|
| 71 |
|
| 72 |
# Get customer data
|
| 73 |
customer_data = conn.execute(f"""
|
|
@@ -81,13 +82,13 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
|
|
| 81 |
MIN(s.CreationDate) as first_order_date
|
| 82 |
FROM customers c
|
| 83 |
LEFT JOIN sales_docs s ON c.Customer = s.SoldToParty
|
| 84 |
-
{where_clause
|
| 85 |
GROUP BY c.Customer, c.CustomerName, c.Country, c.CustomerGroup
|
| 86 |
-
{
|
| 87 |
""").df()
|
| 88 |
|
| 89 |
if len(customer_data) == 0:
|
| 90 |
-
return json.dumps({"error": "No customers found"})
|
| 91 |
|
| 92 |
# Feature engineering (same as training)
|
| 93 |
reference_date = pd.to_datetime('2024-12-31')
|
|
@@ -114,7 +115,6 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
|
|
| 114 |
customer_data[col].fillna('Unknown')
|
| 115 |
)
|
| 116 |
except:
|
| 117 |
-
# Handle unseen categories
|
| 118 |
customer_data[f'{col}_encoded'] = 0
|
| 119 |
|
| 120 |
# Make predictions
|
|
@@ -133,7 +133,7 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
|
|
| 133 |
# High risk customers
|
| 134 |
high_risk = results[results['churn_probability'] >= risk_threshold].sort_values(
|
| 135 |
'churn_probability', ascending=False
|
| 136 |
-
).head(20)
|
| 137 |
|
| 138 |
# Generate recommendations
|
| 139 |
recommendations = []
|
|
@@ -153,8 +153,8 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
|
|
| 153 |
"high_risk_count": len(high_risk),
|
| 154 |
"churn_rate_predicted": round(len(high_risk) / len(results) * 100, 2) if len(results) > 0 else 0,
|
| 155 |
"urgent_actions": recommendations,
|
| 156 |
-
"model_performance":
|
| 157 |
-
"
|
| 158 |
})
|
| 159 |
|
| 160 |
except Exception as e:
|
|
@@ -163,7 +163,7 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
|
|
| 163 |
except Exception as e:
|
| 164 |
return json.dumps({
|
| 165 |
"error": f"Churn analysis failed: {str(e)}",
|
| 166 |
-
"suggestion": "Please ensure model is trained"
|
| 167 |
})
|
| 168 |
|
| 169 |
@tool
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
from datetime import datetime
|
| 12 |
import duckdb
|
|
|
|
| 13 |
|
| 14 |
# Global model cache for HF Spaces
|
| 15 |
_model_cache = {}
|
|
|
|
| 40 |
# Load model
|
| 41 |
model_data = load_model_with_cache()
|
| 42 |
if model_data is None:
|
| 43 |
+
return json.dumps({"error": "Model not found. Please train the model first."})
|
| 44 |
|
| 45 |
model = model_data['model']
|
| 46 |
label_encoders = model_data['label_encoders']
|
|
|
|
| 53 |
CREATE TABLE customers AS
|
| 54 |
SELECT * FROM 'hf://datasets/SAP/SALT/I_Customer.parquet'
|
| 55 |
LIMIT 2000
|
| 56 |
+
""")
|
| 57 |
|
| 58 |
conn.execute("""
|
| 59 |
CREATE TABLE sales_docs AS
|
| 60 |
SELECT * FROM 'hf://datasets/SAP/SALT/I_SalesDocument.parquet'
|
| 61 |
LIMIT 5000
|
| 62 |
+
""")
|
| 63 |
|
| 64 |
# Filter customers if specified
|
| 65 |
if customer_ids:
|
| 66 |
customer_list = [f"'{cid.strip()}'" for cid in customer_ids.split(',')]
|
| 67 |
where_clause = f"WHERE c.Customer IN ({','.join(customer_list)})"
|
| 68 |
+
limit_clause = ""
|
| 69 |
else:
|
| 70 |
+
where_clause = ""
|
| 71 |
+
limit_clause = "LIMIT 500" # Limit for demo
|
| 72 |
|
| 73 |
# Get customer data
|
| 74 |
customer_data = conn.execute(f"""
|
|
|
|
| 82 |
MIN(s.CreationDate) as first_order_date
|
| 83 |
FROM customers c
|
| 84 |
LEFT JOIN sales_docs s ON c.Customer = s.SoldToParty
|
| 85 |
+
{where_clause}
|
| 86 |
GROUP BY c.Customer, c.CustomerName, c.Country, c.CustomerGroup
|
| 87 |
+
{limit_clause}
|
| 88 |
""").df()
|
| 89 |
|
| 90 |
if len(customer_data) == 0:
|
| 91 |
+
return json.dumps({"error": "No customers found for analysis"})
|
| 92 |
|
| 93 |
# Feature engineering (same as training)
|
| 94 |
reference_date = pd.to_datetime('2024-12-31')
|
|
|
|
| 115 |
customer_data[col].fillna('Unknown')
|
| 116 |
)
|
| 117 |
except:
|
|
|
|
| 118 |
customer_data[f'{col}_encoded'] = 0
|
| 119 |
|
| 120 |
# Make predictions
|
|
|
|
| 133 |
# High risk customers
|
| 134 |
high_risk = results[results['churn_probability'] >= risk_threshold].sort_values(
|
| 135 |
'churn_probability', ascending=False
|
| 136 |
+
).head(20)
|
| 137 |
|
| 138 |
# Generate recommendations
|
| 139 |
recommendations = []
|
|
|
|
| 153 |
"high_risk_count": len(high_risk),
|
| 154 |
"churn_rate_predicted": round(len(high_risk) / len(results) * 100, 2) if len(results) > 0 else 0,
|
| 155 |
"urgent_actions": recommendations,
|
| 156 |
+
"model_performance": "Model ready and operational",
|
| 157 |
+
"note": "Results limited for demo performance"
|
| 158 |
})
|
| 159 |
|
| 160 |
except Exception as e:
|
|
|
|
| 163 |
except Exception as e:
|
| 164 |
return json.dumps({
|
| 165 |
"error": f"Churn analysis failed: {str(e)}",
|
| 166 |
+
"suggestion": "Please ensure model is trained and data is available"
|
| 167 |
})
|
| 168 |
|
| 169 |
@tool
|