ElvisWang111 commited on
Commit
d235bdf
·
verified ·
1 Parent(s): 07f0686

Upload folder using huggingface_hub

Browse files
Files changed (43) hide show
  1. Dockerfile +2 -2
  2. app.py +215 -0
  3. components/__init__.py +0 -0
  4. logo/logo_16_9.png +3 -0
  5. logo/logo_big.png +3 -0
  6. logo/logo_blue_wide.png +3 -0
  7. logo/logo_wide.png +3 -0
  8. logo/sec3//346/212/230/347/272/277/345/233/276.png +0 -0
  9. logo/sec3//347/233/264/346/226/271/345/233/276.png +0 -0
  10. logo/sec3//347/256/261/347/272/277/345/233/276.png +0 -0
  11. logo/sec3//351/245/274/345/233/276.png +0 -0
  12. prompt_engineer/.DS_Store +0 -0
  13. prompt_engineer/call_llm.py +144 -0
  14. prompt_engineer/planner.py +177 -0
  15. prompt_engineer/sec1_call_llm.py +248 -0
  16. prompt_engineer/sec2_call_llm.py +374 -0
  17. prompt_engineer/sec3_call_llm.py +691 -0
  18. prompt_engineer/sec4_call_llm.py +606 -0
  19. prompt_engineer/sec5_call_llm.py +617 -0
  20. utils/content.py +13 -0
  21. utils/sanitize_code.py +47 -0
  22. utils/save_secrets.py +33 -0
  23. utils/spinner_pool.py +25 -0
  24. workflow/.DS_Store +0 -0
  25. workflow/dataloading/dataloading_core.py +287 -0
  26. workflow/dataloading/dataloading_render.py +210 -0
  27. workflow/modeling/model_inference.py +102 -0
  28. workflow/modeling/model_training.py +143 -0
  29. workflow/modeling/modeling_render.py +218 -0
  30. workflow/preprocessing/preprocessing_core.py +112 -0
  31. workflow/preprocessing/preprocessing_render.py +159 -0
  32. workflow/report/report_core.py +46 -0
  33. workflow/report/report_html.py +117 -0
  34. workflow/report/report_markdown.py +55 -0
  35. workflow/report/report_prepare_er.py +102 -0
  36. workflow/report/report_render.py +243 -0
  37. workflow/report/report_utils.py +59 -0
  38. workflow/report/report_word.py +89 -0
  39. workflow/visualization/viz_coding.py +110 -0
  40. workflow/visualization/viz_color.py +58 -0
  41. workflow/visualization/viz_quick_action.py +23 -0
  42. workflow/visualization/viz_render.py +192 -0
  43. workflow/visualization/viz_suggestion.py +38 -0
Dockerfile CHANGED
@@ -11,10 +11,10 @@ RUN pip install --no-cache-dir --upgrade pip setuptools wheel \
11
  && pip install --no-cache-dir -r requirements.txt
12
 
13
  # ========= 拷贝项目文件 =========
14
- COPY tmp/ ./tmp/
15
 
16
  # ========= 暴露 Streamlit 端口 =========
17
  EXPOSE 8501
18
 
19
  # ========= 启动命令 =========
