PD03 commited on
Commit
d131cca
·
verified ·
1 Parent(s): f3782cb

Update agent_tools/ml_tools.py

Browse files
Files changed (1) hide show
  1. agent_tools/ml_tools.py +13 -13
agent_tools/ml_tools.py CHANGED
@@ -10,7 +10,6 @@ import json
10
  from pathlib import Path
11
  from datetime import datetime
12
  import duckdb
13
- import streamlit as st
14
 
15
  # Global model cache for HF Spaces
16
  _model_cache = {}
@@ -41,7 +40,7 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
41
  # Load model
42
  model_data = load_model_with_cache()
43
  if model_data is None:
44
- return json.dumps({"error": "Model not found. Please wait for training to complete."})
45
 
46
  model = model_data['model']
47
  label_encoders = model_data['label_encoders']
@@ -54,20 +53,22 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
54
  CREATE TABLE customers AS
55
  SELECT * FROM 'hf://datasets/SAP/SALT/I_Customer.parquet'
56
  LIMIT 2000
57
- """) # Limit for performance
58
 
59
  conn.execute("""
60
  CREATE TABLE sales_docs AS
61
  SELECT * FROM 'hf://datasets/SAP/SALT/I_SalesDocument.parquet'
62
  LIMIT 5000
63
- """) # Limit for performance
64
 
65
  # Filter customers if specified
66
  if customer_ids:
67
  customer_list = [f"'{cid.strip()}'" for cid in customer_ids.split(',')]
68
  where_clause = f"WHERE c.Customer IN ({','.join(customer_list)})"
 
69
  else:
70
- where_clause = "LIMIT 500" # Further limit for demo
 
71
 
72
  # Get customer data
73
  customer_data = conn.execute(f"""
@@ -81,13 +82,13 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
81
  MIN(s.CreationDate) as first_order_date
82
  FROM customers c
83
  LEFT JOIN sales_docs s ON c.Customer = s.SoldToParty
84
- {where_clause if not customer_ids else ""}
85
  GROUP BY c.Customer, c.CustomerName, c.Country, c.CustomerGroup
86
- {where_clause if customer_ids else ""}
87
  """).df()
88
 
89
  if len(customer_data) == 0:
90
- return json.dumps({"error": "No customers found"})
91
 
92
  # Feature engineering (same as training)
93
  reference_date = pd.to_datetime('2024-12-31')
@@ -114,7 +115,6 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
114
  customer_data[col].fillna('Unknown')
115
  )
116
  except:
117
- # Handle unseen categories
118
  customer_data[f'{col}_encoded'] = 0
119
 
120
  # Make predictions
@@ -133,7 +133,7 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
133
  # High risk customers
134
  high_risk = results[results['churn_probability'] >= risk_threshold].sort_values(
135
  'churn_probability', ascending=False
136
- ).head(20) # Limit results for HF Spaces
137
 
138
  # Generate recommendations
139
  recommendations = []
@@ -153,8 +153,8 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
153
  "high_risk_count": len(high_risk),
154
  "churn_rate_predicted": round(len(high_risk) / len(results) * 100, 2) if len(results) > 0 else 0,
155
  "urgent_actions": recommendations,
156
- "model_performance": f"Accuracy: {model_data.get('accuracy', 'N/A')}",
157
- "hf_spaces_note": "Results limited for demo performance"
158
  })
159
 
160
  except Exception as e:
@@ -163,7 +163,7 @@ def predict_customer_churn_hf(customer_ids: str = None, risk_threshold: float =
163
  except Exception as e:
164
  return json.dumps({
165
  "error": f"Churn analysis failed: {str(e)}",
166
- "suggestion": "Please ensure model is trained"
167
  })
168
 
169
  @tool
 
10
  from pathlib import Path
11
  from datetime import datetime
12
  import duckdb
 
13
 
14
  # Global model cache for HF Spaces
15
  _model_cache = {}
 
40
  # Load model
41
  model_data = load_model_with_cache()
42
  if model_data is None:
43
+ return json.dumps({"error": "Model not found. Please train the model first."})
44
 
45
  model = model_data['model']
46
  label_encoders = model_data['label_encoders']
 
53
  CREATE TABLE customers AS
54
  SELECT * FROM 'hf://datasets/SAP/SALT/I_Customer.parquet'
55
  LIMIT 2000
56
+ """)
57
 
58
  conn.execute("""
59
  CREATE TABLE sales_docs AS
60
  SELECT * FROM 'hf://datasets/SAP/SALT/I_SalesDocument.parquet'
61
  LIMIT 5000
62
+ """)
63
 
64
  # Filter customers if specified
65
  if customer_ids:
66
  customer_list = [f"'{cid.strip()}'" for cid in customer_ids.split(',')]
67
  where_clause = f"WHERE c.Customer IN ({','.join(customer_list)})"
68
+ limit_clause = ""
69
  else:
70
+ where_clause = ""
71
+ limit_clause = "LIMIT 500" # Limit for demo
72
 
73
  # Get customer data
74
  customer_data = conn.execute(f"""
 
82
  MIN(s.CreationDate) as first_order_date
83
  FROM customers c
84
  LEFT JOIN sales_docs s ON c.Customer = s.SoldToParty
85
+ {where_clause}
86
  GROUP BY c.Customer, c.CustomerName, c.Country, c.CustomerGroup
87
+ {limit_clause}
88
  """).df()
89
 
90
  if len(customer_data) == 0:
91
+ return json.dumps({"error": "No customers found for analysis"})
92
 
93
  # Feature engineering (same as training)
94
  reference_date = pd.to_datetime('2024-12-31')
 
115
  customer_data[col].fillna('Unknown')
116
  )
117
  except:
 
118
  customer_data[f'{col}_encoded'] = 0
119
 
120
  # Make predictions
 
133
  # High risk customers
134
  high_risk = results[results['churn_probability'] >= risk_threshold].sort_values(
135
  'churn_probability', ascending=False
136
+ ).head(20)
137
 
138
  # Generate recommendations
139
  recommendations = []
 
153
  "high_risk_count": len(high_risk),
154
  "churn_rate_predicted": round(len(high_risk) / len(results) * 100, 2) if len(results) > 0 else 0,
155
  "urgent_actions": recommendations,
156
+ "model_performance": "Model ready and operational",
157
+ "note": "Results limited for demo performance"
158
  })
159
 
160
  except Exception as e:
 
163
  except Exception as e:
164
  return json.dumps({
165
  "error": f"Churn analysis failed: {str(e)}",
166
+ "suggestion": "Please ensure model is trained and data is available"
167
  })
168
 
169
  @tool