Trae Assistant
Enhance: Support CSV/Excel upload, add Chinese localization, fix bugs
f23dcfb
import os
import json
import random
import pandas as pd
import numpy as np
from flask import Flask, render_template, request, jsonify
from datetime import datetime, timedelta
app = Flask(__name__)
# Default Configuration
app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024 # 100MB limit
def process_dataframe(df):
"""Common logic to process dataframe and return RFM results."""
# Validation
# Normalize column names to be case-insensitive or handle variations
df.columns = [c.strip() for c in df.columns]
# Map common column names to required ones
col_map = {
'客户ID': 'CustomerID', 'customerid': 'CustomerID', 'customer_id': 'CustomerID',
'订单日期': 'OrderDate', 'orderdate': 'OrderDate', 'order_date': 'OrderDate', 'date': 'OrderDate',
'金额': 'Amount', 'amount': 'Amount', 'total': 'Amount'
}
df = df.rename(columns={c: col_map.get(c.lower(), c) for c in df.columns})
required_cols = ['CustomerID', 'OrderDate', 'Amount']
missing = [c for c in required_cols if c not in df.columns]
if missing:
raise ValueError(f"缺少必要列: {', '.join(missing)}。请确保包含 CustomerID(客户ID), OrderDate(订单日期), Amount(金额)。")
# Run RFM
return calculate_rfm(df)
@app.route('/api/upload', methods=['POST'])
def upload_file():
try:
if 'file' not in request.files:
return jsonify({"error": "No file part"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
if file and (file.filename.endswith('.csv') or file.filename.endswith('.txt')):
try:
df = pd.read_csv(file)
except UnicodeDecodeError:
# Try common encodings for Chinese users
file.seek(0)
df = pd.read_csv(file, encoding='gbk')
elif file and (file.filename.endswith('.xlsx') or file.filename.endswith('.xls')):
df = pd.read_excel(file)
else:
return jsonify({"error": "Unsupported file format. Please upload CSV or Excel."}), 400
rfm_result = process_dataframe(df)
# Statistics for Charts
segment_counts = rfm_result['Segment'].value_counts().reset_index()
segment_counts.columns = ['name', 'value']
segment_monetary = rfm_result.groupby('Segment')['Monetary'].sum().reset_index()
segment_monetary.columns = ['name', 'value']
# Scatter Data
scatter_data = []
for segment in rfm_result['Segment'].unique():
seg_df = rfm_result[rfm_result['Segment'] == segment]
series_data = seg_df[['Recency', 'Frequency', 'Monetary', 'CustomerID', 'Segment']].values.tolist()
scatter_data.append({
"name": segment,
"data": series_data
})
table_data = rfm_result.sort_values('Monetary', ascending=False).head(100).to_dict(orient='records')
return jsonify({
"segments_pie": segment_counts.to_dict(orient='records'),
"segments_bar": segment_monetary.to_dict(orient='records'),
"scatter_series": scatter_data,
"table_data": table_data,
"summary": {
"total_customers": len(rfm_result),
"total_revenue": float(rfm_result['Monetary'].sum()),
"avg_order_value": float(df['Amount'].mean())
}
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/api/analyze', methods=['POST'])
def analyze():
try:
json_data = request.json
if not json_data:
return jsonify({"error": "No data provided"}), 400
df = pd.DataFrame(json_data)
# Use shared processing logic
rfm_result = process_dataframe(df)
# Statistics for Charts (Duplicate logic for now to keep it simple, or could refactor further)
segment_counts = rfm_result['Segment'].value_counts().reset_index()
segment_counts.columns = ['name', 'value']
segment_monetary = rfm_result.groupby('Segment')['Monetary'].sum().reset_index()
segment_monetary.columns = ['name', 'value']
scatter_data = []
for segment in rfm_result['Segment'].unique():
seg_df = rfm_result[rfm_result['Segment'] == segment]
series_data = seg_df[['Recency', 'Frequency', 'Monetary', 'CustomerID', 'Segment']].values.tolist()
scatter_data.append({
"name": segment,
"data": series_data
})
table_data = rfm_result.sort_values('Monetary', ascending=False).head(100).to_dict(orient='records')
return jsonify({
"segments_pie": segment_counts.to_dict(orient='records'),
"segments_bar": segment_monetary.to_dict(orient='records'),
"scatter_series": scatter_data,
"table_data": table_data,
"summary": {
"total_customers": len(rfm_result),
"total_revenue": float(rfm_result['Monetary'].sum()),
"avg_order_value": float(df['Amount'].mean())
}
})
except Exception as e:
return jsonify({"error": str(e)}), 500
def generate_demo_data(n=500):
"""Generate realistic e-commerce transaction data."""
data = []
end_date = datetime.now()
customer_ids = [f"C{str(i).zfill(3)}" for i in range(1, 101)] # 100 customers
for _ in range(n):
cid = random.choice(customer_ids)
# Random date within last 365 days
days_offset = random.randint(0, 365)
date = end_date - timedelta(days=days_offset)
# Random amount with some outliers
amount = round(random.uniform(10, 500) + (random.random() * 1000 if random.random() > 0.9 else 0), 2)
data.append({
"CustomerID": cid,
"OrderDate": date.strftime("%Y-%m-%d"),
"Amount": amount
})
return data
def calculate_rfm(df):
"""
Calculate RFM metrics and segments.
df columns: CustomerID, OrderDate, Amount
"""
# Ensure date format
df['OrderDate'] = pd.to_datetime(df['OrderDate'])
# Reference date = max date + 1 day
snapshot_date = df['OrderDate'].max() + timedelta(days=1)
# Group by CustomerID
rfm = df.groupby('CustomerID').agg({
'OrderDate': lambda x: (snapshot_date - x.max()).days,
'CustomerID': 'count',
'Amount': 'sum'
}).rename(columns={
'OrderDate': 'Recency',
'CustomerID': 'Frequency',
'Amount': 'Monetary'
})
# Quintiles (1-5)
# Recency: Lower is better (5), Higher is worse (1)
# Frequency: Higher is better (5)
# Monetary: Higher is better (5)
# Handle small datasets where qcut might fail due to duplicate edges
try:
r_labels = range(5, 0, -1)
f_labels = range(1, 6)
m_labels = range(1, 6)
rfm['R'] = pd.qcut(rfm['Recency'], q=5, labels=r_labels, duplicates='drop')
rfm['F'] = pd.qcut(rfm['Frequency'], q=5, labels=f_labels, duplicates='drop')
rfm['M'] = pd.qcut(rfm['Monetary'], q=5, labels=m_labels, duplicates='drop')
except:
# Fallback for very small data: simple ranking
rfm['R'] = 3
rfm['F'] = 3
rfm['M'] = 3
# Cast to int
rfm['R'] = rfm['R'].astype(int)
rfm['F'] = rfm['F'].astype(int)
rfm['M'] = rfm['M'].astype(int)
rfm['RFM_Score'] = rfm['R'].astype(str) + rfm['F'].astype(str) + rfm['M'].astype(str)
# Segment Logic
def segment_customer(row):
r, f, m = row['R'], row['F'], row['M']
avg_fm = (f + m) / 2
if r >= 5 and avg_fm >= 5:
return "Champions (至尊王者)"
elif r >= 3 and avg_fm >= 4:
return "Loyal Customers (忠诚客户)"
elif r >= 4 and avg_fm >= 2:
return "Potential Loyalist (潜力股)"
elif r >= 5 and avg_fm == 1:
return "New Customers (新客)"
elif r >= 3 and avg_fm <= 2:
return "Promising (这就去买)"
elif r <= 2 and avg_fm >= 4:
return "At Risk (流失预警)"
elif r <= 2 and avg_fm >= 2:
return "Hibernating (沉睡客户)"
else:
return "Lost (已流失)"
rfm['Segment'] = rfm.apply(segment_customer, axis=1)
# Prepare for JSON
rfm['CustomerID'] = rfm.index
result = rfm.reset_index(drop=True)
return result
@app.route('/')
def index():
return render_template('index.html')
@app.route('/api/demo-data', methods=['GET'])
def get_demo_data():
data = generate_demo_data()
return jsonify(data)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True)