import os
import json
import pandas as pd
import numpy as np
from flask import Flask, render_template, request, jsonify
from werkzeug.utils import secure_filename
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB limit
app.config['UPLOAD_FOLDER'] = 'uploads'
if not os.path.exists(app.config['UPLOAD_FOLDER']):
os.makedirs(app.config['UPLOAD_FOLDER'])
# --- Helper Functions ---
def analyze_dataframe(df):
"""Generates a summary of the dataframe."""
summary = {
"columns": list(df.columns),
"row_count": len(df),
"dtypes": {k: str(v) for k, v in df.dtypes.items()},
"numeric_columns": list(df.select_dtypes(include=[np.number]).columns),
"categorical_columns": list(df.select_dtypes(include=['object', 'category']).columns),
"sample": df.head(5).to_dict(orient='records'),
"missing_values": df.isnull().sum().to_dict()
}
return summary
def generate_demo_data():
"""Generates a robust demo dataset for retail sales."""
dates = pd.date_range(start='2024-01-01', periods=100)
categories = ['电子产品', '服装', '家居', '图书', '美妆']
data = {
'日期': dates,
'类别': np.random.choice(categories, 100),
'销售额': np.random.randint(100, 5000, 100),
'利润': np.random.randint(10, 1000, 100),
'客户满意度': np.random.uniform(1.0, 5.0, 100).round(1)
}
return pd.DataFrame(data)
# Global store for demo (in production use session or db)
current_df = None
@app.route('/')
def index():
return render_template('index.html')
@app.route('/api/load_demo', methods=['POST'])
def load_demo():
global current_df
current_df = generate_demo_data()
summary = analyze_dataframe(current_df)
return jsonify({"status": "success", "message": "Demo data loaded", "summary": summary})
@app.route('/api/upload', methods=['POST'])
def upload_file():
global current_df
if 'file' not in request.files:
return jsonify({"error": "No file part"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
if file:
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
try:
if filename.endswith('.csv'):
current_df = pd.read_csv(filepath)
elif filename.endswith('.xlsx'):
current_df = pd.read_excel(filepath)
else:
return jsonify({"error": "Unsupported file format"}), 400
summary = analyze_dataframe(current_df)
return jsonify({"status": "success", "summary": summary})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/api/chat', methods=['POST'])
def chat():
global current_df
if current_df is None:
return jsonify({"error": "No data loaded"}), 400
data = request.json
query = data.get('message', '').lower()
response = {
"text": "",
"chart": None
}
# Simple "Agent" Heuristics
numeric_cols = list(current_df.select_dtypes(include=[np.number]).columns)
cat_cols = list(current_df.select_dtypes(include=['object', 'category']).columns)
# Keyword matching for Chinese and English
is_trend = any(k in query for k in ["trend", "time", "趋势", "时间", "变化"])
is_dist = any(k in query for k in ["distribution", "bar", "breakdown", "分布", "柱状", "分类"])
is_corr = any(k in query for k in ["correlation", "相关性", "关系"])
if is_trend:
# Time series analysis
date_cols = [c for c in current_df.columns if any(k in c.lower() for k in ['date', 'time', '日期', '时间'])]
if date_cols and numeric_cols:
d_col = date_cols[0]
v_col = numeric_cols[0]
# Aggregate
try:
df_agg = current_df.groupby(d_col)[v_col].sum().reset_index()
# Convert to JSON for frontend
response["text"] = f"这是 {v_col} 随 {d_col} 变化的趋势图。数据已按时间聚合。"
response["chart"] = {
"type": "line",
"x": df_agg[d_col].astype(str).tolist(),
"y": df_agg[v_col].tolist(),
"label": v_col,
"title": f"{v_col} 趋势分析"
}
except Exception as e:
response["text"] = f"生成趋势图时出错: {str(e)}"
else:
response["text"] = "我没找到日期列来展示趋势。请尝试询问数据分布情况。"
elif is_dist:
if cat_cols and numeric_cols:
c_col = cat_cols[0]
v_col = numeric_cols[0]
try:
df_agg = current_df.groupby(c_col)[v_col].sum().reset_index()
response["text"] = f"这是 {v_col} 按 {c_col} 分组的分布情况。"
response["chart"] = {
"type": "bar",
"x": df_agg[c_col].tolist(),
"y": df_agg[v_col].tolist(),
"label": v_col,
"title": f"{c_col} - {v_col} 分布"
}
except Exception as e:
response["text"] = f"生成分布图时出错: {str(e)}"
else:
response["text"] = "我需要类别数据来展示分布。"
elif is_corr:
if len(numeric_cols) >= 2:
try:
corr = current_df[numeric_cols].corr()
response["text"] = "这是数值变量之间的相关性矩阵。"
# Simplified heatmap data structure
response["chart"] = {
"type": "heatmap",
"x": numeric_cols,
"y": numeric_cols,
"z": corr.values.tolist(),
"title": "相关性矩阵"
}
except Exception as e:
response["text"] = f"计算相关性时出错: {str(e)}"
else:
response["text"] = "数值列不足,无法进行相关性分析。"
else:
# General summary
stats = current_df.describe().to_html(classes="table table-sm")
response["text"] = f"我已分析数据集。以下是基本统计信息:
{stats}
您可以问我 '展示趋势'、'分析分布' 或 '显示相关性'。"
return jsonify(response)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True)