20
- CMD ["streamlit", "run", "tmp/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
11
  && pip install --no-cache-dir -r requirements.txt
12
 
13
  # ========= 拷贝项目文件 =========
14
+ # COPY tmp/ ./tmp/
15
 
16
  # ========= 暴露 Streamlit 端口 =========
17
  EXPOSE 8501
18
 
19
  # ========= 启动命令 =========
20
+ CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ import tempfile
3
+ import streamlit as st
4
+
5
+ from config import MODEL_CONFIGS
6
+ from utils.save_secrets import *
7
+ from prompt_engineer.sec1_call_llm import DataLoadingAgent
8
+ from prompt_engineer.sec2_call_llm import DataPreprocessAgent
9
+ from prompt_engineer.sec3_call_llm import VisualizationAgent
10
+ from prompt_engineer.sec4_call_llm import ModelingCodingAgent
11
+ from prompt_engineer.sec5_call_llm import ReportAgent
12
+ from prompt_engineer.planner import PlannerAgent
13
+
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+ warnings.filterwarnings("ignore", message="missing ScriptRunContext")
17
+
18
+ import numpy as np
19
+ np.set_printoptions(edgeitems=250, threshold=501)
20
+
21
+ sys.path.append(os.path.dirname(__file__))
22
+
23
+ CACHE_FILE = os.path.join(tempfile.gettempdir(), "anystat_cache.pkl")
24
+ CACHE_DIR = './cache'
25
+ SECRETS_PATH = Path(".streamlit") / "secrets.toml"
26
+
27
+
28
+ # 设置页面配置
29
+ st.set_page_config(
30
+ page_title="AnyStat",
31
+ page_icon="🤖",
32
+ layout="wide"
33
+ )
34
+
35
+
36
+ def init_session_state():
37
+
38
+ if 'selected_model' not in st.session_state:
39
+ st.session_state.selected_model = "DeepSeek"
40
+ if "api_keys" not in st.session_state:
41
+ st.session_state.api_keys = load_local_api_keys()
42
+ if 'auto_mode' not in st.session_state:
43
+ st.session_state.auto_mode = False
44
+
45
+ if 'loading_start_time' not in st.session_state:
46
+ st.session_state.loading_start_time = None
47
+ if 'prep_start_time' not in st.session_state:
48
+ st.session_state.prep_start_time = None
49
+ if 'vis_start_time' not in st.session_state:
50
+ st.session_state.vis_start_time = None
51
+ if 'modeling_start_time' not in st.session_state:
52
+ st.session_state.modeling_start_time = None
53
+ if 'report_start_time' not in st.session_state:
54
+ st.session_state.report_start_time = None
55
+
56
+ if 'data_loading_agent' not in st.session_state:
57
+ st.session_state.data_loading_agent = DataLoadingAgent(
58
+ api_keys=st.session_state.api_keys,
59
+ model_configs=MODEL_CONFIGS,
60
+ model=st.session_state.selected_model
61
+ )
62
+ if 'data_preprocess_agent' not in st.session_state:
63
+ st.session_state.data_preprocess_agent = DataPreprocessAgent(
64
+ api_keys=st.session_state.api_keys,
65
+ model_configs=MODEL_CONFIGS,
66
+ model=st.session_state.selected_model
67
+ )
68
+ if 'visualization_agent' not in st.session_state:
69
+ st.session_state.visualization_agent = VisualizationAgent(
70
+ api_keys=st.session_state.api_keys,
71
+ model_configs=MODEL_CONFIGS,
72
+ model=st.session_state.selected_model
73
+ )
74
+ if 'modeling_coding_agent' not in st.session_state:
75
+ st.session_state.modeling_coding_agent = ModelingCodingAgent(
76
+ api_keys=st.session_state.api_keys,
77
+ model_configs=MODEL_CONFIGS,
78
+ model=st.session_state.selected_model
79
+ )
80
+ if 'report_agent' not in st.session_state:
81
+ st.session_state.report_agent = ReportAgent(
82
+ api_keys=st.session_state.api_keys,
83
+ model_configs=MODEL_CONFIGS,
84
+ model=st.session_state.selected_model
85
+ )
86
+ if 'planner_agent' not in st.session_state:
87
+ st.session_state.planner_agent = PlannerAgent(
88
+ api_keys=st.session_state.api_keys,
89
+ model_configs=MODEL_CONFIGS,
90
+ model=st.session_state.selected_model
91
+ )
92
+
93
+
94
+ def on_model_selector_change():
95
+ """
96
+ Callback when the model selector in the sidebar changes.
97
+ """
98
+ st.session_state.selected_model = st.session_state.model_selector
99
+
100
+
101
+ def run_app():
102
+ """
103
+ Main entry point to render the Streamlit app.
104
+ """
105
+ init_session_state()
106
+ with st.sidebar:
107
+ st.subheader("选择大模型")
108
+ models = list(MODEL_CONFIGS.keys())
109
+ st.selectbox(
110
+ "选择要使用的大模型",
111
+ models,
112
+ index=models.index(st.session_state.selected_model),
113
+ key="model_selector",
114
+ on_change=on_model_selector_change,
115
+ )
116
+
117
+ st.subheader("API 密钥设置")
118
+ selected = st.session_state.selected_model
119
+
120
+ api_key_input = st.text_input(
121
+ f"{selected} API 密钥",
122
+ value=st.session_state.api_keys.get(selected, ""),
123
+ type="password",
124
+ key="api_key_input",
125
+ )
126
+
127
+
128
+ if st.button("💾 保存密钥", use_container_width=True, key="save_key"):
129
+ # 保存在 utils/.streamlit/secrets.toml
130
+ update_local_api_key(selected, api_key_input)
131
+
132
+ st.session_state.api_keys[selected] = api_key_input
133
+ st.success("已保存")
134
+ st.rerun()
135
+
136
+ if st.button("🧹 清空数据", use_container_width=True, key="clear_data"):
137
+
138
+ st.session_state.data_loading_agent = DataLoadingAgent(
139
+ api_keys=st.session_state.api_keys,
140
+ model_configs=MODEL_CONFIGS,
141
+ model=st.session_state.selected_model
142
+ )
143
+ st.session_state.data_preprocess_agent = DataPreprocessAgent(
144
+ api_keys=st.session_state.api_keys,
145
+ model_configs=MODEL_CONFIGS,
146
+ model=st.session_state.selected_model
147
+ )
148
+ st.session_state.visualization_agent = VisualizationAgent(
149
+ api_keys=st.session_state.api_keys,
150
+ model_configs=MODEL_CONFIGS,
151
+ model=st.session_state.selected_model
152
+ )
153
+ st.session_state.modeling_coding_agent = ModelingCodingAgent(
154
+ api_keys=st.session_state.api_keys,
155
+ model_configs=MODEL_CONFIGS,
156
+ model=st.session_state.selected_model
157
+ )
158
+ st.session_state.report_agent = ReportAgent(
159
+ api_keys=st.session_state.api_keys,
160
+ model_configs=MODEL_CONFIGS,
161
+ model=st.session_state.selected_model
162
+ )
163
+ st.session_state.planner_agent = PlannerAgent(
164
+ api_keys=st.session_state.api_keys,
165
+ model_configs=MODEL_CONFIGS,
166
+ model=st.session_state.selected_model
167
+ )
168
+ st.session_state.auto_mode = False
169
+ st.rerun()
170
+
171
+ if st.session_state.data_loading_agent.load_df() is not None:
172
+ planner = st.session_state.planner_agent
173
+ if st.button("🚗 自动模式", use_container_width=True, key="self_driving"):
174
+ planner.self_driving(st.session_state.data_loading_agent.load_df())
175
+ st.session_state.auto_mode = True
176
+ st.rerun()
177
+
178
+ st.image(
179
+ "logo/logo_big.png",
180
+ use_container_width=True
181
+ )
182
+
183
+ # Define pages
184
+ data_loading = st.Page(
185
+ "workflow/dataloading/dataloading_render.py",
186
+ title="📥 数据导入",
187
+ )
188
+ preprocessing = st.Page(
189
+ "workflow/preprocessing/preprocessing_render.py",
190
+ title="⚙️ 数据预处理",
191
+ )
192
+ visualization = st.Page(
193
+ "workflow/visualization/viz_render.py",
194
+ title="📊 数据可视化",
195
+ )
196
+ report = st.Page(
197
+ "workflow/report/report_render.py",
198
+ title="📝 报告生成",
199
+ )
200
+ coding_modeling = st.Page(
201
+ "workflow/modeling/modeling_render.py",
202
+ title="🧠 建模分析",
203
+ )
204
+ # Navigation
205
+ pg = st.navigation(
206
+ {
207
+ "设置": [data_loading, preprocessing],
208
+ "功能": [visualization, coding_modeling, report],
209
+ }
210
+ )
211
+ pg.run()
212
+
213
+ if __name__ == "__main__":
214
+ run_app()
215
+
components/__init__.py ADDED
File without changes
logo/logo_16_9.png ADDED

Git LFS Details

  • SHA256: d9d657c0e416eed4a69ac6da7b7a271239c26e185a0cd778f2072b24db594cf6
  • Pointer size: 131 Bytes
  • Size of remote file: 327 kB
logo/logo_big.png ADDED

Git LFS Details

  • SHA256: 389ffb2d5eec47539b6aee2ef89d4949d3bfa2d94d16c2d7198bd7ef394beb59
  • Pointer size: 131 Bytes
  • Size of remote file: 326 kB
logo/logo_blue_wide.png ADDED

Git LFS Details

  • SHA256: 3b28da164d5e95630bb79aea3e35fc71bfd9f556520c067194d0935bc918f036
  • Pointer size: 131 Bytes
  • Size of remote file: 784 kB
logo/logo_wide.png ADDED

Git LFS Details

  • SHA256: af377c54885f9d9fddaf3632f9cdca087e931a9da509a7dc9908eed858057f0a
  • Pointer size: 131 Bytes
  • Size of remote file: 325 kB
logo/sec3//346/212/230/347/272/277/345/233/276.png ADDED
logo/sec3//347/233/264/346/226/271/345/233/276.png ADDED
logo/sec3//347/256/261/347/272/277/345/233/276.png ADDED
logo/sec3//351/245/274/345/233/276.png ADDED
prompt_engineer/.DS_Store ADDED
Binary file (6.15 kB). View file
 
prompt_engineer/call_llm.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from openai import OpenAI, OpenAIError
3
+ from anthropic import Anthropic, AnthropicError
4
+ import requests
5
+ import json
6
+
7
+ import streamlit as st
8
+ import pandas as pd
9
+ import numpy as np
10
+ from config import MODEL_CONFIGS
11
+ from typing import IO, List, Dict
12
+ from zai import ZhipuAiClient
13
+
14
+ class LLMClient:
15
+ def __init__(self, model_configs: dict, api_keys: dict, model: str):
16
+
17
+ self.model = model
18
+ self.model_configs = model_configs
19
+ self.api_keys = api_keys
20
+ self.memory = []
21
+ self.df = None
22
+
23
+ def call(self, prompt) -> str:
24
+
25
+ model_name = st.session_state.selected_model
26
+ config = self.model_configs.get(model_name, {})
27
+ api_key = self.api_keys.get(model_name)
28
+
29
+ if not api_key:
30
+ return "请先在设置中配置 API 密钥"
31
+
32
+ system_msg = (
33
+ "你是一个专业的数据分析助手。"
34
+ )
35
+
36
+ try:
37
+ if model_name == "GPT-4o" or model_name == "GPT-5" or model_name == "DeepSeek" or model_name == "通义千问" or model_name == "Claude" or model_name == "豆包":
38
+ try:
39
+ client = OpenAI(
40
+ api_key=api_key,
41
+ base_url=config["api_base"]
42
+ )
43
+
44
+ # 使用新的 API 调用方式
45
+ resp = client.chat.completions.create(
46
+ model=config["model_name"],
47
+ messages=[
48
+ {"role": "system", "content": system_msg},
49
+ {"role": "user", "content": prompt},
50
+ ],
51
+ stream = False
52
+ )
53
+ return resp.choices[0].message.content
54
+
55
+ except OpenAIError as e:
56
+ # 这里可以捕获所有OpenAI SDK定义的错误
57
+ st.error(f"API调用失败: {str(e)}")
58
+ # 记录日志或提示用户
59
+ return "调用失败,请检查密钥或网络"
60
+ except Exception as e:
61
+ # 捕获其他非预期的异常,如网络问题
62
+ st.error(f"发生未知错误: {str(e)}")
63
+ return "发生未知错误"
64
+
65
+ elif model_name == "智谱AI":
66
+ client = ZhipuAiClient(api_key=api_key)
67
+ response = client.chat.completions.create(
68
+ model=config["model_name"],
69
+ messages=[{"role": "system", "content": "你是一个专业的数据分析助手。"},
70
+ {"role": "user", "content": prompt}],
71
+ thinking={
72
+ "type":"enabled"
73
+ }
74
+ )
75
+ if response:
76
+ print(response.choices[0].message)
77
+ desc = response.choices[0].message.content if hasattr(response.choices[0].message, "content") else str(response.choices[0].message)
78
+ return desc.replace("<|begin_of_box|>", "").replace("<|end_of_box|>", "").strip()
79
+
80
+ st.error(f"智谱调用失败:{response.text}")
81
+ return "调用失败,请检查密钥或网络"
82
+
83
+
84
+ # elif model_name == "DeepSeek":
85
+ # client = OpenAI(
86
+ # api_key=api_key,
87
+ # base_url=config["api_base"])
88
+
89
+ # resp = client.chat.completions.create(
90
+ # model=config["model_name"],
91
+ # messages=[
92
+ # {"role": "system", "content": system_msg},
93
+ # {"role": "user", "content": prompt},
94
+ # ],
95
+ # stream=False
96
+ # )
97
+ # if resp:
98
+ # return resp.choices[0].message.content
99
+ # st.error(f"DeepSeek调用失败:{resp.text}")
100
+ # return "调用失败,请检查密钥或网络"
101
+
102
+ else:
103
+ return f"暂不支持模型:{model_name}"
104
+
105
+ except Exception as e:
106
+ st.error(f"{model_name} 调用异常:{e}")
107
+ return "大模型调用失败,请检查 API 密钥或网络连接"
108
+
109
+
110
+ def add_memory(self, entry: Dict[str, str]) -> None:
111
+
112
+ self.memory.append(entry)
113
+
114
+
115
+ def load_memory(self) -> List[Dict[str, str]]:
116
+
117
+ return self.memory
118
+
119
+
120
+ def clear_memory(self) -> None:
121
+
122
+ self.memory.clear()
123
+
124
+
125
+ def add_df(self, input_df) -> None:
126
+
127
+
128
+
129
+ self.df = input_df
130
+
131
+
132
+ def load_df(self) -> pd.DataFrame:
133
+
134
+ return self.df
135
+
136
+
137
+ def clear_df(self) -> None:
138
+
139
+ self.df = None
140
+
141
+
142
+ def has_df(self) -> bool:
143
+
144
+ return self.df == None
prompt_engineer/planner.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+
4
+ import streamlit as st
5
+ from typing import IO, List
6
+
7
+ from prompt_engineer.call_llm import LLMClient
8
+
9
+
10
+ class PlannerAgent(LLMClient):
11
+
12
+ def __init__(self, *args, **kwargs):
13
+
14
+ super().__init__(*args, **kwargs)
15
+ self.loading_auto = False
16
+ self.prep_auto = False
17
+ self.vis_auto = False
18
+ self.modeling_auto = False
19
+ self.report_auto = False
20
+
21
+ self.switched_loading = False
22
+ self.switched_prep = False
23
+ self.switched_vis = False
24
+ self.switched_modeling = False
25
+ self.switched_report = False
26
+
27
+ def self_driving(self, df, user_input=None) -> str:
28
+
29
+ prompt = (
30
+ f"下面是一个数据集的基本信息,请你根据它和用户的需求,判断需要开启哪些分析步骤:\n\n"
31
+ f"- 数据维度:{df.shape[0]} 行 × {df.shape[1]} 列\n"
32
+ f"- 列名和数据类型:{dict(zip(df.columns.tolist(), df.dtypes.astype(str).tolist()))}\n"
33
+ f"- 前 5 行样本:\n{df.head().to_dict(orient='list')}\n\n"
34
+ )
35
+
36
+ if user_input:
37
+ prompt += f"用户的具体需求是:“{user_input}”。\n\n"
38
+
39
+ prompt += """
40
+ 你需要在以下 5 个步骤中,对每个步骤分别判断是否应该开启(True / False):
41
+ 1. loading_auto —— 是否需要对数据列名进行初步分析?
42
+ 2. prep_auto —— 是否需要做数据预处理或清洗?
43
+ 3. vis_auto —— 是否需要做数据可视化?
44
+ 4. modeling_auto —— 是否需要建模或统计分析?
45
+ 5. report_auto —— 是否需要生成分析报告?
46
+
47
+ 必须以 **JSON 格式** 输出你的判断结果,如:
48
+ {
49
+ "loading_auto": true,
50
+ "prep_auto": false,
51
+ "vis_auto": true,
52
+ "modeling_auto": true,
53
+ "report_auto": true
54
+ }
55
+
56
+ 不要输出其他内容。
57
+ """
58
+
59
+ plan_text = self.call(prompt)
60
+ print(plan_text)
61
+ try:
62
+ plan_dict = json.loads(plan_text)
63
+ except json.JSONDecodeError:
64
+ plan_text_fixed = plan_text.strip().strip('```json').strip('```')
65
+ plan_dict = json.loads(plan_text_fixed)
66
+
67
+ print(plan_dict)
68
+ self.loading_auto = bool(plan_dict.get("loading_auto", False))
69
+ self.prep_auto = bool(plan_dict.get("prep_auto", False))
70
+ self.vis_auto = bool(plan_dict.get("vis_auto", False))
71
+ self.modeling_auto = bool(plan_dict.get("modeling_auto", False))
72
+ # self.modeling_auto = False
73
+ self.report_auto = bool(plan_dict.get("report_auto", False))
74
+
75
+
76
+ def finish_loading_auto(self) -> str:
77
+
78
+ self.switched_loading = True
79
+
80
+
81
+ def finish_prep_auto(self) -> str:
82
+
83
+ self.switched_prep = True
84
+
85
+
86
+ def finish_vis_auto(self) -> str:
87
+
88
+ self.switched_vis = True
89
+
90
+
91
+ def finish_modeling_auto(self) -> str:
92
+
93
+ self.switched_modeling = True
94
+
95
+
96
+ def finish_report_auto(self) -> str:
97
+
98
+ self.switched_report = True
99
+
100
+
101
+ import json
102
+ import ast
103
+ import re
104
+ import traceback
105
+
106
+ def _extract_first_json(text: str):
107
+ """从 text 中提取第一个顶层花括号 JSON 子串(用配对计数法),找不到则返回 None。"""
108
+ if not text:
109
+ return None
110
+ start = text.find('{')
111
+ if start == -1:
112
+ return None
113
+ depth = 0
114
+ for i in range(start, len(text)):
115
+ ch = text[i]
116
+ if ch == '{':
117
+ depth += 1
118
+ elif ch == '}':
119
+ depth -= 1
120
+ if depth == 0:
121
+ return text[start:i+1]
122
+ return None
123
+
124
+ def _safe_parse_json(text: str):
125
+ """
126
+ 尝试多种策略解析 LLM 输出为 dict:
127
+ 1) 直接 json.loads
128
+ 2) 去除 Markdown code fence 后再 loads
129
+ 3) 提取第一个完整花括号块后 loads
130
+ 4) ast.literal_eval 作为最后手段(接受 Python dict 风格)
131
+ 返回 (dict_or_None, used_text, error_message_or_None)
132
+ """
133
+ if not text or not text.strip():
134
+ return None, text, "empty"
135
+ # 1) 直接尝试
136
+ try:
137
+ return json.loads(text), text, None
138
+ except Exception as e1:
139
+ pass
140
+
141
+ # 2) 去掉 ```json / ``` fence
142
+ try:
143
+ cleaned = re.sub(r'```json\s*', '', text, flags=re.IGNORECASE)
144
+ cleaned = re.sub(r'```', '', cleaned)
145
+ cleaned = cleaned.strip()
146
+ return json.loads(cleaned), cleaned, None
147
+ except Exception:
148
+ pass
149
+
150
+ # 3) 提取首个匹配的 { ... } 顶层块
151
+ try:
152
+ sub = _extract_first_json(text)
153
+ if sub:
154
+ return json.loads(sub), sub, None
155
+ except Exception:
156
+ pass
157
+
158
+ # 4) ast.literal_eval 兼容 Python 字典格式(单引号等)
159
+ try:
160
+ literal = ast.literal_eval(text)
161
+ if isinstance(literal, dict):
162
+ return literal, text, None
163
+ except Exception:
164
+ pass
165
+
166
+ # 5) 再次尝试在提取的子串上用 literal_eval(防止单引号)
167
+ try:
168
+ sub = _extract_first_json(text)
169
+ if sub:
170
+ literal = ast.literal_eval(sub)
171
+ if isinstance(literal, dict):
172
+ return literal, sub, None
173
+ except Exception:
174
+ pass
175
+
176
+ # 最后,返回 None 并带上错误信息
177
+ return None, text, "unable_to_parse"
prompt_engineer/sec1_call_llm.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import streamlit as st
4
+ from typing import IO, List
5
+
6
+ from prompt_engineer.call_llm import LLMClient
7
+
8
+
9
+ class DataLoadingAgent(LLMClient):
10
+
11
+ def __init__(self, *args, **kwargs):
12
+
13
+ super().__init__(*args, **kwargs)
14
+ self.file_name = []
15
+ self.user_input = None
16
+ self.par_content = ""
17
+ self.dfs = None
18
+ self.abstract=None
19
+ self.full = None
20
+ self.finish_auto_task = False
21
+
22
+
23
+ def finish_auto(self):
24
+
25
+ self.finish_auto_task = True
26
+
27
+
28
+ def save_file_name(self, file_name):
29
+
30
+ self.file_name.append(file_name)
31
+
32
+
33
+ def load_file_name(self):
34
+
35
+ return self.file_name
36
+
37
+
38
+ def save_dfs(self, dfs):
39
+
40
+ self.dfs = (dfs)
41
+
42
+
43
+ def load_dfs(self):
44
+
45
+ return self.dfs
46
+
47
+
48
+ def clear_file_name(self):
49
+
50
+ self.file_name = []
51
+
52
+
53
+ def read_names_from_file(self, uploaded_names_file, df_head):
54
+ """
55
+ 从上传的 .names/.arff 文件中提取属性名。
56
+ 优先使用 LLM 识别 @attribute 行中的属性名;如果 LLM 调用失败,退回到正则解析。
57
+ """
58
+
59
+ raw = uploaded_names_file.read().decode('utf-8', errors='ignore')
60
+ try:
61
+ uploaded_names_file.seek(0)
62
+ except Exception:
63
+ pass
64
+
65
+ prompt = (
66
+ "下面是上传的 names 和 df_head 文件内容,请仅以 Python 列表格式返回与df_head一一对应的所有属性(attribute)名称,"
67
+ "并保持顺序,不要添加多余文字,请注意,你只需要返回一个列表,不要出现任何markdown语法:\n```\n"
68
+ f"name文件:{raw}\n```"
69
+ f"df_head:{df_head}\n```"
70
+ )
71
+ try:
72
+ response = self.call(prompt)
73
+ names_list = eval(response.strip())
74
+ if isinstance(names_list, list) and all(isinstance(n, str) for n in names_list):
75
+ col_names = names_list
76
+ else:
77
+ raise ValueError("LLM 输出格式不正确")
78
+ except Exception:
79
+
80
+ col_names = []
81
+ attr_re = re.compile(
82
+ r"""^@attribute\s+
83
+ ['"]?([^'"\s]+)['"]?
84
+ \s+.+
85
+ """,
86
+ re.IGNORECASE | re.VERBOSE
87
+ )
88
+ for line in raw.splitlines():
89
+ line = line.strip()
90
+ if not line:
91
+ continue
92
+ if line.lower().startswith('@data'):
93
+ break
94
+ m = attr_re.match(line)
95
+ if m:
96
+ col_names.append(m.group(1))
97
+
98
+ counts: dict[str, int] = {}
99
+ unique_names: List[str] = []
100
+ for name in col_names:
101
+ if name in counts:
102
+ counts[name] += 1
103
+ unique_names.append(f"{name}_{counts[name]}")
104
+ else:
105
+ counts[name] = 0
106
+ unique_names.append(name)
107
+
108
+ return unique_names
109
+
110
+
111
+ def do_data_description(self, df, user_input=None, memory_limit=6):
112
+
113
+ recent_memory = self.memory[-memory_limit:] if self.memory else []
114
+ if recent_memory:
115
+ formatted_memory = "\n".join(
116
+ f"{m['role']}: {m['content']}" for m in recent_memory
117
+ )
118
+ memory_block = f"{formatted_memory}"
119
+ else:
120
+ memory_block = ""
121
+
122
+ prompt = (
123
+ "你是一名专业的数据分析助手,负责解释数据结构与业务含义。\n"
124
+ f"- 数据维度:{df.shape[0]} 行 × {df.shape[1]} 列\n"
125
+ f"- 列名和数据类型:{dict(zip(df.columns.tolist(), df.dtypes.astype(str).tolist()))}\n"
126
+ f"- 前 5 行样本:\n{df.head().to_dict(orient='list')}\n\n"
127
+ f"""- 数据解释聊天对话:
128
+ --- 开始聊天记录 ---
129
+ {memory_block}
130
+ --- 结束聊天记录 ---"""
131
+ )
132
+
133
+ if user_input is not None:
134
+ prompt += f"""
135
+ 请严格依据用户需求“{user_input}”,对当前数据进行深入、系统的分析。
136
+ 要求:
137
+ 1. 分析内容必须与该需求完全对应,不能添加无关推断。
138
+ 2. 结论要具体、清晰,可直接支持后续报告撰写或建模步骤。
139
+ 3. 分析语言应专业、简洁,不使用模糊或情绪化表述。
140
+ """
141
+ else:
142
+ prompt += """
143
+ 以下是一个数据集的基本概览。请帮助我分析它的性质和结构,并回答以下问题:
144
+
145
+ 1. 该数据集可能来源于什么业务或研究场景?
146
+ 2. 各主要字段分别代表什么含义?若能判断,请说明其单位或数值含义。
147
+ 3. 数据中是否存在明显异常、异常分布或需要注意的特征?
148
+
149
+ 输出要求:
150
+ - 使用自然、流畅的中文描述;
151
+ - 采用清晰的分条结构(1、2、3);
152
+ - 语言客观简洁,不使用“可能”“也许”“似乎”等模糊词;
153
+ - 重点突出数据结构、含义与潜在问题。
154
+ """
155
+
156
+ desc = self.call(prompt)
157
+
158
+ return desc
159
+
160
+
161
+ def summary_html(self):
162
+
163
+ df = self.load_df()
164
+ df_head = df.head()
165
+ dtype_info = df.dtypes.astype(str)
166
+
167
+ prompt = f"""
168
+ 你正在撰写一份数据分析报告的第一章——《数据概览与数据含义分析》。
169
+ 请根据以下输入内容,整理关键信息并进行分析说明:
170
+ 数据格式:
171
+ {dtype_info}
172
+
173
+ 前五行数据:
174
+ {df_head}
175
+
176
+ 数据解释聊天对话:
177
+ --- 开始聊天记录 ---
178
+ {self.memory}
179
+ --- 结束聊天记录 ---
180
+
181
+ 额外要求:
182
+ 1. 要用流畅的自然语言
183
+ 2. 不要滥用形容词和副词,尽量用简单的动词和名词表达意思
184
+ 3. 不用"可能""也许""似乎""微妙"等模糊表述
185
+ """.strip()
186
+
187
+ desc = self.call(prompt)
188
+
189
+ summary = {
190
+ "title": "数据导入",
191
+ "df": df_head,
192
+ "desc": desc,
193
+ }
194
+
195
+ return summary
196
+
197
+
198
+ def summary_word(self):
199
+
200
+ return self.summary_html()
201
+
202
+
203
+ def check_abstract(self):
204
+
205
+ if self.abstract is None:
206
+ df = self.load_df()
207
+ df_head = df.head()
208
+ dtype_info = df.dtypes.astype(str)
209
+
210
+ prompt = f"""
211
+ 这是数据分析的数据导入阶段
212
+ 数据格式:
213
+ {dtype_info}
214
+
215
+ 前五行数据:
216
+ {df_head}
217
+
218
+ 数据解释聊天对话:
219
+ --- 开始聊天记录 ---
220
+ {self.memory}
221
+ --- 结束聊天记录 ---
222
+
223
+ 要求:
224
+ 请基于上述数据与对话内容,生成一段简洁、准确的综合摘要。
225
+ 摘要需完整呈现核心信息,便于后续自动判断该内容在报告撰写中是否需要被引用。
226
+ """.strip()
227
+
228
+ desc = self.call(prompt)
229
+ self.abstract = desc
230
+
231
+ return self.abstract
232
+
233
+
234
+ def check_full(self):
235
+
236
+ if self.full is None:
237
+ df = self.load_df()
238
+ df_head = df.head()
239
+ dtype_info = df.dtypes.astype(str)
240
+
241
+ self.full = (
242
+ f"【阶段说明】这是数据分析流程中的数据导入阶段。\n"
243
+ f"【数据格式】{dtype_info}\n"
244
+ f"【样本预览】\n{df_head}\n"
245
+ f"【分析对话记录】\n{self.memory}"
246
+ )
247
+
248
+ return self.full
prompt_engineer/sec2_call_llm.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from prompt_engineer.call_llm import LLMClient
5
+
6
+ class DataPreprocessAgent(LLMClient):
7
+
8
+ def __init__(self, *args, **kwargs):
9
+
10
+ super().__init__(*args, **kwargs)
11
+ self.processed_df = None
12
+ self.code = None
13
+ self.preprocessing_suggestions = None
14
+ self.allowed_libs = [
15
+ "numpy",
16
+ "pandas",
17
+ "sklearn.impute",
18
+ "sklearn.preprocessing",
19
+ "sklearn.compose",
20
+ "sklearn.pipeline"
21
+ ]
22
+ self.par_content = ""
23
+ self.error = None
24
+ self.user_input = None
25
+ self.refined_suggestions = ""
26
+ self.abstract=None
27
+ self.full = None
28
+ self.finish_auto_task = False
29
+ self.debug_num = 0
30
+
31
+
32
+ def finish_auto(self):
33
+
34
+ self.finish_auto_task = True
35
+
36
+
37
+ def save_code(self, code):
38
+
39
+ self.code = code
40
+
41
+
42
+ def load_code(self):
43
+
44
+ return self.code
45
+
46
+
47
+ def save_user_input(self, user_input):
48
+
49
+ self.user_input = user_input
50
+
51
+
52
+ def load_user_input(self):
53
+
54
+ return self.user_input
55
+
56
+
57
+ def save_error(self, error):
58
+
59
+ self.error = error
60
+
61
+
62
+ def load_error(self):
63
+
64
+ return self.error
65
+
66
+
67
+ def save_preprocessing_suggestions(self, suggestions):
68
+
69
+ self.preprocessing_suggestions = suggestions
70
+
71
+
72
+ def load_preprocessing_suggestions(self):
73
+
74
+ return self.preprocessing_suggestions
75
+
76
+
77
+ def save_processed_df(self, processed_df):
78
+
79
+ if not isinstance(processed_df, pd.DataFrame):
80
+ if isinstance(processed_df, np.ndarray):
81
+ processed_df = pd.DataFrame(processed_df)
82
+ else:
83
+ raise TypeError(f"期望 pandas.DataFrame 或 numpy.ndarray,收到 {type(processed_df)}")
84
+
85
+ self.processed_df = processed_df
86
+
87
+
88
+ def load_processed_df(self):
89
+
90
+ return self.processed_df
91
+
92
+
93
+ def load_refined_suggestions(self):
94
+ return self.refined_suggestions
95
+
96
+
97
+ def save_refined_suggestions(self, refined_suggestions):
98
+ self.refined_suggestions = refined_suggestions
99
+
100
+
101
+ def refine_suggestions(self, df_head):
102
+ """将 LLM 返回的预处理推荐进行信息提取"""
103
+
104
+ suggestion = self.load_preprocessing_suggestions()
105
+
106
+ prompt = f"""
107
+ 请根据以下预处理建议,概括数据集中每一列的推荐预处理方法。
108
+
109
+ 数据示例:
110
+ {df_head}
111
+
112
+ 详细预处理建议:
113
+ {suggestion}
114
+
115
+ 输出要求(必须严格遵守):
116
+ 1. 输出格式:列名:推荐预处理方法;每条独立换行。
117
+ 2. 每列最多给出三个推荐方法,多个方法用逗号分隔。
118
+ 3. 输出必须为纯文本,不使用任何 Markdown 标记。
119
+ 4. 每个方法的长度不得超过20个汉字,若包含英文则不超过10个单词。"""
120
+
121
+ refined_suggestions = self.call(prompt)
122
+ self.refined_suggestions = refined_suggestions
123
+
124
+ return refined_suggestions
125
+
126
+
127
+ def get_preprocessing_suggestions(
128
+ self,
129
+ user_input=None,
130
+ memory_limit=6,
131
+ ):
132
+
133
+ df = self.load_df()
134
+
135
+ # 基本统计
136
+ n_rows, n_cols = df.shape
137
+ dtype_counts = df.dtypes.value_counts().to_dict()
138
+ missing_total = int(df.isnull().sum().sum())
139
+ missing_by_col = df.isnull().mean().mul(100).round(2).to_dict()
140
+ num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
141
+
142
+ # 整理 memory 片段
143
+ recent_memory = self.memory[-memory_limit:] if self.memory else []
144
+ if recent_memory:
145
+ formatted_memory = "\n".join(
146
+ f"{m['role']}: {m['content']}" for m in recent_memory
147
+ )
148
+ memory_block = f"{formatted_memory}"
149
+ else:
150
+ memory_block = ""
151
+
152
+ prompt = f"""
153
+ 你是一名资深的数据预处理专家,负责为数据分析报告提供高质量的预处理建议。
154
+
155
+ === 数据概览 ===
156
+ - 数据规模:{n_rows} 行 × {n_cols} 列
157
+ - 数据类型分布:{dtype_counts}
158
+ - 缺失值总数:{missing_total}
159
+ - 各列缺失率:{missing_by_col}
160
+ - 数值型列:{num_cols}
161
+ - 历史上下文(仅供参考):{memory_block}
162
+ """
163
+
164
+ if user_input is None:
165
+ prompt += """
166
+ === 请对每一列进行逐项分析(注意,是逐列分析) ===
167
+ 请针对每一列依次说明以下四个方面:
168
+
169
+ 1. **数据类型**:明确该列的数据类型,若存在混合类型或异常值类型,请指出。
170
+ 2. **缺失值处理建议**:说明该列的缺失值处理策略;若建议调整,请指明具体“缺失值处理 策略”操作。
171
+ 3. **异常值处理建议**:说明该列的异常检测与处理方案;若需调整,请说明“异常值处理 策略或阈��”操作。
172
+ 4. **标准化建议**:说明是否建议标准化或缩放,并在需要时指出“标准化处理 策略”操作。
173
+
174
+ 输出格式要求:
175
+ - 按“列名 + 分点说明(1–4)”的形式分段输出;
176
+ - 每一列独立成段,并以换行分隔;
177
+ - 使用清晰、简洁的专业语言。
178
+ """
179
+ else:
180
+ prompt += f"""
181
+ === 用户新需求 ===
182
+ {user_input}
183
+
184
+ 请结合以上数据概览与历史上下文,针对该需求,给出下一步操作。
185
+ 可考虑的操作包括:缺失值处理、异常值检测与修正、标准化或归一化、特征类型调整等。
186
+ 输出应保持结构化与连贯性,避免重复说明。
187
+ """
188
+
189
+ suggestions = self.call(prompt)
190
+
191
+ return suggestions
192
+
193
+
194
+ def code_generation(self, df_head, user_prompt):
195
+ """生成 LLM prompt:要求 LLM 输出 process_df(pandas DataFrame)。"""
196
+ allowed = ", ".join(self.allowed_libs)
197
+
198
+ prompt = f"""
199
+ 请**严格只输出纯 Python 代码**,不得包含以下内容:
200
+ - 解释性文字、注释、示例;
201
+ - Markdown 代码块标记(禁止出现 ``` 或 ```python 等);
202
+ - 任何多余输出(如 print、全局变量赋值等)。
203
+
204
+ === 运行环境说明 ===
205
+ 运行环境中已提供以下对象与库:
206
+ - pandas DataFrame 变量:`df`
207
+ - 库:numpy (np)、SimpleImputer、StandardScaler、MinMaxScaler、RobustScaler、
208
+ OneHotEncoder、OrdinalEncoder、LabelEncoder、FunctionTransformer、
209
+ ColumnTransformer、Pipeline。
210
+ 若所需功能在这些库中不存在,请自行写 Python code 实现。
211
+
212
+ === 生成要求 ===
213
+ 1. 若有用户需求,请优先满足用户需求(优先级高于 LLM 返回的通用建议)。
214
+ 2. 若建议指出某列“无需处理”,则对该列不进行任何操作。
215
+ 3. 禁止导入其他库、禁止文件读写。
216
+ 4. 所有括号(圆括号、方括号、大括号)必须成对闭合,不得错位或遗漏。
217
+ 5. 对类别特征,可使用 OneHotEncoder 或 OrdinalEncoder;
218
+ 若为单列字符串/类别列,请使用 LabelEncoder 或 OrdinalEncoder,不得 passthrough。
219
+ 6. 在构建 ColumnTransformer 前,需检测并处理“混合型列”
220
+ —— 即同时包含数值和字符串的列,
221
+ 使用 `FunctionTransformer(lambda x: x.astype(str))` 将其统一为字符串类型。
222
+ 7. ColumnTransformer 的 transformers 中仅包含经过上述处理的列。
223
+ 8. 使用 OneHotEncoder 时,若输出稀疏矩阵,请确保所有输入特征均为数值类型。
224
+ 9. 若 df 中存在重复表头(如第 0 行与 header 相同),需自动检测并删除重复表头行。
225
+ 10. 确保预处理后的 DataFrame 中每一列均有明确列名。
226
+ 11. 脚本最后仅保留一行结果:
227
+ `process_df = ...`
228
+ 不允许出现 print、显示语句或其他多余输出。
229
+
230
+ === 输入数据示例 ===
231
+ {df_head}
232
+
233
+ === 用户指定需求 ===
234
+ {user_prompt}
235
+
236
+ 请严格依据以上要求,输出完整且可直接执行的 Python 代码(纯代码块,无额外说明)。
237
+ """.strip()
238
+
239
+ if self.error is not None:
240
+ if self.debug_num < 5 :
241
+ self.debug_num += 1
242
+
243
+ prompt += f"""
244
+ 上次生成的代码运行失败。
245
+ 【错误信息】:
246
+ {self.error}
247
+
248
+ 【原始代码】:
249
+ {self.code}
250
+
251
+ 请在不输出任何解释性文字的情况下,推理并理解导致错误的根本原因,
252
+
253
+ 要求:
254
+ 1. 不输出任何分析、解释或说明(包括文字、列表或注释段落);
255
+ 2. 可在代码内部使用简短注释说明关键修改;
256
+ 3. 若错误源于逻辑、数据结构或函数使用不当,请自行调整;
257
+ 4. 若依赖库方法不适用,可自行实现替代函数;
258
+ 5. 生成的代码必须可独立运行,无语法错误;
259
+ 6. 保持整体逻辑与原代码意图一致,仅做必要修正。
260
+ """
261
+
262
+ else:
263
+ self.debug_num = 0
264
+
265
+ if self.user_input is not None:
266
+ prompt += f"用户需求:{self.user_input}。\n请严格遵循并优先执行该需求,其优先级高于所有其他建议或规则。\n"
267
+
268
+ if self.refined_suggestions is not None:
269
+ prompt += f"LLM返回的预处理建议:{self.refined_suggestions}"
270
+
271
+ raw = self.call(prompt)
272
+ return raw
273
+
274
+
275
+ def summary_html(self):
276
+
277
+ if self.code is None:
278
+ summary = None
279
+ return summary
280
+
281
+ else:
282
+ processed_df = self.load_processed_df()
283
+ prompt = f"""
284
+ 你正在撰写数据分析报告的第二章——《数据预处理与标准化》。
285
+ 请根据以下输入内容,提炼关键信息并撰写相应分析段落。
286
+
287
+ - 预处理代码:
288
+ {self.code}
289
+
290
+ - 预处理结果(数据示例):
291
+ {processed_df.head()}
292
+
293
+ {f"- 预处理建议对话记录:{self.load_memory}" if self.load_memory else ""}
294
+
295
+ 撰写要求:
296
+ 1. 使用流畅、自然的中文表达;
297
+ 2. 语言应简洁、准确,避免过多形容词或副词;
298
+ 3. 不使用“可能”“也许”“似乎”“微妙”等模糊表述;
299
+ 4. 不添加大标题,可使用自然段进行叙述;
300
+ 5. 内容需逻辑清晰,体现代码与结果之间的分析关联。
301
+
302
+ """.strip()
303
+
304
+ desc = self.call(prompt)
305
+
306
+ summary = {
307
+ "title": "数据预处理",
308
+ "desc": desc,
309
+ "processed_df": self.processed_df.head(),
310
+ "code": self.code,
311
+ }
312
+
313
+ return summary
314
+
315
+
316
+ def summary_word(self):
317
+
318
+ return self.summary_html()
319
+
320
+
321
+ def check_abstract(self):
322
+
323
+ if self.abstract is None:
324
+
325
+ processed_df = self.load_processed_df()
326
+
327
+ if self.code is None:
328
+ self.abstract = None
329
+
330
+ else:
331
+
332
+ memory = f"【预处理建议对话记录】\n{self.load_memory}\n" if self.load_memory else ""
333
+
334
+ prompt = f"""
335
+ 这是数据分析流程中的“数据预处理与标准化”阶段。
336
+
337
+ 【预处理代码】
338
+ {self.code}
339
+
340
+ 【预处理结果(前五行)】
341
+ {processed_df.head()}
342
+
343
+ {memory}
344
+ 请在确保信息准确完整的前提下,将上述内容概括为一段简洁的文字摘要。
345
+ 要求:
346
+ 1. 语言自然流畅,保持客观和专业;
347
+ 2. 内容应涵盖关键点(包括主要预处理步骤与结果特征);
348
+ 3. 重点在于“说明核心信息”,而非逐行描述;
349
+ 4. 生成的摘要应可用于报告编写时判断该部分是否需要引用。
350
+ """.strip()
351
+
352
+ desc = self.call(prompt)
353
+ self.abstract = desc
354
+
355
+ return self.abstract
356
+
357
+
358
+ def check_full(self):
359
+ if self.full is None:
360
+ processed_df = self.load_processed_df()
361
+ if self.code is None:
362
+ self.full = None
363
+ else:
364
+ content = f"""
365
+ 【阶段说明】这是数据分析流程中的数据预处理阶段。
366
+ 【预处理代码】{self.code}
367
+ 【预处理结果前五行】{processed_df.head()}
368
+ """.strip()
369
+ if self.load_memory is not None:
370
+ content += f"\n【预处理建议聊天对话】{self.load_memory}"
371
+
372
+ self.full = content
373
+
374
+ return self.full
prompt_engineer/sec3_call_llm.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import base64
3
+ import plotly.graph_objs as go
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
+
6
+ from prompt_engineer.call_llm import LLMClient
7
+
8
+ import numpy as np
9
+ np.set_printoptions(edgeitems=250, threshold=501)
10
+
11
+ class VisualizationAgent(LLMClient):
12
+
13
+ def __init__(self, *args, **kwargs):
14
+
15
+ super().__init__(*args, **kwargs)
16
+ self.cols_wo_id = None
17
+ self.recommendations = None
18
+ self.analysis = []
19
+ self.quick_action = None
20
+ self.data_meaning = ""
21
+ self.allowed_libs = [
22
+ "numpy", "plotly", "plotly.express", "plotly.graph_objects"
23
+ ]
24
+ self.code = None
25
+ self.result = None
26
+ self.suggestion = None
27
+ self.user_input = None
28
+ self.fig = []
29
+ self.par_content = ""
30
+ self.error = None
31
+ self.abstract=None
32
+ self.full = None
33
+ self.color = None
34
+ self.finish_auto_task = False
35
+ self.debug_num = 0
36
+ self.refined_suggestions = None
37
+
38
+
39
+ def finish_auto(self):
40
+
41
+ self.finish_auto_task = True
42
+
43
+
44
+ def save_user_input(self, user_input):
45
+
46
+ self.user_input = user_input
47
+
48
+
49
+ def load_user_input(self):
50
+
51
+ return self.user_input
52
+
53
+
54
+ def save_color(self, color):
55
+
56
+ self.color = color
57
+
58
+
59
+ def load_color(self):
60
+
61
+ return self.color
62
+
63
+
64
+ def add_fig(self, fig, desc):
65
+
66
+ entry = {"fig": fig, "desc": desc}
67
+ self.fig.append(entry)
68
+
69
+
70
+ def load_fig(self):
71
+
72
+ return self.fig
73
+
74
+
75
+ def save_cols_wo_id(self, col):
76
+
77
+ self.cols_wo_id = col
78
+
79
+
80
+ def load_cols_wo_id(self):
81
+
82
+ return self.cols_wo_id
83
+
84
+
85
+ def save_code(self, code):
86
+
87
+ self.code = code
88
+
89
+
90
+ def load_code(self):
91
+
92
+ return self.code
93
+
94
+
95
+ def save_recommendations(self, recommendations):
96
+
97
+ self.recommendations = recommendations
98
+
99
+
100
+ def load_recommendations(self):
101
+
102
+ return self.recommendations
103
+
104
+
105
+ def save_suggestion(self, suggestion):
106
+
107
+ self.suggestion = suggestion
108
+
109
+
110
+ def load_suggestion(self):
111
+
112
+ return self.suggestion
113
+
114
+
115
+ def load_data_meaning(self):
116
+
117
+ return self.data_meaning
118
+
119
+
120
+ def save_error(self, error):
121
+
122
+ self.error = error
123
+
124
+
125
+ def load_error(self):
126
+
127
+ return self.error
128
+
129
+
130
+ def refine_suggestions(self, rec):
131
+
132
+ prompt = f"""
133
+ 请根据以下详细的可视化建议,提取每一列与每个变量组的推荐可视化方法。
134
+
135
+ 详细可视化建议:
136
+ {rec}
137
+
138
+ 输出要求(必须严格遵守):
139
+ 1. 输出为纯文本,每条独立换行,且不得有多余说明。
140
+ 2. 单变量格式:列名:图表1, 图表2。
141
+ 3. 多变量格式:关系组:列A,列B:图表1, 图表2。
142
+ 4. 总体变量格式:总体:图表1, 图表2。
143
+ 5. 严格不要添加标题、编号、示例或额外解释。
144
+ 6. 提取可视化方法精准。
145
+ """
146
+
147
+ refined_suggestions = self.call(prompt)
148
+ self.refined_suggestions = refined_suggestions
149
+
150
+ return refined_suggestions
151
+
152
+
153
+ def get_visualization_recommendations(
154
+ self,
155
+ cols,
156
+ user_input=None,
157
+ memory_limit: int = 6,
158
+ ) -> str:
159
+
160
+ dim_info = f"{self.df.shape[0]} 行 x {self.df.shape[1]} 列"
161
+
162
+ recent_memory = self.memory[-memory_limit:] if getattr(self, "memory", None) else []
163
+ if recent_memory:
164
+ formatted_memory = "\n".join(
165
+ f"{m['role']}: {m['content']}" for m in recent_memory
166
+ )
167
+ memory_block = f"{formatted_memory}"
168
+ else:
169
+ memory_block = ""
170
+
171
+ if user_input is None:
172
+ prompt = f"""
173
+ 你是一位资深数据可视化专家,请根据以下信息,为数据分析报告的“可视化设计”章节提供系统、专业的建议。
174
+
175
+ 【数据集信息】
176
+ - 数值型变量:{cols}
177
+ - 数据维度:{dim_info}
178
+ - 历史上下文(仅供参考):{memory_block}
179
+
180
+ 【输出格式】
181
+ 请严格按照以下结构输出(保持标题和层级一致,不得增减):
182
+
183
+ 一、单变量可视化(Univariate)
184
+ 1. 针对每个数值型变量,推荐 1–2 种最合适的可视化方法,并简要说明理由。
185
+ 例如:
186
+ - `列1`:推荐“直方图(Histogram)”和“盒须图(Box Plot)”,理由:……
187
+
188
+ 二、多变量关系可视化(Multivariate)
189
+ 1. 从上述变量中选择 1–3 组值得重点分析的变量组合(每组包含 2–3 个变量),并说明选择理由。
190
+ 例如:
191
+ - 关系组 1:`[列1, 列2]`,理由:……
192
+ 2. 对每一组变量,推荐最合适的可视化方法,并简要说明。
193
+ 例如:
194
+ - 关系��� 1:散点图(Scatter Plot)+ 回归线(Regression Line),理由:……
195
+
196
+ 三、整体分布可视化(Distribution Overview)
197
+ 1. 针对全数据的总体分布特征,推荐 1–2 种全局可视化方法,并说明用途。
198
+ 例如:
199
+ - 推荐“小提琴图矩阵(Violin Plot Matrix)”,用途:……
200
+ - 推荐“热力图(Heatmap)”,用途:……
201
+
202
+ 【执行要求】
203
+ 1. 若列名无实际意义(如索引、冗余 ID),应自动过滤;
204
+ 2. 输出内容需保持条理清晰、语言简洁、专业。
205
+ """.strip()
206
+
207
+ else:
208
+ prompt = f"""
209
+ 你是一位资深数据可视化专家,请根据以下信息,请回应用户需求,实现用户需求:
210
+
211
+ 【用户需求】
212
+ {user_input}
213
+
214
+ 【数据集信息】
215
+ - 数值型变量:{cols}
216
+ - 数据维度:{dim_info}
217
+ - 数据概览(前几行):
218
+ {self.df.head().to_string(index=False)}
219
+ - 历史上下文(仅供参考):{memory_block}
220
+
221
+ 【执行要求】
222
+ 1. 若用户明确指定可视化列,仅针对这些列给出建议;
223
+ 2. 若用户提出特定要求(如图形大小、坐标轴 log 缩放等),必须在输出中体现;
224
+ 3. 仅响应用户需求,不输出无关内容;
225
+ 4. 若用户要求对先前内容进行局部修改,应保留未更动部分,仅更新相关建议;
226
+ 5. 输出内容应结构清晰、逻辑连贯、语言简洁。
227
+ 6. 禁止输出代码。
228
+ """.strip()
229
+
230
+ recommendations = self.call(prompt)
231
+ return recommendations
232
+
233
+
234
+ def desc_fig(self, fig, dtype_info):
235
+
236
+ selected = st.session_state.selected_model
237
+
238
+ if selected == "智谱AI" or selected == "通义千问" or selected == "GPT-4o" or selected == "GPT-5" or selected == "豆包" or selected == "Claude":
239
+ img_bytes = fig.to_image(format="jpg")
240
+ fig_info = extract_plotly_info(fig)
241
+ base64_bytes = base64.b64encode(img_bytes)
242
+ base64_string = base64_bytes.decode('utf-8')
243
+
244
+ prompt_payload = [
245
+ {
246
+ "type": "image_url",
247
+ "image_url": {"url": f"data:image/jpg;base64,{base64_string}"}
248
+ },
249
+ {
250
+ "type": "text",
251
+ "text": f"""
252
+ 请综合下方可视化图与变量信息,进行**简洁但深入的分析**。
253
+ 从分布形态、趋势特征、变量间关系、潜在异常现象、现实含义五个角度,提炼关键洞察。
254
+ 输出一段不超过 120 字的自然语言分析结论(非摘要)。
255
+
256
+ 【变量信息】
257
+ {dtype_info}
258
+
259
+ 【图表结构信息】
260
+ {fig_info}
261
+
262
+ 写作要求:
263
+ 1. 分析需包含对数据异常的识别与说明:
264
+ - 若存在明显异常点、异常段或突变趋势,请指出其特征与潜在影响;
265
+ - 若未发现异常,也需明确说明整体分布稳定或无显著异常;
266
+ 2. 内容需体现推理与解释性思考,而非表面描述;
267
+ 3. 使用逻辑清晰、客观专业的语言;
268
+ 4. 使用动词驱动句式(如“呈现出”“反映出”“揭示出”“说明了”等);
269
+ 5. 不使用模糊词(如“可能”“似乎”“微妙”等);
270
+ 6. 不使用标题、列表或格式符号;
271
+ 7. 若变量含义中存在噪声或重复信息,请自动忽略;
272
+ 8. 保持语气简洁有力,强调数据特征与分析结论。
273
+ """.strip()
274
+ }
275
+ ]
276
+
277
+ desc_fig = self.call(prompt_payload)
278
+
279
+ else:
280
+ prompt = f"""
281
+ 请综合下方可视化图与变量信息,从数据分布、趋势特征及潜在关系等角度进行分析。
282
+ 以不超过 100 字的自然语言总结关键发现,突出该变量在整体数据结构中的意义或异常现象。
283
+
284
+ 【变量信息】
285
+ {dtype_info}
286
+
287
+ 【图表信息】
288
+ {fig.to_dict()}
289
+
290
+ 写作要求:
291
+ 1. 语言应流畅自然,保持客观、专业;
292
+ 2. 使用简洁的动词和名词,不滥用形容词或副词;
293
+ 3. 避免“可能”“也许”“似乎”“微妙”等模糊词;
294
+ 4. 不添加标题或列表结构;
295
+ 5. 结合数据含义和图表特征,给出具有洞察力的简要结论;
296
+ 6. 若变量含义中存在杂乱或重复信息,请自动忽略。
297
+ """.strip()
298
+
299
+ desc_fig = self.call(prompt)
300
+
301
+ return desc_fig
302
+
303
+
304
+ def summary_html(self) -> str:
305
+
306
+ analysis = self.summary_fig_analysis_list()
307
+
308
+ if analysis is None:
309
+
310
+ return None
311
+
312
+ else:
313
+ analysis = {i: item for i, item in enumerate(analysis)}
314
+
315
+ summary = {
316
+ "title": "数据可视化",
317
+ "fig_analysis": analysis,
318
+ }
319
+
320
+ return summary
321
+
322
+
323
+ def summary_word(self) -> str:
324
+
325
+ analysis = self.summary_fig_analysis_list()
326
+
327
+ if analysis is None:
328
+
329
+ return None
330
+
331
+ else:
332
+
333
+ summary = {
334
+ "title": "数据可视化",
335
+ "fig_analysis": analysis,
336
+ }
337
+
338
+ return summary
339
+
340
+
341
+ def summary_fig_analysis_list(self) -> str:
342
+
343
+ if not self.code:
344
+ return self.analysis
345
+
346
+ if self.analysis:
347
+ return self.analysis
348
+
349
+ # state_copy = dict(st.session_state)
350
+ selected = st.session_state.get("selected_model", "default")
351
+ # selected = state_copy.get("selected_model", "default")
352
+
353
+ # --- 定义单个任务 ---
354
+ def analyze_one(item, offset):
355
+ fig = item["fig"]
356
+ desc = item["desc"]
357
+
358
+ # 恢复状态(如果需要访问 st.session_state)
359
+ # st.session_state.update(state_copy)
360
+ selected = st.session_state.get("selected_model", "default")
361
+ if isinstance(fig, go.Figure):
362
+ if selected == "智谱AI" or selected == "通义千问" or selected == "GPT-4o" or selected == "GPT-5" or selected == "豆包" or selected == "Claude":
363
+ img_bytes = fig.to_image(format="jpg")
364
+ base64_string = base64.b64encode(img_bytes).decode("utf-8")
365
+
366
+ fig_info = extract_plotly_info(fig)
367
+
368
+ prompt_payload = [
369
+ {
370
+ "type": "image_url",
371
+ "image_url": {"url": f"data:image/jpg;base64,{base64_string}"}
372
+ },
373
+ {
374
+ "type": "text",
375
+ "text": f"""
376
+ 你正在撰写数据分析报告的第三章——《数据可视化》。
377
+ 请针对下方变量,结合其**业务含义、统计特征**与**可视化图表现**,撰写一段专业、逻辑严谨、可直接用于报告正文的分析内容。
378
+
379
+ 【变量信息】
380
+ {self.cols_wo_id}
381
+
382
+ 【Plotly 图表结构】
383
+ {fig_info}
384
+
385
+ 【基础统计概览】
386
+ {desc}
387
+
388
+ 【分析任务】
389
+ 请在脑中先完成以下推理步骤,然后输出结构化正文:
390
+ 1. 从图表识别核心模式:整体趋势、峰值、分布形态、异常点或聚集区;
391
+ 2. 思考该模式与变量业务含义的关系;
392
+ 3. 判断是否存在异常现象(单点异常、阶段性异常或结构性突变),并说明其潜在影响;
393
+ 4. 若图中包含其他变量,请分析它们之间的统计或逻辑关联;
394
+ 5. 将上述洞察整合成逻辑完整、语言自然的段落。
395
+
396
+ 【输出格式(严格遵守)】
397
+ 输出为纯文本,依次包含以下三部分(不使用 Markdown 或符号):
398
+
399
+ 1. 概述
400
+ - 简述变量的定义、业务角色及数据表现的总体趋势;
401
+ - 提出该变量在整体数据结构中可能的重要性。
402
+
403
+ 2. 分布与特征分析
404
+ - 从统计与图形角度分析其分布特征(集中趋势、离散程度、偏态、峰度、周期性等);
405
+ - 若发现异常或突变,请具体说明其表现形式与潜在机制;
406
+ - 若与其他变量有关联趋势,指出方向与强度。
407
+
408
+ 3. 实际含义与推论
409
+ - 结合业务或研究背景,解释观察到的现象;
410
+ - 分析其可能揭示的现实规律、风险或优化方向;
411
+ - 若合适,可提出合理推测或后续分析建议(保持客观与逻辑自洽)。
412
+
413
+ 【写作要求】
414
+ 1. 保持语言正式、专业、逻辑紧密;
415
+ 2. 句式多样、表达自然,避免模板化表述���
416
+ 3. 禁用模糊词汇(如“可能”“似乎”“大概”等);
417
+ 4. 不使用任何标题符号(如 #、** 等);
418
+ 5. 不输出“AI”“模型”“助手”等字样;
419
+ 6. 输出为连续正文,不包含解释性语句或附加说明。
420
+ """.strip()
421
+ }
422
+ ]
423
+
424
+ analysis_text = self.call(prompt_payload)
425
+
426
+ else:
427
+
428
+ prompt = f"""
429
+ 你正在撰写数据分析报告的第三章——《数据可视化》。
430
+ 请针对下方变量,结合其业务含义与对应的可视化图,撰写一段结构化、专业的分析文字。
431
+
432
+ 【变量信息】
433
+ {self.cols_wo_id}
434
+
435
+ 【Plotly 图表信息】
436
+ {fig.to_dict()}
437
+
438
+ 【基础统计概览】
439
+ {desc}
440
+
441
+ 请严格按照以下格式撰写内容(使用纯文本,不使用 Markdown 语法或符号):
442
+
443
+ 1. 概述
444
+ - 说明该变量的含义及其在数据或业务中的作用;
445
+ - 简要描述整体分布特征或变量间的主要关联趋势。
446
+
447
+ 2. 分布 / 关联特征
448
+ - 从统计角度说明变量的分布特征或相关关系;
449
+ - 可引用关键统计量(均值、中位数、四分位数、相关系数等)支持分析。
450
+
451
+ 3. 现实含义
452
+ - 结合变量在实际情境中的意义,解释所观察到的分布或关系;
453
+ - 指出这些模式可能反映的现实现象或潜在影响(例如:某变量偏高代表风险上升或群体特征差异)。
454
+
455
+ 【写作要求】
456
+ 1. 使用流畅、自然且正式的中文表达;
457
+ 2. 语言应客观、简洁,避免冗余修辞;
458
+ 3. 禁止使用“可能”“也许”“似乎”“微妙”等模糊词;
459
+ 4. 不使用标题符号(#、** 等);
460
+ 5. 保持逻辑连贯,分析层次清晰。
461
+ """.strip()
462
+
463
+ analysis_text = self.call(prompt)
464
+ print(prompt)
465
+ return offset, {"figure": fig, "analysis": analysis_text}
466
+
467
+ # --- 并行执行 ---
468
+ results = []
469
+ with ThreadPoolExecutor(max_workers=4) as executor:
470
+ futures = [executor.submit(analyze_one, item, i) for i, item in enumerate(self.fig)]
471
+ for f in as_completed(futures):
472
+ result = f.result()
473
+ if result:
474
+ results.append(result)
475
+
476
+ # --- 按原顺序排序 ---
477
+ results.sort(key=lambda x: x[0])
478
+ self.analysis = [r[1] for r in results]
479
+
480
+ return self.analysis
481
+
482
+
483
+ def code_generation(self, df_head: str, user_prompt: str) -> str:
484
+ """生成 LLM prompt:要求 LLM 输出 result_dict(可 JSON 序列化)。"""
485
+ allowed = ", ".join(self.allowed_libs)
486
+
487
+ prompt = (
488
+ "请**严格只输出纯 Python 代码**,**不要**输出任何解释性文字、注释、示例、markdown code fence(禁止出现 ``` 或 ```python 等)"
489
+ "运行环境已提供 pandas DataFrame 变量 `df`、numpy(np)、"
490
+ "plotly.express(px)、plotly.graph_objects(go)。\n\n"
491
+ "##严格要求##:\n"
492
+ "1) **严格执行用户需求**:若用户指定了要可视化的列,可能是精确列名,也可能是模糊输入"
493
+ "(如输入 “ordera” 但实际列名为 “ordertypea”),不要凭空产生虚假列名!!!"
494
+ f"请在脚本开头使用 LLM 理解将用户输入映射到 {df_head} 中最合适的真正列名,或采用更保守的索引(如第0列,第1列 推荐!),再仅对这些列绘制图表;\n"
495
+ """2) **统计并重命名**:所有类别分布图请按下面模板写,**绝不直接用** `index` 作为列名——
496
+ # === 模板:统计并绘制 Bar Chart ===
497
+ for col in categorical_cols:
498
+ df_counts = df[col] \\
499
+ .value_counts() \\
500
+ .rename_axis(col) \\
501
+ .reset_index(name='count')
502
+ fig = px.bar(
503
+ df_counts,
504
+ x=col,
505
+ y='count',
506
+ title=f'Bar Chart of {col}',
507
+ labels={col: col, 'count': 'Count'}
508
+ )
509
+ fig_dict[f'{col}_bar'] = fig
510
+
511
+ 3) 智能选图:根据数据类型(数值/类别)自动选择合适的图表。
512
+ 4) 自动检测是否需要按分类列着色,并做两种处理:若存在指定的分类列且想连续映射,先编码为数值 codes;如要离散映射,使用 parallel_categories
513
+ 5) 如 Plotly Express 中无合适图表,使用 `go.Figure` 自定义。
514
+ 6) 脚本末尾仅包含 `fig_dict = {...}`,不要 `print`、不要额外全局变量。
515
+ 7) 任何情况下不得“造”列名或直接写 `'index'`;若要使用索引,必须显式使用 `df.index`。
516
+ 8) 不要使用文件读写或其他外部 IO。
517
+ 9) 请只给我python代码,不要给我任何'''python等非代码内容的标识符。"""
518
+ f"示例数据头部:\n{df_head}\n\n"
519
+ f"每一张图的颜色必须从{self.color}中,选择\n\n"
520
+ f"画图建议: {self.refined_suggestions}\n\n"
521
+ "返回:完整 Python 代码(纯代码块)。"
522
+ )
523
+
524
+ if self.error is not None:
525
+ if self.debug_num < 5 :
526
+ self.debug_num += 1
527
+ prompt += f"""
528
+ 上次生成的代码运行失败。
529
+ 【错误信息】:
530
+ {self.error}
531
+
532
+ 【原始代码】:
533
+ {self.code}
534
+
535
+ 请在不输出任何解释性文字的情况下,推理并理解导致错误的根本原因,
536
+
537
+ 要求:
538
+ 1. 不输出任何分析、解释或说明(包括文字、列表或注释段落);
539
+ 2. 可在代码内部使用简短注释说明关键修改;
540
+ 3. 若错误源于逻辑、数据结构或函数使用不当,请自行调整;
541
+ 4. 若依赖库方法不适用,可自行实现替代函数;
542
+ 5. 生成的代码必须可独立运行,无语法错误;
543
+ 6. 保持整体逻辑与原代码意图一致,仅做必要修正。
544
+ """
545
+ else:
546
+ self.debug_num = 0
547
+
548
+ raw = self.call(prompt)
549
+
550
+ return raw
551
+
552
+
553
+ def check_abstract(self):
554
+ if self.abstract is None:
555
+ # 获取所有分析内容
556
+ analysis_list = self.summary_fig_analysis_list()
557
+
558
+ if not analysis_list :
559
+ self.abstract = "暂无可视化分析内容。"
560
+ return self.abstract
561
+
562
+ # 合并所有分析内容为一个整体文本
563
+ all_analyses = "\n\n".join([
564
+ f"【变量分析 {i+1}】\n{item['analysis']}"
565
+ for i, item in enumerate(analysis_list)
566
+ ])
567
+
568
+ prompt = f"""
569
+ 请阅读并综合以下多个变量的分析内容:
570
+ {all_analyses}
571
+
572
+ 任务:
573
+ 将这些分析整合为一段结构化、信息充分的**综合语义总结**,供后续大模型自动生成报告目录使用。
574
+
575
+ 目标:
576
+ - 输出内容应帮助后续模型理解分析中包含的主题、变量、维度、关系与逻辑顺序;
577
+ - 它将作为“目录生成模型”的输入,因此必须让模型能看出报告中应有哪些章节与子章节。
578
+
579
+ 写作要求:
580
+ 1. **信息保留**:
581
+ - 保留每个变量的关键结论、趋势、特征、显著差异;
582
+ - 明确变量间的联系、对比或影响;
583
+ - 不得省略任何对分析主题有价值的事实。
584
+
585
+ 2. **结构导向**:
586
+ - 按逻辑顺序组织:总体特征 → 各变量分析 → 变量间关系 → 潜在规律;
587
+ - 若存在不同主题(如气象因素、污染物指标、模型结果),应自然体现层次;
588
+ - 语义中隐含章节边界信号(如“首先…其次…最后…”、“在气象变量方面…”、“在建模部分…”等)。
589
+
590
+ 3. **语言风格**:
591
+ - 专业、清晰、客观;
592
+ - 使用完整句表达,不使用列表或编号;
593
+ - 可以稍微详细,不追求简短。
594
+
595
+ 4. **输出格式**:
596
+ - 输出仅为一段完整文字;
597
+ - 不得加入标题、注释、JSON、代码块;
598
+ - 该文字将被直接送入目录生成模型,不对人类展示。
599
+
600
+ 请生成符合上述要求的综合语义总结。
601
+ """.strip()
602
+
603
+ self.abstract = self.call(prompt)
604
+
605
+ return self.abstract
606
+
607
+
608
+ def check_full(self):
609
+ """
610
+ 返回结构化的内容,遵守图片插入协议:
611
+ - 每个分析内容前标注索引
612
+ - 图片插入位置用 [FIG:index] 表示
613
+ - 后续处理时可根据此协议替换为实际图像
614
+ """
615
+ if self.full is None:
616
+ analysis_list = self.summary_fig_analysis_list()
617
+
618
+ if not analysis_list :
619
+ self.full = "暂无可视化分析内容。"
620
+ return self.full
621
+
622
+ # 构造结构化文本:带图片插入标记
623
+ full_parts = ["""【阶段说明】这是数据分析流程中的数据可视化阶段。"""]
624
+ for i, item in enumerate(analysis_list):
625
+ desc = item["analysis"]
626
+ part = f"""
627
+ 【对图 {i}的分析】
628
+ {desc}
629
+ [FIG:{i}] # 图片插入位置标记
630
+ """.strip()
631
+ full_parts.append(part)
632
+
633
+ self.full = "\n\n".join(full_parts)
634
+
635
+ # 添加协议说明
636
+ protocol_note = """
637
+ ---
638
+ # 图片插入处理协议说明:
639
+ # [FIG:index] 表示图片插入位置
640
+ # index 对应分析内容中的索引
641
+ # 你在需要放图的地方用 [FIG:index] 代替即可
642
+ """.strip()
643
+
644
+ self.full = f"{self.full}\n\n{protocol_note}"
645
+
646
+ return self.full
647
+
648
+
649
+ def extract_plotly_info(fig):
650
+ """
651
+ 从 Plotly Figure(对象 / dict / 字符串)中提取关键信息:
652
+ - 图标题
653
+ - X/Y 轴标题
654
+ - 图类型
655
+ - 颜色信息
656
+ - trace 数量
657
+ """
658
+ import ast
659
+ import plotly.graph_objects as go
660
+
661
+ if isinstance(fig, go.Figure):
662
+ fig = fig.to_dict()
663
+ elif isinstance(fig, dict):
664
+ pass
665
+ elif isinstance(fig, str):
666
+ clean_str = fig.strip()
667
+ if clean_str.startswith("Figure("):
668
+ clean_str = clean_str[len("Figure("):-1]
669
+ try:
670
+ fig = ast.literal_eval(clean_str)
671
+ except Exception as e:
672
+ raise ValueError(f"无法解析字符串形式的 Figure: {e}")
673
+ else:
674
+ raise TypeError(f"不支持的 fig 类型: {type(fig)}")
675
+
676
+ layout = fig.get("layout", {})
677
+ title = layout.get("title", {}).get("text", "")
678
+ xaxis_title = layout.get("xaxis", {}).get("title", {}).get("text", "")
679
+ yaxis_title = layout.get("yaxis", {}).get("title", {}).get("text", "")
680
+
681
+ data_list = fig.get("data", [])
682
+ types = list({d.get("type", "") for d in data_list})
683
+
684
+
685
+ return {
686
+ "title": title or "(无标题)",
687
+ "xaxis": xaxis_title or "(无X轴标题)",
688
+ "yaxis": yaxis_title or "(无Y轴标题)",
689
+ "types": types,
690
+
691
+ }
prompt_engineer/sec4_call_llm.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from prompt_engineer.call_llm import LLMClient
4
+
5
+
6
+ class ModelingCodingAgent(LLMClient):
7
+
8
+ def __init__(self, *args, **kwargs):
9
+
10
+ super().__init__(*args, **kwargs)
11
+ self.allowed_libs = [
12
+ "numpy", "sklearn.model_selection", "sklearn.preprocessing", "sklearn.ensemble", 'torch', 'torchvision', 'torchaudio', 'xgboost', 'lightgbm'
13
+ ]
14
+ self.code = None
15
+ self.result = None
16
+ self.suggestion = None
17
+ self.user_selection = None
18
+ self.par_content = ""
19
+ self.inference_code = None
20
+ self.best_model = None
21
+ self.inference_data = None
22
+ self.inference_processed_df = None
23
+ self.abstract=None
24
+ self.full = None
25
+ self.error = None
26
+ self.inference_error = None
27
+ self.target = None
28
+ self.finish_auto_task = False
29
+ self.best_model_gz_bytes = None
30
+ self.debug_num = 0
31
+ self.refined_suggestions = None
32
+
33
+ def finish_auto(self):
34
+
35
+ self.finish_auto_task = True
36
+
37
+
38
+ def save_best_model_gz_bytes(self, best_model_gz_bytes):
39
+
40
+ self.best_model_gz_bytes = best_model_gz_bytes
41
+
42
+
43
+ def load_best_model_gz_bytes(self):
44
+
45
+ return self.best_model_gz_bytes
46
+
47
+
48
+ def save_target(self, target):
49
+
50
+ self.target = target
51
+
52
+
53
+ def load_target(self):
54
+
55
+ return self.target
56
+
57
+
58
+ def save_error(self, error):
59
+
60
+ self.error = error
61
+
62
+
63
+ def load_error(self):
64
+
65
+ return self.error
66
+
67
+
68
+ def save_inference_error(self, inference_error):
69
+
70
+ self.inference_error = inference_error
71
+
72
+
73
+ def load_inference_error(self):
74
+
75
+ return self.inference_error
76
+
77
+
78
+ def save_inference_data(self, inference_data):
79
+
80
+ self.inference_data = inference_data
81
+
82
+
83
+ def load_inference_data(self):
84
+
85
+ return self.inference_data
86
+
87
+
88
+ def save_inference_processed_df(self, inference_processed_df):
89
+
90
+ self.inference_processed_df = inference_processed_df
91
+
92
+
93
+ def load_inference_processed_df(self):
94
+
95
+ return self.inference_processed_df
96
+
97
+
98
+ def save_inference_code(self, code):
99
+
100
+ self.inference_code = code
101
+
102
+
103
+ def load_inference_code(self):
104
+
105
+ return self.inference_code
106
+
107
+
108
+ def save_best_model(self, best_model):
109
+
110
+ self.best_model = best_model
111
+
112
+
113
+ def load_best_model(self):
114
+
115
+ return self.best_model
116
+
117
+
118
+ def save_code(self, code):
119
+
120
+ self.code = code
121
+
122
+
123
+ def load_code(self):
124
+
125
+ return self.code
126
+
127
+
128
+ def save_suggestion(self, suggestion):
129
+
130
+ self.suggestion = suggestion
131
+
132
+
133
+ def load_suggestion(self):
134
+
135
+ return self.suggestion
136
+
137
+
138
+ def save_modeling_result(self, result):
139
+
140
+ self.result = result
141
+
142
+
143
+ def load_modeling_result(self):
144
+
145
+ return self.result
146
+
147
+
148
+ def save_user_selection(self, user_selection):
149
+
150
+ self.user_selection = user_selection
151
+
152
+
153
+ def load_user_selection(self):
154
+
155
+ return self.user_selection
156
+
157
+
158
+ def refine_suggestions(self):
159
+ """将 LLM 返回的预处理推荐进行信息提取"""
160
+
161
+ prompt = f"""
162
+ 请阅读以下建模建议,并将其转化为对下一个 coding agent 的清晰建模任务指令。
163
+
164
+ === 建模建议 ===
165
+ {self.suggestion}
166
+
167
+ === 输出要求(必须严格遵守) ===
168
+ 1. 输出为纯文本,不使用任何 Markdown、编号或符号;
169
+ 2. 指令应简洁明确,便于 coding agent 直接理解并执行;
170
+ 3. 内容应聚焦于模型构建、训练或评估的具体任务;
171
+ 4. 避免解释性或分析性语言,仅描述“需要执行的操作”;
172
+ 5. 输出应覆盖所有关键步骤,使 coding agent 能独立完成建模流程。
173
+ """.strip()
174
+
175
+ refined_suggestions = self.call(prompt)
176
+ self.refined_suggestions = refined_suggestions
177
+
178
+ print(refined_suggestions)
179
+
180
+ return refined_suggestions
181
+
182
+
183
+ def code_generation(self, df_head: str, user_prompt: str) -> str:
184
+ """生成 LLM prompt:要求 LLM 输出 result_dict(可 JSON 序列化)。"""
185
+ allowed = ", ".join(self.allowed_libs)
186
+
187
+ if self.refined_suggestions is None:
188
+ suggestion = user_prompt
189
+ else:
190
+ suggestion = self.refined_suggestions
191
+
192
+ prompt = (
193
+ f"""请**严格只输出纯 Python 代码**,**不要**输出任何解释性文字、注释、示例、markdown code fence(禁止出现 ``` 或 ```python 等)。运行环境已提供 pandas DataFrame 变量 `df`、numpy(np)、train_test_split、StandardScaler、以及用户在 Requirement 中可能提到的任意模型类(例如 RandomForestRegressor、GradientBoostingRegressor、LinearRegression、XGBRegressor、LogisticRegression、SVC 等)。
194
+
195
+ 要求:
196
+
197
+ 1) 使��� 80/20 切分(random_state=42),根据用户需求决定是否对数值特征标准化(StandardScaler),如果标准化,务必只应用于数值列并在训练/测试集上分别执行 fit_transform/transform。
198
+ 2) **对 Requirement 中列出的所有模型都依次训练和评估**,不得只选随机森林;如果用户在 Requirement 中指定了多个模型名称,脚本必须循环遍历这些模型并分别训练、预测、计算指标。
199
+ 3) 不要导入任何评价库(如 sklearn.metrics),如需评价请用 numpy 手写实现常见指标(回归:MAE、MSE、R2;分类:accuracy、precision、recall、f1)。
200
+ 4) **脚本最后必须只输出并赋值一个变量 `result_dict`,且它是一个可以 JSON 序列化的 Python dict。**
201
+ 推荐 schema(必须包含以下键):
202
+ {{
203
+ "dataset": "<可选描述字符串>",
204
+ "models": [
205
+ {{
206
+ "name": "<模型类名>",
207
+ "type": "<regression 或 classification>",
208
+ "metrics": {{ "<指标名>": <float>, ... }}
209
+ }},
210
+ ...
211
+ ],
212
+ "best_model": {{
213
+ "name": "<得分最优的模型类名>",
214
+ "score": <float>
215
+ }},
216
+ "artifacts": {{
217
+ "best_model_b64": "<base64 字符串>",
218
+ "best_model_format": "pickle+gzip"
219
+ }},
220
+ // 如模型过大,可选 "artifact_warning": <int 字节大小>
221
+ // 以及用户在 Requirement 中提出的其他字段
222
+ }}
223
+ 5) 确保所有数值均为 Python 原生类型(float、int),字段名严格为 models、best_model、artifacts;如果用户有额外需求,如记录训练时间、特征重要性等,也请加入 result_dict。
224
+ 6) 模型导出:训练完毕后,将选定的 best_model 用 pickle 序列化并 gzip 压缩,再 base64 编码;把编码字符串和格式信息填入 result_dict["artifacts"],并确保最终 result_dict 可 JSON 序列化。
225
+ 7) 脚本末尾仅包含一行 `result_dict = {{...}}`,不要 print、不创建其他全局变量、不读写文件。
226
+ 8) 如果模型序列化后的字节数超过合理大小,请在 result_dict 中添加 `"artifact_warning": <字节数>`。
227
+ 9) 不要使用任何外部 IO 或文件操作。
228
+ 10) 请准确实现Requirement中要求的模型,不许添加Requirement之外的模型,若先提供的库中无法直接调用对应模型,请手动实现!
229
+
230
+ 示例数据头部:
231
+ {df_head}
232
+
233
+ Requirement(请根据以下建模任务指令,对所有列出的模型依次执行训练与评估。若某模型在当前环境不可用,请手动实现对应算法或类,使结果完整可复现):
234
+ {suggestion}
235
+
236
+ Allowed libraries: {allowed}。
237
+
238
+ 返回:完整 Python 代码(纯代码块)。"""
239
+ )
240
+
241
+ if self.error is not None:
242
+ if self.debug_num < 5 :
243
+ self.debug_num += 1
244
+ prompt += f"""
245
+ 上次生成的代码运行失败。
246
+ 【错误信息】:
247
+ {self.error}
248
+
249
+ 【原始代码】:
250
+ {self.code}
251
+
252
+ 请在不输出任何解释性文字的情况下,推理并理解导致错误的根本原因,
253
+
254
+ 要求:
255
+ 1. 不输出任何分析、解释或说明(包括文字、列表或注释段落);
256
+ 2. 可在代码内部使用简短注释说明关键修改;
257
+ 3. 若错误源于逻辑、数据结构或函数使用不当,请自行调整;
258
+ 4. 若依赖库方法不适用,可自行实现替代函数;
259
+ 5. 生成的代码必须可独立运行,无语法错误;
260
+ 6. 保持整体逻辑与原代码意图一致,仅做必要修正。
261
+ """
262
+ else:
263
+ self.debug_num = 0
264
+
265
+ raw = self.call(prompt)
266
+
267
+ return raw
268
+
269
+
270
+ def result_format_prompt(self, result_json: str) -> str:
271
+ """生成 LLM prompt:要求 LLM 输出 result_dict(可 JSON 序列化)。"""
272
+
273
+ prompt = f"""
274
+ 下面给出一个 JSON 对象(包含模型评估结果结构)。请将其转换为一份对人类友好的 Markdown 报告,输出要求如下:
275
+
276
+ === 输出要求 ===
277
+ 1. 报告开头需有一行简短的“数据集说明”。
278
+ 2. 对每个模型,展示以下内容:
279
+ - 模型名称;
280
+ - 模型类型(分类 / 回归);
281
+ - 主要性能指标(如准确率、R²、MAE、MSE 等),每个指标保留 4 位小数;
282
+ - 建议使用表格或分点列表清晰呈现。
283
+ 3. 明确标出 **best_model**(以粗体高亮显示其名称和最优指标)。
284
+ 4. 若 JSON 中包含特征工程相关信息,���在“特征工程说明”部分详细描述其具体方法与作用。
285
+ 5. 输出格式:
286
+ - 只返回 Markdown 文本;
287
+ - 不得使用任何代码块标记(如 ```、```markdown 等);
288
+ - 不输出解释性文字,仅输出最终报告内容(便于直接渲染于 Streamlit)。
289
+
290
+ === 输入 JSON ===
291
+ {result_json}
292
+ """.strip()
293
+
294
+ raw = self.call(prompt)
295
+
296
+ return raw
297
+
298
+
299
+ def get_model_suggestion(
300
+ self,
301
+ user_input=None,
302
+ memory_limit: int = 6, # 控制引入的 memory 轮数
303
+ ) -> str:
304
+ """
305
+ 根据数据集与历史上下文,生成建模阶段的智能建议。
306
+ 自动整合 memory(最近几轮对话)作为辅助上下文。
307
+ """
308
+
309
+ # === 加载基础数据 ===
310
+ df = self.load_df()
311
+ df_head = df.head().to_string(index=False)
312
+ columns = df.columns.tolist()
313
+ data_info = f"数据列名: {columns}\n\n数据前5行:\n{df_head}"
314
+
315
+ # === 整理 memory 片段 ===
316
+ recent_memory = self.memory[-memory_limit:] if getattr(self, "memory", None) else []
317
+ if recent_memory:
318
+ formatted_memory = "\n".join(
319
+ f"{m['role']}: {m['content']}" for m in recent_memory
320
+ )
321
+ memory_block = f"\n=== 历史上下文(仅供参考) ===\n{formatted_memory}\n"
322
+ else:
323
+ memory_block = ""
324
+
325
+ # === 主 prompt 组装 ===
326
+ prompt = f"""
327
+ 你是一位资深的机器学习建模专家,请基于以下信息进行分析与推理,输出针对性建模建议或改进方案。
328
+
329
+ === 数据信息 ===
330
+ {data_info}
331
+
332
+ === 历史上下文(仅供参考) ===
333
+ {memory_block}
334
+ """.strip()
335
+
336
+ # 若用户有明确建模目标
337
+ if getattr(self, "target", None):
338
+ prompt += f"""
339
+
340
+ === 建模目标 ===
341
+ {self.target}
342
+ (请务必满足该目标,并在回答中明确复述建模意图。)
343
+ """
344
+
345
+ # 若用户额外输入了需求
346
+ if user_input:
347
+ prompt += f"""
348
+
349
+ === 用户当前需求 ===
350
+ {user_input}
351
+ (请严格满足该需求。若为局部修改,请保留原逻辑,仅更新指定部分。)
352
+ """
353
+
354
+ # 若有之前生成的训练代码
355
+ train_code = self.load_code()
356
+ if train_code:
357
+ prompt += f"""
358
+
359
+ === 历史训练代码 ===
360
+ {train_code}
361
+
362
+ 请在充分理解上述代码的基础上,提出 **1–2 条高质量的模型改进建议**。
363
+ 可从以下角度思考,但不限于此:
364
+ - 模型结构优化(如增加层数、调整激活函数、替换模型类型等);
365
+ - 特征工程改进(如变量选择、特征交互、归一化策略等);
366
+ - 训练流程优化(如正则化、学习率调度、损失函数调整等);
367
+ - 超参数调整(如树深度、学习率、batch size 等)。
368
+ 在给出建议时,请简要说明“为什么”与“预期改进效果”。
369
+ """
370
+ else:
371
+ prompt += """
372
+
373
+ === 建模建议任务 ===
374
+ 请根据数据特征和上下文,推荐 2–3 个适合的模型方案。
375
+ 要求:
376
+ 1. 每个模型需包含模型名称、主要原理、适用场景;
377
+ 2. 指出其在当前任务中的优势与潜在局限;
378
+ 3. 保持语言专业、简洁,不输出代码。
379
+ """
380
+
381
+
382
+ # # === 主 prompt 组装 ===
383
+ # prompt = f"""
384
+ # 你是一位资深的机器学习建模专家。
385
+
386
+ # 以下是用户的数据信息:
387
+ # {data_info}
388
+
389
+ # {memory_block}
390
+ # """.strip()
391
+
392
+ # # 若用户有明确建模目标
393
+ # if getattr(self, "target", None):
394
+ # prompt += f"\n\n建模目标:{self.target}(务必满足,并请在回答中复述)"
395
+
396
+ # # 若用户额外输入了需求
397
+ # if user_input:
398
+ # prompt += f"""\n\n用户的当前需求:{user_input}(务必满足!)
399
+ # 若用户的要求是局部更新,则保留先前内容,仅修改特定部分。"""
400
+
401
+ # # 若有之前生成的训练代码
402
+ # train_code = self.load_code()
403
+ # if train_code:
404
+ # prompt += f"""
405
+
406
+ # 用户之前生成的训练代码:
407
+ # {train_code}
408
+
409
+ # 请在理解该代码的基础上,提供 **1–2 条模型改进建议**,
410
+ # 可涉及但不限于:
411
+ # - 模型结构调整
412
+ # - 特征工程优化
413
+ # - 模型替换(例如从树模型切换为深度学习模型)
414
+ # - 超参数调整或正则化策略优化
415
+ # """
416
+ # else:
417
+ # prompt += """
418
+
419
+ # 请基于数据特征,推荐 2–3 个合适的模型,
420
+ # 并说明每个模型的适用场景和优劣分析。
421
+ # """
422
+
423
+ # # 若存在以往建模结果
424
+ # modeling_result = self.load_modeling_result()
425
+ # if modeling_result:
426
+ # prompt += f"\n\n用户之前的模型运行结果:\n{modeling_result}"
427
+
428
+ # === 调用 LLM ===
429
+ raw = self.call(prompt)
430
+ return raw
431
+
432
+
433
+
434
+ def summary_html(self) -> str:
435
+
436
+ if self.code is None:
437
+
438
+ summary = None
439
+
440
+ return summary
441
+
442
+ else:
443
+
444
+ prompt = f"""
445
+ 你正在撰写数据分析报告的**第四章:数据建模**。
446
+ 请根据以下输入内容,综合分析并生成完整的章节正文。
447
+ 内容需逻辑严谨、表达自然,体现专业的分析与总结能力。
448
+
449
+ === 输出结构 ===
450
+ 请严格按照以下五个小节组织内容:
451
+
452
+ 1. 概述
453
+ - 说明本次建模的目标、研究背景及数据来源的上下文。
454
+
455
+ 2. 方法说明
456
+ - 介绍所采用的模型或算法的核心思想与实现流程;
457
+ - 若涉及特征工程、超参数选择或数据预处理,请一并说明;
458
+ - 可适当涉及模型的数学原理或优化机制,以体现技术深度。
459
+
460
+ 3. 关键代码解读
461
+ - 聚焦核心函数与模块,说明其在建模流程中的作用;
462
+ - 可提及模型结构定义、训练循环、损失函数与评估逻辑;
463
+ - 语言应清晰简练,避免逐行解释。
464
+
465
+ 4. 结果与评估
466
+ - 概述主要性能指标(如 Accuracy、AUC、MSE 等)及结果表现;
467
+ - 分析模型效果是否符合预期,并指出主要优劣与瓶颈。
468
+
469
+ 5. 改进建议
470
+ - 针对模型性能与实验发现,提出具体可行的优化方向;
471
+ - 可从模型结构、特征选择、训练策略或正则化等角度给出建议。
472
+
473
+ === 写作要求 ===
474
+ 1. 使用自然流畅、正式的书面表达;
475
+ 2. 避免使用模糊或主观词汇(如“可能”“似乎”“微妙”等);
476
+ 3. 注重逻辑连贯与专业性;
477
+ 4. 不输出标题、列表标记或额外说明,只生成正文内容。
478
+ """.strip()
479
+
480
+ if self.code is not None:
481
+ prompt += f"=== 数据建模代码 ===\n\n{self.code}"
482
+ if self.target is not None:
483
+ prompt += f"=== 用户建模目标 ===\n\n{self.target}"
484
+ if self.load_memory is not None:
485
+ prompt += f"=== 数据建模聊天对话 ===\n\n{self.load_memory}"
486
+ if self.result is not None:
487
+ prompt += f"=== 建模运行结果 ===\n\n{self.result}"
488
+
489
+ desc = self.call(prompt)
490
+
491
+ summary = {
492
+ "title": "建模分析",
493
+ "code": self.code,
494
+ "desc": desc,
495
+ "result": self.result,
496
+ }
497
+
498
+ return summary
499
+
500
+
501
+ def summary_word(self) -> str:
502
+
503
+ return self.summary_html()
504
+
505
+
506
+ def code_generation_for_inference(self, code, inference_df_head) -> str:
507
+ """生成 LLM prompt:要求 LLM 输出推断分析代码。"""
508
+
509
+ prompt = (
510
+ f"""请生成完整的 Python 推断分析脚本(仅返回代码,不要任何解释文字)。运行环境已提供 pandas DataFrame 变量 `inference_df`、已经 train 好的模型 `model_obj`、numpy(np)、StandardScaler 库、align_features 辅助函数,其余未提及的库请手写实现。要求:
511
+
512
+ 示例数据信息:
513
+ {code}, inference_df 前五行: {inference_df_head}(请勿引入不存在 inference_df 中的变量)
514
+
515
+ 1) **可用变量说明:**
516
+ - `inference_df`:推断数据集(Pandas DataFrame)
517
+ - `model_obj`:已训练好的模型对象(从best_model.joblib加载)
518
+ - `np`:NumPy库
519
+ - `pd`:Pandas库
520
+ - `StandardScaler`:用于数据标准化的sklearn工具
521
+
522
+ 2) **脚本必须实现的功能:**
523
+ a) 对推断数据进行与训练时完全一致的预处理(例如,缺失值处理、编码转换、标准化等)
524
+ b) **关键步骤:在预测前,必须使用align_features函数处理特征数据,确保特征数量和顺序与训练时一致**
525
+ c) 使用model_obj对预处理并对齐后的特征数据进行预测
526
+ d) 生成详细的推断报告,包含预处理步骤、预测结果分析等
527
+
528
+ 3) **预测结果处理要求:**
529
+ - 将模型输出转换为人类可理解的形式(如概率值、类别标签、数值结果等)
530
+ - **必须生成带预测结果的DataFrame**:将��始或处理后的`inference_df`与预测结果合并,命名为`inference_df_with_predictions`
531
+ - 合并后的DataFrame必须包含原始特征列和一列名为`'prediction'`的预测结果列(模型输出多维时扩展为`prediction_0`, `prediction_1`, ...)
532
+
533
+ 4) **序列化要求(用于前端下载):**
534
+ - 将`inference_df_with_predictions`转换为无索引的CSV格式
535
+ - 对CSV数据进行gzip压缩,然后编码为base64字符串
536
+ - 创建包含以下键的`result_dict['artifacts']`字典:
537
+ * `'predictions_df_b64'`:base64编码的压缩数据
538
+ * `'predictions_df_format'`:固定值'csv+gzip'
539
+ * `'predictions_df_size_bytes'`:压缩后的字节大小(整数)
540
+ - 在`result_dict`中添加`'predictions_df_records'`键,值为`inference_df_with_predictions.to_dict(orient='records')`
541
+ - 确保所有numpy/pandas类型转换为原生Python类型(int/float/str)以保证JSON可序列化
542
+
543
+ 5) **代码结构与输出约束:**
544
+ - 脚本最后**仅**包含一行`result_dict = {...}`语句
545
+ - `result_dict`必须是完全JSON可序列化的Python字典
546
+ - 禁止任何外部IO操作(不读写文件)
547
+ - 禁止使用print语句或创建额外的全局变量
548
+
549
+ 8) **生成代码质量要求:**
550
+ - 确保所有变量名称与上述规范严格一致
551
+ - 逻辑清晰,步骤完整,严格按照用户提供的数据和最佳模型文件生成代码
552
+ - 处理可能出现的各种异常情况,提高代码的稳定性和可靠性
553
+
554
+ 返回:完整的Python代码(仅包含代码本身,不要任何解释性文字)。"""
555
+ )
556
+
557
+ raw = self.call(prompt)
558
+
559
+ return raw
560
+
561
+
562
+ def check_abstract(self):
563
+ if self.abstract is None:
564
+ if self.code is None:
565
+ self.abstract = None
566
+ else:
567
+ prompt = f"""
568
+ 这是数据分析流程中的“建模阶段”。
569
+
570
+ 请基于以下信息,在保留所有关键信息的前提下,将内容整理成一段简洁、连贯的文字摘要,用于报告撰写中的建模小节预览。
571
+
572
+ === 输入信息 ===
573
+ - 用户初始需求:{self.target}
574
+ - 建模代码:{self.code}
575
+ - 建模阶段的交互记录:{self.load_memory}
576
+ - 建模运行结果:{self.result}
577
+
578
+ === 输出要求 ===
579
+ 1. 以自然流畅的语言撰写一段总结,全面涵盖上述内容中的核心信息;
580
+ 2. 重点说明建模目标、所用方法、主要实现逻辑与结果特征;
581
+ 3. 避免逐行描述代码,仅提炼核心思路;
582
+ 4. 语言应专业、客观,不使用“可能”“似乎”“也许”等模糊表达;
583
+ 5. 输出仅为一段完整文字(不要标题、编号或列表);
584
+ 6. 摘要应能让人据此判断该部分是否需要纳入最终报告。
585
+ """.strip()
586
+
587
+ desc = self.call(prompt)
588
+ self.abstract = desc
589
+
590
+ return self.abstract
591
+
592
+
593
+ def check_full(self):
594
+ if self.full is None:
595
+ if self.code is None:
596
+ self.full = None
597
+ else:
598
+ self.full = f"""
599
+ 【阶段说明】这是数据分析流程中的数据建模阶段。
600
+ 【用户初始需求】{self.target}
601
+ 【数据建模代码】{self.code}
602
+ 【建模聊天对话】{self.load_memory}
603
+ 【建模运行结果】{self.result}
604
+ """.strip()
605
+
606
+ return self.full
prompt_engineer/sec5_call_llm.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import openai
3
+ import requests
4
+ import json
5
+ import re
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+ from config import MODEL_CONFIGS
10
+ from prompt_engineer.call_llm import LLMClient
11
+
12
+
13
+ class ReportAgent(LLMClient):
14
+
15
+ def __init__(self, *args, **kwargs):
16
+
17
+ super().__init__(*args, **kwargs)
18
+ self.template = None
19
+ self.name = None
20
+ self.date = None
21
+ self.report_format = None
22
+ self.html = None
23
+ self.word = None
24
+ self.markdown = None
25
+ self.user_input = None
26
+ self.outline = None
27
+ self.outline_length = None
28
+ self.report= None
29
+ self.finish_auto_task = False
30
+ self.gen_mode = None
31
+
32
+
33
+ def save_gen_mode(self, gen_mode):
34
+
35
+ self.gen_mode = gen_mode
36
+
37
+
38
+ def load_gen_mode(self):
39
+
40
+ return self.gen_mode
41
+
42
+
43
+ def finish_auto(self):
44
+
45
+ self.finish_auto_task = True
46
+
47
+
48
+ def save_user_input(self, user_input):
49
+
50
+ self.user_input = user_input
51
+
52
+
53
+ def load_user_input(self):
54
+
55
+ return self.user_input
56
+
57
+
58
+ def save_outline_length(self, outline_length):
59
+
60
+ self.outline_length = outline_length
61
+
62
+
63
+ def load_outline_length(self):
64
+
65
+ return self.outline_length
66
+
67
+
68
+ def save_outline(self, outline):
69
+
70
+ self.outline = outline
71
+
72
+
73
+ def load_outline(self):
74
+
75
+ return self.outline
76
+
77
+
78
+ def save_template(self, template):
79
+
80
+ self.template = template
81
+
82
+
83
+ def load_template(self):
84
+
85
+ return self.template
86
+
87
+
88
+ def save_word(self, word):
89
+
90
+ self.word = word
91
+
92
+
93
+ def load_word(self):
94
+
95
+ return self.word
96
+
97
+
98
+ def save_html(self, html):
99
+
100
+ self.html = html
101
+
102
+
103
+ def load_html(self):
104
+
105
+ return self.html
106
+
107
+
108
+ def save_markdown(self, markdown):
109
+
110
+ self.markdown = markdown
111
+
112
+
113
+ def load_markdown(self):
114
+
115
+ return self.markdown
116
+
117
+
118
+ def save_report(self, report):
119
+
120
+ self.report = report
121
+
122
+
123
+ def load_report(self):
124
+
125
+ return self.report
126
+
127
+
128
+ def save_report_format(self, report_format):
129
+
130
+ self.report_format = report_format
131
+
132
+
133
+ def load_report_format(self):
134
+
135
+ return self.report_format
136
+
137
+
138
+ def save_date(self, date):
139
+
140
+ self.date = date
141
+
142
+
143
+ def load_date(self):
144
+
145
+ return self.date
146
+
147
+
148
+ def save_name(self, name):
149
+
150
+ self.name = name
151
+
152
+
153
+ def load_name(self):
154
+
155
+ return self.name
156
+
157
+
158
+ def generate_template(self, user_input = None) -> str:
159
+ """
160
+ 调用 LLM 生成一个带有占位符的 HTML 报告模板,
161
+ 包含标题、摘要、表格和图表区域等。
162
+ """
163
+ prompt = (
164
+ """
165
+ 我希望你输出一个现代、简洁且美观的 HTML 章节模板,请满足以下要求:
166
+
167
+ 1. 整体配色采用“蓝 – 白”主题:
168
+ - 背景为白色,标题与边框使用深蓝(#1E3A8A)和浅蓝(#3B82F6);
169
+ 2. 最外层用 `<section class="chapter" id="chapter-{{ num }}">` 包裹;
170
+ 3. 标题使用 `<h2>{{ title }}</h2>`:
171
+ - 文字颜色:#1E3A8A;
172
+ - 下方装饰性下划线:高度 3px,颜色 #3B82F6,宽度 30%;
173
+ 4. 正文内容区 `<div class="content">{{ body }}</div>`,支持任意 HTML;
174
+ - **仅对“重点摘录”或“引用”段落加用圆角矩形**,其余普通段落保持标准 `<p>` 样式;
175
+ - 圆角矩形样式:背景 #EFF6FF,padding 12px,border-radius 8px,margin-bottom 16px;
176
+ 5. 如果有图片列表 `images`:
177
+ - ≤3 张时水平并排;>3 张时自动换行,每行最多 3 张;
178
+ - `<img>` 带 6px 圆角、轻微阴影 `box-shadow:0 2px 6px rgba(0,0,0,0.1)`;
179
+ 6. 在 `<style>` 中内联基础样式:
180
+ - `.chapter` 外层间距、内边距、最大宽度、白底阴影;
181
+ - `.chapter h2` 字体、颜色、下划线;
182
+ - `.content p` 和 `.content .highlight`(重点段落)样式区分;
183
+ - `.images` 的 flex 布局与 gap;
184
+ 7. 使用 Jinja2 占位符:
185
+ - 普通段落:`{% for p in paragraphs %}<p>{{ p }}</p>{% endfor %}`;
186
+ - 重点段落数组 `highlights`:`{% for h in highlights %}<div class="highlight">{{ h }}</div>{% endfor %}`;
187
+ 8. **只输出完整的 `<section>…</section>` 片段**,不要任何解释文字或其他标签。
188
+ 9. 在模板的 .content 区域加入一个 DataFrame 占位并用 Jinja2 渲染变量 df_html({{ df_html | safe }}),要求输出为响应式 HTML 表格(显示表头、支持横向滚动并在窄屏下自动换行),以便在导出为 PDF 时正确排版。
189
+
190
+ 请直接给出最终的 HTML 模板代码。
191
+ """
192
+ )
193
+
194
+ if user_input is not None:
195
+ prompt += f"请根据用户需求进行调整{user_input}"
196
+
197
+ return self.call(prompt)
198
+
199
+
200
+ def fill_report(self, template: str, content: str) -> str:
201
+ """
202
+ 将 DataFrame 转为 HTML 表格,拼接进模板,
203
+ 并让 LLM 对报告进行润色、补充解释文字。
204
+ """
205
+
206
+ prompt = (f"""
207
+ 下面是章节结构模板:
208
+ {template}
209
+ 请仅输出 `<section>` 里完整的 HTML(包括标题、正文、图片区块),请将重点内容用highlight凸显,
210
+ 对于内容的分析具有一下要求:
211
+ 1. 要用流畅的自然语言
212
+ 2. 不要滥用形容词和副词,尽量用简单的动词和名词表达意思
213
+ 3. 不用"可能""也许""似乎""微妙"等模糊表述
214
+ 请根据一下提供的信息对文章进行深入分析:
215
+ """)
216
+
217
+ if content.get("title") is not None:
218
+ prompt += f"- title={content['title']}\n"
219
+ if content.get("fig_analysis") is not None:
220
+ prompt += f"- images及其分析(请将image也放入报告中):{content['fig_analysis']}\n"
221
+ if content.get("df") is not None:
222
+ prompt += f"- 表格预览(请将表格也放入报告中,输出美观完整):{content['df']}\n"
223
+ if content.get("code") is not None:
224
+ prompt += f"- 对应部分代码(请将代码中的重点公式与内容进行讲解与分析):{content['code']}\n"
225
+ if content.get("processed_df") is not None:
226
+ prompt += f"- 预处理后的数据预览:{content['processed_df']}\n"
227
+ if content.get("desc") is not None:
228
+ prompt += f"- 具体内容分析:{content['desc']}\n"
229
+ if content.get("header") is not None:
230
+ prompt = f"""
231
+ 下面是章节结构模板:
232
+ {template}
233
+ 要求:header单独占一页
234
+ - 请为我生成封面header:{content['header']}
235
+ """
236
+ if content.get("footer") is not None:
237
+ prompt = f"""
238
+ 下面是章节结构模板:
239
+ {template}
240
+ 要求:footer单独占一页
241
+ - 请为我生成最后一页footer:{content['footer']}
242
+ """
243
+
244
+ prompt += "请仅返回提供html"
245
+
246
+ return self.call(prompt)
247
+
248
+
249
+ def fill_report_word(self, content: str) -> str:
250
+
251
+
252
+ prompt = (f"""
253
+ 你是一个资深的数据分析专家,
254
+ 请仅输出每一章节的完整的word内容(包括标题、正文、图片区块),
255
+ 对于内容的分析具有一下要求:
256
+ 1. 要用流畅的自然语言
257
+ 2. 不要滥用形容词和副词,尽量用简单的动词和名词表达意思
258
+ 3. 不用"可能""也许""似乎""微妙"等模糊表述
259
+ 请根据一下提供的信息对文章进行深入分析:
260
+ """)
261
+
262
+ if content.get("title") is not None:
263
+ prompt += f"- title={content['title']}\n"
264
+ if content.get("fig_analysis") is not None:
265
+ prompt += f"- images及其分析(请将image也放入报告中):{content['fig_analysis']}\n"
266
+ if content.get("df") is not None:
267
+ prompt += f"- 表格预览(请将表格也放入报告中,输出美观完整):{content['df']}\n"
268
+ if content.get("code") is not None:
269
+ prompt += f"- 对应部分代码(请将代码中的重点公式与内容进行讲解与分析):{content['code']}\n"
270
+ if content.get("processed_df") is not None:
271
+ prompt += f"- 预处理后的数据预览:{content['processed_df']}\n"
272
+ if content.get("desc") is not None:
273
+ prompt += f"- 具体内容分析:{content['desc']}\n"
274
+ if content.get("header") is not None:
275
+ prompt = f"""
276
+ 下面是章节结构模板:
277
+ {template}
278
+ 要求:header单独占一页
279
+ - 请为我生成封面header:{content['header']}
280
+ """
281
+ if content.get("footer") is not None:
282
+ prompt = f"""
283
+ 下面是章节结构模板:
284
+ {template}
285
+ 要求:footer单独占一页
286
+ - 请为我生成最后一页footer:{content['footer']}
287
+ """
288
+
289
+ prompt += "请仅返回提供html"
290
+
291
+ return self.call(prompt)
292
+
293
+
294
+ def get_content(self, agent):
295
+
296
+ content = agent.summary()
297
+
298
+ return content
299
+
300
+ def generate_toc_from_summary(self, full_summary) -> str:
301
+ """
302
+ 调用大模型,根据已有 summary 内容自动生成带有分级结构与内容大纲的目录(最多 2 级标题)
303
+ """
304
+
305
+ prompt = f"""
306
+ 你是一位资深数据分析报告结构设计专家。
307
+
308
+ 请你根据以下报告摘要内容,为该数据分析报告生成**层次清晰、内容具体、贴合数据本身**的目录结构。
309
+
310
+ 【输出要求】
311
+ 1. 格式:
312
+ - 纯文本输出(不得使用 Markdown、代码块、Python 列表或符号标记)
313
+ - 每行一个目录项,无缩进或前缀符号
314
+ - 示例格式:
315
+ 1.概述(说明报告背景与目标)
316
+ 2.数据导入(说明数据来源与结构)
317
+ 2.1 数据概览(展示核心字段与样本规模)
318
+ 2.1.1 租赁数量趋势(分析租赁随时间的变化)
319
+ 2. 编号规则:
320
+ - 一级标题:1, 2, 3...
321
+ - 二级标题:2.1, 2.2...
322
+ - 三级标题:2.1.1, 2.1.2...
323
+ 3. 内容说明:
324
+ - 所有标题与说明应以摘要为基础,可在保持主题一致的前提下,适度补充逻辑性或结构性内容。
325
+ - 每个标题后附一句说明,用于指导后续大模型撰写章节内容;
326
+ - 说明须以中文括号“( )”包裹;
327
+ - 每条说明需精准、具体,**明确指示该部分的写作任务、分析角度、数据焦点或方法方向**;
328
+ - 字数不超过 50 字;
329
+ - 上下级说明应保持语义连贯,避免重复;
330
+ - 说明可涉及:
331
+ - 要分析的变量或主题(如“气温”“租赁数量”“污染物浓度”);
332
+ - 要执行的任务(如“展示分布”“分析趋势”“比较模型性能”);
333
+ 4. 禁止输出任何解释、前言、说明、提示、或多余空行,仅输出目录正文。
334
+
335
+ 【生成逻辑】
336
+ 1. 依据摘要内容中出现的主题(如数据特征、指标、变量名、任务目标)生成章节标题。
337
+ - 若摘要中提及 “租赁数量”“气温”“湿度”“时间”等,请将其体现在相关标题中。
338
+ - 避免使用模糊标题(如“数据分析”“关系探索”“模型评估”等)。
339
+ 2. 报告可能包含模块:
340
+ “数据导入”、“数据预处理”、“数据可视化”、“建模分析”。
341
+ - 仅生成摘要中实际涉及的模块。
342
+ 3. 确保章节间语义互斥(正交),避免内容重叠。
343
+ 4. 根据详细程度动态调整层级:
344
+ - 简要:生成两级标题;
345
+ - 标准:生成三级标题;
346
+ - 详细:生成四级标题。
347
+ 5. 若摘要涉及具体变量(如“Temperature”、“Rented Bike Count”),
348
+ 请在目录中直接引用中文变量名(如“气温”、“租赁数量”),
349
+ 以体现报告的“数据感知性”。
350
+
351
+ 用户选择的目录详细程度为:{self.outline_length}
352
+
353
+ 报告摘要如下:
354
+ {full_summary}
355
+ """
356
+
357
+
358
+ toc_response = self.call(prompt)
359
+ return toc_response.strip()
360
+
361
+
362
+ def selected_photo_update_toc(self, toc, selected_full_contents_vis: str) -> list:
363
+ """
364
+ 根据完整报告内容 selected_full_contents_vis,更新 toc,在每个小节增加第四项:对应的图像编号列表。
365
+ """
366
+ print(selected_full_contents_vis)
367
+
368
+ prompt = f"""
369
+ 你是一位专业的数据分析报告结构与图文匹配专家。
370
+
371
+ 任务:请你根据报告的目录结构和,正文内容和阶段说明,判断每个 [FIG:x] 图像最合适归属的章节。
372
+
373
+ 【输入内容】
374
+ 1. 目录结构(含标题、层级、内容大纲):
375
+ {toc}
376
+
377
+ 2. 报告完整正文(带有 [FIG:x] 图片标记):
378
+ {selected_full_contents_vis}
379
+
380
+ 【任务说明】
381
+ 请你逐一分析每个 [FIG:x] 图像的出现上下文,并结合目录内容,判断该图应归属于哪个章节。
382
+ 要求同时考虑:
383
+ - **语义匹配**:图像内容的主题(如污染物趋势、气象变化、时间分布、模型结果)与章节描述的一致性;
384
+ - **上下文位置**:图像在正文中出现时,其前后段落通常属于哪个章节;
385
+ - **粒度优先**:若图像语义符合多个章节(如“气象参数”与“气象参数图形分析”),优先归入更具体的章节(层级数字更大);
386
+ - **禁止误归**:禁止将图像分配到“概述”“结论”“摘要”等非分析或与图像不相关的章节!
387
+ - **全部使用**:所有 [FIG:x] 必须被使用一次,不得遗漏或重复。
388
+
389
+ 【输出格式】
390
+ 请以 Python 列表形式输出,每项为:
391
+ (标题, 层级, 内容大纲, 图编号列表)
392
+ 要求:
393
+ - 图编号按出现顺序排列;
394
+ - 若无图片则为空列表 [];
395
+ - 层级仅用整数表示(1, 2, 3...);
396
+ - 不输出任何解释、注释、Markdown标记。
397
+
398
+ 【示例格式】
399
+ [
400
+ ('概述',1,'说明报告背景与目标',[]),
401
+ ('数据导入',1,'说明数据来源与结构',[]),
402
+ ('数据可视化',1,'展示变量特征与关系',[4,5]),
403
+ ('气象参数分析',2,'研究温度与湿度对污染的影响',[2,3]),
404
+ ('模型评估',2,'展示预测结果与误差',[6,7])
405
+ ]
406
+
407
+ 【提示与约束】
408
+ 1. 若章节间存在嵌套关系,优先分配���最具体的子章节(如 3.1.2 比 3.1 更优)。
409
+ """
410
+
411
+ toc_with_figs = self.call(prompt)
412
+ return toc_with_figs.strip()
413
+
414
+
415
+ def summarize_all_sections(
416
+ self,
417
+ toc_md: str,
418
+ load_summary: str,
419
+ preproc_summary: str,
420
+ visual_summary: str,
421
+ coding_summary: str
422
+ ) -> str:
423
+ """
424
+ 汇总所有 agent 的 summary,并根据 toc_md 结构进行文字性总结
425
+ """
426
+
427
+ # Step 1:拼接所有 agent 的摘要
428
+ section_summaries = {
429
+ "加载阶段": load_summary,
430
+ "预处理阶段": preproc_summary,
431
+ "可视化分析": visual_summary,
432
+ "模型建构": coding_summary,
433
+ }
434
+
435
+ # Step 2:构建大模型 prompt
436
+ prompt = f"""你现在是一个经验丰富的数据分析报告撰写助手。
437
+
438
+ 我已经完成了一个数据分析项目的初稿,结构目录如下:
439
+ {toc_md}
440
+
441
+ 现在我将为你提供各个章节的内容摘要,请你根据这些内容,用流畅的中文撰写一段总结性描述(可用于报告的导语或结语),要求包括但不限于:
442
+
443
+ 1. 报告分析的主题方向
444
+ 2. 各章节的核心处理逻辑和大致作用
445
+ 3. 报告内容的整体风格与结构特性(例如是否包含图表、是否强调建模等)
446
+ 4. 使用自然语言、风格正式,避免主观判断词汇(如“也许”、“不错”、“感觉”)
447
+ 5. 最终输出 150~300 字中文总结段落,不需要标题
448
+
449
+ 每个阶段摘要如下:\n\n"""
450
+
451
+ for title, content in section_summaries.items():
452
+ if content:
453
+ prompt += f"\n【{title}】\n{content}\n"
454
+
455
+ # 调用大模型总结
456
+ overall_summary = self.call(prompt)
457
+
458
+ return overall_summary
459
+
460
+
461
+ def update_toc_with_relevant_sections(self, toc, agent_abstracts):
462
+ """
463
+ 根据 toc 和各模块摘要,为每个章节生成应参考的模块编号列表,
464
+ 并将结果添加为第五项。
465
+ """
466
+ prompt = f"""
467
+ 你是一个专业的数据分析报告规划助手。
468
+ 我将提供报告目录和各分析模块的摘要,请为每个章节确定应参考的模块编号列表。
469
+
470
+ 报告目录(每个元素为四元组:标题、层级、内容大纲、图编号列表):
471
+ {toc}
472
+
473
+ 各数据分析模块摘要如下:
474
+ {agent_abstracts}
475
+ 请根据:
476
+ 1. 各章节的标题、层级与内容大纲;
477
+ 2. 各数据处理板块摘要;
478
+ 3. 各章节的图编号分配情况(报告目录第四项);
479
+
480
+ 合理判断各章节在生成报告时应参考哪些数据处理板块的信息。
481
+ 输出要求:
482
+
483
+ - 对每个章节生成一个五元组 (标题, 层级, 内容大纲, 图编号列表, 模块编号列表)
484
+ - 标题, 层级, 内容大纲, 图编号列表一定不能改变,只在原有基础上添加第五项
485
+ - 模块编号列表为 Python list,例如 [0, 2]
486
+ - 若无需参考任何模块,返回 []
487
+ - 输出为 Python 列表,不含任何额外说明
488
+ 示例:
489
+ 输入:
490
+ [
491
+ ('概述',1,'介绍报告背景与目标',[1]),
492
+ ('数据可视化',1,'分析空气质量和相关环境变量的可视化图表',[2,3]),
493
+ ('xxxx关联性分析',2,'分析相对湿度与其他污染物关系',[4,5])
494
+ ]
495
+ 输出:
496
+ [
497
+ ('概述',1,'介绍报告背景与目标',[1],[1,2]),
498
+ ('数据可视化',1,'分析空气质量和相关环境变量的可视化图表',[2,3],[0,1]),
499
+ ('xxxx关联性分析',2,'分析相对湿度与其他污染物关系',[4,5],[2,3])
500
+ ]
501
+ """
502
+ toc_with_sections = self.call(prompt)
503
+ print(toc_with_sections)
504
+ return toc_with_sections.strip()
505
+
506
+
507
+ def write_section_body(self, toc, t, selected_full_contents, history_content):
508
+
509
+ prompt = f"""
510
+ 你是一个专业的数据分析报告撰写助手。你的任务是基于我提供的参考信息,生成逻辑清晰、结构严谨、内容专业的报告章节。
511
+
512
+ 当前章节信息(四元组:标题、层级、内容大纲、图编号列表):
513
+ {t}
514
+
515
+ 报告目录结构(包含所有章节的四元组信息):
516
+ {toc}
517
+
518
+ 可参考的分析内容如下:
519
+ {selected_full_contents}
520
+
521
+ 此前已生成的章节内容如下(用于保持整体风格一致、避免重复):
522
+ {history_content}
523
+
524
+ 写作要求:
525
+
526
+ 一、写作目标
527
+ 1. 仅撰写当前章节“{t[0]}”的正文内容;
528
+ 2. 内容必须以“参考信息”为核心依据,可在其逻辑框架内**进行适度拓展与归纳总结**;
529
+ 3. 允许进行合理的专业性补充(如统计学解释、方法原理、结果含义),但**禁止编造具体数据、图表结果、实验场景或样本特征**;
530
+ 4. 若参考信息不足,可补充一般性分析思路,但需保持内容通用、客观、抽象,不得具体化为假想数据。
531
+
532
+ 二、语言与结构
533
+ 1. 文风应正式、专业、学术化;
534
+ 2. 论述应符合数据分析逻辑:先描述、后解释、再总结;
535
+ 3. 每一自然段应围绕一个逻辑核心展开(如趋势、对比、相关性、分布特征等)。
536
+
537
+ 三、图表使用规范
538
+ 1. 正文中仅可使用本章节的图编号 {t[3]};
539
+ 2. 使用占位符 [FIG:index] 标注图表位置;
540
+ 3. 在每个占位符下方添加图片标题:
541
+ 图:图片标题(简要说明图片内容及分析要点)
542
+ 4. 图片位置与语义保持自然衔接:
543
+ - 若图片引出分析 → 放在段落开头;
544
+ - 若图片支撑论点 → 放在相关描述句之后;
545
+ - 若图片总结结果 → 放在段落结尾;
546
+ 5. 不得增删或重排图片编号。
547
+
548
+ 四、输出要求
549
+ - 仅输出正文内容;
550
+ - 不得输出标题、编号、解释文字、Markdown;
551
+ - 不使用加粗、斜体、符号修饰或非正文语句;
552
+ - 不得出现“我认为”、“请继续”、“综上可见”等主观表达。
553
+
554
+ 五、写作模式
555
+ 当前模式:{self.outline_length}
556
+ - 简要:仅写结论;
557
+ - 标准:含逻辑与结论;
558
+ - 详细:包含推理与方法,但仍应基于参考信息,不得自由创作。
559
+
560
+ 请严格在以上范围内撰写本章节正文。
561
+ """
562
+
563
+
564
+ # prompt = f"""
565
+ # 你是一个专业的数据分析报告撰写助手。你需要基于我提供的参考信息进行深入分析,生成结构严谨、逻辑清晰的专业报告章节。
566
+
567
+ # 当前需要撰写的章节信息(完整四元组,依次为标题、层级、内容大纲、图编号列表):
568
+ # {t}
569
+
570
+ # 报告目录结构(包含所有章节的四元组信息):
571
+ # {toc}
572
+
573
+ # 可参考的分析内容如下:
574
+ # {selected_full_contents}
575
+
576
+ # 此前已生成的章节内容如下(用于保持整体风格一致,并避免内容重复):
577
+ # {history_content}
578
+
579
+ # 请根据章节标题、层级、内容大纲、参考信息和图编号列表,生成该章节的完整正文。
580
+
581
+ # 正文详细程度有三种模式:
582
+ # - 简要:只包含核心结论与关键点,语言精炼;
583
+ # - 标准:包含主要分析逻辑、步骤与结果;
584
+ # - 详细:展开完整分析、方法论、推理过程与补充说明。
585
+ # 用户当前选择的模式是:{self.outline_length}
586
+
587
+ # 写作要求:
588
+ # 1. **核心任务**:仅撰写当前章节 **“{t[0]}”** 的正文内容,不得涉及其他章节。
589
+ # 2. **图表引用**:正文中引用的图表必须严格对应本章节的图编号(即 {t[3]}),**不得使用或编造其他编号**。
590
+ # 3. **语言规范**:
591
+ # - 语言应专业、准确、逻辑严谨;
592
+ # - 叙述风格应正式、学术化;
593
+ # - 禁止使用口语化或主观色彩表达。
594
+ # 4. **输出要求**:
595
+ # - 仅输出章节正文内容,不得输出 Markdown;
596
+ # - 不得输出任何标题,如:1,一,(1)等;
597
+ # - 禁止加粗、斜体、表情符号或其他符号修饰;
598
+ # - 不得出现非正文短语,如 “我认为”、“请继续”、“感谢阅读”、“---” 等;
599
+ # 5. **图片规范**:
600
+ # - 图片应独立成行,不得嵌入句子内部;
601
+ # - 图片可放置在段落的开头、结尾,或自然停顿处(如句号、分号后),以保持语义连贯;
602
+ # - 使用占位符格式 [FIG:index] 标记图片位置,其中 index 为对应图片的编号;
603
+ # - 在每个 [FIG:index] 占位符后,需紧跟一行图片标题,格式如下:
604
+ # 图:图片标题(简要说明图片内容及分析要点)
605
+ # - 图片插入位置应依据其语义和上下文逻辑确定:
606
+ # · 若图片用于引出分析,应放在段落开头;
607
+ # · 若用于支撑论述,应放在对应描述句之后;
608
+ # · 若总结结果或展示对比,应放在段落结尾;
609
+ # - 请务必确保图片位置与文字逻辑匹配,使图片与正文形成自然的论证衔接;
610
+ # - 请不要删除、合并或重排序图片编号,系统将在后续自动替换为真实图像。
611
+
612
+ # 请直接输出该章节的正文内容,不要有任何其他文字。
613
+ # """
614
+
615
+ content = self.call(prompt)
616
+
617
+ return content
utils/content.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class Content:
6
+ def __init__(self, text: str = None, fig = None):
7
+ self.text = text
8
+ self.fig = fig
9
+
10
+ def display(self):
11
+ pass
12
+
13
+
utils/sanitize_code.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Any
3
+ import numpy as np
4
+
5
+
6
+ def sanitize_code(code: str) -> str:
7
+ """清理可能包含的 Markdown 代码块标记。"""
8
+ if not isinstance(code, str):
9
+ return ""
10
+ code = code.strip()
11
+ if code.startswith("```") and code.endswith("```"):
12
+ lines = code.splitlines()
13
+ # 去掉首尾 ``` 或 ```python
14
+ if re.match(r"^```(?:python)?", lines[0].strip()):
15
+ lines = lines[1:]
16
+ if lines and lines[-1].strip() == "```":
17
+ lines = lines[:-1]
18
+ code = "\n".join(lines)
19
+ return code
20
+
21
+
22
+ def to_json_serializable(obj: Any) -> Any:
23
+ """将可能含 numpy 类型的对象转换为可 JSON 序列化类型(递归)。"""
24
+ if obj is None:
25
+ return None
26
+ if isinstance(obj, (str, bool, int)):
27
+ return obj
28
+ if isinstance(obj, float):
29
+ # 确保是内置 float(JSON 支持)
30
+ return float(obj)
31
+ if isinstance(obj, np.generic):
32
+ return obj.item()
33
+ if isinstance(obj, np.ndarray):
34
+ return obj.tolist()
35
+ if isinstance(obj, dict):
36
+ return {str(k): to_json_serializable(v) for k, v in obj.items()}
37
+ if isinstance(obj, (list, tuple)):
38
+ return [to_json_serializable(v) for v in obj]
39
+ # fallback: try to cast to float / str
40
+ try:
41
+ return float(obj)
42
+ except Exception:
43
+ try:
44
+ return str(obj)
45
+ except Exception:
46
+ return None
47
+
utils/save_secrets.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import toml
2
+ from pathlib import Path
3
+
4
+ # 我们把 secrets 放在项目根目录下的 .streamlit 文件夹
5
+ BASE = Path(__file__).parent
6
+ SECRETS_DIR = BASE / ".streamlit"
7
+ SECRETS_FILE = SECRETS_DIR / "secrets.toml"
8
+
9
+
10
+ def load_local_api_keys() -> dict[str, str]:
11
+ """
12
+ 从项目目录的 .streamlit/secrets.toml 中读取 [api_keys] 部分。
13
+ 如果文件或该节不存在,返回空字典。
14
+ """
15
+ if not SECRETS_FILE.exists():
16
+ return {}
17
+ data = toml.load(SECRETS_FILE)
18
+ return data.get("api_keys", {})
19
+
20
+
21
+ def update_local_api_key(model_name: str, api_key: str) -> None:
22
+ """
23
+ 将一对 model_name: api_key 写入 .streamlit/secrets.toml 的 [api_keys]。
24
+ 如果文件或该节不存在,会自动创建;保留其它已有设置。
25
+ """
26
+ SECRETS_DIR.mkdir(exist_ok=True)
27
+ if SECRETS_FILE.exists():
28
+ data = toml.load(SECRETS_FILE)
29
+ else:
30
+ data = {}
31
+ data.setdefault("api_keys", {})[model_name] = api_key
32
+ with SECRETS_FILE.open("w", encoding="utf-8") as f:
33
+ toml.dump(data, f)
utils/spinner_pool.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ def get_spinner_msg(stage="writing"):
4
+ msg_pool = {
5
+ "summarizing": [
6
+ "正在汇总各模块的分析结果...",
7
+ "稍等一下,正在总结前面几个 Agent 的内容...",
8
+ "AI 正在整理前面的分析,请稍候...",
9
+ "正在综合各分析步骤的结论..."
10
+ ],
11
+ "writing": [
12
+ "正在生成各章节内容...",
13
+ "请稍候,系统正在详细撰写报告...",
14
+ "AI 正在逐步生成报告章节...",
15
+ "正在整理并撰写每一章节..."
16
+ ],
17
+ "default": [
18
+ "正在处理数据,请稍候...",
19
+ "AI 正在努力生成结果...",
20
+ "请耐心等待,正在计算中..."
21
+ ]
22
+ }
23
+
24
+ pool = msg_pool.get(stage, msg_pool["default"])
25
+ return random.choice(pool)
workflow/.DS_Store ADDED
Binary file (6.15 kB). View file
 
workflow/dataloading/dataloading_core.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import io
3
+ import os
4
+ from typing import List, Optional
5
+
6
+ import chardet
7
+ import numpy as np
8
+ import pandas as pd
9
+ from scipy import sparse
10
+ from scipy.io import loadmat, arff
11
+ import streamlit as st
12
+ import streamlit_antd_components as sac
13
+
14
+
15
+ def read_data_from_file(
16
+ uploaded_data_file,
17
+ col_names: Optional[List[str]] = None,
18
+ sep: Optional[str] = None,
19
+ na_values: List[str] = ['?'],
20
+ encoding: Optional[str] = None
21
+ ) -> pd.DataFrame:
22
+ """
23
+ 从上传的数据文件读取 DataFrame。
24
+ - 支持 .csv/.data/.txt/.xlsx/.xls/.mat
25
+ - col_names=None 时使用 header=0(文件首行做列名)
26
+ - col_names 不为 None 时使用 header=None 并指定 names=col_names
27
+ - 文本文件:自动探测编码、嗅探分隔符,跳过坏行
28
+ - Excel 文件:直接使用 pandas.read_excel
29
+ - MAT 文件:使用 scipy.loadmat,提取第一个主要变量,转为 DataFrame,并保证一维列
30
+ """
31
+ # 读取所有字节
32
+ data_bytes = uploaded_data_file.read()
33
+ # 重置流位置
34
+ try:
35
+ uploaded_data_file.seek(0)
36
+ except Exception:
37
+ pass
38
+
39
+ name = uploaded_data_file.name
40
+ ext = os.path.splitext(name)[1].lower()
41
+
42
+ # Excel 文件处理
43
+ if ext in ('.xlsx', '.xls'):
44
+ excel_kwargs = {}
45
+ if col_names is None:
46
+ excel_kwargs['header'] = 0
47
+ else:
48
+ excel_kwargs['header'] = None
49
+ excel_kwargs['names'] = col_names
50
+ return pd.read_excel(io.BytesIO(data_bytes), **excel_kwargs)
51
+
52
+ # ARFF 文件特殊处理
53
+ if ext == '.arff':
54
+ text = data_bytes.decode(encoding or 'utf-8', errors='ignore')
55
+ raw_data, meta = arff.loadarff(io.StringIO(text))
56
+ df = pd.DataFrame(raw_data)
57
+ for col in df.select_dtypes([object]).columns:
58
+ if isinstance(df[col].iloc[0], bytes):
59
+ df[col] = df[col].str.decode('utf-8', errors='ignore')
60
+ if col_names is not None and df.shape[1] == len(col_names):
61
+ df.columns = col_names
62
+ return df
63
+
64
+ # —— MAT 文件特殊处理 —— #
65
+ if ext == '.mat':
66
+ mat = loadmat(io.BytesIO(data_bytes))
67
+ data_keys = [k for k in mat.keys() if not k.startswith('__')]
68
+ if not data_keys:
69
+ raise ValueError('MAT 文件中未发现有效数据变量')
70
+ arr = mat[data_keys[0]]
71
+
72
+ # —— 先处理稀疏矩阵 —— #
73
+ if sparse.issparse(arr):
74
+ arr = arr.toarray()
75
+
76
+ arr = np.array(arr)
77
+ if arr.ndim > 2:
78
+ arr = arr.reshape(arr.shape[0], -1)
79
+
80
+ df = pd.DataFrame(arr)
81
+
82
+ if col_names is not None and df.shape[1] == len(col_names):
83
+ df.columns = col_names
84
+
85
+ return df
86
+
87
+ if encoding is None:
88
+ detected = chardet.detect(data_bytes)
89
+ encoding = detected.get('encoding', 'utf-8')
90
+ sample = data_bytes[:10_000].decode(encoding, errors='ignore')
91
+
92
+ try:
93
+ dialect = csv.Sniffer().sniff(sample, delimiters=[',',';','\t','|'])
94
+ detected_sep = dialect.delimiter
95
+ use_whitespace = False
96
+ except csv.Error:
97
+ detected_sep = None
98
+ use_whitespace = True
99
+
100
+ read_kwargs = {
101
+ 'engine': 'python',
102
+ 'encoding': encoding,
103
+ 'na_values': na_values,
104
+ 'comment': '|',
105
+ 'skipinitialspace': True,
106
+ 'on_bad_lines': 'skip',
107
+ }
108
+ if use_whitespace:
109
+ read_kwargs['delim_whitespace'] = True
110
+ else:
111
+ read_kwargs['sep'] = detected_sep
112
+
113
+ if col_names is None:
114
+ read_kwargs['header'] = 0
115
+ else:
116
+ read_kwargs['header'] = None
117
+ read_kwargs['names'] = col_names
118
+
119
+ return pd.read_csv(io.BytesIO(data_bytes), **read_kwargs)
120
+
121
+
122
+ def process_complex_data(uploaded_files, dataloadingagent):
123
+ """
124
+ 上传处理逻辑:
125
+ - 单文件:当作普通表格或 MAT 文件读(第一行当表头)
126
+ - 多文件:若有 .names/.arff 表头文件,则用其列名;否则推断列名
127
+ 并在存在多个数据文件时,通过用户选择进行横向或纵向拼接
128
+ """
129
+ if not uploaded_files:
130
+ st.error("请先上传文件")
131
+ return None, None
132
+
133
+ names_exts = ('.names', '.arff', '.doc')
134
+ data_exts = ('.data', '.csv', '.txt', '.xlsx', '.xls', '.mat', '.arff', '.tsv', '.dat', '.tst')
135
+
136
+ names_files = [f for f in uploaded_files
137
+ if os.path.splitext(f.name)[1].lower() in names_exts]
138
+ data_files = [f for f in uploaded_files
139
+ if os.path.splitext(f.name)[1].lower() in data_exts]
140
+
141
+ # 单文件直接读取
142
+ if len(uploaded_files) == 1 and uploaded_files[0] in data_files:
143
+ return read_data_from_file(uploaded_files[0], col_names=None), None
144
+
145
+ if not data_files:
146
+ raise ValueError(
147
+ "未检测到任何数据文件,请上传支持的格式:.csv/.data/.txt/.xlsx/.xls/.mat/.arff/.tsv/.dat/.tst"
148
+ )
149
+
150
+ # 1) 如果存在表头文件 (.names/.arff),读取���名
151
+ if names_files:
152
+ header_file = names_files[0]
153
+ # 使用 read_data_from_file 读取 sample,以确保正确处理编码
154
+ sample_df = read_data_from_file(data_files[0], col_names=None)
155
+ col_names = dataloadingagent.read_names_from_file(header_file, sample_df.head())
156
+ else:
157
+ # 2) 否则从第一个数据文件推断列名,加入编码容错
158
+ sample = data_files[0]
159
+ ext0 = os.path.splitext(sample.name)[1].lower()
160
+ try:
161
+ if ext0 in ('.xlsx', '.xls'):
162
+ col_names = list(pd.read_excel(sample, nrows=0))
163
+ elif ext0 == '.mat':
164
+ df_sample = read_data_from_file(sample, col_names=None)
165
+ col_names = list(df_sample.columns)
166
+ else:
167
+ # 文本文件推断列名,带上 encoding 参数
168
+ # 先通过 chardet 检测,再尝试 utf-8,失败则 latin1
169
+ raw_bytes = sample.read()
170
+ detected = chardet.detect(raw_bytes)
171
+ enc = detected.get('encoding', 'utf-8')
172
+ try:
173
+ col_names = list(pd.read_csv(
174
+ io.BytesIO(raw_bytes),
175
+ nrows=0,
176
+ encoding=enc,
177
+ engine='python'
178
+ ).columns)
179
+ except UnicodeDecodeError:
180
+ col_names = list(pd.read_csv(
181
+ io.BytesIO(raw_bytes),
182
+ nrows=0,
183
+ encoding='latin1',
184
+ engine='python'
185
+ ).columns)
186
+ finally:
187
+ try: sample.seek(0)
188
+ except: pass
189
+
190
+ # 读取所有数据文件并统一列名
191
+ dfs = [read_data_from_file(f, col_names=col_names) for f in data_files]
192
+
193
+ # 若多个数据文件,弹出拼接模式选择
194
+ if len(data_files) >= 2:
195
+
196
+ big_df = pd.concat(dfs, axis=0, ignore_index=True)
197
+
198
+ else:
199
+ big_df = dfs[0]
200
+
201
+ return big_df, dfs
202
+
203
+
204
+ def load_from_path(local_path):
205
+
206
+ ext = os.path.splitext(local_path)[1].lower()
207
+ if ext in (".csv", ".txt", ".data"):
208
+ df_local = pd.read_csv(local_path)
209
+ elif ext in (".xls", ".xlsx"):
210
+ df_local = pd.read_excel(local_path)
211
+ elif ext == ".json":
212
+ df_local = pd.read_json(local_path)
213
+ elif ext == ".jsonl":
214
+ df_local = pd.read_json(local_path, lines=True)
215
+ elif ext == ".parquet":
216
+ df_local = pd.read_parquet(local_path)
217
+ elif ext in (".pkl", ".pickle"):
218
+ df_local = pd.read_pickle(local_path)
219
+ elif ext == ".feather":
220
+ df_local = pd.read_feather(local_path)
221
+ elif ext == ".arff":
222
+ data, meta = arff.loadarff(local_path)
223
+ df_local = pd.DataFrame(data)
224
+ for col in df_local.select_dtypes([object]).columns:
225
+ if isinstance(df_local[col].iloc[0], bytes):
226
+ df_local[col] = df_local[col].str.decode('utf-8')
227
+ else:
228
+ st.error(f"不支持的文件类型:{ext}")
229
+ df_local = None
230
+
231
+ return df_local
232
+
233
+
234
+ def load_concat_file(dfs, agent):
235
+
236
+ mode = sac.segmented(
237
+ items=[
238
+ sac.SegmentedItem(label='纵向拼接'),
239
+ sac.SegmentedItem(label='横向拼接'),
240
+ ], label='检测到多个数据文件,请选择拼接方式', size='sm', radius='sm'
241
+ )
242
+
243
+ if mode.startswith("横向拼接"):
244
+ dfs_pos = [df.reset_index(drop=True) for df in dfs]
245
+ big_df = pd.concat(dfs_pos, axis=1)
246
+
247
+ cols = []
248
+ seen = {}
249
+ for c in big_df.columns:
250
+ if c in seen:
251
+ seen[c] += 1
252
+ cols.append(f"{c}_{seen[c]}")
253
+ else:
254
+ seen[c] = 0
255
+ cols.append(c)
256
+ big_df.columns = cols
257
+ agent.add_df(big_df)
258
+ else:
259
+ big_df = pd.concat(dfs, axis=0, ignore_index=True)
260
+ agent.add_df(big_df)
261
+
262
+ csv_bytes = big_df.to_csv(index=False).encode('utf-8')
263
+ st.download_button(
264
+ label="下载文件",
265
+ data=csv_bytes,
266
+ file_name="processed_data.csv",
267
+ mime="text/csv"
268
+ )
269
+
270
+
271
+ class PathFileWrapper:
272
+ """A wrapper to treat a local file path as a Streamlit UploadedFile."""
273
+ def __init__(self, path):
274
+ self.path = path
275
+ self.name = os.path.basename(path)
276
+ self._file = None
277
+
278
+ def read(self, *args, **kwargs):
279
+ with open(self.path, 'rb') as f:
280
+ return f.read()
281
+
282
+ def seek(self, offset, whence=0):
283
+
284
+ pass
285
+
286
+ def __repr__(self):
287
+ return f"PathFileWrapper(path='{self.path}')"
workflow/dataloading/dataloading_render.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional
3
+
4
+ import pandas as pd
5
+ import streamlit as st
6
+ import streamlit_antd_components as sac
7
+
8
+ from workflow.dataloading.dataloading_core import process_complex_data, load_from_path, load_concat_file, PathFileWrapper
9
+
10
+
11
+ def loading_data_file(agent):
12
+
13
+ st.info(
14
+ "💡 提示:\n"
15
+ "1. 支持一次上传多个数据文件\n"
16
+ "2. 自动使用大模型分析并处理数据\n"
17
+ "3. 支持多种格式的文件类型上传\n"
18
+ )
19
+
20
+ selected_index = sac.tabs([
21
+ sac.TabsItem(label='本地上传'),
22
+ sac.TabsItem(label='路径导入'),
23
+ ], color='#5980AE',)
24
+
25
+ if selected_index == "本地上传":
26
+ # 点击上传文件
27
+ uploaded_files = st.file_uploader(
28
+ "选择新文件",
29
+ accept_multiple_files=True,
30
+ help="拖拽或点击上传多个文件",
31
+ )
32
+
33
+ if uploaded_files:
34
+ current_memory_file_name = agent.load_file_name()
35
+ new_files = [f for f in uploaded_files if f.name not in current_memory_file_name]
36
+ if new_files:
37
+ try:
38
+ with st.spinner("正在处理数据..."):
39
+ df, dfs = process_complex_data(new_files, agent)
40
+ if df is not None:
41
+ agent.add_df(df)
42
+ agent.save_dfs(dfs)
43
+ for f in new_files:
44
+ agent.save_file_name(f.name)
45
+ st.rerun()
46
+ except Exception as err:
47
+ st.error(f"导入失败:{err}")
48
+
49
+ elif selected_index == "路径导入":
50
+ # 路径上传文件
51
+ raw_paths = st.text_area(
52
+ "从路径导入数据 (每行一个文件路径)",
53
+ placeholder= "C:\\data\\iris.names\nC:\\data\\iris.data",
54
+ height=100
55
+ )
56
+
57
+ if st.button("从路径加载文件", use_container_width=True):
58
+ if raw_paths:
59
+
60
+ path_list = [p.strip().strip("'\"") for p in raw_paths.strip().split('\n') if p.strip()]
61
+
62
+ valid_paths = [p for p in path_list if os.path.exists(p)]
63
+ invalid_paths = [p for p in path_list if not os.path.exists(p)]
64
+
65
+ if invalid_paths:
66
+ st.warning(f"路径不存在,已跳过:\n- " + "\n- ".join(invalid_paths))
67
+
68
+ if not valid_paths:
69
+ st.error("未找到任何有效的本地文件路径。")
70
+ else:
71
+ current_memory_file_name = agent.load_file_name()
72
+ new_paths = [p for p in valid_paths if p not in current_memory_file_name]
73
+
74
+ if not new_paths:
75
+ st.info("所有指定的路径文件均已加载。")
76
+ else:
77
+ files_to_process = [PathFileWrapper(p) for p in new_paths]
78
+ try:
79
+ with st.spinner("正在处理数据..."):
80
+ df, dfs = process_complex_data(files_to_process, agent)
81
+ if df is not None:
82
+ agent.add_df(df)
83
+ agent.save_dfs(dfs)
84
+ for p in new_paths:
85
+ agent.save_file_name(p)
86
+ st.rerun()
87
+ except Exception as err:
88
+ st.error(f"本地文件读取失败:{err}")
89
+
90
+ dfs = agent.load_dfs()
91
+ if dfs is not None and len(dfs) >= 2:
92
+ load_concat_file(dfs, agent)
93
+
94
+
95
+ def loading_basic_info(agent):
96
+
97
+ df = agent.load_df()
98
+ if df is not None:
99
+ r, c = df.shape
100
+ missing = int(df.isnull().sum().sum())
101
+ col1, col2, col3 = st.columns(3)
102
+ col1.metric("行数", r)
103
+ col2.metric("列数", c)
104
+ col3.metric("缺失值总数", missing)
105
+
106
+ dtype_info = pd.DataFrame({
107
+ "列名": df.columns,
108
+ "类型": df.dtypes.astype(str),
109
+ "非空": df.count().values,
110
+ "缺失%": (df.isnull().mean() * 100).round(2).values,
111
+ }).reset_index(drop=True)
112
+
113
+ selected_index = sac.tabs([
114
+ sac.TabsItem(label='数据类型概览'),
115
+ sac.TabsItem(label='数据预览'),
116
+ ],color='#5980AE',)
117
+
118
+ if selected_index == "数据类型概览":
119
+ st.dataframe(dtype_info, use_container_width=True)
120
+ elif selected_index == "数据预览":
121
+ if st.button("🎲 随机抽样"):
122
+ display_df = df.sample(10)
123
+ st.dataframe(display_df, use_container_width=True)
124
+ else:
125
+ st.dataframe(df.head(10), use_container_width=True)
126
+
127
+
128
+ def loading_chat(agent, auto=False) -> None:
129
+
130
+ df = agent.load_df()
131
+ if df is None:
132
+ return
133
+
134
+ with st.chat_message("assistant"):
135
+ st.write(
136
+ "我是 Anystat 数据分析助手,很高兴为您服务!\n\n"
137
+ "请先上传您的数据文件,上传完成后,您可以在下方和我对话,也可以直接点击按钮解析数据含义。"
138
+ )
139
+ analyze_btn = st.button("🔍 解析含义")
140
+ result_placeholder = st.empty()
141
+
142
+ # 渲染历史对话
143
+ chat_history = agent.load_memory()
144
+
145
+ for idx, entry in enumerate(chat_history):
146
+ bubble = st.chat_message(entry["role"])
147
+ content = entry["content"]
148
+ if isinstance(content, str):
149
+ bubble.write(content)
150
+
151
+ already_generated = any(
152
+ entry["role"] == "assistant" and "含义" in str(entry["content"])
153
+ for entry in chat_history
154
+ )
155
+
156
+ if analyze_btn or (auto and not already_generated):
157
+ st.chat_message("user").write("请帮我解析数据含义")
158
+ agent.add_memory({"role": "user", "content": "请帮我解析数据含义"})
159
+ with st.spinner("分析中..."):
160
+ desc = agent.do_data_description(df)
161
+
162
+ agent.finish_auto()
163
+ st.chat_message("assistant").write(desc)
164
+ agent.add_memory({"role": "assistant", "content": desc})
165
+ st.rerun()
166
+
167
+ # 用户自定义输入
168
+ user_input = st.chat_input("请输入需求,例如“帮我分析xx列”")
169
+ if user_input:
170
+ st.chat_message("user").write(user_input)
171
+ agent.add_memory({"role": "user", "content": user_input})
172
+ with st.spinner("处理中…"):
173
+ reply = agent.do_data_description(df, user_input)
174
+
175
+ st.chat_message("assistant").write(reply)
176
+ agent.add_memory({"role": "assistant", "content": reply})
177
+ st.rerun()
178
+
179
+
180
+ if __name__ == "__main__":
181
+
182
+ agent = st.session_state.data_loading_agent
183
+ planner = st.session_state.planner_agent
184
+ auto = planner.loading_auto
185
+
186
+ if st.session_state.auto_mode == True:
187
+ if (agent.finish_auto_task == True and planner.switched_prep == False) or planner.prep_auto == False:
188
+ planner.finish_loading_auto()
189
+ st.switch_page("workflow/preprocessing/preprocessing_render.py")
190
+
191
+ c1,c2 = st.columns(2)
192
+ with c1:
193
+ st.title("数据导入")
194
+ with c2:
195
+ st.write("")
196
+ st.write("")
197
+ sac.buttons([
198
+ sac.ButtonsItem(label='Github', icon='github', href='https://github.com/ElvisWang1111/AAAAAnystat'),
199
+ sac.ButtonsItem(label='Doc', icon=sac.BsIcon(name='bi bi-file-earmark-post-fill', size=16), href='https://elviswang1111.github.io/anystatweb.github.io/index.html'),
200
+ ], align='end', color='dark', variant='filled', index=None)
201
+ st.markdown("---")
202
+
203
+ c = st.columns(2)
204
+ with c[0].expander('数据上传', True):
205
+ loading_data_file(agent)
206
+ with c[1].expander('数据建议', True):
207
+ loading_chat(agent, auto)
208
+ with c[0].expander('数据展示', True):
209
+ loading_basic_info(agent)
210
+
workflow/modeling/model_inference.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import gzip
3
+ import io
4
+ import json
5
+ import traceback
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ from sklearn.preprocessing import StandardScaler
10
+ import streamlit as st
11
+
12
+ from workflow.dataloading.dataloading_core import process_complex_data
13
+ from utils.sanitize_code import sanitize_code, to_json_serializable
14
+
15
+
16
+ def infer_load_data(agent) -> None:
17
+
18
+ uploaded_files = st.file_uploader(
19
+ "选择推理数据集",
20
+ accept_multiple_files=True,
21
+ help="拖拽或点击上传多个文件",
22
+ )
23
+
24
+ if uploaded_files:
25
+ try:
26
+ with st.spinner("正在处理数据..."):
27
+ big_df, dfs = process_complex_data(uploaded_files, agent)
28
+ if big_df is not None:
29
+ agent.save_inference_data(big_df)
30
+ st.success("导入并处理完成!")
31
+ except Exception as err:
32
+ st.error(f"导入失败:{err}")
33
+
34
+
35
+ def infer_execution(agent):
36
+
37
+ inference_df = agent.load_inference_processed_df()
38
+ edited_code = agent.load_inference_code()
39
+
40
+ try:
41
+ model_obj = agent.load_best_model()
42
+
43
+ exec_ns = {
44
+ "inference_df": inference_df,
45
+ 'model_obj': model_obj,
46
+ "np": np,
47
+ "pd": pd,
48
+ "StandardScaler": StandardScaler
49
+ }
50
+
51
+ with st.spinner("正在进行推断分析..."):
52
+ exec(edited_code, exec_ns)
53
+
54
+ result_dict = exec_ns.get("result_dict")
55
+ if result_dict is None:
56
+ st.error("脚本未写入 `result_dict`。请确保编辑后的脚本在末尾赋值 result_dict。")
57
+ else:
58
+ art = result_dict.get('artifacts', {})
59
+ b64 = art.pop('predictions_df_b64', None)
60
+ if not art:
61
+ result_dict.pop('artifacts', None)
62
+
63
+ serializable = to_json_serializable(result_dict)
64
+ try:
65
+ result_json = json.dumps(serializable, ensure_ascii=False)
66
+ except Exception:
67
+ result_json = json.dumps(serializable, default=str, ensure_ascii=False)
68
+
69
+ with st.expander("推理结果", True):
70
+ if b64:
71
+ try:
72
+ gz_bytes = base64.b64decode(b64)
73
+ csv_bytes = gzip.decompress(gz_bytes)
74
+
75
+ df_pred = pd.read_csv(io.BytesIO(csv_bytes))
76
+ st.success("已加载带预测结果的 DataFrame")
77
+ st.dataframe(df_pred)
78
+
79
+ st.download_button(
80
+ label="下载带预测结果(predictions.csv)",
81
+ data=csv_bytes,
82
+ file_name="predictions.csv",
83
+ mime="text/csv"
84
+ )
85
+ except Exception as e:
86
+ st.error(f"解码 predictions_df 失败: {e}")
87
+ # 兜底:尝试从 records 字段恢复
88
+ records = result_dict.get('predictions_df_records')
89
+ if records:
90
+ try:
91
+ df_pred = pd.DataFrame(records)
92
+ st.dataframe(df_pred)
93
+ except Exception as e2:
94
+ st.error(f"从 records 恢复表格失败: {e2}")
95
+
96
+ except Exception as e:
97
+ st.error(f"推断失败:{e}")
98
+ st.text(traceback.format_exc())
99
+ agent.save_inference_error(traceback.format_exc())
100
+ raw = agent.code_generation_for_inference(agent.load_code(), inference_data.head(), auto=True)
101
+ code = sanitize_code(raw)
102
+ agent.save_inference_code(code)
workflow/modeling/model_training.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import json
3
+ import traceback
4
+ import base64
5
+ import gzip
6
+ import pickle
7
+ import time
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import streamlit as st
12
+ import streamlit_antd_components as sac
13
+ from streamlit_ace import st_ace
14
+ import torch
15
+ import torchvision
16
+ import xgboost
17
+ import lightgbm
18
+ from sklearn.ensemble import GradientBoostingRegressor, RandomForestClassifier, RandomForestRegressor
19
+ from sklearn.linear_model import LinearRegression
20
+ from sklearn.model_selection import train_test_split
21
+ from sklearn.preprocessing import StandardScaler
22
+
23
+ from utils.sanitize_code import sanitize_code, to_json_serializable
24
+
25
+
26
+ def train_execution(agent):
27
+
28
+ code = agent.load_code()
29
+ df = agent.load_df()
30
+
31
+ torch = importlib.import_module("torch")
32
+ torchvision = importlib.import_module("torchvision")
33
+
34
+ exec_ns = {
35
+ "df": df,
36
+ "np": np,
37
+ "pd": pd,
38
+ "torch": torch,
39
+ "torchvision": torchvision,
40
+ "train_test_split": train_test_split,
41
+ "StandardScaler": StandardScaler,
42
+ "LinearRegression": LinearRegression,
43
+ "RandomForestRegressor": RandomForestRegressor,
44
+ "GradientBoostingRegressor": GradientBoostingRegressor,
45
+ "RandomForestClassifier": RandomForestClassifier,
46
+ "xgboost": xgboost,
47
+ "lightgbm": lightgbm,
48
+ }
49
+
50
+ try:
51
+ with st.spinner("正在运行程序..."):
52
+ exec(code, exec_ns)
53
+ except Exception as exc:
54
+ st.error(f"已保存报错,请重新调用llm生成代码debug")
55
+ # st.error(f"脚本执行失败:{exc}")
56
+ st.text(traceback.format_exc())
57
+ agent.save_error(traceback.format_exc())
58
+ modeling_code_gen(agent, debug=True)
59
+ else:
60
+ result_dict = exec_ns.get("result_dict")
61
+ if result_dict is None:
62
+ st.error(
63
+ "脚本未写入 `result_dict`。请确保编辑后的脚本在末尾赋值 result_dict。"
64
+ )
65
+ else:
66
+ art = result_dict.get('artifacts', {})
67
+ b64 = art.pop('best_model_b64', None)
68
+ artifact_warning = result_dict.pop('artifact_warning', None)
69
+
70
+ if not art:
71
+ result_dict.pop('artifacts', None)
72
+
73
+ serializable = to_json_serializable(result_dict)
74
+ try:
75
+ result_json = json.dumps(serializable, ensure_ascii=False)
76
+ except Exception:
77
+ result_json = json.dumps(serializable, default=str, ensure_ascii=False)
78
+
79
+ with st.spinner("请求 LLM 格式化结果为 Markdown..."):
80
+ formatted = agent.result_format_prompt(result_json)
81
+ agent.save_modeling_result(formatted)
82
+
83
+ if b64:
84
+
85
+ gz_bytes = base64.b64decode(b64)
86
+
87
+ try:
88
+ agent.save_best_model_gz_bytes(gz_bytes)
89
+ model_obj = pickle.loads(gzip.decompress(gz_bytes))
90
+ st.success("最佳模型已加载到内存,可用于即时推理(示例)。")
91
+ agent.save_best_model(model_obj)
92
+ except Exception as e:
93
+ st.error(f"加载模型失败:{e}")
94
+
95
+
96
+ def modeling_code_gen(agent, debug = False, auto = False, ) -> None:
97
+
98
+ df = agent.load_df()
99
+ suggest = agent.load_suggestion()
100
+ print(suggest)
101
+ chat_history = agent.load_memory()
102
+ already_generated = any(
103
+ entry["role"] == "assistant" and "训练脚本已更新!请重新运行代码!" in str(entry["content"])
104
+ for entry in chat_history
105
+ )
106
+
107
+ if suggest is not None:
108
+ if debug == True or (auto and not already_generated):
109
+ with st.spinner("建模 Agent 正在生成训练脚本..."):
110
+ raw = agent.code_generation(
111
+ df.head().to_string(),
112
+ suggest,
113
+ )
114
+ code = sanitize_code(raw)
115
+ agent.save_code(code)
116
+ st.chat_message("assistant").write("训练脚本已更新!请重新运行代码!")
117
+ agent.add_memory({"role": "assistant", "content": "训练脚本已更新!请重新运行代码!"})
118
+ st.rerun()
119
+
120
+ analyze_btn = st.button("🔧 生成建模代码", key='modeling_code')
121
+ if analyze_btn:
122
+ with st.spinner("建模 Agent 正在生成训练脚本..."):
123
+ raw = agent.code_generation(
124
+ df.head().to_string(),
125
+ suggest,
126
+ )
127
+ code = sanitize_code(raw)
128
+ agent.save_code(code)
129
+ st.chat_message("assistant").write("训练脚本已更新!请重新运行代码!")
130
+ agent.add_memory({"role": "assistant", "content": "训练脚本已更新!请重新运行代码!"})
131
+ st.rerun()
132
+
133
+
134
+ def train_download_model(agent):
135
+
136
+ model = agent.load_best_model_gz_bytes()
137
+ if model is not None:
138
+ st.download_button(
139
+ label="⬇�� 下载最佳模型",
140
+ data=model,
141
+ file_name="best_model.pkl.gz",
142
+ mime="application/gzip"
143
+ )
workflow/modeling/modeling_render.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit_antd_components as sac
3
+ from streamlit_ace import st_ace
4
+
5
+ from utils.sanitize_code import sanitize_code
6
+ from workflow.modeling.model_training import train_execution, modeling_code_gen, train_download_model
7
+ from workflow.modeling.model_inference import infer_load_data, infer_execution
8
+ from workflow.preprocessing.preprocessing_core import prep_meta_execution
9
+
10
+
11
+ def modeling_quick_actions(agent):
12
+
13
+ st.write("选择一个或多个model:")
14
+ selected_models = sac.chip(
15
+ items=[
16
+ sac.ChipItem(label='线性回归'),
17
+ sac.ChipItem(label='XGBoost'),
18
+ sac.ChipItem(label='随机森林'),
19
+ sac.ChipItem(label='神经网络'),
20
+ ], index=[0, 2], align='center', radius='md', color='#44658C', multiple=True
21
+ )
22
+
23
+ df = agent.load_df()
24
+
25
+ if st.button("🖋️ 快速建模"):
26
+ if not selected_models:
27
+ st.error("请先选择训练model。")
28
+ else:
29
+ with st.spinner("建模 Agent 正在生成训练脚本..."):
30
+ raw = agent.code_generation(df.head().to_string(), selected_models)
31
+ code = sanitize_code(raw)
32
+ agent.save_code(code)
33
+ agent.save_suggestion(selected_models)
34
+ agent.save_user_selection(selected_models)
35
+ st.success("训练脚本已生成并保存。")
36
+ st.rerun()
37
+
38
+ return selected_models
39
+
40
+
41
+ def modeling_execution(agent, auto = False) -> None:
42
+
43
+ code = agent.load_code()
44
+
45
+ edited = st_ace(
46
+ value=code,
47
+ height=450,
48
+ theme="tomorrow_night",
49
+ language="python",
50
+ auto_update=True
51
+ )
52
+
53
+ not_executed = agent.load_modeling_result() == None
54
+
55
+ if edited is not None:
56
+ if st.button("▶️ 执行建模", key="modeling_run_code") or (auto and not_executed):
57
+ code = sanitize_code(edited)
58
+ agent.save_code(code)
59
+ train_execution(agent)
60
+ agent.finish_auto()
61
+ st.rerun()
62
+
63
+ modeling_result = agent.load_modeling_result()
64
+ if modeling_result is None:
65
+ result_expand = False
66
+ else:
67
+ result_expand = True
68
+ train_download_model(agent)
69
+ with st.expander("训练结果", result_expand):
70
+ if modeling_result:
71
+ st.subheader("训练结果")
72
+ try:
73
+ st.markdown(modeling_result)
74
+ except Exception:
75
+ st.write(modeling_result)
76
+
77
+
78
+ def modeling_inference(agent, preproc_agent):
79
+
80
+ infer_load_data(agent)
81
+ inference_processed_data = agent.load_inference_processed_df()
82
+ inference_data = agent.load_inference_data()
83
+
84
+ code = agent.load_inference_code()
85
+
86
+ if st.button("▶️ 执行推断"):
87
+
88
+ with st.spinner("正在对推理数据进行预处理..."):
89
+
90
+ inference_data = agent.load_inference_data()
91
+ if preproc_agent.code is not None:
92
+ inference_processed_df = prep_meta_execution(preproc_agent, preproc_agent.code, inference_data)
93
+ inference_data = inference_processed_df
94
+ agent.save_inference_processed_df(inference_data)
95
+ st.write("推断数据预览:")
96
+ st.dataframe(inference_data.head())
97
+
98
+ with st.spinner("建模 Agent 正在生成推理脚本..."):
99
+
100
+ raw = agent.code_generation_for_inference(agent.load_code(), inference_data.head())
101
+ code = sanitize_code(raw)
102
+ agent.save_inference_code(code)
103
+
104
+ if code is not None:
105
+ edited_code = st_ace(
106
+ value=code,
107
+ height=450,
108
+ theme="tomorrow_night",
109
+ language="python",
110
+ auto_update=True
111
+ )
112
+ agent.save_inference_code(code)
113
+ if st.button("▶️ 执行建模"):
114
+ infer_execution(agent)
115
+
116
+
117
+ def modeling_chat(agent, auto) -> None:
118
+
119
+ user_input = st.text_input("建模目标", "默认")
120
+ agent.save_target(user_input)
121
+
122
+ with st.chat_message("assistant"):
123
+ st.write(
124
+ "我是 Anystat 数据分析助手,很高兴为您服务!\n\n"
125
+ "您可以在下方输入建模相关问题,或直接点击按钮获取建模建议。"
126
+ )
127
+ analyze_btn = st.button("🔍 建模推荐", key='modeling_suggest')
128
+ result_placeholder = st.empty()
129
+
130
+ chat_history = agent.load_memory()
131
+
132
+ for idx, entry in enumerate(chat_history):
133
+ bubble = st.chat_message(entry["role"])
134
+ content = entry["content"]
135
+ if isinstance(content, str):
136
+ bubble.write(content)
137
+
138
+ already_generated = any(
139
+ entry["role"] == "assistant" and "模" in str(entry["content"])
140
+ for entry in chat_history
141
+ )
142
+
143
+ if analyze_btn or (auto and not already_generated):
144
+ st.chat_message("user").write("请帮我获取��模建议")
145
+ agent.add_memory({"role": "user", "content": "请帮我获取建模建议"})
146
+ with st.spinner("分析中..."):
147
+ suggestion = agent.get_model_suggestion()
148
+ agent.save_suggestion(suggestion)
149
+ agent.refine_suggestions()
150
+ st.chat_message("assistant").write(suggestion)
151
+ agent.add_memory({"role": "assistant", "content": suggestion})
152
+ st.chat_message("assistant").write("需要进一步优化?再次点击按钮获取下一条建议")
153
+ agent.add_memory({"role": "assistant", "content": "需要进一步优化?再次点击按钮获取下一条建议"})
154
+
155
+ user_input = st.chat_input("请输入您的问题,例如“如何优化这个模型”")
156
+ if user_input:
157
+ st.chat_message("user").write(user_input)
158
+ agent.add_memory({"role": "user", "content": user_input})
159
+ with st.spinner("处理中…"):
160
+ reply = agent.get_model_suggestion(user_input)
161
+ agent.save_suggestion(reply)
162
+ agent.refine_suggestions()
163
+ st.chat_message("assistant").write(reply)
164
+ agent.add_memory({"role": "assistant", "content": reply})
165
+ st.chat_message("assistant").write("需要进一步优化?再次点击按钮获取下一条建议")
166
+ agent.add_memory({"role": "assistant", "content": "需要进一步优化?再次点击按钮获取下一条建议"})
167
+
168
+
169
+ if __name__ == "__main__":
170
+
171
+ st.title("数据建模")
172
+ st.markdown("---")
173
+
174
+ preproc_agent = st.session_state.data_preprocess_agent
175
+ load_agent = st.session_state.data_loading_agent
176
+
177
+ processed_df = preproc_agent.load_processed_df()
178
+ if processed_df is None:
179
+ df = load_agent.load_df()
180
+ else:
181
+ df = processed_df
182
+
183
+ if df is None:
184
+ st.warning("⚠️ 请先在数据导入页面加载数据")
185
+ st.stop()
186
+
187
+ agent = st.session_state.modeling_coding_agent
188
+ agent.add_df(df)
189
+ planner = st.session_state.planner_agent
190
+ auto = planner.modeling_auto
191
+
192
+ if st.session_state.auto_mode == True:
193
+ if (agent.finish_auto_task == True and planner.switched_modeling == False) or planner.modeling_auto == False:
194
+ planner.finish_modeling_auto()
195
+ st.switch_page("workflow/report/report_render.py")
196
+
197
+ code = agent.load_code()
198
+ if code is None:
199
+ expand = False
200
+ else:
201
+ expand = True
202
+
203
+ inference_model = agent.load_best_model()
204
+ if inference_model is None:
205
+ inference_expand = False
206
+ else:
207
+ inference_expand = True
208
+
209
+ c = st.columns(2)
210
+ with c[0].expander('快速建模', True):
211
+ modeling_quick_actions(agent)
212
+ with c[1].expander('建模建议', True):
213
+ modeling_chat(agent, auto)
214
+ modeling_code_gen(agent, auto=auto)
215
+ with c[0].expander('建模执行', expand):
216
+ modeling_execution(agent, auto)
217
+ # with c[0].expander('推断分析', inference_expand):
218
+ # modeling_inference(agent, preproc_agent)
workflow/preprocessing/preprocessing_core.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import traceback
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import streamlit as st
7
+ from streamlit_ace import st_ace
8
+ from sklearn.compose import ColumnTransformer
9
+ from sklearn.impute import SimpleImputer
10
+ from sklearn.pipeline import Pipeline
11
+ from sklearn.preprocessing import FunctionTransformer
12
+ from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder, OrdinalEncoder, RobustScaler, StandardScaler
13
+
14
+ from utils.sanitize_code import sanitize_code
15
+
16
+
17
+ def prep_meta_execution(agent, code, df, auto=False):
18
+
19
+ edited = st_ace(
20
+ value=code,
21
+ height=400,
22
+ theme="tomorrow_night",
23
+ language="python",
24
+ auto_update=True
25
+ )
26
+
27
+ not_generated = agent.load_processed_df() is None
28
+
29
+ if code is not None:
30
+ if st.button("▶️ 执行预处理") or (auto and not_generated):
31
+ code = sanitize_code(edited)
32
+ agent.save_code(code)
33
+
34
+ exec_ns = {
35
+ "df": df,
36
+ "np": np,
37
+ "pd": pd,
38
+ "st": st,
39
+ "SimpleImputer": SimpleImputer,
40
+ "FunctionTransformer": FunctionTransformer,
41
+ "StandardScaler": StandardScaler,
42
+ "MinMaxScaler": MinMaxScaler,
43
+ "RobustScaler": RobustScaler,
44
+ "OneHotEncoder": OneHotEncoder,
45
+ "OrdinalEncoder": OrdinalEncoder,
46
+ "LabelEncoder": LabelEncoder,
47
+ "ColumnTransformer": ColumnTransformer,
48
+ "Pipeline": Pipeline,
49
+ }
50
+
51
+ try:
52
+ with st.spinner("正在运行程序..."):
53
+ exec(code, exec_ns)
54
+ except Exception as exc:
55
+ st.error(f"已保存报错,请重新调用llm生成代码debug")
56
+ st.text(traceback.format_exc())
57
+ agent.save_error(traceback.format_exc())
58
+ prep_code_gen(agent, debug=True)
59
+ else:
60
+ process_df = exec_ns.get("process_df")
61
+ if process_df is None:
62
+ st.error(
63
+ "脚本未写入 `process_df`。请确保编辑后的脚本在末尾赋值 process_df"
64
+ )
65
+ else:
66
+ agent.save_processed_df(process_df)
67
+ agent.finish_auto()
68
+ st.rerun()
69
+ return process_df
70
+
71
+
72
+ def prep_code_gen(agent, auto = False, debug = False):
73
+
74
+ suggest = agent.load_preprocessing_suggestions()
75
+ df = agent.load_df()
76
+
77
+ chat_history = agent.load_memory()
78
+ already_generated = any(
79
+ entry["role"] == "assistant" and "预处理脚本已更新!请重新运行代码!" in str(entry["content"])
80
+ for entry in chat_history
81
+ )
82
+
83
+ if suggest is not None:
84
+
85
+ if debug == True or (auto and not already_generated):
86
+ with st.spinner("预处理 Agent 正在编写脚本..."):
87
+ raw = agent.code_generation(
88
+ df.head(10).to_string(),
89
+ suggest,
90
+ )
91
+ code = sanitize_code(raw)
92
+ agent.save_code(code)
93
+
94
+ st.chat_message("assistant").write("预处理脚本已更新!请重新运行代码!")
95
+ agent.add_memory({"role": "assistant", "content": "预处理脚本已更新!请重新运行代码!"})
96
+
97
+ st.rerun()
98
+
99
+ analyze_btn = st.button("🔧 生成预处理代码", key='prep_code')
100
+ if analyze_btn:
101
+ with st.spinner("向 LLM 请求生成预处理脚本..."):
102
+ raw = agent.code_generation(
103
+ df.head(10).to_string(),
104
+ suggest,
105
+ )
106
+ code = sanitize_code(raw)
107
+ agent.save_code(code)
108
+
109
+ st.chat_message("assistant").write("预处理脚本已更新!请重新运行代码!")
110
+ agent.add_memory({"role": "assistant", "content": "预处理脚本已更新!请重新运行代码!"})
111
+
112
+ st.rerun()
workflow/preprocessing/preprocessing_render.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import traceback
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import streamlit as st
7
+ from streamlit_ace import st_ace
8
+ from sklearn.compose import ColumnTransformer
9
+ from sklearn.impute import SimpleImputer
10
+ from sklearn.pipeline import Pipeline
11
+ from sklearn.preprocessing import FunctionTransformer
12
+ from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder, OrdinalEncoder, RobustScaler, StandardScaler
13
+
14
+ from utils.sanitize_code import sanitize_code
15
+ from workflow.preprocessing.preprocessing_core import prep_meta_execution, prep_code_gen
16
+
17
+
18
+ def prep_basic_info(agent):
19
+
20
+ df = agent.load_df()
21
+
22
+ # 展示基本统计
23
+ r, c = df.shape
24
+ missing = int(df.isnull().sum().sum())
25
+ col1, col2, col3 = st.columns(3)
26
+ col1.metric("行数", r)
27
+ col2.metric("列数", c)
28
+ col3.metric("缺失值总数", missing)
29
+
30
+ dtype_info = pd.DataFrame({
31
+ '列名': df.columns,
32
+ '类型': df.dtypes.astype(str),
33
+ '非空值数量': df.count().values,
34
+ '缺失值比例(%)': (df.isnull().mean() * 100).round(2).values,
35
+ })
36
+ dtype_info = dtype_info.reset_index(drop=True)
37
+ st.dataframe(dtype_info, use_container_width=True)
38
+
39
+
40
+ def prep_execution(agent, auto=False):
41
+ '''
42
+ training data进行预处理
43
+ '''
44
+
45
+ code = agent.load_code()
46
+ df = agent.load_df()
47
+
48
+ process_df = prep_meta_execution(agent, code, df, auto=auto)
49
+
50
+
51
+ def prep_result(agent):
52
+
53
+ process_df = agent.load_processed_df()
54
+ df = agent.load_df()
55
+
56
+ if process_df is not None:
57
+ st.write("处理前数据预览:", df.head(10))
58
+ st.write("处理后数据预览:", process_df.head(10))
59
+
60
+ csv_buffer = io.StringIO()
61
+ process_df.to_csv(csv_buffer, index=False)
62
+ csv_bytes = csv_buffer.getvalue().encode('utf-8')
63
+
64
+ st.download_button(
65
+ label="⬇️ 下载处理后数据",
66
+ data=csv_bytes,
67
+ file_name="processed_data.csv",
68
+ mime="text/csv",
69
+ )
70
+
71
+
72
+ def prep_chat(agent, auto=False):
73
+ """渲染对话式建议区"""
74
+
75
+ with st.chat_message("assistant"):
76
+ st.write("我是 Anystat 数据分析助手,很高兴为您服务!\n\n"
77
+ "您可以在下方输入预处理需求,或直接点击按钮获取预处理建议。")
78
+ analyze_btn = st.button("🔍 预处理推荐", key='prep_suggest')
79
+
80
+ # 对话历史渲染
81
+ chat_history = agent.load_memory()
82
+
83
+ for idx, entry in enumerate(chat_history):
84
+ bubble = st.chat_message(entry["role"])
85
+ content = entry["content"]
86
+ if isinstance(content, str):
87
+ bubble.write(content)
88
+
89
+ already_generated = any(
90
+ entry["role"] == "assistant" and "预处理" in str(entry["content"])
91
+ for entry in chat_history
92
+ )
93
+
94
+ # 自动/手动触发
95
+ if analyze_btn or (auto and not already_generated):
96
+
97
+ st.chat_message("user").write("请给我预处理建议")
98
+ agent.add_memory({'role': 'user', 'content': "请给我预处理建议"})
99
+
100
+ with st.spinner("生成建议中…"):
101
+ text = agent.get_preprocessing_suggestions()
102
+ agent.save_preprocessing_suggestions(text)
103
+ agent.refine_suggestions(df.head(10).to_string())
104
+ st.chat_message("assistant").write(text)
105
+ agent.add_memory({'role': 'assistant', 'content': text})
106
+
107
+ # 用户自然语言交互
108
+ user_input = st.chat_input("请输入您的问题")
109
+ if user_input:
110
+ st.chat_message("user").write(user_input)
111
+ agent.add_memory({'role': 'user', 'content': user_input})
112
+ agent.save_user_input(user_input)
113
+ with st.spinner("处理中…"):
114
+ reply = agent.get_preprocessing_suggestions(user_input)
115
+ agent.save_preprocessing_suggestions(reply)
116
+ agent.refine_suggestions(df.head(10).to_string())
117
+ st.chat_message('assistant').write(reply)
118
+ agent.add_memory({'role': 'assistant', 'content': reply})
119
+
120
+
121
+ if __name__ == '__main__':
122
+
123
+ st.title("数据预处理与标准化")
124
+
125
+ st.markdown("---")
126
+
127
+ data_loading_agent = st.session_state.data_loading_agent
128
+ df = data_loading_agent.load_df()
129
+ planner = st.session_state.planner_agent
130
+ auto = planner.prep_auto
131
+
132
+ if df is None:
133
+ st.warning("⚠️ 请先在数据导入页面加载数据")
134
+ st.stop()
135
+
136
+ agent = st.session_state.data_preprocess_agent
137
+ agent.add_df(df)
138
+
139
+ if st.session_state.auto_mode == True:
140
+ if (agent.finish_auto_task == True and planner.switched_prep == False) or planner.prep_auto == False:
141
+ planner.finish_prep_auto()
142
+ st.switch_page("workflow/visualization/viz_render.py")
143
+
144
+ code = agent.load_code()
145
+ if code is None:
146
+ code_expand = False
147
+ else:
148
+ code_expand = True
149
+
150
+ c = st.columns(2)
151
+ with c[0].expander('预处理展示', True):
152
+ prep_basic_info(agent)
153
+ with c[1].expander('预处理建议', True):
154
+ prep_chat(agent, auto)
155
+ prep_code_gen(agent, auto=auto)
156
+ with c[0].expander('预处理执行', code_expand):
157
+ prep_execution(agent, auto)
158
+ with c[0].expander('预处理结果', code_expand):
159
+ prep_result(agent)
workflow/report/report_core.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ReportNode:
2
+ """文档节点:可以是 heading 或 paragraph"""
3
+ def __init__(self, node_type, text, level=0):
4
+ self.type = node_type # "heading" 或 "paragraph"
5
+ self.text = text
6
+ self.level = level
7
+ self.children = [] # 子节点(用于分层)
8
+
9
+ def to_dict(self):
10
+ return {
11
+ "type": self.type,
12
+ "text": self.text,
13
+ "level": self.level,
14
+ "children": [c.to_dict() for c in self.children]
15
+ }
16
+
17
+ # 现在只适合于顺序添加
18
+ class Reportcore:
19
+ def __init__(self):
20
+ self.root = ReportNode("root", "", level=-1) # 虚拟根节点
21
+ self.current_stack = [self.root] # 用栈管理当前层级
22
+
23
+ def add_heading(self, text, level=0):# 从0开始
24
+ """
25
+ 添加标题,根据 level 自动挂载到合适的父节点
26
+ """
27
+ new_node = ReportNode("heading", text, level)
28
+
29
+ # 回溯到合适的父节点
30
+ while self.current_stack and self.current_stack[-1].level >= level:
31
+ self.current_stack.pop()
32
+
33
+ parent = self.current_stack[-1]
34
+ parent.children.append(new_node)
35
+ self.current_stack.append(new_node)
36
+
37
+ def add_paragraph(self, text):
38
+ """
39
+ 添加段落,挂在当前最后一个 heading 下
40
+ """
41
+ parent = self.current_stack[-1]
42
+
43
+ parent.children.append(ReportNode("paragraph", text, level=parent.level + 1))
44
+
45
+ def to_dict(self):
46
+ return self.root.to_dict()
workflow/report/report_html.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import re
3
+ import streamlit as st
4
+ import plotly.io as pio
5
+ from utils.sanitize_code import sanitize_code
6
+ import base64
7
+
8
+ def write_html(agents):
9
+ report_agent = agents[-1]
10
+ report_obj = report_agent.load_report() # Reportcore
11
+
12
+ # 图像分析列表
13
+ analysis_list = agents[2].summary_fig_analysis_list()
14
+
15
+ # 给 heading 加唯一 id
16
+ heading_counter = {"count": 0}
17
+ def _gen_id(text):
18
+ heading_counter["count"] += 1
19
+ return f"sec-{heading_counter['count']}"
20
+
21
+ # 遍历树 → 正文 & TOC
22
+ toc_items, content_items = [], []
23
+
24
+ def _process_node(node):
25
+ if node.type == "heading":
26
+ sec_id = _gen_id(node.text)
27
+ toc_items.append((sec_id, node.text, node.level))
28
+ content_items.append(
29
+ f"<h{node.level} id='{sec_id}' class='font-bold text-gray-800 mt-8 mb-4 text-{max(6-node.level,1)}xl'>{node.text}</h{node.level}>"
30
+ )
31
+ for ch in node.children:
32
+ _process_node(ch)
33
+
34
+ elif node.type == "paragraph":
35
+ parts = re.split(r'(\[FIG:\d+\])', node.text)
36
+ html_parts = []
37
+ for part in parts:
38
+ part = part.strip()
39
+ if not part:
40
+ continue
41
+ if part.startswith("[FIG:") and part.endswith("]"):
42
+ idx = int(part[5:-1])
43
+ fig_html = ""
44
+ if 0 <= idx < len(analysis_list):
45
+ fig_obj = analysis_list[idx].get("figure")
46
+ try:
47
+ buf = io.BytesIO()
48
+ pio.write_image(fig_obj, buf, format="png")
49
+ data = buf.getvalue()
50
+ b64 = base64.b64encode(data).decode("utf-8")
51
+ fig_html = f"<div class='flex justify-center my-6'><img src='data:image/png;base64,{b64}' class='rounded-xl shadow-md max-w-3xl w-full'/></div>"
52
+ except Exception as e:
53
+ fig_html = f"<p class='text-red-500'>[图像插入失败: {e}]</p>"
54
+ html_parts.append(fig_html)
55
+ else:
56
+ html_parts.append(f"<p class='text-gray-700 leading-relaxed mb-4'>{part}</p>")
57
+ content_items.append("".join(html_parts))
58
+
59
+ else: # root
60
+ for ch in node.children:
61
+ _process_node(ch)
62
+
63
+ _process_node(report_obj.root)
64
+
65
+ # TOC HTML
66
+ toc_html = ["<nav class='space-y-2'>"]
67
+ prev_level = -1
68
+ for sec_id, text, level in toc_items:
69
+ indent = "ml-" + str(level * 4)
70
+ toc_html.append(f"<a href='#{sec_id}' class='block {indent} text-gray-600 hover:text-blue-600 transition-colors'>{text}</a>")
71
+ toc_html.append("</nav>")
72
+
73
+ # 拼接完整 HTML
74
+ html_content = f"""
75
+ <html>
76
+ <head>
77
+ <meta charset="utf-8">
78
+ <script src="https://cdn.tailwindcss.com"></script>
79
+ <script>
80
+ document.addEventListener("DOMContentLoaded", function() {{
81
+ const sections = document.querySelectorAll("h1, h2, h3, h4, h5, h6");
82
+ const navLinks = document.querySelectorAll("nav a");
83
+
84
+ function onScroll() {{
85
+ let scrollPos = document.documentElement.scrollTop || document.body.scrollTop;
86
+ let currentId = "";
87
+ sections.forEach(sec => {{
88
+ if (sec.offsetTop - 80 <= scrollPos) {{
89
+ currentId = sec.id;
90
+ }}
91
+ }});
92
+ navLinks.forEach(link => {{
93
+ link.classList.remove("font-bold", "text-blue-600");
94
+ if (link.getAttribute("href") === "#" + currentId) {{
95
+ link.classList.add("font-bold", "text-blue-600");
96
+ }}
97
+ }});
98
+ }}
99
+ window.addEventListener("scroll", onScroll);
100
+ onScroll();
101
+ }});
102
+ </script>
103
+ </head>
104
+ <body class="flex font-sans">
105
+ <aside class="fixed top-0 left-0 h-screen w-64 bg-gray-100 border-r border-gray-300 p-6 overflow-y-auto">
106
+ <h2 class="text-xl font-bold mb-4">目录</h2>
107
+ {''.join(toc_html)}
108
+ </aside>
109
+ <main class="ml-64 p-10 w-full max-w-5xl">
110
+ {''.join(content_items)}
111
+ </main>
112
+ </body>
113
+ </html>
114
+ """
115
+
116
+ report_agent.save_html(html_content)
117
+ st.success("HTML 报告 (Tailwind 风格) 生成成功 ✅")
workflow/report/report_markdown.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import re
3
+ import base64
4
+ import plotly.io as pio
5
+ import streamlit as st
6
+ def write_markdown(agents):
7
+ report_agent = agents[-1]
8
+ report_obj = report_agent.load_report() # Reportcore
9
+
10
+ # 图像分析列表
11
+ analysis_list = agents[2].summary_fig_analysis_list()
12
+
13
+ md_parts = []
14
+
15
+ def _process_node(node):
16
+ if node.type == "heading":
17
+ prefix = "#" * (node.level if node.level > 0 else 1)
18
+ md_parts.append(f"{prefix} {node.text}\n")
19
+ for ch in node.children:
20
+ _process_node(ch)
21
+
22
+ elif node.type == "paragraph":
23
+ parts = re.split(r'(\[FIG:\d+\])', node.text)
24
+ for part in parts:
25
+ part = part.strip()
26
+ if not part:
27
+ continue
28
+
29
+ if part.startswith("[FIG:") and part.endswith("]"):
30
+ idx = int(part[5:-1])
31
+ if 0 <= idx < len(analysis_list):
32
+ fig_obj = analysis_list[idx].get("figure")
33
+ try:
34
+ buf = io.BytesIO()
35
+ pio.write_image(fig_obj, buf, format="png")
36
+ data = buf.getvalue()
37
+ b64 = base64.b64encode(data).decode("utf-8")
38
+ # 🔹 直接内嵌 base64
39
+ md_parts.append(
40
+ f"![Figure {idx}](data:image/png;base64,{b64})\n"
41
+ )
42
+ except Exception as e:
43
+ md_parts.append(f"> **图像插入失败**: {e}\n")
44
+ else:
45
+ md_parts.append(f"{part}\n\n")
46
+
47
+ else: # root
48
+ for ch in node.children:
49
+ _process_node(ch)
50
+
51
+ _process_node(report_obj.root)
52
+
53
+ md_content = "\n".join(md_parts)
54
+ report_agent.save_markdown(md_content)
55
+ st.success("Markdown 报告生成成功 ✅")
workflow/report/report_prepare_er.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import io
3
+ import re
4
+ from io import BytesIO
5
+
6
+ import streamlit as st
7
+ from tqdm import tqdm
8
+ from stqdm import stqdm
9
+ from docx import Document
10
+ from docx.oxml.ns import qn
11
+ from docx.shared import Inches
12
+ from docx.enum.table import WD_TABLE_ALIGNMENT
13
+ from docx.enum.text import WD_ALIGN_PARAGRAPH
14
+ import plotly.express as px
15
+ import plotly.io as pio
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
+
18
+ from utils.sanitize_code import sanitize_code
19
+ from workflow.report.report_core import *
20
+
21
+
22
+ def report_prepare(agents, parallel=True, max_workers=4):
23
+ report_agent = agents[-1]
24
+ toc = report_agent.load_outline()
25
+ if toc is None:
26
+ st.error("请先生成目录")
27
+ return
28
+
29
+ toc = sanitize_code(toc)
30
+
31
+ # === 汇总各分析模块的摘要 ===
32
+ agent_abstracts = {}
33
+ with st.spinner("正在汇总各分析模块的结果..."):
34
+ for i in stqdm(range(len(agents) - 1)):
35
+ agent_abstracts[i] = agents[i].check_abstract()
36
+
37
+ # === 更新 toc 的 FIG 列表 ===
38
+ selected_full_contents_vis = agents[2].check_full()
39
+ toc = report_agent.selected_photo_update_toc(toc, selected_full_contents_vis)
40
+ toc = sanitize_code(toc)
41
+ print(toc)
42
+ try:
43
+ toc = ast.literal_eval(toc)
44
+ except Exception:
45
+ pass
46
+
47
+ # === 更新 toc 的 模块选择 列表 ===
48
+ with st.spinner("正在匹配各章节所需的分析模块..."):
49
+ toc_with_choice = report_agent.update_toc_with_relevant_sections(toc, agent_abstracts)
50
+ toc_with_choice = sanitize_code(toc_with_choice)
51
+ try:
52
+ toc_with_choice = ast.literal_eval(toc_with_choice)
53
+ except Exception:
54
+ pass
55
+
56
+ # === 初始化报告结构 ===
57
+ doc = Reportcore()
58
+ doc.add_heading('数据分析报告', 0)
59
+ selected_model = st.session_state.selected_model
60
+
61
+ def process_section(idx, t,t_w_c, history_content=""):
62
+ st.session_state.selected_model = selected_model
63
+ # t: ('标题', 层级, 内容大纲, [figs], [modules])
64
+ _, _, _, _, choice_list = t_w_c
65
+ selected_full_contents = {i: agents[i].check_full() for i in choice_list if i < len(agents) - 1}
66
+ content = report_agent.write_section_body(toc, t, selected_full_contents, history_content)
67
+ print(idx)
68
+ return (idx, t, content)
69
+
70
+ results = []
71
+
72
+ # 串行或并行
73
+ if not parallel:
74
+ with st.spinner("正在串行生成各章节内容(带上下文)..."):
75
+ history_content = ""
76
+ for idx, t in stqdm(enumerate(toc)):
77
+ t_w_c= toc_with_choice[idx]
78
+ _, _, content = process_section(idx, t,t_w_c, history_content)
79
+ results.append((idx, t, content))
80
+ history_content += f"\n\n{t[0]}\n{content}"
81
+ else:
82
+ with st.spinner(f"正在并行生成各章节内容({max_workers}线程)..."):
83
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
84
+ print(toc_with_choice)
85
+ # print(f"idx={idx}, len={len(toc_with_choice)}")
86
+ futures = {
87
+ executor.submit(process_section, idx, t, toc_with_choice[idx], ""): idx
88
+ for idx, t in enumerate(toc)
89
+ }
90
+ for future in stqdm(as_completed(futures), total=len(futures)):
91
+ try:
92
+ results.append(future.result())
93
+ except Exception as e:
94
+ print(f"章节生成失败: {e}")
95
+
96
+ # 排序 & 写入报告
97
+ results.sort(key=lambda x: x[0])
98
+ for _, t, content in results:
99
+ doc.add_heading(t[0], level=t[1])
100
+ doc.add_paragraph(content)
101
+
102
+ report_agent.save_report(doc)
workflow/report/report_render.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import io
3
+ from io import BytesIO
4
+
5
+ from stqdm import stqdm
6
+ import mammoth
7
+ import numpy as np
8
+ import pandas as pd
9
+ import streamlit as st
10
+ import streamlit_antd_components as sac
11
+ from concurrent.futures import ThreadPoolExecutor, as_completed
12
+
13
+ from prompt_engineer.sec5_call_llm import *
14
+ from workflow.report.report_utils import html_dowmload
15
+ from workflow.report.report_html import write_html
16
+ from workflow.report.report_word import write_word
17
+ from workflow.report.report_markdown import write_markdown
18
+ from workflow.report.report_prepare_er import report_prepare
19
+
20
+
21
+ def report_save(agents, auto):
22
+
23
+ report_agent = agents[-1]
24
+ action = report_agent.load_report_format()
25
+
26
+ if report_agent.load_report_format() == 'HTML':
27
+ not_generate = report_agent.html == None
28
+ if report_agent.load_report_format() == 'Word':
29
+ not_generate = report_agent.word == None
30
+ if report_agent.load_report_format() == 'Markdown':
31
+ not_generate = report_agent.markdown == None
32
+
33
+ mode = report_agent.load_gen_mode()
34
+ parallel = (mode == "并行")
35
+
36
+ if st.button(f"📝 生成 {action} 报告") or (auto and not_generate):
37
+ with st.spinner(f"正在生成 {action} 报告..."):
38
+
39
+ report_prepare(agents, parallel=parallel)
40
+
41
+ if report_agent.load_report_format() == 'Word':
42
+ write_word(agents)
43
+ elif report_agent.load_report_format() == 'HTML':
44
+ write_html(agents)
45
+ elif report_agent.load_report_format() == 'Markdown':
46
+ write_markdown(agents)
47
+
48
+
49
+ def report_basic_info(agent, auto) -> None:
50
+ outline_length = sac.segmented(
51
+ items=[
52
+ sac.SegmentedItem(label='简要'),
53
+ sac.SegmentedItem(label='标准'),
54
+ sac.SegmentedItem(label='详细'),
55
+ ],
56
+ label='详细程度', index=1, align='center',
57
+ size='sm', radius='sm', use_container_width=True
58
+ )
59
+ agent.save_outline_length(outline_length)
60
+
61
+ c1, c2 = st.columns(2)
62
+ with c1:
63
+ date = st.date_input("报告日期", datetime.date(2025, 10, 1))
64
+ agent.save_date(date)
65
+ with c2:
66
+ name = st.text_input("报告作者", "Anystat")
67
+ agent.save_name(name)
68
+
69
+ c1, c2 = st.columns([3, 1])
70
+ with c1:
71
+ report_format = sac.chip(
72
+ items=[
73
+ sac.ChipItem(label='Word', icon=sac.BsIcon(name='file-earmark-word', size=16)),
74
+ sac.ChipItem(label='HTML', icon=sac.BsIcon(name='filetype-html', size=16)),
75
+ sac.ChipItem(label='Markdown', icon=sac.BsIcon(name='file-earmark-code', size=16)),
76
+ ],
77
+ label='选择报告生成格式', index=[0, 2],
78
+ align='start', radius='md', multiple=False,
79
+ )
80
+ agent.save_report_format(report_format)
81
+
82
+ with c2:
83
+ mode = sac.segmented(
84
+ items=[
85
+ sac.SegmentedItem(label='并行'),
86
+ sac.SegmentedItem(label='串行'),
87
+ ],
88
+ label='生成模式', align='end', size='sm',
89
+ use_container_width=True, radius='md'
90
+ )
91
+ agent.save_gen_mode(mode)
92
+
93
+ user_input = st.text_input("报告生成要求", "默认")
94
+ agent.save_user_input(user_input)
95
+
96
+ not_generated = report_agent.load_outline() is None
97
+
98
+ # === 并行生成目录 ===
99
+ if st.button("🗒️ 生成目录") or (auto and not_generated):
100
+ with st.spinner("⏳ 正在自动生成目录结构..."):
101
+ summaries = []
102
+
103
+ # === 保存当前 Streamlit 状态副本 ===
104
+ session_snapshot = dict(st.session_state)
105
+
106
+ def process_summary(idx, sub_agent, session_snapshot):
107
+ """并行执行 summary_html/summary_word(带状态复制)"""
108
+ # 恢复 session_state
109
+ for k, v in session_snapshot.items():
110
+ st.session_state[k] = v
111
+
112
+ # 实际生成逻辑
113
+ if hasattr(sub_agent, "summary_html"):
114
+ summary = sub_agent.summary_html()
115
+ else:
116
+ summary = sub_agent.summary_word()
117
+
118
+ return idx, summary
119
+
120
+ max_workers = min(6, len(agents) - 1)
121
+ results = []
122
+
123
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
124
+ futures = {
125
+ executor.submit(process_summary, i, sub_agent, session_snapshot): i
126
+ for i, sub_agent in enumerate(agents[:-1])
127
+ }
128
+
129
+ for future in stqdm(as_completed(futures), total=len(futures)):
130
+ try:
131
+ idx, summary = future.result()
132
+ if summary:
133
+ results.append((idx, summary))
134
+ except Exception as e:
135
+ print(f"子模块摘要生成失败: {e}")
136
+
137
+ # === 恢复章节原顺序 ===
138
+ results.sort(key=lambda x: x[0])
139
+ summaries = [summary for _, summary in results if summary]
140
+
141
+ # === 生成目录 ===
142
+ default_toc = report_agent.generate_toc_from_summary(summaries)
143
+ report_agent.save_outline(default_toc)
144
+
145
+
146
+ def report_outline(agents):
147
+
148
+ st.subheader("目录结构预览与编辑")
149
+ load_agent, preproc_agent, visual_agent, coding_agent, report_agent = agents[0], agents[1], agents[2], agents[3], agents[4]
150
+
151
+ default_toc = report_agent.load_outline()
152
+
153
+ toc_md = st.text_area(
154
+ "您可以在此处编辑目录结构",
155
+ value=default_toc,
156
+ height=260
157
+ )
158
+
159
+ report_agent.save_outline(toc_md)
160
+
161
+
162
+ def report_execution(report_agent):
163
+
164
+ if report_agent.load_report_format() == 'Word':
165
+
166
+ full_report = report_agent.load_word()
167
+ if full_report is not None:
168
+ st.download_button(
169
+ label="⬇️ 下载 Word 报告",
170
+ data=full_report,
171
+ file_name="report.docx",
172
+ mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
173
+ )
174
+
175
+ elif report_agent.load_report_format() == 'HTML':
176
+
177
+ full_report = report_agent.load_html()
178
+
179
+ if full_report is not None:
180
+ st.download_button(
181
+ label="⬇️ 下载 HTML 报告",
182
+ data=full_report.encode("utf-8"),
183
+ file_name="report.html",
184
+ mime="text/html",
185
+ )
186
+
187
+ if st.button("⬇️ 下载 PDF 报告"):
188
+ html_dowmload(full_report)
189
+ elif report_agent.load_report_format() == 'Markdown':
190
+
191
+ full_report = report_agent.load_markdown()
192
+ if full_report is not None:
193
+
194
+ # 提供下载按钮
195
+ st.download_button(
196
+ label="⬇️ 下载 Markdown 报告",
197
+ data=full_report,
198
+ file_name="report.md",
199
+ mime="text/markdown"
200
+ )
201
+
202
+
203
+ if __name__ == "__main__":
204
+
205
+ st.title("报告生成")
206
+
207
+ st.markdown("---")
208
+
209
+ load_agent = st.session_state.data_loading_agent
210
+ preproc_agent = st.session_state.data_preprocess_agent
211
+ visual_agent = st.session_state.visualization_agent
212
+ coding_agent = st.session_state.modeling_coding_agent
213
+ planner = st.session_state.planner_agent
214
+ auto = planner.report_auto
215
+
216
+ processed_df = preproc_agent.load_processed_df()
217
+ if processed_df is None:
218
+ df = load_agent.load_df()
219
+ else:
220
+ df = processed_df
221
+
222
+ if df is None:
223
+ st.warning("⚠️ 请先在数据导入页面加载数据")
224
+ st.stop()
225
+
226
+ if isinstance(df, np.ndarray):
227
+ df = pd.DataFrame(df)
228
+
229
+ df_shuffled = df.sample(frac=1, random_state=42).reset_index(drop=True)
230
+
231
+ report_agent = st.session_state.report_agent
232
+ report_agent.add_df(df_shuffled)
233
+
234
+ agents = [load_agent, preproc_agent, visual_agent, coding_agent, report_agent]
235
+
236
+ c = st.columns(2)
237
+ with c[0].expander('报告设置', True):
238
+ report_basic_info(report_agent, auto)
239
+
240
+ with c[1].expander('报告大纲', True):
241
+ report_outline(agents)
242
+ report_save(agents, auto)
243
+ report_execution(report_agent)
workflow/report/report_utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import base64
3
+
4
+ import streamlit as st
5
+ from playwright.sync_api import sync_playwright
6
+
7
+
8
+ def html_to_pdf_bytes_playwright(html: str) -> bytes:
9
+ with sync_playwright() as p:
10
+ browser = p.chromium.launch()
11
+ page = browser.new_page()
12
+ page.set_content(html, wait_until="load")
13
+ pdf_bytes = page.pdf(format="A4", print_background=True)
14
+ browser.close()
15
+ return pdf_bytes
16
+
17
+
18
+ def html_dowmload(full_report):
19
+
20
+ try:
21
+ pdf_bytes = html_to_pdf_bytes_playwright(full_report)
22
+ except Exception as e:
23
+ st.error(f"生成 PDF 出错:{e}")
24
+ else:
25
+ b64 = base64.b64encode(pdf_bytes).decode("utf-8")
26
+
27
+ auto_download_html = f"""
28
+ <html>
29
+ <body>
30
+ <a id="dl_link"
31
+ href="data:application/pdf;base64,{b64}"
32
+ download="report.pdf"
33
+ style="display:none">download</a>
34
+ <script>
35
+ (function() {{
36
+ const a = document.getElementById('dl_link');
37
+ try {{
38
+ a.click();
39
+ }} catch (err) {{
40
+ // 如果自动点击被阻止,替换页面内容并露出手动链接
41
+ document.body.innerHTML =
42
+ '<p>自动下载被浏览器阻止,请点击下面链接手动下载:</p>' + a.outerHTML;
43
+ }}
44
+ }})();
45
+ </script>
46
+ </body>
47
+ </html>
48
+ """
49
+
50
+ st.components.v1.html(auto_download_html, height=120)
51
+
52
+ st.download_button(
53
+ label="⬇️ 手动下载 PDF(回退)",
54
+ data=pdf_bytes,
55
+ file_name="report.pdf",
56
+ mime="application/pdf",
57
+ )
58
+
59
+ st.success("PDF 已生成(如未自动下载,请使用上方手动下载按钮)。")
workflow/report/report_word.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import io
3
+ import re
4
+ from io import BytesIO
5
+
6
+ import streamlit as st
7
+ from stqdm import stqdm
8
+ from docx import Document
9
+ from docx.oxml.ns import qn
10
+ from docx.shared import Inches
11
+ from docx.enum.table import WD_TABLE_ALIGNMENT
12
+ from docx.enum.text import WD_ALIGN_PARAGRAPH
13
+ import plotly.express as px
14
+ import plotly.io as pio
15
+
16
+ from utils.sanitize_code import sanitize_code
17
+
18
+
19
+
20
+ def write_word(agents):
21
+ '''
22
+ choice:是否要搜索
23
+ True:根据目录搜索相关章节
24
+ False:全部章节
25
+ '''
26
+ # 拿图
27
+ analysis_list = agents[2].summary_fig_analysis_list()
28
+
29
+ report_agent = agents[-1]
30
+ report_obj = report_agent.load_report() # Reportcore
31
+
32
+ doc = Document()
33
+
34
+ style = doc.styles['Normal']
35
+
36
+ style.font.name = 'Times New Roman'
37
+ style._element.rPr.rFonts.set(qn('w:eastAsia'), '微软雅黑')
38
+
39
+ def _insert_figure(fig_obj):
40
+ if fig_obj is None:
41
+ return
42
+ try:
43
+ img_bytes = io.BytesIO()
44
+ img_bytes = io.BytesIO(fig_obj.to_image(format="png"))
45
+ # pio.write_image(fig_obj, img_bytes, format='png')
46
+ img_bytes.seek(0)
47
+
48
+ paragraph = doc.add_paragraph()
49
+ run = paragraph.add_run()
50
+ run.add_picture(img_bytes, width=Inches(4))
51
+ paragraph.alignment = WD_ALIGN_PARAGRAPH.CENTER
52
+ except Exception as e:
53
+ doc.add_paragraph(f"[图像插入失败: {e}]")
54
+
55
+ def _process_node(node):
56
+ if node.type == "heading":
57
+ doc.add_heading(node.text, level=node.level)
58
+ for ch in node.children:
59
+ _process_node(ch)
60
+
61
+ elif node.type == "paragraph":
62
+ parts = re.split(r'(\[FIG:\d+\])', node.text)
63
+
64
+ for part in parts:
65
+ part = part.strip()
66
+ if not part:
67
+ continue
68
+ if part.startswith("[FIG:") and part.endswith("]"):
69
+ idx = int(part[5:-1])
70
+ fig_obj = None
71
+ if 0 <= idx < len(analysis_list):
72
+ entry = analysis_list[idx]
73
+ fig_obj = entry.get("figure")
74
+ _insert_figure(fig_obj)
75
+ else:
76
+ doc.add_paragraph(part)
77
+
78
+ else: # root
79
+ for ch in node.children:
80
+ _process_node(ch)
81
+
82
+ # 从 root.children 开始写
83
+ _process_node(report_obj.root)
84
+
85
+ buf = io.BytesIO()
86
+ doc.save(buf)
87
+ buf.seek(0)
88
+ report_agent.save_word(buf.getvalue())
89
+ st.success("Word 报告生成成功")
workflow/visualization/viz_coding.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import traceback
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import plotly.express as px
7
+ import plotly.graph_objs as go
8
+ import streamlit as st
9
+ from stqdm import stqdm
10
+ from streamlit_ace import st_ace
11
+ import streamlit_antd_components as sac
12
+
13
+ from utils.sanitize_code import sanitize_code
14
+
15
+
16
+ def vis_code_gen(agent, debug = False, auto = False) -> None:
17
+
18
+ df = agent.load_df()
19
+ suggest = agent.load_suggestion()
20
+ user_input = agent.load_user_input()
21
+
22
+ chat_history = agent.load_memory()
23
+ already_generated = any(
24
+ entry["role"] == "assistant" and "训练脚本已更新!请重新运行代码!" in str(entry["content"])
25
+ for entry in chat_history
26
+ )
27
+
28
+ if suggest is not None:
29
+ if debug == True or (auto and not already_generated):
30
+ with st.spinner("可视化 Agent 正在编写脚本..."):
31
+ raw = agent.code_generation(
32
+ df.head().to_string(),
33
+ suggest,
34
+ )
35
+ code = sanitize_code(raw)
36
+ agent.save_code(code)
37
+ st.chat_message("assistant").write("训练脚本已更新!请重新运行代码!")
38
+ agent.add_memory({"role": "assistant", "content": "训练脚本已更新!请重新运行代码!"})
39
+ st.rerun()
40
+
41
+ analyze_btn = st.button("🔧 生成可视化代码", key="viz_code")
42
+ if analyze_btn:
43
+ with st.spinner("可视化 Agent 正在编写脚本..."):
44
+ raw = agent.code_generation(
45
+ df.head().to_string(),
46
+ suggest,
47
+ )
48
+ code = sanitize_code(raw)
49
+ agent.save_code(code)
50
+ st.chat_message("assistant").write("训练脚本已更新!请重新运行代码!")
51
+ agent.add_memory({"role": "assistant", "content": "训练脚本已更新!请重新运行代码!"})
52
+ st.rerun()
53
+
54
+
55
+ def vis_execution(agent, auto = False):
56
+
57
+ df = agent.load_df()
58
+
59
+ exec_ns = {
60
+ "df": df,
61
+ "np": np,
62
+ "pd": pd,
63
+ "px": px,
64
+ "go": go,
65
+ }
66
+
67
+ code = agent.load_code()
68
+ edited = st_ace(
69
+ value=code,
70
+ height=450,
71
+ theme="tomorrow_night",
72
+ language="python",
73
+ auto_update=True
74
+ )
75
+ desc_switch = sac.switch(label='附加分析', value=False, off_label='Off')
76
+ if code is not None:
77
+ not_executed = agent.load_fig() == []
78
+ # 当点击按钮,或者 auto=True 且尚未执行过时才执行
79
+ if st.button("▶️ 执行可视化") or (auto and not_executed):
80
+ code = sanitize_code(edited)
81
+ agent.save_code(code)
82
+ try:
83
+ with st.spinner("正在运行可视化脚本..."):
84
+ exec(code, exec_ns)
85
+ except Exception as exc:
86
+ st.error(f"已记录报错内容,正在为您debug...")
87
+ st.text(traceback.format_exc())
88
+ agent.save_error(traceback.format_exc())
89
+ vis_code_gen(agent, debug=True)
90
+ else:
91
+ fig_dict = exec_ns.get("fig_dict")
92
+ if not fig_dict or not isinstance(fig_dict, dict):
93
+ st.error(
94
+ "脚本未写入 `fig_dict` 或格式不正确。请确保末尾赋值 `fig_dict`,且它是一个 {列名: 图表} 的 dict。"
95
+ )
96
+ agent.save_error(traceback.format_exc())
97
+ vis_code_gen(agent, debug=True)
98
+ else:
99
+ with st.spinner("正在处理可视化结果..."):
100
+ for col_name, fig in stqdm(fig_dict.items()):
101
+ dtype_info = ", ".join(
102
+ f"{c}: {df[c].dtype}" for c in df.columns
103
+ )
104
+ if desc_switch == True:
105
+ desc = agent.desc_fig(fig, dtype_info)
106
+ else:
107
+ desc = None
108
+ agent.add_fig(fig, desc)
109
+ agent.finish_auto()
110
+ st.rerun()
workflow/visualization/viz_color.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ PALETTES = {
4
+ "Classic": [
5
+ "#2B5C8A", "#4F81AF", "#77ACD3", "#D9D5C9", "#F69035"
6
+ ],
7
+ "Ocean Breeze": [
8
+ "#03045E", "#0077B6", "#00B4D8", "#90E0EF", "#CAF0F8"
9
+ ],
10
+ "Olive Garden Feast": [
11
+ "#606C38", "#283618", "#FEFAE0", "#DDA15E", "#BC6C25"
12
+ ],
13
+ "Fiery Ocean": [
14
+ "#780000", "#C1121F", "#FDF0D5", "#003049", "#669BBC"
15
+ ],
16
+ "Refreshing Summer Fun": [
17
+ "#8ECAE6", "#219EBC", "#023047", "#FFB703", "#FB8500"
18
+ ],
19
+ "Golden Summer Fields": [
20
+ "#CCD5AE", "#E9EDC9", "#FEFAE0", "#FAEDCD", "#D4A373"
21
+ ],
22
+ "Deep Sea": [
23
+ "#0D1B2A", "#1B263B", "#415A77", "#778DA9", "#E0E1DD"
24
+ ],
25
+ "Bold Berry": [
26
+ "#F9DBBD", "#FFA5AB", "#DA627D", "#A53860", "#450920"
27
+ ],
28
+ "Fresh Greens": [
29
+ "#D8F3DC", "#95D5B2", "#52B788", "#2D6A4F", "#1B4332"
30
+ ],
31
+ "Deep Sea": [
32
+ "#EDEDE9", "#D6CCC2", "#F5EBE0", "#E3D5CA", "#D5BDAF"
33
+ ],
34
+ }
35
+
36
+ def vis_palette(agent):
37
+
38
+ choice = st.selectbox("请选择配色方案", list(PALETTES.keys()))
39
+ colors = PALETTES[choice]
40
+
41
+ cols = st.columns(len(colors))
42
+ for col, code in zip(cols, colors):
43
+ col.markdown(
44
+ f"""
45
+ <div style="
46
+ background-color: {code};
47
+ height: 30px;
48
+ border-radius: 4px;
49
+ margin-bottom: 2px;
50
+ "></div>
51
+ <div style="text-align: center; font-size: 10px;">{code}</div>
52
+ """,
53
+ unsafe_allow_html=True
54
+ )
55
+
56
+ agent.save_color(colors)
57
+
58
+ return colors
workflow/visualization/viz_quick_action.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import plotly.express as px
3
+
4
+
5
+ def plot_for_option(df, option: str, column: str):
6
+
7
+ series = df[column]
8
+
9
+ if option == "直方图":
10
+ fig = px.histogram(df, x=column, title=f"{column} 的直方图")
11
+ elif option == "饼图":
12
+ counts = series.value_counts().reset_index()
13
+ counts.columns = [column, 'count']
14
+ fig = px.pie(counts, names=column, values='count', title=f"{column} 的饼图")
15
+ elif option == "折线图":
16
+ fig = px.line(df, y=column, title=f"{column} 的折线图")
17
+ elif option == "箱线图":
18
+ fig = px.box(df, y=column, title=f"{column} 的箱线图")
19
+ else:
20
+ st.error("未知的图表类型")
21
+ return
22
+
23
+ return fig
workflow/visualization/viz_render.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from PIL import Image
6
+ import plotly.graph_objs as go
7
+ import streamlit as st
8
+ import streamlit_antd_components as sac
9
+
10
+ from utils.sanitize_code import sanitize_code
11
+ from workflow.visualization.viz_suggestion import vis_button_suggest, vis_talk_suggest
12
+ from workflow.visualization.viz_coding import vis_execution, vis_code_gen
13
+ from workflow.visualization.viz_quick_action import plot_for_option
14
+ from workflow.visualization.viz_color import vis_palette
15
+
16
+ def vis_quick_actions(agent):
17
+
18
+ cols_list = agent.load_df().columns.tolist()
19
+ options = ["直方图", "饼图", "箱线图", "折线图"]
20
+
21
+ selected_col = st.selectbox("选择一个列:", cols_list)
22
+
23
+ logo_dir = r"logo\sec3"
24
+ logo_paths = {opt: os.path.join(logo_dir, f"{opt}.png") for opt in options}
25
+
26
+ cols = st.columns(len(options))
27
+
28
+ fig_placeholder = st.empty()
29
+
30
+ for idx, opt in enumerate(options):
31
+ with cols[idx]:
32
+ left, center, right = st.columns([1, 8, 1])
33
+ with center:
34
+ st.write(opt)
35
+ path = logo_paths.get(opt)
36
+ if path and os.path.exists(path):
37
+ st.image(Image.open(path), width=80)
38
+ else:
39
+ st.text("Logo 文件未找到")
40
+
41
+ if st.button("Try me", key=f"try_{idx}"):
42
+ fig = plot_for_option(agent.load_df(), opt, selected_col)
43
+ fig_placeholder.plotly_chart(fig, use_container_width=True)
44
+
45
+
46
+ def vis_result(agent) -> None:
47
+
48
+ fig_desc_list = agent.load_fig()
49
+ total = len(fig_desc_list)
50
+ PAGE_SIZE = 5
51
+
52
+ current_page = sac.pagination(
53
+ total=total,
54
+ page_size=PAGE_SIZE,
55
+ align='center',
56
+ jump=False,
57
+ show_total=True,
58
+ variant='filled',
59
+ color='#44658C'
60
+ )
61
+
62
+ start_idx = (current_page - 1) * PAGE_SIZE
63
+ end_idx = min(start_idx + PAGE_SIZE, total)
64
+ page_items = fig_desc_list[start_idx:end_idx]
65
+
66
+ for offset, item in enumerate(page_items):
67
+
68
+ idx = start_idx + offset
69
+ fig = item["fig"]
70
+ desc = item["desc"]
71
+
72
+ st.plotly_chart(
73
+ fig,
74
+ use_container_width=True,
75
+ key=f"fig_{idx}"
76
+ )
77
+ if desc is not None:
78
+ st.write(desc)
79
+
80
+
81
+ def vis_chat(agent, auto = False):
82
+
83
+ msg = st.chat_message("assistant")
84
+ msg.write(
85
+ "我是 Anystat 数据分析助手,很高兴为您服务!\n\n"
86
+ "您可以在下方对话框输入具体可视化需求,"
87
+ "也可以点击下面的按钮,一键获取可视化建议并绘图。"
88
+ )
89
+ analyze_clicked = msg.button("🔍 可视化推荐", key="viz_suggest")
90
+ reply_placeholder = msg.empty()
91
+
92
+ chat_history = agent.load_memory()
93
+
94
+ for idx, entry in enumerate(chat_history):
95
+ bubble = st.chat_message(entry["role"])
96
+ content = entry["content"]
97
+ if isinstance(content, str):
98
+ bubble.write(content)
99
+ elif isinstance(content, go.Figure):
100
+ bubble.plotly_chart(
101
+ content,
102
+ use_container_width=True,
103
+ key=f"chart-{idx}"
104
+ )
105
+
106
+ already_generated = any(
107
+ entry["role"] == "assistant" and "图" in str(entry["content"])
108
+ for entry in chat_history
109
+ )
110
+
111
+ # 按钮路径
112
+ if analyze_clicked or (auto and not already_generated):
113
+ st.chat_message("user").write("请帮我做可视化分析")
114
+ agent.add_memory({'role': 'user', 'content': "请帮我做可视化分析"})
115
+ with st.spinner("正在处理您的请求..."):
116
+ rec = vis_button_suggest(agent)
117
+ agent.save_suggestion(rec)
118
+ st.chat_message("assistant").write(rec)
119
+ agent.add_memory({"role": "assistant", "content": str(rec)})
120
+
121
+ # 对话路径
122
+ reply = None
123
+ user_input = None
124
+ user_input = st.chat_input("请输入需求,例如'请给我一些可视化建议'")
125
+ if user_input is not None:
126
+ st.chat_message("user").write(user_input)
127
+ with st.spinner("正在处理您的请求..."):
128
+ agent.save_user_input(user_input)
129
+ agent.add_memory({"role": "user", "content": user_input})
130
+ rec = vis_talk_suggest(agent, user_input)
131
+ agent.save_suggestion(rec)
132
+ st.chat_message("assistant").write(rec)
133
+ agent.add_memory({"role": "assistant", "content": str(rec)})
134
+ st.rerun()
135
+
136
+
137
+ if __name__ == "__main__":
138
+
139
+ st.title("统计可视化分析")
140
+ st.markdown("---")
141
+
142
+ preproc_agent = st.session_state.data_preprocess_agent
143
+ load_agent = st.session_state.data_loading_agent
144
+ planner = st.session_state.planner_agent
145
+ auto = planner.vis_auto
146
+
147
+ processed_df = preproc_agent.load_processed_df()
148
+ if processed_df is None:
149
+ df = load_agent.load_df()
150
+ else:
151
+ df = processed_df
152
+
153
+ if df is None:
154
+ st.warning("⚠️ ���先在数据导入页面加载数据")
155
+ st.stop()
156
+
157
+ if isinstance(df, np.ndarray):
158
+ df = pd.DataFrame(df)
159
+
160
+ df_shuffled = df.sample(frac=1, random_state=42).reset_index(drop=True)
161
+ agent = st.session_state.visualization_agent
162
+ agent.add_df(df_shuffled)
163
+
164
+ if st.session_state.auto_mode == True:
165
+ if (agent.finish_auto_task == True and planner.switched_vis == False) or planner.vis_auto == False:
166
+ planner.finish_vis_auto()
167
+ st.switch_page("workflow/modeling/modeling_render.py")
168
+
169
+ code = agent.load_code()
170
+ if code is None:
171
+ code_expand = False
172
+ else:
173
+ code_expand = True
174
+
175
+ fig = agent.load_fig()
176
+ if fig is None:
177
+ fig_expand = False
178
+ else:
179
+ fig_expand = True
180
+
181
+ c = st.columns(2)
182
+ # with c[1].expander('快速可视化', False):
183
+ # vis_quick_actions(agent)
184
+ with c[0].expander('配色选择', True):
185
+ vis_palette(agent)
186
+ with c[1].expander('可视化建议', True):
187
+ vis_chat(agent, auto)
188
+ vis_code_gen(agent, auto = auto)
189
+ with c[0].expander('可视化执行', code_expand):
190
+ vis_execution(agent, auto = auto)
191
+ with c[0].expander('可视化结果', fig_expand):
192
+ vis_result(agent)
workflow/visualization/viz_suggestion.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def vis_button_suggest(agent):
5
+ """
6
+ 按钮路径:调用 LLM 获取结构化的可视化推荐(JSON)。
7
+ """
8
+ df = agent.load_df()
9
+ cols_wo_id = agent.load_cols_wo_id()
10
+
11
+ if cols_wo_id is None:
12
+ cols_wo_id = [str(c) for c in df.columns if not str(c).lower().startswith(('id', 'idx', 'index'))]
13
+ agent.save_cols_wo_id(cols_wo_id)
14
+
15
+ rec = agent.get_visualization_recommendations(cols_wo_id)
16
+
17
+ agent.save_recommendations(rec)
18
+ agent.refine_suggestions(rec)
19
+
20
+ return rec
21
+
22
+
23
+ def vis_talk_suggest(agent, user_input):
24
+ """
25
+ 对话路径:根据对话获取建议
26
+ """
27
+ df = agent.load_df()
28
+ cols_wo_id = agent.load_cols_wo_id()
29
+
30
+ if cols_wo_id is None:
31
+ cols_wo_id = [c for c in df.columns if not c.lower().startswith(('id', '编号', '序号', 'index'))]
32
+ agent.save_cols_wo_id(cols_wo_id)
33
+
34
+ rec = agent.get_visualization_recommendations(cols_wo_id, user_input)
35
+ agent.save_recommendations(rec)
36
+ agent.refine_suggestions(rec)
37
+
38
+ return rec