File size: 6,731 Bytes
273c668
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os
import json
import pandas as pd
import numpy as np
from flask import Flask, render_template, request, jsonify
from werkzeug.utils import secure_filename

app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024  # 16MB limit
app.config['UPLOAD_FOLDER'] = 'uploads'

if not os.path.exists(app.config['UPLOAD_FOLDER']):
    os.makedirs(app.config['UPLOAD_FOLDER'])

# --- Helper Functions ---
def analyze_dataframe(df):
    """Generates a summary of the dataframe."""
    summary = {
        "columns": list(df.columns),
        "row_count": len(df),
        "dtypes": {k: str(v) for k, v in df.dtypes.items()},
        "numeric_columns": list(df.select_dtypes(include=[np.number]).columns),
        "categorical_columns": list(df.select_dtypes(include=['object', 'category']).columns),
        "sample": df.head(5).to_dict(orient='records'),
        "missing_values": df.isnull().sum().to_dict()
    }
    return summary

def generate_demo_data():
    """Generates a robust demo dataset for retail sales."""
    dates = pd.date_range(start='2024-01-01', periods=100)
    categories = ['电子产品', '服装', '家居', '图书', '美妆']
    data = {
        '日期': dates,
        '类别': np.random.choice(categories, 100),
        '销售额': np.random.randint(100, 5000, 100),
        '利润': np.random.randint(10, 1000, 100),
        '客户满意度': np.random.uniform(1.0, 5.0, 100).round(1)
    }
    return pd.DataFrame(data)

# Global store for demo (in production use session or db)
current_df = None

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/api/load_demo', methods=['POST'])
def load_demo():
    global current_df
    current_df = generate_demo_data()
    summary = analyze_dataframe(current_df)
    return jsonify({"status": "success", "message": "Demo data loaded", "summary": summary})

@app.route('/api/upload', methods=['POST'])
def upload_file():
    global current_df
    if 'file' not in request.files:
        return jsonify({"error": "No file part"}), 400
    file = request.files['file']
    if file.filename == '':
        return jsonify({"error": "No selected file"}), 400
    
    if file:
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(filepath)
        
        try:
            if filename.endswith('.csv'):
                current_df = pd.read_csv(filepath)
            elif filename.endswith('.xlsx'):
                current_df = pd.read_excel(filepath)
            else:
                return jsonify({"error": "Unsupported file format"}), 400
                
            summary = analyze_dataframe(current_df)
            return jsonify({"status": "success", "summary": summary})
        except Exception as e:
            return jsonify({"error": str(e)}), 500

@app.route('/api/chat', methods=['POST'])
def chat():
    global current_df
    if current_df is None:
        return jsonify({"error": "No data loaded"}), 400
        
    data = request.json
    query = data.get('message', '').lower()
    
    response = {
        "text": "",
        "chart": None
    }
    
    # Simple "Agent" Heuristics
    numeric_cols = list(current_df.select_dtypes(include=[np.number]).columns)
    cat_cols = list(current_df.select_dtypes(include=['object', 'category']).columns)
    
    # Keyword matching for Chinese and English
    is_trend = any(k in query for k in ["trend", "time", "趋势", "时间", "变化"])
    is_dist = any(k in query for k in ["distribution", "bar", "breakdown", "分布", "柱状", "分类"])
    is_corr = any(k in query for k in ["correlation", "相关性", "关系"])
    
    if is_trend:
        # Time series analysis
        date_cols = [c for c in current_df.columns if any(k in c.lower() for k in ['date', 'time', '日期', '时间'])]
        if date_cols and numeric_cols:
            d_col = date_cols[0]
            v_col = numeric_cols[0]
            # Aggregate
            try:
                df_agg = current_df.groupby(d_col)[v_col].sum().reset_index()
                # Convert to JSON for frontend
                response["text"] = f"这是 {v_col}{d_col} 变化的趋势图。数据已按时间聚合。"
                response["chart"] = {
                    "type": "line",
                    "x": df_agg[d_col].astype(str).tolist(),
                    "y": df_agg[v_col].tolist(),
                    "label": v_col,
                    "title": f"{v_col} 趋势分析"
                }
            except Exception as e:
                response["text"] = f"生成趋势图时出错: {str(e)}"
        else:
            response["text"] = "我没找到日期列来展示趋势。请尝试询问数据分布情况。"

    elif is_dist:
        if cat_cols and numeric_cols:
            c_col = cat_cols[0]
            v_col = numeric_cols[0]
            try:
                df_agg = current_df.groupby(c_col)[v_col].sum().reset_index()
                response["text"] = f"这是 {v_col}{c_col} 分组的分布情况。"
                response["chart"] = {
                    "type": "bar",
                    "x": df_agg[c_col].tolist(),
                    "y": df_agg[v_col].tolist(),
                    "label": v_col,
                    "title": f"{c_col} - {v_col} 分布"
                }
            except Exception as e:
                response["text"] = f"生成分布图时出错: {str(e)}"
        else:
            response["text"] = "我需要类别数据来展示分布。"
            
    elif is_corr:
        if len(numeric_cols) >= 2:
            try:
                corr = current_df[numeric_cols].corr()
                response["text"] = "这是数值变量之间的相关性矩阵。"
                # Simplified heatmap data structure
                response["chart"] = {
                    "type": "heatmap",
                    "x": numeric_cols,
                    "y": numeric_cols,
                    "z": corr.values.tolist(),
                    "title": "相关性矩阵"
                }
            except Exception as e:
                 response["text"] = f"计算相关性时出错: {str(e)}"
        else:
            response["text"] = "数值列不足,无法进行相关性分析。"
            
    else:
        # General summary
        stats = current_df.describe().to_html(classes="table table-sm")
        response["text"] = f"我已分析数据集。以下是基本统计信息:<br>{stats}<br>您可以问我 '展示趋势'、'分析分布' 或 '显示相关性'。"
        
    return jsonify(response)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860, debug=True)