import os import json import pandas as pd import numpy as np from flask import Flask, render_template, request, jsonify from werkzeug.utils import secure_filename app = Flask(__name__) app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB limit app.config['UPLOAD_FOLDER'] = 'uploads' if not os.path.exists(app.config['UPLOAD_FOLDER']): os.makedirs(app.config['UPLOAD_FOLDER']) # --- Helper Functions --- def analyze_dataframe(df): """Generates a summary of the dataframe.""" summary = { "columns": list(df.columns), "row_count": len(df), "dtypes": {k: str(v) for k, v in df.dtypes.items()}, "numeric_columns": list(df.select_dtypes(include=[np.number]).columns), "categorical_columns": list(df.select_dtypes(include=['object', 'category']).columns), "sample": df.head(5).to_dict(orient='records'), "missing_values": df.isnull().sum().to_dict() } return summary def generate_demo_data(): """Generates a robust demo dataset for retail sales.""" dates = pd.date_range(start='2024-01-01', periods=100) categories = ['电子产品', '服装', '家居', '图书', '美妆'] data = { '日期': dates, '类别': np.random.choice(categories, 100), '销售额': np.random.randint(100, 5000, 100), '利润': np.random.randint(10, 1000, 100), '客户满意度': np.random.uniform(1.0, 5.0, 100).round(1) } return pd.DataFrame(data) # Global store for demo (in production use session or db) current_df = None @app.route('/') def index(): return render_template('index.html') @app.route('/api/load_demo', methods=['POST']) def load_demo(): global current_df current_df = generate_demo_data() summary = analyze_dataframe(current_df) return jsonify({"status": "success", "message": "Demo data loaded", "summary": summary}) @app.route('/api/upload', methods=['POST']) def upload_file(): global current_df if 'file' not in request.files: return jsonify({"error": "No file part"}), 400 file = request.files['file'] if file.filename == '': return jsonify({"error": "No selected file"}), 400 if file: filename = secure_filename(file.filename) filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(filepath) try: if filename.endswith('.csv'): current_df = pd.read_csv(filepath) elif filename.endswith('.xlsx'): current_df = pd.read_excel(filepath) else: return jsonify({"error": "Unsupported file format"}), 400 summary = analyze_dataframe(current_df) return jsonify({"status": "success", "summary": summary}) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route('/api/chat', methods=['POST']) def chat(): global current_df if current_df is None: return jsonify({"error": "No data loaded"}), 400 data = request.json query = data.get('message', '').lower() response = { "text": "", "chart": None } # Simple "Agent" Heuristics numeric_cols = list(current_df.select_dtypes(include=[np.number]).columns) cat_cols = list(current_df.select_dtypes(include=['object', 'category']).columns) # Keyword matching for Chinese and English is_trend = any(k in query for k in ["trend", "time", "趋势", "时间", "变化"]) is_dist = any(k in query for k in ["distribution", "bar", "breakdown", "分布", "柱状", "分类"]) is_corr = any(k in query for k in ["correlation", "相关性", "关系"]) if is_trend: # Time series analysis date_cols = [c for c in current_df.columns if any(k in c.lower() for k in ['date', 'time', '日期', '时间'])] if date_cols and numeric_cols: d_col = date_cols[0] v_col = numeric_cols[0] # Aggregate try: df_agg = current_df.groupby(d_col)[v_col].sum().reset_index() # Convert to JSON for frontend response["text"] = f"这是 {v_col} 随 {d_col} 变化的趋势图。数据已按时间聚合。" response["chart"] = { "type": "line", "x": df_agg[d_col].astype(str).tolist(), "y": df_agg[v_col].tolist(), "label": v_col, "title": f"{v_col} 趋势分析" } except Exception as e: response["text"] = f"生成趋势图时出错: {str(e)}" else: response["text"] = "我没找到日期列来展示趋势。请尝试询问数据分布情况。" elif is_dist: if cat_cols and numeric_cols: c_col = cat_cols[0] v_col = numeric_cols[0] try: df_agg = current_df.groupby(c_col)[v_col].sum().reset_index() response["text"] = f"这是 {v_col} 按 {c_col} 分组的分布情况。" response["chart"] = { "type": "bar", "x": df_agg[c_col].tolist(), "y": df_agg[v_col].tolist(), "label": v_col, "title": f"{c_col} - {v_col} 分布" } except Exception as e: response["text"] = f"生成分布图时出错: {str(e)}" else: response["text"] = "我需要类别数据来展示分布。" elif is_corr: if len(numeric_cols) >= 2: try: corr = current_df[numeric_cols].corr() response["text"] = "这是数值变量之间的相关性矩阵。" # Simplified heatmap data structure response["chart"] = { "type": "heatmap", "x": numeric_cols, "y": numeric_cols, "z": corr.values.tolist(), "title": "相关性矩阵" } except Exception as e: response["text"] = f"计算相关性时出错: {str(e)}" else: response["text"] = "数值列不足,无法进行相关性分析。" else: # General summary stats = current_df.describe().to_html(classes="table table-sm") response["text"] = f"我已分析数据集。以下是基本统计信息:
{stats}
您可以问我 '展示趋势'、'分析分布' 或 '显示相关性'。" return jsonify(response) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860, debug=True)