Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- Dockerfile +2 -2
- app.py +215 -0
- components/__init__.py +0 -0
- logo/logo_16_9.png +3 -0
- logo/logo_big.png +3 -0
- logo/logo_blue_wide.png +3 -0
- logo/logo_wide.png +3 -0
- logo/sec3//346/212/230/347/272/277/345/233/276.png +0 -0
- logo/sec3//347/233/264/346/226/271/345/233/276.png +0 -0
- logo/sec3//347/256/261/347/272/277/345/233/276.png +0 -0
- logo/sec3//351/245/274/345/233/276.png +0 -0
- prompt_engineer/.DS_Store +0 -0
- prompt_engineer/call_llm.py +144 -0
- prompt_engineer/planner.py +177 -0
- prompt_engineer/sec1_call_llm.py +248 -0
- prompt_engineer/sec2_call_llm.py +374 -0
- prompt_engineer/sec3_call_llm.py +691 -0
- prompt_engineer/sec4_call_llm.py +606 -0
- prompt_engineer/sec5_call_llm.py +617 -0
- utils/content.py +13 -0
- utils/sanitize_code.py +47 -0
- utils/save_secrets.py +33 -0
- utils/spinner_pool.py +25 -0
- workflow/.DS_Store +0 -0
- workflow/dataloading/dataloading_core.py +287 -0
- workflow/dataloading/dataloading_render.py +210 -0
- workflow/modeling/model_inference.py +102 -0
- workflow/modeling/model_training.py +143 -0
- workflow/modeling/modeling_render.py +218 -0
- workflow/preprocessing/preprocessing_core.py +112 -0
- workflow/preprocessing/preprocessing_render.py +159 -0
- workflow/report/report_core.py +46 -0
- workflow/report/report_html.py +117 -0
- workflow/report/report_markdown.py +55 -0
- workflow/report/report_prepare_er.py +102 -0
- workflow/report/report_render.py +243 -0
- workflow/report/report_utils.py +59 -0
- workflow/report/report_word.py +89 -0
- workflow/visualization/viz_coding.py +110 -0
- workflow/visualization/viz_color.py +58 -0
- workflow/visualization/viz_quick_action.py +23 -0
- workflow/visualization/viz_render.py +192 -0
- workflow/visualization/viz_suggestion.py +38 -0
Dockerfile
CHANGED
|
@@ -11,10 +11,10 @@ RUN pip install --no-cache-dir --upgrade pip setuptools wheel \
|
|
| 11 |
&& pip install --no-cache-dir -r requirements.txt
|
| 12 |
|
| 13 |
# ========= 拷贝项目文件 =========
|
| 14 |
-
COPY tmp/ ./tmp/
|
| 15 |
|
| 16 |
# ========= 暴露 Streamlit 端口 =========
|
| 17 |
EXPOSE 8501
|
| 18 |
|
| 19 |
# ========= 启动命令 =========
|
| 20 |
-
CMD ["streamlit", "run", "
|
|
|
|
| 11 |
&& pip install --no-cache-dir -r requirements.txt
|
| 12 |
|
| 13 |
# ========= 拷贝项目文件 =========
|
| 14 |
+
# COPY tmp/ ./tmp/
|
| 15 |
|
| 16 |
# ========= 暴露 Streamlit 端口 =========
|
| 17 |
EXPOSE 8501
|
| 18 |
|
| 19 |
# ========= 启动命令 =========
|
| 20 |
+
CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
app.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys, os
|
| 2 |
+
import tempfile
|
| 3 |
+
import streamlit as st
|
| 4 |
+
|
| 5 |
+
from config import MODEL_CONFIGS
|
| 6 |
+
from utils.save_secrets import *
|
| 7 |
+
from prompt_engineer.sec1_call_llm import DataLoadingAgent
|
| 8 |
+
from prompt_engineer.sec2_call_llm import DataPreprocessAgent
|
| 9 |
+
from prompt_engineer.sec3_call_llm import VisualizationAgent
|
| 10 |
+
from prompt_engineer.sec4_call_llm import ModelingCodingAgent
|
| 11 |
+
from prompt_engineer.sec5_call_llm import ReportAgent
|
| 12 |
+
from prompt_engineer.planner import PlannerAgent
|
| 13 |
+
|
| 14 |
+
import warnings
|
| 15 |
+
warnings.filterwarnings("ignore")
|
| 16 |
+
warnings.filterwarnings("ignore", message="missing ScriptRunContext")
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
np.set_printoptions(edgeitems=250, threshold=501)
|
| 20 |
+
|
| 21 |
+
sys.path.append(os.path.dirname(__file__))
|
| 22 |
+
|
| 23 |
+
CACHE_FILE = os.path.join(tempfile.gettempdir(), "anystat_cache.pkl")
|
| 24 |
+
CACHE_DIR = './cache'
|
| 25 |
+
SECRETS_PATH = Path(".streamlit") / "secrets.toml"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# 设置页面配置
|
| 29 |
+
st.set_page_config(
|
| 30 |
+
page_title="AnyStat",
|
| 31 |
+
page_icon="🤖",
|
| 32 |
+
layout="wide"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def init_session_state():
|
| 37 |
+
|
| 38 |
+
if 'selected_model' not in st.session_state:
|
| 39 |
+
st.session_state.selected_model = "DeepSeek"
|
| 40 |
+
if "api_keys" not in st.session_state:
|
| 41 |
+
st.session_state.api_keys = load_local_api_keys()
|
| 42 |
+
if 'auto_mode' not in st.session_state:
|
| 43 |
+
st.session_state.auto_mode = False
|
| 44 |
+
|
| 45 |
+
if 'loading_start_time' not in st.session_state:
|
| 46 |
+
st.session_state.loading_start_time = None
|
| 47 |
+
if 'prep_start_time' not in st.session_state:
|
| 48 |
+
st.session_state.prep_start_time = None
|
| 49 |
+
if 'vis_start_time' not in st.session_state:
|
| 50 |
+
st.session_state.vis_start_time = None
|
| 51 |
+
if 'modeling_start_time' not in st.session_state:
|
| 52 |
+
st.session_state.modeling_start_time = None
|
| 53 |
+
if 'report_start_time' not in st.session_state:
|
| 54 |
+
st.session_state.report_start_time = None
|
| 55 |
+
|
| 56 |
+
if 'data_loading_agent' not in st.session_state:
|
| 57 |
+
st.session_state.data_loading_agent = DataLoadingAgent(
|
| 58 |
+
api_keys=st.session_state.api_keys,
|
| 59 |
+
model_configs=MODEL_CONFIGS,
|
| 60 |
+
model=st.session_state.selected_model
|
| 61 |
+
)
|
| 62 |
+
if 'data_preprocess_agent' not in st.session_state:
|
| 63 |
+
st.session_state.data_preprocess_agent = DataPreprocessAgent(
|
| 64 |
+
api_keys=st.session_state.api_keys,
|
| 65 |
+
model_configs=MODEL_CONFIGS,
|
| 66 |
+
model=st.session_state.selected_model
|
| 67 |
+
)
|
| 68 |
+
if 'visualization_agent' not in st.session_state:
|
| 69 |
+
st.session_state.visualization_agent = VisualizationAgent(
|
| 70 |
+
api_keys=st.session_state.api_keys,
|
| 71 |
+
model_configs=MODEL_CONFIGS,
|
| 72 |
+
model=st.session_state.selected_model
|
| 73 |
+
)
|
| 74 |
+
if 'modeling_coding_agent' not in st.session_state:
|
| 75 |
+
st.session_state.modeling_coding_agent = ModelingCodingAgent(
|
| 76 |
+
api_keys=st.session_state.api_keys,
|
| 77 |
+
model_configs=MODEL_CONFIGS,
|
| 78 |
+
model=st.session_state.selected_model
|
| 79 |
+
)
|
| 80 |
+
if 'report_agent' not in st.session_state:
|
| 81 |
+
st.session_state.report_agent = ReportAgent(
|
| 82 |
+
api_keys=st.session_state.api_keys,
|
| 83 |
+
model_configs=MODEL_CONFIGS,
|
| 84 |
+
model=st.session_state.selected_model
|
| 85 |
+
)
|
| 86 |
+
if 'planner_agent' not in st.session_state:
|
| 87 |
+
st.session_state.planner_agent = PlannerAgent(
|
| 88 |
+
api_keys=st.session_state.api_keys,
|
| 89 |
+
model_configs=MODEL_CONFIGS,
|
| 90 |
+
model=st.session_state.selected_model
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def on_model_selector_change():
|
| 95 |
+
"""
|
| 96 |
+
Callback when the model selector in the sidebar changes.
|
| 97 |
+
"""
|
| 98 |
+
st.session_state.selected_model = st.session_state.model_selector
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def run_app():
|
| 102 |
+
"""
|
| 103 |
+
Main entry point to render the Streamlit app.
|
| 104 |
+
"""
|
| 105 |
+
init_session_state()
|
| 106 |
+
with st.sidebar:
|
| 107 |
+
st.subheader("选择大模型")
|
| 108 |
+
models = list(MODEL_CONFIGS.keys())
|
| 109 |
+
st.selectbox(
|
| 110 |
+
"选择要使用的大模型",
|
| 111 |
+
models,
|
| 112 |
+
index=models.index(st.session_state.selected_model),
|
| 113 |
+
key="model_selector",
|
| 114 |
+
on_change=on_model_selector_change,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
st.subheader("API 密钥设置")
|
| 118 |
+
selected = st.session_state.selected_model
|
| 119 |
+
|
| 120 |
+
api_key_input = st.text_input(
|
| 121 |
+
f"{selected} API 密钥",
|
| 122 |
+
value=st.session_state.api_keys.get(selected, ""),
|
| 123 |
+
type="password",
|
| 124 |
+
key="api_key_input",
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if st.button("💾 保存密钥", use_container_width=True, key="save_key"):
|
| 129 |
+
# 保存在 utils/.streamlit/secrets.toml
|
| 130 |
+
update_local_api_key(selected, api_key_input)
|
| 131 |
+
|
| 132 |
+
st.session_state.api_keys[selected] = api_key_input
|
| 133 |
+
st.success("已保存")
|
| 134 |
+
st.rerun()
|
| 135 |
+
|
| 136 |
+
if st.button("🧹 清空数据", use_container_width=True, key="clear_data"):
|
| 137 |
+
|
| 138 |
+
st.session_state.data_loading_agent = DataLoadingAgent(
|
| 139 |
+
api_keys=st.session_state.api_keys,
|
| 140 |
+
model_configs=MODEL_CONFIGS,
|
| 141 |
+
model=st.session_state.selected_model
|
| 142 |
+
)
|
| 143 |
+
st.session_state.data_preprocess_agent = DataPreprocessAgent(
|
| 144 |
+
api_keys=st.session_state.api_keys,
|
| 145 |
+
model_configs=MODEL_CONFIGS,
|
| 146 |
+
model=st.session_state.selected_model
|
| 147 |
+
)
|
| 148 |
+
st.session_state.visualization_agent = VisualizationAgent(
|
| 149 |
+
api_keys=st.session_state.api_keys,
|
| 150 |
+
model_configs=MODEL_CONFIGS,
|
| 151 |
+
model=st.session_state.selected_model
|
| 152 |
+
)
|
| 153 |
+
st.session_state.modeling_coding_agent = ModelingCodingAgent(
|
| 154 |
+
api_keys=st.session_state.api_keys,
|
| 155 |
+
model_configs=MODEL_CONFIGS,
|
| 156 |
+
model=st.session_state.selected_model
|
| 157 |
+
)
|
| 158 |
+
st.session_state.report_agent = ReportAgent(
|
| 159 |
+
api_keys=st.session_state.api_keys,
|
| 160 |
+
model_configs=MODEL_CONFIGS,
|
| 161 |
+
model=st.session_state.selected_model
|
| 162 |
+
)
|
| 163 |
+
st.session_state.planner_agent = PlannerAgent(
|
| 164 |
+
api_keys=st.session_state.api_keys,
|
| 165 |
+
model_configs=MODEL_CONFIGS,
|
| 166 |
+
model=st.session_state.selected_model
|
| 167 |
+
)
|
| 168 |
+
st.session_state.auto_mode = False
|
| 169 |
+
st.rerun()
|
| 170 |
+
|
| 171 |
+
if st.session_state.data_loading_agent.load_df() is not None:
|
| 172 |
+
planner = st.session_state.planner_agent
|
| 173 |
+
if st.button("🚗 自动模式", use_container_width=True, key="self_driving"):
|
| 174 |
+
planner.self_driving(st.session_state.data_loading_agent.load_df())
|
| 175 |
+
st.session_state.auto_mode = True
|
| 176 |
+
st.rerun()
|
| 177 |
+
|
| 178 |
+
st.image(
|
| 179 |
+
"logo/logo_big.png",
|
| 180 |
+
use_container_width=True
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Define pages
|
| 184 |
+
data_loading = st.Page(
|
| 185 |
+
"workflow/dataloading/dataloading_render.py",
|
| 186 |
+
title="📥 数据导入",
|
| 187 |
+
)
|
| 188 |
+
preprocessing = st.Page(
|
| 189 |
+
"workflow/preprocessing/preprocessing_render.py",
|
| 190 |
+
title="⚙️ 数据预处理",
|
| 191 |
+
)
|
| 192 |
+
visualization = st.Page(
|
| 193 |
+
"workflow/visualization/viz_render.py",
|
| 194 |
+
title="📊 数据可视化",
|
| 195 |
+
)
|
| 196 |
+
report = st.Page(
|
| 197 |
+
"workflow/report/report_render.py",
|
| 198 |
+
title="📝 报告生成",
|
| 199 |
+
)
|
| 200 |
+
coding_modeling = st.Page(
|
| 201 |
+
"workflow/modeling/modeling_render.py",
|
| 202 |
+
title="🧠 建模分析",
|
| 203 |
+
)
|
| 204 |
+
# Navigation
|
| 205 |
+
pg = st.navigation(
|
| 206 |
+
{
|
| 207 |
+
"设置": [data_loading, preprocessing],
|
| 208 |
+
"功能": [visualization, coding_modeling, report],
|
| 209 |
+
}
|
| 210 |
+
)
|
| 211 |
+
pg.run()
|
| 212 |
+
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
run_app()
|
| 215 |
+
|
components/__init__.py
ADDED
|
File without changes
|
logo/logo_16_9.png
ADDED
|
Git LFS Details
|
logo/logo_big.png
ADDED
|
Git LFS Details
|
logo/logo_blue_wide.png
ADDED
|
Git LFS Details
|
logo/logo_wide.png
ADDED
|
Git LFS Details
|
logo/sec3//346/212/230/347/272/277/345/233/276.png
ADDED
|
logo/sec3//347/233/264/346/226/271/345/233/276.png
ADDED
|
logo/sec3//347/256/261/347/272/277/345/233/276.png
ADDED
|
logo/sec3//351/245/274/345/233/276.png
ADDED
|
prompt_engineer/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
prompt_engineer/call_llm.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from openai import OpenAI, OpenAIError
|
| 3 |
+
from anthropic import Anthropic, AnthropicError
|
| 4 |
+
import requests
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
import streamlit as st
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from config import MODEL_CONFIGS
|
| 11 |
+
from typing import IO, List, Dict
|
| 12 |
+
from zai import ZhipuAiClient
|
| 13 |
+
|
| 14 |
+
class LLMClient:
|
| 15 |
+
def __init__(self, model_configs: dict, api_keys: dict, model: str):
|
| 16 |
+
|
| 17 |
+
self.model = model
|
| 18 |
+
self.model_configs = model_configs
|
| 19 |
+
self.api_keys = api_keys
|
| 20 |
+
self.memory = []
|
| 21 |
+
self.df = None
|
| 22 |
+
|
| 23 |
+
def call(self, prompt) -> str:
|
| 24 |
+
|
| 25 |
+
model_name = st.session_state.selected_model
|
| 26 |
+
config = self.model_configs.get(model_name, {})
|
| 27 |
+
api_key = self.api_keys.get(model_name)
|
| 28 |
+
|
| 29 |
+
if not api_key:
|
| 30 |
+
return "请先在设置中配置 API 密钥"
|
| 31 |
+
|
| 32 |
+
system_msg = (
|
| 33 |
+
"你是一个专业的数据分析助手。"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
if model_name == "GPT-4o" or model_name == "GPT-5" or model_name == "DeepSeek" or model_name == "通义千问" or model_name == "Claude" or model_name == "豆包":
|
| 38 |
+
try:
|
| 39 |
+
client = OpenAI(
|
| 40 |
+
api_key=api_key,
|
| 41 |
+
base_url=config["api_base"]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# 使用新的 API 调用方式
|
| 45 |
+
resp = client.chat.completions.create(
|
| 46 |
+
model=config["model_name"],
|
| 47 |
+
messages=[
|
| 48 |
+
{"role": "system", "content": system_msg},
|
| 49 |
+
{"role": "user", "content": prompt},
|
| 50 |
+
],
|
| 51 |
+
stream = False
|
| 52 |
+
)
|
| 53 |
+
return resp.choices[0].message.content
|
| 54 |
+
|
| 55 |
+
except OpenAIError as e:
|
| 56 |
+
# 这里可以捕获所有OpenAI SDK定义的错误
|
| 57 |
+
st.error(f"API调用失败: {str(e)}")
|
| 58 |
+
# 记录日志或提示用户
|
| 59 |
+
return "调用失败,请检查密钥或网络"
|
| 60 |
+
except Exception as e:
|
| 61 |
+
# 捕获其他非预期的异常,如网络问题
|
| 62 |
+
st.error(f"发生未知错误: {str(e)}")
|
| 63 |
+
return "发生未知错误"
|
| 64 |
+
|
| 65 |
+
elif model_name == "智谱AI":
|
| 66 |
+
client = ZhipuAiClient(api_key=api_key)
|
| 67 |
+
response = client.chat.completions.create(
|
| 68 |
+
model=config["model_name"],
|
| 69 |
+
messages=[{"role": "system", "content": "你是一个专业的数据分析助手。"},
|
| 70 |
+
{"role": "user", "content": prompt}],
|
| 71 |
+
thinking={
|
| 72 |
+
"type":"enabled"
|
| 73 |
+
}
|
| 74 |
+
)
|
| 75 |
+
if response:
|
| 76 |
+
print(response.choices[0].message)
|
| 77 |
+
desc = response.choices[0].message.content if hasattr(response.choices[0].message, "content") else str(response.choices[0].message)
|
| 78 |
+
return desc.replace("<|begin_of_box|>", "").replace("<|end_of_box|>", "").strip()
|
| 79 |
+
|
| 80 |
+
st.error(f"智谱调用失败:{response.text}")
|
| 81 |
+
return "调用失败,请检查密钥或网络"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# elif model_name == "DeepSeek":
|
| 85 |
+
# client = OpenAI(
|
| 86 |
+
# api_key=api_key,
|
| 87 |
+
# base_url=config["api_base"])
|
| 88 |
+
|
| 89 |
+
# resp = client.chat.completions.create(
|
| 90 |
+
# model=config["model_name"],
|
| 91 |
+
# messages=[
|
| 92 |
+
# {"role": "system", "content": system_msg},
|
| 93 |
+
# {"role": "user", "content": prompt},
|
| 94 |
+
# ],
|
| 95 |
+
# stream=False
|
| 96 |
+
# )
|
| 97 |
+
# if resp:
|
| 98 |
+
# return resp.choices[0].message.content
|
| 99 |
+
# st.error(f"DeepSeek调用失败:{resp.text}")
|
| 100 |
+
# return "调用失败,请检查密钥或网络"
|
| 101 |
+
|
| 102 |
+
else:
|
| 103 |
+
return f"暂不支持模型:{model_name}"
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
st.error(f"{model_name} 调用异常:{e}")
|
| 107 |
+
return "大模型调用失败,请检查 API 密钥或网络连接"
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def add_memory(self, entry: Dict[str, str]) -> None:
|
| 111 |
+
|
| 112 |
+
self.memory.append(entry)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def load_memory(self) -> List[Dict[str, str]]:
|
| 116 |
+
|
| 117 |
+
return self.memory
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def clear_memory(self) -> None:
|
| 121 |
+
|
| 122 |
+
self.memory.clear()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def add_df(self, input_df) -> None:
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
self.df = input_df
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def load_df(self) -> pd.DataFrame:
|
| 133 |
+
|
| 134 |
+
return self.df
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def clear_df(self) -> None:
|
| 138 |
+
|
| 139 |
+
self.df = None
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def has_df(self) -> bool:
|
| 143 |
+
|
| 144 |
+
return self.df == None
|
prompt_engineer/planner.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
import streamlit as st
|
| 5 |
+
from typing import IO, List
|
| 6 |
+
|
| 7 |
+
from prompt_engineer.call_llm import LLMClient
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class PlannerAgent(LLMClient):
|
| 11 |
+
|
| 12 |
+
def __init__(self, *args, **kwargs):
|
| 13 |
+
|
| 14 |
+
super().__init__(*args, **kwargs)
|
| 15 |
+
self.loading_auto = False
|
| 16 |
+
self.prep_auto = False
|
| 17 |
+
self.vis_auto = False
|
| 18 |
+
self.modeling_auto = False
|
| 19 |
+
self.report_auto = False
|
| 20 |
+
|
| 21 |
+
self.switched_loading = False
|
| 22 |
+
self.switched_prep = False
|
| 23 |
+
self.switched_vis = False
|
| 24 |
+
self.switched_modeling = False
|
| 25 |
+
self.switched_report = False
|
| 26 |
+
|
| 27 |
+
def self_driving(self, df, user_input=None) -> str:
|
| 28 |
+
|
| 29 |
+
prompt = (
|
| 30 |
+
f"下面是一个数据集的基本信息,请你根据它和用户的需求,判断需要开启哪些分析步骤:\n\n"
|
| 31 |
+
f"- 数据维度:{df.shape[0]} 行 × {df.shape[1]} 列\n"
|
| 32 |
+
f"- 列名和数据类型:{dict(zip(df.columns.tolist(), df.dtypes.astype(str).tolist()))}\n"
|
| 33 |
+
f"- 前 5 行样本:\n{df.head().to_dict(orient='list')}\n\n"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
if user_input:
|
| 37 |
+
prompt += f"用户的具体需求是:“{user_input}”。\n\n"
|
| 38 |
+
|
| 39 |
+
prompt += """
|
| 40 |
+
你需要在以下 5 个步骤中,对每个步骤分别判断是否应该开启(True / False):
|
| 41 |
+
1. loading_auto —— 是否需要对数据列名进行初步分析?
|
| 42 |
+
2. prep_auto —— 是否需要做数据预处理或清洗?
|
| 43 |
+
3. vis_auto —— 是否需要做数据可视化?
|
| 44 |
+
4. modeling_auto —— 是否需要建模或统计分析?
|
| 45 |
+
5. report_auto —— 是否需要生成分析报告?
|
| 46 |
+
|
| 47 |
+
必须以 **JSON 格式** 输出你的判断结果,如:
|
| 48 |
+
{
|
| 49 |
+
"loading_auto": true,
|
| 50 |
+
"prep_auto": false,
|
| 51 |
+
"vis_auto": true,
|
| 52 |
+
"modeling_auto": true,
|
| 53 |
+
"report_auto": true
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
不要输出其他内容。
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
plan_text = self.call(prompt)
|
| 60 |
+
print(plan_text)
|
| 61 |
+
try:
|
| 62 |
+
plan_dict = json.loads(plan_text)
|
| 63 |
+
except json.JSONDecodeError:
|
| 64 |
+
plan_text_fixed = plan_text.strip().strip('```json').strip('```')
|
| 65 |
+
plan_dict = json.loads(plan_text_fixed)
|
| 66 |
+
|
| 67 |
+
print(plan_dict)
|
| 68 |
+
self.loading_auto = bool(plan_dict.get("loading_auto", False))
|
| 69 |
+
self.prep_auto = bool(plan_dict.get("prep_auto", False))
|
| 70 |
+
self.vis_auto = bool(plan_dict.get("vis_auto", False))
|
| 71 |
+
self.modeling_auto = bool(plan_dict.get("modeling_auto", False))
|
| 72 |
+
# self.modeling_auto = False
|
| 73 |
+
self.report_auto = bool(plan_dict.get("report_auto", False))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def finish_loading_auto(self) -> str:
|
| 77 |
+
|
| 78 |
+
self.switched_loading = True
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def finish_prep_auto(self) -> str:
|
| 82 |
+
|
| 83 |
+
self.switched_prep = True
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def finish_vis_auto(self) -> str:
|
| 87 |
+
|
| 88 |
+
self.switched_vis = True
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def finish_modeling_auto(self) -> str:
|
| 92 |
+
|
| 93 |
+
self.switched_modeling = True
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def finish_report_auto(self) -> str:
|
| 97 |
+
|
| 98 |
+
self.switched_report = True
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
import json
|
| 102 |
+
import ast
|
| 103 |
+
import re
|
| 104 |
+
import traceback
|
| 105 |
+
|
| 106 |
+
def _extract_first_json(text: str):
|
| 107 |
+
"""从 text 中提取第一个顶层花括号 JSON 子串(用配对计数法),找不到则返回 None。"""
|
| 108 |
+
if not text:
|
| 109 |
+
return None
|
| 110 |
+
start = text.find('{')
|
| 111 |
+
if start == -1:
|
| 112 |
+
return None
|
| 113 |
+
depth = 0
|
| 114 |
+
for i in range(start, len(text)):
|
| 115 |
+
ch = text[i]
|
| 116 |
+
if ch == '{':
|
| 117 |
+
depth += 1
|
| 118 |
+
elif ch == '}':
|
| 119 |
+
depth -= 1
|
| 120 |
+
if depth == 0:
|
| 121 |
+
return text[start:i+1]
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
def _safe_parse_json(text: str):
|
| 125 |
+
"""
|
| 126 |
+
尝试多种策略解析 LLM 输出为 dict:
|
| 127 |
+
1) 直接 json.loads
|
| 128 |
+
2) 去除 Markdown code fence 后再 loads
|
| 129 |
+
3) 提取第一个完整花括号块后 loads
|
| 130 |
+
4) ast.literal_eval 作为最后手段(接受 Python dict 风格)
|
| 131 |
+
返回 (dict_or_None, used_text, error_message_or_None)
|
| 132 |
+
"""
|
| 133 |
+
if not text or not text.strip():
|
| 134 |
+
return None, text, "empty"
|
| 135 |
+
# 1) 直接尝试
|
| 136 |
+
try:
|
| 137 |
+
return json.loads(text), text, None
|
| 138 |
+
except Exception as e1:
|
| 139 |
+
pass
|
| 140 |
+
|
| 141 |
+
# 2) 去掉 ```json / ``` fence
|
| 142 |
+
try:
|
| 143 |
+
cleaned = re.sub(r'```json\s*', '', text, flags=re.IGNORECASE)
|
| 144 |
+
cleaned = re.sub(r'```', '', cleaned)
|
| 145 |
+
cleaned = cleaned.strip()
|
| 146 |
+
return json.loads(cleaned), cleaned, None
|
| 147 |
+
except Exception:
|
| 148 |
+
pass
|
| 149 |
+
|
| 150 |
+
# 3) 提取首个匹配的 { ... } 顶层块
|
| 151 |
+
try:
|
| 152 |
+
sub = _extract_first_json(text)
|
| 153 |
+
if sub:
|
| 154 |
+
return json.loads(sub), sub, None
|
| 155 |
+
except Exception:
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
# 4) ast.literal_eval 兼容 Python 字典格式(单引号等)
|
| 159 |
+
try:
|
| 160 |
+
literal = ast.literal_eval(text)
|
| 161 |
+
if isinstance(literal, dict):
|
| 162 |
+
return literal, text, None
|
| 163 |
+
except Exception:
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
# 5) 再次尝试在提取的子串上用 literal_eval(防止单引号)
|
| 167 |
+
try:
|
| 168 |
+
sub = _extract_first_json(text)
|
| 169 |
+
if sub:
|
| 170 |
+
literal = ast.literal_eval(sub)
|
| 171 |
+
if isinstance(literal, dict):
|
| 172 |
+
return literal, sub, None
|
| 173 |
+
except Exception:
|
| 174 |
+
pass
|
| 175 |
+
|
| 176 |
+
# 最后,返回 None 并带上错误信息
|
| 177 |
+
return None, text, "unable_to_parse"
|
prompt_engineer/sec1_call_llm.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
from typing import IO, List
|
| 5 |
+
|
| 6 |
+
from prompt_engineer.call_llm import LLMClient
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DataLoadingAgent(LLMClient):
|
| 10 |
+
|
| 11 |
+
def __init__(self, *args, **kwargs):
|
| 12 |
+
|
| 13 |
+
super().__init__(*args, **kwargs)
|
| 14 |
+
self.file_name = []
|
| 15 |
+
self.user_input = None
|
| 16 |
+
self.par_content = ""
|
| 17 |
+
self.dfs = None
|
| 18 |
+
self.abstract=None
|
| 19 |
+
self.full = None
|
| 20 |
+
self.finish_auto_task = False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def finish_auto(self):
|
| 24 |
+
|
| 25 |
+
self.finish_auto_task = True
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def save_file_name(self, file_name):
|
| 29 |
+
|
| 30 |
+
self.file_name.append(file_name)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_file_name(self):
|
| 34 |
+
|
| 35 |
+
return self.file_name
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def save_dfs(self, dfs):
|
| 39 |
+
|
| 40 |
+
self.dfs = (dfs)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_dfs(self):
|
| 44 |
+
|
| 45 |
+
return self.dfs
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def clear_file_name(self):
|
| 49 |
+
|
| 50 |
+
self.file_name = []
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def read_names_from_file(self, uploaded_names_file, df_head):
|
| 54 |
+
"""
|
| 55 |
+
从上传的 .names/.arff 文件中提取属性名。
|
| 56 |
+
优先使用 LLM 识别 @attribute 行中的属性名;如果 LLM 调用失败,退回到正则解析。
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
raw = uploaded_names_file.read().decode('utf-8', errors='ignore')
|
| 60 |
+
try:
|
| 61 |
+
uploaded_names_file.seek(0)
|
| 62 |
+
except Exception:
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
prompt = (
|
| 66 |
+
"下面是上传的 names 和 df_head 文件内容,请仅以 Python 列表格式返回与df_head一一对应的所有属性(attribute)名称,"
|
| 67 |
+
"并保持顺序,不要添加多余文字,请注意,你只需要返回一个列表,不要出现任何markdown语法:\n```\n"
|
| 68 |
+
f"name文件:{raw}\n```"
|
| 69 |
+
f"df_head:{df_head}\n```"
|
| 70 |
+
)
|
| 71 |
+
try:
|
| 72 |
+
response = self.call(prompt)
|
| 73 |
+
names_list = eval(response.strip())
|
| 74 |
+
if isinstance(names_list, list) and all(isinstance(n, str) for n in names_list):
|
| 75 |
+
col_names = names_list
|
| 76 |
+
else:
|
| 77 |
+
raise ValueError("LLM 输出格式不正确")
|
| 78 |
+
except Exception:
|
| 79 |
+
|
| 80 |
+
col_names = []
|
| 81 |
+
attr_re = re.compile(
|
| 82 |
+
r"""^@attribute\s+
|
| 83 |
+
['"]?([^'"\s]+)['"]?
|
| 84 |
+
\s+.+
|
| 85 |
+
""",
|
| 86 |
+
re.IGNORECASE | re.VERBOSE
|
| 87 |
+
)
|
| 88 |
+
for line in raw.splitlines():
|
| 89 |
+
line = line.strip()
|
| 90 |
+
if not line:
|
| 91 |
+
continue
|
| 92 |
+
if line.lower().startswith('@data'):
|
| 93 |
+
break
|
| 94 |
+
m = attr_re.match(line)
|
| 95 |
+
if m:
|
| 96 |
+
col_names.append(m.group(1))
|
| 97 |
+
|
| 98 |
+
counts: dict[str, int] = {}
|
| 99 |
+
unique_names: List[str] = []
|
| 100 |
+
for name in col_names:
|
| 101 |
+
if name in counts:
|
| 102 |
+
counts[name] += 1
|
| 103 |
+
unique_names.append(f"{name}_{counts[name]}")
|
| 104 |
+
else:
|
| 105 |
+
counts[name] = 0
|
| 106 |
+
unique_names.append(name)
|
| 107 |
+
|
| 108 |
+
return unique_names
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def do_data_description(self, df, user_input=None, memory_limit=6):
|
| 112 |
+
|
| 113 |
+
recent_memory = self.memory[-memory_limit:] if self.memory else []
|
| 114 |
+
if recent_memory:
|
| 115 |
+
formatted_memory = "\n".join(
|
| 116 |
+
f"{m['role']}: {m['content']}" for m in recent_memory
|
| 117 |
+
)
|
| 118 |
+
memory_block = f"{formatted_memory}"
|
| 119 |
+
else:
|
| 120 |
+
memory_block = ""
|
| 121 |
+
|
| 122 |
+
prompt = (
|
| 123 |
+
"你是一名专业的数据分析助手,负责解释数据结构与业务含义。\n"
|
| 124 |
+
f"- 数据维度:{df.shape[0]} 行 × {df.shape[1]} 列\n"
|
| 125 |
+
f"- 列名和数据类型:{dict(zip(df.columns.tolist(), df.dtypes.astype(str).tolist()))}\n"
|
| 126 |
+
f"- 前 5 行样本:\n{df.head().to_dict(orient='list')}\n\n"
|
| 127 |
+
f"""- 数据解释聊天对话:
|
| 128 |
+
--- 开始聊天记录 ---
|
| 129 |
+
{memory_block}
|
| 130 |
+
--- 结束聊天记录 ---"""
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
if user_input is not None:
|
| 134 |
+
prompt += f"""
|
| 135 |
+
请严格依据用户需求“{user_input}”,对当前数据进行深入、系统的分析。
|
| 136 |
+
要求:
|
| 137 |
+
1. 分析内容必须与该需求完全对应,不能添加无关推断。
|
| 138 |
+
2. 结论要具体、清晰,可直接支持后续报告撰写或建模步骤。
|
| 139 |
+
3. 分析语言应专业、简洁,不使用模糊或情绪化表述。
|
| 140 |
+
"""
|
| 141 |
+
else:
|
| 142 |
+
prompt += """
|
| 143 |
+
以下是一个数据集的基本概览。请帮助我分析它的性质和结构,并回答以下问题:
|
| 144 |
+
|
| 145 |
+
1. 该数据集可能来源于什么业务或研究场景?
|
| 146 |
+
2. 各主要字段分别代表什么含义?若能判断,请说明其单位或数值含义。
|
| 147 |
+
3. 数据中是否存在明显异常、异常分布或需要注意的特征?
|
| 148 |
+
|
| 149 |
+
输出要求:
|
| 150 |
+
- 使用自然、流畅的中文描述;
|
| 151 |
+
- 采用清晰的分条结构(1、2、3);
|
| 152 |
+
- 语言客观简洁,不使用“可能”“也许”“似乎”等模糊词;
|
| 153 |
+
- 重点突出数据结构、含义与潜在问题。
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
desc = self.call(prompt)
|
| 157 |
+
|
| 158 |
+
return desc
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def summary_html(self):
|
| 162 |
+
|
| 163 |
+
df = self.load_df()
|
| 164 |
+
df_head = df.head()
|
| 165 |
+
dtype_info = df.dtypes.astype(str)
|
| 166 |
+
|
| 167 |
+
prompt = f"""
|
| 168 |
+
你正在撰写一份数据分析报告的第一章——《数据概览与数据含义分析》。
|
| 169 |
+
请根据以下输入内容,整理关键信息并进行分析说明:
|
| 170 |
+
数据格式:
|
| 171 |
+
{dtype_info}
|
| 172 |
+
|
| 173 |
+
前五行数据:
|
| 174 |
+
{df_head}
|
| 175 |
+
|
| 176 |
+
数据解释聊天对话:
|
| 177 |
+
--- 开始聊天记录 ---
|
| 178 |
+
{self.memory}
|
| 179 |
+
--- 结束聊天记录 ---
|
| 180 |
+
|
| 181 |
+
额外要求:
|
| 182 |
+
1. 要用流畅的自然语言
|
| 183 |
+
2. 不要滥用形容词和副词,尽量用简单的动词和名词表达意思
|
| 184 |
+
3. 不用"可能""也许""似乎""微妙"等模糊表述
|
| 185 |
+
""".strip()
|
| 186 |
+
|
| 187 |
+
desc = self.call(prompt)
|
| 188 |
+
|
| 189 |
+
summary = {
|
| 190 |
+
"title": "数据导入",
|
| 191 |
+
"df": df_head,
|
| 192 |
+
"desc": desc,
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
return summary
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def summary_word(self):
|
| 199 |
+
|
| 200 |
+
return self.summary_html()
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def check_abstract(self):
|
| 204 |
+
|
| 205 |
+
if self.abstract is None:
|
| 206 |
+
df = self.load_df()
|
| 207 |
+
df_head = df.head()
|
| 208 |
+
dtype_info = df.dtypes.astype(str)
|
| 209 |
+
|
| 210 |
+
prompt = f"""
|
| 211 |
+
这是数据分析的数据导入阶段
|
| 212 |
+
数据格式:
|
| 213 |
+
{dtype_info}
|
| 214 |
+
|
| 215 |
+
前五行数据:
|
| 216 |
+
{df_head}
|
| 217 |
+
|
| 218 |
+
数据解释聊天对话:
|
| 219 |
+
--- 开始聊天记录 ---
|
| 220 |
+
{self.memory}
|
| 221 |
+
--- 结束聊天记录 ---
|
| 222 |
+
|
| 223 |
+
要求:
|
| 224 |
+
请基于上述数据与对话内容,生成一段简洁、准确的综合摘要。
|
| 225 |
+
摘要需完整呈现核心信息,便于后续自动判断该内容在报告撰写中是否需要被引用。
|
| 226 |
+
""".strip()
|
| 227 |
+
|
| 228 |
+
desc = self.call(prompt)
|
| 229 |
+
self.abstract = desc
|
| 230 |
+
|
| 231 |
+
return self.abstract
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def check_full(self):
|
| 235 |
+
|
| 236 |
+
if self.full is None:
|
| 237 |
+
df = self.load_df()
|
| 238 |
+
df_head = df.head()
|
| 239 |
+
dtype_info = df.dtypes.astype(str)
|
| 240 |
+
|
| 241 |
+
self.full = (
|
| 242 |
+
f"【阶段说明】这是数据分析流程中的数据导入阶段。\n"
|
| 243 |
+
f"【数据格式】{dtype_info}\n"
|
| 244 |
+
f"【样本预览】\n{df_head}\n"
|
| 245 |
+
f"【分析对话记录】\n{self.memory}"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
return self.full
|
prompt_engineer/sec2_call_llm.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
from prompt_engineer.call_llm import LLMClient
|
| 5 |
+
|
| 6 |
+
class DataPreprocessAgent(LLMClient):
|
| 7 |
+
|
| 8 |
+
def __init__(self, *args, **kwargs):
|
| 9 |
+
|
| 10 |
+
super().__init__(*args, **kwargs)
|
| 11 |
+
self.processed_df = None
|
| 12 |
+
self.code = None
|
| 13 |
+
self.preprocessing_suggestions = None
|
| 14 |
+
self.allowed_libs = [
|
| 15 |
+
"numpy",
|
| 16 |
+
"pandas",
|
| 17 |
+
"sklearn.impute",
|
| 18 |
+
"sklearn.preprocessing",
|
| 19 |
+
"sklearn.compose",
|
| 20 |
+
"sklearn.pipeline"
|
| 21 |
+
]
|
| 22 |
+
self.par_content = ""
|
| 23 |
+
self.error = None
|
| 24 |
+
self.user_input = None
|
| 25 |
+
self.refined_suggestions = ""
|
| 26 |
+
self.abstract=None
|
| 27 |
+
self.full = None
|
| 28 |
+
self.finish_auto_task = False
|
| 29 |
+
self.debug_num = 0
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def finish_auto(self):
|
| 33 |
+
|
| 34 |
+
self.finish_auto_task = True
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def save_code(self, code):
|
| 38 |
+
|
| 39 |
+
self.code = code
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def load_code(self):
|
| 43 |
+
|
| 44 |
+
return self.code
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def save_user_input(self, user_input):
|
| 48 |
+
|
| 49 |
+
self.user_input = user_input
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_user_input(self):
|
| 53 |
+
|
| 54 |
+
return self.user_input
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def save_error(self, error):
|
| 58 |
+
|
| 59 |
+
self.error = error
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def load_error(self):
|
| 63 |
+
|
| 64 |
+
return self.error
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def save_preprocessing_suggestions(self, suggestions):
|
| 68 |
+
|
| 69 |
+
self.preprocessing_suggestions = suggestions
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def load_preprocessing_suggestions(self):
|
| 73 |
+
|
| 74 |
+
return self.preprocessing_suggestions
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def save_processed_df(self, processed_df):
|
| 78 |
+
|
| 79 |
+
if not isinstance(processed_df, pd.DataFrame):
|
| 80 |
+
if isinstance(processed_df, np.ndarray):
|
| 81 |
+
processed_df = pd.DataFrame(processed_df)
|
| 82 |
+
else:
|
| 83 |
+
raise TypeError(f"期望 pandas.DataFrame 或 numpy.ndarray,收到 {type(processed_df)}")
|
| 84 |
+
|
| 85 |
+
self.processed_df = processed_df
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def load_processed_df(self):
|
| 89 |
+
|
| 90 |
+
return self.processed_df
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def load_refined_suggestions(self):
|
| 94 |
+
return self.refined_suggestions
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def save_refined_suggestions(self, refined_suggestions):
|
| 98 |
+
self.refined_suggestions = refined_suggestions
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def refine_suggestions(self, df_head):
|
| 102 |
+
"""将 LLM 返回的预处理推荐进行信息提取"""
|
| 103 |
+
|
| 104 |
+
suggestion = self.load_preprocessing_suggestions()
|
| 105 |
+
|
| 106 |
+
prompt = f"""
|
| 107 |
+
请根据以下预处理建议,概括数据集中每一列的推荐预处理方法。
|
| 108 |
+
|
| 109 |
+
数据示例:
|
| 110 |
+
{df_head}
|
| 111 |
+
|
| 112 |
+
详细预处理建议:
|
| 113 |
+
{suggestion}
|
| 114 |
+
|
| 115 |
+
输出要求(必须严格遵守):
|
| 116 |
+
1. 输出格式:列名:推荐预处理方法;每条独立换行。
|
| 117 |
+
2. 每列最多给出三个推荐方法,多个方法用逗号分隔。
|
| 118 |
+
3. 输出必须为纯文本,不使用任何 Markdown 标记。
|
| 119 |
+
4. 每个方法的长度不得超过20个汉字,若包含英文则不超过10个单词。"""
|
| 120 |
+
|
| 121 |
+
refined_suggestions = self.call(prompt)
|
| 122 |
+
self.refined_suggestions = refined_suggestions
|
| 123 |
+
|
| 124 |
+
return refined_suggestions
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_preprocessing_suggestions(
|
| 128 |
+
self,
|
| 129 |
+
user_input=None,
|
| 130 |
+
memory_limit=6,
|
| 131 |
+
):
|
| 132 |
+
|
| 133 |
+
df = self.load_df()
|
| 134 |
+
|
| 135 |
+
# 基本统计
|
| 136 |
+
n_rows, n_cols = df.shape
|
| 137 |
+
dtype_counts = df.dtypes.value_counts().to_dict()
|
| 138 |
+
missing_total = int(df.isnull().sum().sum())
|
| 139 |
+
missing_by_col = df.isnull().mean().mul(100).round(2).to_dict()
|
| 140 |
+
num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 141 |
+
|
| 142 |
+
# 整理 memory 片段
|
| 143 |
+
recent_memory = self.memory[-memory_limit:] if self.memory else []
|
| 144 |
+
if recent_memory:
|
| 145 |
+
formatted_memory = "\n".join(
|
| 146 |
+
f"{m['role']}: {m['content']}" for m in recent_memory
|
| 147 |
+
)
|
| 148 |
+
memory_block = f"{formatted_memory}"
|
| 149 |
+
else:
|
| 150 |
+
memory_block = ""
|
| 151 |
+
|
| 152 |
+
prompt = f"""
|
| 153 |
+
你是一名资深的数据预处理专家,负责为数据分析报告提供高质量的预处理建议。
|
| 154 |
+
|
| 155 |
+
=== 数据概览 ===
|
| 156 |
+
- 数据规模:{n_rows} 行 × {n_cols} 列
|
| 157 |
+
- 数据类型分布:{dtype_counts}
|
| 158 |
+
- 缺失值总数:{missing_total}
|
| 159 |
+
- 各列缺失率:{missing_by_col}
|
| 160 |
+
- 数值型列:{num_cols}
|
| 161 |
+
- 历史上下文(仅供参考):{memory_block}
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
if user_input is None:
|
| 165 |
+
prompt += """
|
| 166 |
+
=== 请对每一列进行逐项分析(注意,是逐列分析) ===
|
| 167 |
+
请针对每一列依次说明以下四个方面:
|
| 168 |
+
|
| 169 |
+
1. **数据类型**:明确该列的数据类型,若存在混合类型或异常值类型,请指出。
|
| 170 |
+
2. **缺失值处理建议**:说明该列的缺失值处理策略;若建议调整,请指明具体“缺失值处理 策略”操作。
|
| 171 |
+
3. **异常值处理建议**:说明该列的异常检测与处理方案;若需调整,请说明“异常值处理 策略或阈��”操作。
|
| 172 |
+
4. **标准化建议**:说明是否建议标准化或缩放,并在需要时指出“标准化处理 策略”操作。
|
| 173 |
+
|
| 174 |
+
输出格式要求:
|
| 175 |
+
- 按“列名 + 分点说明(1–4)”的形式分段输出;
|
| 176 |
+
- 每一列独立成段,并以换行分隔;
|
| 177 |
+
- 使用清晰、简洁的专业语言。
|
| 178 |
+
"""
|
| 179 |
+
else:
|
| 180 |
+
prompt += f"""
|
| 181 |
+
=== 用户新需求 ===
|
| 182 |
+
{user_input}
|
| 183 |
+
|
| 184 |
+
请结合以上数据概览与历史上下文,针对该需求,给出下一步操作。
|
| 185 |
+
可考虑的操作包括:缺失值处理、异常值检测与修正、标准化或归一化、特征类型调整等。
|
| 186 |
+
输出应保持结构化与连贯性,避免重复说明。
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
suggestions = self.call(prompt)
|
| 190 |
+
|
| 191 |
+
return suggestions
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def code_generation(self, df_head, user_prompt):
|
| 195 |
+
"""生成 LLM prompt:要求 LLM 输出 process_df(pandas DataFrame)。"""
|
| 196 |
+
allowed = ", ".join(self.allowed_libs)
|
| 197 |
+
|
| 198 |
+
prompt = f"""
|
| 199 |
+
请**严格只输出纯 Python 代码**,不得包含以下内容:
|
| 200 |
+
- 解释性文字、注释、示例;
|
| 201 |
+
- Markdown 代码块标记(禁止出现 ``` 或 ```python 等);
|
| 202 |
+
- 任何多余输出(如 print、全局变量赋值等)。
|
| 203 |
+
|
| 204 |
+
=== 运行环境说明 ===
|
| 205 |
+
运行环境中已提供以下对象与库:
|
| 206 |
+
- pandas DataFrame 变量:`df`
|
| 207 |
+
- 库:numpy (np)、SimpleImputer、StandardScaler、MinMaxScaler、RobustScaler、
|
| 208 |
+
OneHotEncoder、OrdinalEncoder、LabelEncoder、FunctionTransformer、
|
| 209 |
+
ColumnTransformer、Pipeline。
|
| 210 |
+
若所需功能在这些库中不存在,请自行写 Python code 实现。
|
| 211 |
+
|
| 212 |
+
=== 生成要求 ===
|
| 213 |
+
1. 若有用户需求,请优先满足用户需求(优先级高于 LLM 返回的通用建议)。
|
| 214 |
+
2. 若建议指出某列“无需处理”,则对该列不进行任何操作。
|
| 215 |
+
3. 禁止导入其他库、禁止文件读写。
|
| 216 |
+
4. 所有括号(圆括号、方括号、大括号)必须成对闭合,不得错位或遗漏。
|
| 217 |
+
5. 对类别特征,可使用 OneHotEncoder 或 OrdinalEncoder;
|
| 218 |
+
若为单列字符串/类别列,请使用 LabelEncoder 或 OrdinalEncoder,不得 passthrough。
|
| 219 |
+
6. 在构建 ColumnTransformer 前,需检测并处理“混合型列”
|
| 220 |
+
—— 即同时包含数值和字符串的列,
|
| 221 |
+
使用 `FunctionTransformer(lambda x: x.astype(str))` 将其统一为字符串类型。
|
| 222 |
+
7. ColumnTransformer 的 transformers 中仅包含经过上述处理的列。
|
| 223 |
+
8. 使用 OneHotEncoder 时,若输出稀疏矩阵,请确保所有输入特征均为数值类型。
|
| 224 |
+
9. 若 df 中存在重复表头(如第 0 行与 header 相同),需自动检测并删除重复表头行。
|
| 225 |
+
10. 确保预处理后的 DataFrame 中每一列均有明确列名。
|
| 226 |
+
11. 脚本最后仅保留一行结果:
|
| 227 |
+
`process_df = ...`
|
| 228 |
+
不允许出现 print、显示语句或其他多余输出。
|
| 229 |
+
|
| 230 |
+
=== 输入数据示例 ===
|
| 231 |
+
{df_head}
|
| 232 |
+
|
| 233 |
+
=== 用户指定需求 ===
|
| 234 |
+
{user_prompt}
|
| 235 |
+
|
| 236 |
+
请严格依据以上要求,输出完整且可直接执行的 Python 代码(纯代码块,无额外说明)。
|
| 237 |
+
""".strip()
|
| 238 |
+
|
| 239 |
+
if self.error is not None:
|
| 240 |
+
if self.debug_num < 5 :
|
| 241 |
+
self.debug_num += 1
|
| 242 |
+
|
| 243 |
+
prompt += f"""
|
| 244 |
+
上次生成的代码运行失败。
|
| 245 |
+
【错误信息】:
|
| 246 |
+
{self.error}
|
| 247 |
+
|
| 248 |
+
【原始代码】:
|
| 249 |
+
{self.code}
|
| 250 |
+
|
| 251 |
+
请在不输出任何解释性文字的情况下,推理并理解导致错误的根本原因,
|
| 252 |
+
|
| 253 |
+
要求:
|
| 254 |
+
1. 不输出任何分析、解释或说明(包括文字、列表或注释段落);
|
| 255 |
+
2. 可在代码内部使用简短注释说明关键修改;
|
| 256 |
+
3. 若错误源于逻辑、数据结构或函数使用不当,请自行调整;
|
| 257 |
+
4. 若依赖库方法不适用,可自行实现替代函数;
|
| 258 |
+
5. 生成的代码必须可独立运行,无语法错误;
|
| 259 |
+
6. 保持整体逻辑与原代码意图一致,仅做必要修正。
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
else:
|
| 263 |
+
self.debug_num = 0
|
| 264 |
+
|
| 265 |
+
if self.user_input is not None:
|
| 266 |
+
prompt += f"用户需求:{self.user_input}。\n请严格遵循并优先执行该需求,其优先级高于所有其他建议或规则。\n"
|
| 267 |
+
|
| 268 |
+
if self.refined_suggestions is not None:
|
| 269 |
+
prompt += f"LLM返回的预处理建议:{self.refined_suggestions}"
|
| 270 |
+
|
| 271 |
+
raw = self.call(prompt)
|
| 272 |
+
return raw
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def summary_html(self):
|
| 276 |
+
|
| 277 |
+
if self.code is None:
|
| 278 |
+
summary = None
|
| 279 |
+
return summary
|
| 280 |
+
|
| 281 |
+
else:
|
| 282 |
+
processed_df = self.load_processed_df()
|
| 283 |
+
prompt = f"""
|
| 284 |
+
你正在撰写数据分析报告的第二章——《数据预处理与标准化》。
|
| 285 |
+
请根据以下输入内容,提炼关键信息并撰写相应分析段落。
|
| 286 |
+
|
| 287 |
+
- 预处理代码:
|
| 288 |
+
{self.code}
|
| 289 |
+
|
| 290 |
+
- 预处理结果(数据示例):
|
| 291 |
+
{processed_df.head()}
|
| 292 |
+
|
| 293 |
+
{f"- 预处理建议对话记录:{self.load_memory}" if self.load_memory else ""}
|
| 294 |
+
|
| 295 |
+
撰写要求:
|
| 296 |
+
1. 使用流畅、自然的中文表达;
|
| 297 |
+
2. 语言应简洁、准确,避免过多形容词或副词;
|
| 298 |
+
3. 不使用“可能”“也许”“似乎”“微妙”等模糊表述;
|
| 299 |
+
4. 不添加大标题,可使用自然段进行叙述;
|
| 300 |
+
5. 内容需逻辑清晰,体现代码与结果之间的分析关联。
|
| 301 |
+
|
| 302 |
+
""".strip()
|
| 303 |
+
|
| 304 |
+
desc = self.call(prompt)
|
| 305 |
+
|
| 306 |
+
summary = {
|
| 307 |
+
"title": "数据预处理",
|
| 308 |
+
"desc": desc,
|
| 309 |
+
"processed_df": self.processed_df.head(),
|
| 310 |
+
"code": self.code,
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
return summary
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def summary_word(self):
|
| 317 |
+
|
| 318 |
+
return self.summary_html()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def check_abstract(self):
|
| 322 |
+
|
| 323 |
+
if self.abstract is None:
|
| 324 |
+
|
| 325 |
+
processed_df = self.load_processed_df()
|
| 326 |
+
|
| 327 |
+
if self.code is None:
|
| 328 |
+
self.abstract = None
|
| 329 |
+
|
| 330 |
+
else:
|
| 331 |
+
|
| 332 |
+
memory = f"【预处理建议对话记录】\n{self.load_memory}\n" if self.load_memory else ""
|
| 333 |
+
|
| 334 |
+
prompt = f"""
|
| 335 |
+
这是数据分析流程中的“数据预处理与标准化”阶段。
|
| 336 |
+
|
| 337 |
+
【预处理代码】
|
| 338 |
+
{self.code}
|
| 339 |
+
|
| 340 |
+
【预处理结果(前五行)】
|
| 341 |
+
{processed_df.head()}
|
| 342 |
+
|
| 343 |
+
{memory}
|
| 344 |
+
请在确保信息准确完整的前提下,将上述内容概括为一段简洁的文字摘要。
|
| 345 |
+
要求:
|
| 346 |
+
1. 语言自然流畅,保持客观和专业;
|
| 347 |
+
2. 内容应涵盖关键点(包括主要预处理步骤与结果特征);
|
| 348 |
+
3. 重点在于“说明核心信息”,而非逐行描述;
|
| 349 |
+
4. 生成的摘要应可用于报告编写时判断该部分是否需要引用。
|
| 350 |
+
""".strip()
|
| 351 |
+
|
| 352 |
+
desc = self.call(prompt)
|
| 353 |
+
self.abstract = desc
|
| 354 |
+
|
| 355 |
+
return self.abstract
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def check_full(self):
|
| 359 |
+
if self.full is None:
|
| 360 |
+
processed_df = self.load_processed_df()
|
| 361 |
+
if self.code is None:
|
| 362 |
+
self.full = None
|
| 363 |
+
else:
|
| 364 |
+
content = f"""
|
| 365 |
+
【阶段说明】这是数据分析流程中的数据预处理阶段。
|
| 366 |
+
【预处理代码】{self.code}
|
| 367 |
+
【预处理结果前五行】{processed_df.head()}
|
| 368 |
+
""".strip()
|
| 369 |
+
if self.load_memory is not None:
|
| 370 |
+
content += f"\n【预处理建议聊天对话】{self.load_memory}"
|
| 371 |
+
|
| 372 |
+
self.full = content
|
| 373 |
+
|
| 374 |
+
return self.full
|
prompt_engineer/sec3_call_llm.py
ADDED
|
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import base64
|
| 3 |
+
import plotly.graph_objs as go
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 5 |
+
|
| 6 |
+
from prompt_engineer.call_llm import LLMClient
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
np.set_printoptions(edgeitems=250, threshold=501)
|
| 10 |
+
|
| 11 |
+
class VisualizationAgent(LLMClient):
|
| 12 |
+
|
| 13 |
+
def __init__(self, *args, **kwargs):
|
| 14 |
+
|
| 15 |
+
super().__init__(*args, **kwargs)
|
| 16 |
+
self.cols_wo_id = None
|
| 17 |
+
self.recommendations = None
|
| 18 |
+
self.analysis = []
|
| 19 |
+
self.quick_action = None
|
| 20 |
+
self.data_meaning = ""
|
| 21 |
+
self.allowed_libs = [
|
| 22 |
+
"numpy", "plotly", "plotly.express", "plotly.graph_objects"
|
| 23 |
+
]
|
| 24 |
+
self.code = None
|
| 25 |
+
self.result = None
|
| 26 |
+
self.suggestion = None
|
| 27 |
+
self.user_input = None
|
| 28 |
+
self.fig = []
|
| 29 |
+
self.par_content = ""
|
| 30 |
+
self.error = None
|
| 31 |
+
self.abstract=None
|
| 32 |
+
self.full = None
|
| 33 |
+
self.color = None
|
| 34 |
+
self.finish_auto_task = False
|
| 35 |
+
self.debug_num = 0
|
| 36 |
+
self.refined_suggestions = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def finish_auto(self):
|
| 40 |
+
|
| 41 |
+
self.finish_auto_task = True
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def save_user_input(self, user_input):
|
| 45 |
+
|
| 46 |
+
self.user_input = user_input
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_user_input(self):
|
| 50 |
+
|
| 51 |
+
return self.user_input
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def save_color(self, color):
|
| 55 |
+
|
| 56 |
+
self.color = color
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_color(self):
|
| 60 |
+
|
| 61 |
+
return self.color
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def add_fig(self, fig, desc):
|
| 65 |
+
|
| 66 |
+
entry = {"fig": fig, "desc": desc}
|
| 67 |
+
self.fig.append(entry)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def load_fig(self):
|
| 71 |
+
|
| 72 |
+
return self.fig
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def save_cols_wo_id(self, col):
|
| 76 |
+
|
| 77 |
+
self.cols_wo_id = col
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def load_cols_wo_id(self):
|
| 81 |
+
|
| 82 |
+
return self.cols_wo_id
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def save_code(self, code):
|
| 86 |
+
|
| 87 |
+
self.code = code
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def load_code(self):
|
| 91 |
+
|
| 92 |
+
return self.code
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def save_recommendations(self, recommendations):
|
| 96 |
+
|
| 97 |
+
self.recommendations = recommendations
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def load_recommendations(self):
|
| 101 |
+
|
| 102 |
+
return self.recommendations
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def save_suggestion(self, suggestion):
|
| 106 |
+
|
| 107 |
+
self.suggestion = suggestion
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def load_suggestion(self):
|
| 111 |
+
|
| 112 |
+
return self.suggestion
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def load_data_meaning(self):
|
| 116 |
+
|
| 117 |
+
return self.data_meaning
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def save_error(self, error):
|
| 121 |
+
|
| 122 |
+
self.error = error
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def load_error(self):
|
| 126 |
+
|
| 127 |
+
return self.error
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def refine_suggestions(self, rec):
|
| 131 |
+
|
| 132 |
+
prompt = f"""
|
| 133 |
+
请根据以下详细的可视化建议,提取每一列与每个变量组的推荐可视化方法。
|
| 134 |
+
|
| 135 |
+
详细可视化建议:
|
| 136 |
+
{rec}
|
| 137 |
+
|
| 138 |
+
输出要求(必须严格遵守):
|
| 139 |
+
1. 输出为纯文本,每条独立换行,且不得有多余说明。
|
| 140 |
+
2. 单变量格式:列名:图表1, 图表2。
|
| 141 |
+
3. 多变量格式:关系组:列A,列B:图表1, 图表2。
|
| 142 |
+
4. 总体变量格式:总体:图表1, 图表2。
|
| 143 |
+
5. 严格不要添加标题、编号、示例或额外解释。
|
| 144 |
+
6. 提取可视化方法精准。
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
refined_suggestions = self.call(prompt)
|
| 148 |
+
self.refined_suggestions = refined_suggestions
|
| 149 |
+
|
| 150 |
+
return refined_suggestions
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def get_visualization_recommendations(
|
| 154 |
+
self,
|
| 155 |
+
cols,
|
| 156 |
+
user_input=None,
|
| 157 |
+
memory_limit: int = 6,
|
| 158 |
+
) -> str:
|
| 159 |
+
|
| 160 |
+
dim_info = f"{self.df.shape[0]} 行 x {self.df.shape[1]} 列"
|
| 161 |
+
|
| 162 |
+
recent_memory = self.memory[-memory_limit:] if getattr(self, "memory", None) else []
|
| 163 |
+
if recent_memory:
|
| 164 |
+
formatted_memory = "\n".join(
|
| 165 |
+
f"{m['role']}: {m['content']}" for m in recent_memory
|
| 166 |
+
)
|
| 167 |
+
memory_block = f"{formatted_memory}"
|
| 168 |
+
else:
|
| 169 |
+
memory_block = ""
|
| 170 |
+
|
| 171 |
+
if user_input is None:
|
| 172 |
+
prompt = f"""
|
| 173 |
+
你是一位资深数据可视化专家,请根据以下信息,为数据分析报告的“可视化设计”章节提供系统、专业的建议。
|
| 174 |
+
|
| 175 |
+
【数据集信息】
|
| 176 |
+
- 数值型变量:{cols}
|
| 177 |
+
- 数据维度:{dim_info}
|
| 178 |
+
- 历史上下文(仅供参考):{memory_block}
|
| 179 |
+
|
| 180 |
+
【输出格式】
|
| 181 |
+
请严格按照以下结构输出(保持标题和层级一致,不得增减):
|
| 182 |
+
|
| 183 |
+
一、单变量可视化(Univariate)
|
| 184 |
+
1. 针对每个数值型变量,推荐 1–2 种最合适的可视化方法,并简要说明理由。
|
| 185 |
+
例如:
|
| 186 |
+
- `列1`:推荐“直方图(Histogram)”和“盒须图(Box Plot)”,理由:……
|
| 187 |
+
|
| 188 |
+
二、多变量关系可视化(Multivariate)
|
| 189 |
+
1. 从上述变量中选择 1–3 组值得重点分析的变量组合(每组包含 2–3 个变量),并说明选择理由。
|
| 190 |
+
例如:
|
| 191 |
+
- 关系组 1:`[列1, 列2]`,理由:……
|
| 192 |
+
2. 对每一组变量,推荐最合适的可视化方法,并简要说明。
|
| 193 |
+
例如:
|
| 194 |
+
- 关系��� 1:散点图(Scatter Plot)+ 回归线(Regression Line),理由:……
|
| 195 |
+
|
| 196 |
+
三、整体分布可视化(Distribution Overview)
|
| 197 |
+
1. 针对全数据的总体分布特征,推荐 1–2 种全局可视化方法,并说明用途。
|
| 198 |
+
例如:
|
| 199 |
+
- 推荐“小提琴图矩阵(Violin Plot Matrix)”,用途:……
|
| 200 |
+
- 推荐“热力图(Heatmap)”,用途:……
|
| 201 |
+
|
| 202 |
+
【执行要求】
|
| 203 |
+
1. 若列名无实际意义(如索引、冗余 ID),应自动过滤;
|
| 204 |
+
2. 输出内容需保持条理清晰、语言简洁、专业。
|
| 205 |
+
""".strip()
|
| 206 |
+
|
| 207 |
+
else:
|
| 208 |
+
prompt = f"""
|
| 209 |
+
你是一位资深数据可视化专家,请根据以下信息,请回应用户需求,实现用户需求:
|
| 210 |
+
|
| 211 |
+
【用户需求】
|
| 212 |
+
{user_input}
|
| 213 |
+
|
| 214 |
+
【数据集信息】
|
| 215 |
+
- 数值型变量:{cols}
|
| 216 |
+
- 数据维度:{dim_info}
|
| 217 |
+
- 数据概览(前几行):
|
| 218 |
+
{self.df.head().to_string(index=False)}
|
| 219 |
+
- 历史上下文(仅供参考):{memory_block}
|
| 220 |
+
|
| 221 |
+
【执行要求】
|
| 222 |
+
1. 若用户明确指定可视化列,仅针对这些列给出建议;
|
| 223 |
+
2. 若用户提出特定要求(如图形大小、坐标轴 log 缩放等),必须在输出中体现;
|
| 224 |
+
3. 仅响应用户需求,不输出无关内容;
|
| 225 |
+
4. 若用户要求对先前内容进行局部修改,应保留未更动部分,仅更新相关建议;
|
| 226 |
+
5. 输出内容应结构清晰、逻辑连贯、语言简洁。
|
| 227 |
+
6. 禁止输出代码。
|
| 228 |
+
""".strip()
|
| 229 |
+
|
| 230 |
+
recommendations = self.call(prompt)
|
| 231 |
+
return recommendations
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def desc_fig(self, fig, dtype_info):
|
| 235 |
+
|
| 236 |
+
selected = st.session_state.selected_model
|
| 237 |
+
|
| 238 |
+
if selected == "智谱AI" or selected == "通义千问" or selected == "GPT-4o" or selected == "GPT-5" or selected == "豆包" or selected == "Claude":
|
| 239 |
+
img_bytes = fig.to_image(format="jpg")
|
| 240 |
+
fig_info = extract_plotly_info(fig)
|
| 241 |
+
base64_bytes = base64.b64encode(img_bytes)
|
| 242 |
+
base64_string = base64_bytes.decode('utf-8')
|
| 243 |
+
|
| 244 |
+
prompt_payload = [
|
| 245 |
+
{
|
| 246 |
+
"type": "image_url",
|
| 247 |
+
"image_url": {"url": f"data:image/jpg;base64,{base64_string}"}
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"type": "text",
|
| 251 |
+
"text": f"""
|
| 252 |
+
请综合下方可视化图与变量信息,进行**简洁但深入的分析**。
|
| 253 |
+
从分布形态、趋势特征、变量间关系、潜在异常现象、现实含义五个角度,提炼关键洞察。
|
| 254 |
+
输出一段不超过 120 字的自然语言分析结论(非摘要)。
|
| 255 |
+
|
| 256 |
+
【变量信息】
|
| 257 |
+
{dtype_info}
|
| 258 |
+
|
| 259 |
+
【图表结构信息】
|
| 260 |
+
{fig_info}
|
| 261 |
+
|
| 262 |
+
写作要求:
|
| 263 |
+
1. 分析需包含对数据异常的识别与说明:
|
| 264 |
+
- 若存在明显异常点、异常段或突变趋势,请指出其特征与潜在影响;
|
| 265 |
+
- 若未发现异常,也需明确说明整体分布稳定或无显著异常;
|
| 266 |
+
2. 内容需体现推理与解释性思考,而非表面描述;
|
| 267 |
+
3. 使用逻辑清晰、客观专业的语言;
|
| 268 |
+
4. 使用动词驱动句式(如“呈现出”“反映出”“揭示出”“说明了”等);
|
| 269 |
+
5. 不使用模糊词(如“可能”“似乎”“微妙”等);
|
| 270 |
+
6. 不使用标题、列表或格式符号;
|
| 271 |
+
7. 若变量含义中存在噪声或重复信息,请自动忽略;
|
| 272 |
+
8. 保持语气简洁有力,强调数据特征与分析结论。
|
| 273 |
+
""".strip()
|
| 274 |
+
}
|
| 275 |
+
]
|
| 276 |
+
|
| 277 |
+
desc_fig = self.call(prompt_payload)
|
| 278 |
+
|
| 279 |
+
else:
|
| 280 |
+
prompt = f"""
|
| 281 |
+
请综合下方可视化图与变量信息,从数据分布、趋势特征及潜在关系等角度进行分析。
|
| 282 |
+
以不超过 100 字的自然语言总结关键发现,突出该变量在整体数据结构中的意义或异常现象。
|
| 283 |
+
|
| 284 |
+
【变量信息】
|
| 285 |
+
{dtype_info}
|
| 286 |
+
|
| 287 |
+
【图表信息】
|
| 288 |
+
{fig.to_dict()}
|
| 289 |
+
|
| 290 |
+
写作要求:
|
| 291 |
+
1. 语言应流畅自然,保持客观、专业;
|
| 292 |
+
2. 使用简洁的动词和名词,不滥用形容词或副词;
|
| 293 |
+
3. 避免“可能”“也许”“似乎”“微妙”等模糊词;
|
| 294 |
+
4. 不添加标题或列表结构;
|
| 295 |
+
5. 结合数据含义和图表特征,给出具有洞察力的简要结论;
|
| 296 |
+
6. 若变量含义中存在杂乱或重复信息,请自动忽略。
|
| 297 |
+
""".strip()
|
| 298 |
+
|
| 299 |
+
desc_fig = self.call(prompt)
|
| 300 |
+
|
| 301 |
+
return desc_fig
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def summary_html(self) -> str:
|
| 305 |
+
|
| 306 |
+
analysis = self.summary_fig_analysis_list()
|
| 307 |
+
|
| 308 |
+
if analysis is None:
|
| 309 |
+
|
| 310 |
+
return None
|
| 311 |
+
|
| 312 |
+
else:
|
| 313 |
+
analysis = {i: item for i, item in enumerate(analysis)}
|
| 314 |
+
|
| 315 |
+
summary = {
|
| 316 |
+
"title": "数据可视化",
|
| 317 |
+
"fig_analysis": analysis,
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
return summary
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def summary_word(self) -> str:
|
| 324 |
+
|
| 325 |
+
analysis = self.summary_fig_analysis_list()
|
| 326 |
+
|
| 327 |
+
if analysis is None:
|
| 328 |
+
|
| 329 |
+
return None
|
| 330 |
+
|
| 331 |
+
else:
|
| 332 |
+
|
| 333 |
+
summary = {
|
| 334 |
+
"title": "数据可视化",
|
| 335 |
+
"fig_analysis": analysis,
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
return summary
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def summary_fig_analysis_list(self) -> str:
|
| 342 |
+
|
| 343 |
+
if not self.code:
|
| 344 |
+
return self.analysis
|
| 345 |
+
|
| 346 |
+
if self.analysis:
|
| 347 |
+
return self.analysis
|
| 348 |
+
|
| 349 |
+
# state_copy = dict(st.session_state)
|
| 350 |
+
selected = st.session_state.get("selected_model", "default")
|
| 351 |
+
# selected = state_copy.get("selected_model", "default")
|
| 352 |
+
|
| 353 |
+
# --- 定义单个任务 ---
|
| 354 |
+
def analyze_one(item, offset):
|
| 355 |
+
fig = item["fig"]
|
| 356 |
+
desc = item["desc"]
|
| 357 |
+
|
| 358 |
+
# 恢复状态(如果需要访问 st.session_state)
|
| 359 |
+
# st.session_state.update(state_copy)
|
| 360 |
+
selected = st.session_state.get("selected_model", "default")
|
| 361 |
+
if isinstance(fig, go.Figure):
|
| 362 |
+
if selected == "智谱AI" or selected == "通义千问" or selected == "GPT-4o" or selected == "GPT-5" or selected == "豆包" or selected == "Claude":
|
| 363 |
+
img_bytes = fig.to_image(format="jpg")
|
| 364 |
+
base64_string = base64.b64encode(img_bytes).decode("utf-8")
|
| 365 |
+
|
| 366 |
+
fig_info = extract_plotly_info(fig)
|
| 367 |
+
|
| 368 |
+
prompt_payload = [
|
| 369 |
+
{
|
| 370 |
+
"type": "image_url",
|
| 371 |
+
"image_url": {"url": f"data:image/jpg;base64,{base64_string}"}
|
| 372 |
+
},
|
| 373 |
+
{
|
| 374 |
+
"type": "text",
|
| 375 |
+
"text": f"""
|
| 376 |
+
你正在撰写数据分析报告的第三章——《数据可视化》。
|
| 377 |
+
请针对下方变量,结合其**业务含义、统计特征**与**可视化图表现**,撰写一段专业、逻辑严谨、可直接用于报告正文的分析内容。
|
| 378 |
+
|
| 379 |
+
【变量信息】
|
| 380 |
+
{self.cols_wo_id}
|
| 381 |
+
|
| 382 |
+
【Plotly 图表结构】
|
| 383 |
+
{fig_info}
|
| 384 |
+
|
| 385 |
+
【基础统计概览】
|
| 386 |
+
{desc}
|
| 387 |
+
|
| 388 |
+
【分析任务】
|
| 389 |
+
请在脑中先完成以下推理步骤,然后输出结构化正文:
|
| 390 |
+
1. 从图表识别核心模式:整体趋势、峰值、分布形态、异常点或聚集区;
|
| 391 |
+
2. 思考该模式与变量业务含义的关系;
|
| 392 |
+
3. 判断是否存在异常现象(单点异常、阶段性异常或结构性突变),并说明其潜在影响;
|
| 393 |
+
4. 若图中包含其他变量,请分析它们之间的统计或逻辑关联;
|
| 394 |
+
5. 将上述洞察整合成逻辑完整、语言自然的段落。
|
| 395 |
+
|
| 396 |
+
【输出格式(严格遵守)】
|
| 397 |
+
输出为纯文本,依次包含以下三部分(不使用 Markdown 或符号):
|
| 398 |
+
|
| 399 |
+
1. 概述
|
| 400 |
+
- 简述变量的定义、业务角色及数据表现的总体趋势;
|
| 401 |
+
- 提出该变量在整体数据结构中可能的重要性。
|
| 402 |
+
|
| 403 |
+
2. 分布与特征分析
|
| 404 |
+
- 从统计与图形角度分析其分布特征(集中趋势、离散程度、偏态、峰度、周期性等);
|
| 405 |
+
- 若发现异常或突变,请具体说明其表现形式与潜在机制;
|
| 406 |
+
- 若与其他变量有关联趋势,指出方向与强度。
|
| 407 |
+
|
| 408 |
+
3. 实际含义与推论
|
| 409 |
+
- 结合业务或研究背景,解释观察到的现象;
|
| 410 |
+
- 分析其可能揭示的现实规律、风险或优化方向;
|
| 411 |
+
- 若合适,可提出合理推测或后续分析建议(保持客观与逻辑自洽)。
|
| 412 |
+
|
| 413 |
+
【写作要求】
|
| 414 |
+
1. 保持语言正式、专业、逻辑紧密;
|
| 415 |
+
2. 句式多样、表达自然,避免模板化表述���
|
| 416 |
+
3. 禁用模糊词汇(如“可能”“似乎”“大概”等);
|
| 417 |
+
4. 不使用任何标题符号(如 #、** 等);
|
| 418 |
+
5. 不输出“AI”“模型”“助手”等字样;
|
| 419 |
+
6. 输出为连续正文,不包含解释性语句或附加说明。
|
| 420 |
+
""".strip()
|
| 421 |
+
}
|
| 422 |
+
]
|
| 423 |
+
|
| 424 |
+
analysis_text = self.call(prompt_payload)
|
| 425 |
+
|
| 426 |
+
else:
|
| 427 |
+
|
| 428 |
+
prompt = f"""
|
| 429 |
+
你正在撰写数据分析报告的第三章——《数据可视化》。
|
| 430 |
+
请针对下方变量,结合其业务含义与对应的可视化图,撰写一段结构化、专业的分析文字。
|
| 431 |
+
|
| 432 |
+
【变量信息】
|
| 433 |
+
{self.cols_wo_id}
|
| 434 |
+
|
| 435 |
+
【Plotly 图表信息】
|
| 436 |
+
{fig.to_dict()}
|
| 437 |
+
|
| 438 |
+
【基础统计概览】
|
| 439 |
+
{desc}
|
| 440 |
+
|
| 441 |
+
请严格按照以下格式撰写内容(使用纯文本,不使用 Markdown 语法或符号):
|
| 442 |
+
|
| 443 |
+
1. 概述
|
| 444 |
+
- 说明该变量的含义及其在数据或业务中的作用;
|
| 445 |
+
- 简要描述整体分布特征或变量间的主要关联趋势。
|
| 446 |
+
|
| 447 |
+
2. 分布 / 关联特征
|
| 448 |
+
- 从统计角度说明变量的分布特征或相关关系;
|
| 449 |
+
- 可引用关键统计量(均值、中位数、四分位数、相关系数等)支持分析。
|
| 450 |
+
|
| 451 |
+
3. 现实含义
|
| 452 |
+
- 结合变量在实际情境中的意义,解释所观察到的分布或关系;
|
| 453 |
+
- 指出这些模式可能反映的现实现象或潜在影响(例如:某变量偏高代表风险上升或群体特征差异)。
|
| 454 |
+
|
| 455 |
+
【写作要求】
|
| 456 |
+
1. 使用流畅、自然且正式的中文表达;
|
| 457 |
+
2. 语言应客观、简洁,避免冗余修辞;
|
| 458 |
+
3. 禁止使用“可能”“也许”“似乎”“微妙”等模糊词;
|
| 459 |
+
4. 不使用标题符号(#、** 等);
|
| 460 |
+
5. 保持逻辑连贯,分析层次清晰。
|
| 461 |
+
""".strip()
|
| 462 |
+
|
| 463 |
+
analysis_text = self.call(prompt)
|
| 464 |
+
print(prompt)
|
| 465 |
+
return offset, {"figure": fig, "analysis": analysis_text}
|
| 466 |
+
|
| 467 |
+
# --- 并行执行 ---
|
| 468 |
+
results = []
|
| 469 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
| 470 |
+
futures = [executor.submit(analyze_one, item, i) for i, item in enumerate(self.fig)]
|
| 471 |
+
for f in as_completed(futures):
|
| 472 |
+
result = f.result()
|
| 473 |
+
if result:
|
| 474 |
+
results.append(result)
|
| 475 |
+
|
| 476 |
+
# --- 按原顺序排序 ---
|
| 477 |
+
results.sort(key=lambda x: x[0])
|
| 478 |
+
self.analysis = [r[1] for r in results]
|
| 479 |
+
|
| 480 |
+
return self.analysis
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def code_generation(self, df_head: str, user_prompt: str) -> str:
|
| 484 |
+
"""生成 LLM prompt:要求 LLM 输出 result_dict(可 JSON 序列化)。"""
|
| 485 |
+
allowed = ", ".join(self.allowed_libs)
|
| 486 |
+
|
| 487 |
+
prompt = (
|
| 488 |
+
"请**严格只输出纯 Python 代码**,**不要**输出任何解释性文字、注释、示例、markdown code fence(禁止出现 ``` 或 ```python 等)"
|
| 489 |
+
"运行环境已提供 pandas DataFrame 变量 `df`、numpy(np)、"
|
| 490 |
+
"plotly.express(px)、plotly.graph_objects(go)。\n\n"
|
| 491 |
+
"##严格要求##:\n"
|
| 492 |
+
"1) **严格执行用户需求**:若用户指定了要可视化的列,可能是精确列名,也可能是模糊输入"
|
| 493 |
+
"(如输入 “ordera” 但实际列名为 “ordertypea”),不要凭空产生虚假列名!!!"
|
| 494 |
+
f"请在脚本开头使用 LLM 理解将用户输入映射到 {df_head} 中最合适的真正列名,或采用更保守的索引(如第0列,第1列 推荐!),再仅对这些列绘制图表;\n"
|
| 495 |
+
"""2) **统计并重命名**:所有类别分布图请按下面模板写,**绝不直接用** `index` 作为列名——
|
| 496 |
+
# === 模板:统计并绘制 Bar Chart ===
|
| 497 |
+
for col in categorical_cols:
|
| 498 |
+
df_counts = df[col] \\
|
| 499 |
+
.value_counts() \\
|
| 500 |
+
.rename_axis(col) \\
|
| 501 |
+
.reset_index(name='count')
|
| 502 |
+
fig = px.bar(
|
| 503 |
+
df_counts,
|
| 504 |
+
x=col,
|
| 505 |
+
y='count',
|
| 506 |
+
title=f'Bar Chart of {col}',
|
| 507 |
+
labels={col: col, 'count': 'Count'}
|
| 508 |
+
)
|
| 509 |
+
fig_dict[f'{col}_bar'] = fig
|
| 510 |
+
|
| 511 |
+
3) 智能选图:根据数据类型(数值/类别)自动选择合适的图表。
|
| 512 |
+
4) 自动检测是否需要按分类列着色,并做两种处理:若存在指定的分类列且想连续映射,先编码为数值 codes;如要离散映射,使用 parallel_categories
|
| 513 |
+
5) 如 Plotly Express 中无合适图表,使用 `go.Figure` 自定义。
|
| 514 |
+
6) 脚本末尾仅包含 `fig_dict = {...}`,不要 `print`、不要额外全局变量。
|
| 515 |
+
7) 任何情况下不得“造”列名或直接写 `'index'`;若要使用索引,必须显式使用 `df.index`。
|
| 516 |
+
8) 不要使用文件读写或其他外部 IO。
|
| 517 |
+
9) 请只给我python代码,不要给我任何'''python等非代码内容的标识符。"""
|
| 518 |
+
f"示例数据头部:\n{df_head}\n\n"
|
| 519 |
+
f"每一张图的颜色必须从{self.color}中,选择\n\n"
|
| 520 |
+
f"画图建议: {self.refined_suggestions}\n\n"
|
| 521 |
+
"返回:完整 Python 代码(纯代码块)。"
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
if self.error is not None:
|
| 525 |
+
if self.debug_num < 5 :
|
| 526 |
+
self.debug_num += 1
|
| 527 |
+
prompt += f"""
|
| 528 |
+
上次生成的代码运行失败。
|
| 529 |
+
【错误信息】:
|
| 530 |
+
{self.error}
|
| 531 |
+
|
| 532 |
+
【原始代码】:
|
| 533 |
+
{self.code}
|
| 534 |
+
|
| 535 |
+
请在不输出任何解释性文字的情况下,推理并理解导致错误的根本原因,
|
| 536 |
+
|
| 537 |
+
要求:
|
| 538 |
+
1. 不输出任何分析、解释或说明(包括文字、列表或注释段落);
|
| 539 |
+
2. 可在代码内部使用简短注释说明关键修改;
|
| 540 |
+
3. 若错误源于逻辑、数据结构或函数使用不当,请自行调整;
|
| 541 |
+
4. 若依赖库方法不适用,可自行实现替代函数;
|
| 542 |
+
5. 生成的代码必须可独立运行,无语法错误;
|
| 543 |
+
6. 保持整体逻辑与原代码意图一致,仅做必要修正。
|
| 544 |
+
"""
|
| 545 |
+
else:
|
| 546 |
+
self.debug_num = 0
|
| 547 |
+
|
| 548 |
+
raw = self.call(prompt)
|
| 549 |
+
|
| 550 |
+
return raw
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def check_abstract(self):
|
| 554 |
+
if self.abstract is None:
|
| 555 |
+
# 获取所有分析内容
|
| 556 |
+
analysis_list = self.summary_fig_analysis_list()
|
| 557 |
+
|
| 558 |
+
if not analysis_list :
|
| 559 |
+
self.abstract = "暂无可视化分析内容。"
|
| 560 |
+
return self.abstract
|
| 561 |
+
|
| 562 |
+
# 合并所有分析内容为一个整体文本
|
| 563 |
+
all_analyses = "\n\n".join([
|
| 564 |
+
f"【变量分析 {i+1}】\n{item['analysis']}"
|
| 565 |
+
for i, item in enumerate(analysis_list)
|
| 566 |
+
])
|
| 567 |
+
|
| 568 |
+
prompt = f"""
|
| 569 |
+
请阅读并综合以下多个变量的分析内容:
|
| 570 |
+
{all_analyses}
|
| 571 |
+
|
| 572 |
+
任务:
|
| 573 |
+
将这些分析整合为一段结构化、信息充分的**综合语义总结**,供后续大模型自动生成报告目录使用。
|
| 574 |
+
|
| 575 |
+
目标:
|
| 576 |
+
- 输出内容应帮助后续模型理解分析中包含的主题、变量、维度、关系与逻辑顺序;
|
| 577 |
+
- 它将作为“目录生成模型”的输入,因此必须让模型能看出报告中应有哪些章节与子章节。
|
| 578 |
+
|
| 579 |
+
写作要求:
|
| 580 |
+
1. **信息保留**:
|
| 581 |
+
- 保留每个变量的关键结论、趋势、特征、显著差异;
|
| 582 |
+
- 明确变量间的联系、对比或影响;
|
| 583 |
+
- 不得省略任何对分析主题有价值的事实。
|
| 584 |
+
|
| 585 |
+
2. **结构导向**:
|
| 586 |
+
- 按逻辑顺序组织:总体特征 → 各变量分析 → 变量间关系 → 潜在规律;
|
| 587 |
+
- 若存在不同主题(如气象因素、污染物指标、模型结果),应自然体现层次;
|
| 588 |
+
- 语义中隐含章节边界信号(如“首先…其次…最后…”、“在气象变量方面…”、“在建模部分…”等)。
|
| 589 |
+
|
| 590 |
+
3. **语言风格**:
|
| 591 |
+
- 专业、清晰、客观;
|
| 592 |
+
- 使用完整句表达,不使用列表或编号;
|
| 593 |
+
- 可以稍微详细,不追求简短。
|
| 594 |
+
|
| 595 |
+
4. **输出格式**:
|
| 596 |
+
- 输出仅为一段完整文字;
|
| 597 |
+
- 不得加入标题、注释、JSON、代码块;
|
| 598 |
+
- 该文字将被直接送入目录生成模型,不对人类展示。
|
| 599 |
+
|
| 600 |
+
请生成符合上述要求的综合语义总结。
|
| 601 |
+
""".strip()
|
| 602 |
+
|
| 603 |
+
self.abstract = self.call(prompt)
|
| 604 |
+
|
| 605 |
+
return self.abstract
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def check_full(self):
|
| 609 |
+
"""
|
| 610 |
+
返回结构化的内容,遵守图片插入协议:
|
| 611 |
+
- 每个分析内容前标注索引
|
| 612 |
+
- 图片插入位置用 [FIG:index] 表示
|
| 613 |
+
- 后续处理时可根据此协议替换为实际图像
|
| 614 |
+
"""
|
| 615 |
+
if self.full is None:
|
| 616 |
+
analysis_list = self.summary_fig_analysis_list()
|
| 617 |
+
|
| 618 |
+
if not analysis_list :
|
| 619 |
+
self.full = "暂无可视化分析内容。"
|
| 620 |
+
return self.full
|
| 621 |
+
|
| 622 |
+
# 构造结构化文本:带图片插入标记
|
| 623 |
+
full_parts = ["""【阶段说明】这是数据分析流程中的数据可视化阶段。"""]
|
| 624 |
+
for i, item in enumerate(analysis_list):
|
| 625 |
+
desc = item["analysis"]
|
| 626 |
+
part = f"""
|
| 627 |
+
【对图 {i}的分析】
|
| 628 |
+
{desc}
|
| 629 |
+
[FIG:{i}] # 图片插入位置标记
|
| 630 |
+
""".strip()
|
| 631 |
+
full_parts.append(part)
|
| 632 |
+
|
| 633 |
+
self.full = "\n\n".join(full_parts)
|
| 634 |
+
|
| 635 |
+
# 添加协议说明
|
| 636 |
+
protocol_note = """
|
| 637 |
+
---
|
| 638 |
+
# 图片插入处理协议说明:
|
| 639 |
+
# [FIG:index] 表示图片插入位置
|
| 640 |
+
# index 对应分析内容中的索引
|
| 641 |
+
# 你在需要放图的地方用 [FIG:index] 代替即可
|
| 642 |
+
""".strip()
|
| 643 |
+
|
| 644 |
+
self.full = f"{self.full}\n\n{protocol_note}"
|
| 645 |
+
|
| 646 |
+
return self.full
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def extract_plotly_info(fig):
|
| 650 |
+
"""
|
| 651 |
+
从 Plotly Figure(对象 / dict / 字符串)中提取关键信息:
|
| 652 |
+
- 图标题
|
| 653 |
+
- X/Y 轴标题
|
| 654 |
+
- 图类型
|
| 655 |
+
- 颜色信息
|
| 656 |
+
- trace 数量
|
| 657 |
+
"""
|
| 658 |
+
import ast
|
| 659 |
+
import plotly.graph_objects as go
|
| 660 |
+
|
| 661 |
+
if isinstance(fig, go.Figure):
|
| 662 |
+
fig = fig.to_dict()
|
| 663 |
+
elif isinstance(fig, dict):
|
| 664 |
+
pass
|
| 665 |
+
elif isinstance(fig, str):
|
| 666 |
+
clean_str = fig.strip()
|
| 667 |
+
if clean_str.startswith("Figure("):
|
| 668 |
+
clean_str = clean_str[len("Figure("):-1]
|
| 669 |
+
try:
|
| 670 |
+
fig = ast.literal_eval(clean_str)
|
| 671 |
+
except Exception as e:
|
| 672 |
+
raise ValueError(f"无法解析字符串形式的 Figure: {e}")
|
| 673 |
+
else:
|
| 674 |
+
raise TypeError(f"不支持的 fig 类型: {type(fig)}")
|
| 675 |
+
|
| 676 |
+
layout = fig.get("layout", {})
|
| 677 |
+
title = layout.get("title", {}).get("text", "")
|
| 678 |
+
xaxis_title = layout.get("xaxis", {}).get("title", {}).get("text", "")
|
| 679 |
+
yaxis_title = layout.get("yaxis", {}).get("title", {}).get("text", "")
|
| 680 |
+
|
| 681 |
+
data_list = fig.get("data", [])
|
| 682 |
+
types = list({d.get("type", "") for d in data_list})
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
return {
|
| 686 |
+
"title": title or "(无标题)",
|
| 687 |
+
"xaxis": xaxis_title or "(无X轴标题)",
|
| 688 |
+
"yaxis": yaxis_title or "(无Y轴标题)",
|
| 689 |
+
"types": types,
|
| 690 |
+
|
| 691 |
+
}
|
prompt_engineer/sec4_call_llm.py
ADDED
|
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
from prompt_engineer.call_llm import LLMClient
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ModelingCodingAgent(LLMClient):
|
| 7 |
+
|
| 8 |
+
def __init__(self, *args, **kwargs):
|
| 9 |
+
|
| 10 |
+
super().__init__(*args, **kwargs)
|
| 11 |
+
self.allowed_libs = [
|
| 12 |
+
"numpy", "sklearn.model_selection", "sklearn.preprocessing", "sklearn.ensemble", 'torch', 'torchvision', 'torchaudio', 'xgboost', 'lightgbm'
|
| 13 |
+
]
|
| 14 |
+
self.code = None
|
| 15 |
+
self.result = None
|
| 16 |
+
self.suggestion = None
|
| 17 |
+
self.user_selection = None
|
| 18 |
+
self.par_content = ""
|
| 19 |
+
self.inference_code = None
|
| 20 |
+
self.best_model = None
|
| 21 |
+
self.inference_data = None
|
| 22 |
+
self.inference_processed_df = None
|
| 23 |
+
self.abstract=None
|
| 24 |
+
self.full = None
|
| 25 |
+
self.error = None
|
| 26 |
+
self.inference_error = None
|
| 27 |
+
self.target = None
|
| 28 |
+
self.finish_auto_task = False
|
| 29 |
+
self.best_model_gz_bytes = None
|
| 30 |
+
self.debug_num = 0
|
| 31 |
+
self.refined_suggestions = None
|
| 32 |
+
|
| 33 |
+
def finish_auto(self):
|
| 34 |
+
|
| 35 |
+
self.finish_auto_task = True
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def save_best_model_gz_bytes(self, best_model_gz_bytes):
|
| 39 |
+
|
| 40 |
+
self.best_model_gz_bytes = best_model_gz_bytes
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_best_model_gz_bytes(self):
|
| 44 |
+
|
| 45 |
+
return self.best_model_gz_bytes
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def save_target(self, target):
|
| 49 |
+
|
| 50 |
+
self.target = target
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_target(self):
|
| 54 |
+
|
| 55 |
+
return self.target
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def save_error(self, error):
|
| 59 |
+
|
| 60 |
+
self.error = error
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_error(self):
|
| 64 |
+
|
| 65 |
+
return self.error
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def save_inference_error(self, inference_error):
|
| 69 |
+
|
| 70 |
+
self.inference_error = inference_error
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_inference_error(self):
|
| 74 |
+
|
| 75 |
+
return self.inference_error
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def save_inference_data(self, inference_data):
|
| 79 |
+
|
| 80 |
+
self.inference_data = inference_data
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def load_inference_data(self):
|
| 84 |
+
|
| 85 |
+
return self.inference_data
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def save_inference_processed_df(self, inference_processed_df):
|
| 89 |
+
|
| 90 |
+
self.inference_processed_df = inference_processed_df
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def load_inference_processed_df(self):
|
| 94 |
+
|
| 95 |
+
return self.inference_processed_df
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def save_inference_code(self, code):
|
| 99 |
+
|
| 100 |
+
self.inference_code = code
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def load_inference_code(self):
|
| 104 |
+
|
| 105 |
+
return self.inference_code
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def save_best_model(self, best_model):
|
| 109 |
+
|
| 110 |
+
self.best_model = best_model
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def load_best_model(self):
|
| 114 |
+
|
| 115 |
+
return self.best_model
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def save_code(self, code):
|
| 119 |
+
|
| 120 |
+
self.code = code
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def load_code(self):
|
| 124 |
+
|
| 125 |
+
return self.code
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def save_suggestion(self, suggestion):
|
| 129 |
+
|
| 130 |
+
self.suggestion = suggestion
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def load_suggestion(self):
|
| 134 |
+
|
| 135 |
+
return self.suggestion
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def save_modeling_result(self, result):
|
| 139 |
+
|
| 140 |
+
self.result = result
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def load_modeling_result(self):
|
| 144 |
+
|
| 145 |
+
return self.result
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def save_user_selection(self, user_selection):
|
| 149 |
+
|
| 150 |
+
self.user_selection = user_selection
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def load_user_selection(self):
|
| 154 |
+
|
| 155 |
+
return self.user_selection
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def refine_suggestions(self):
|
| 159 |
+
"""将 LLM 返回的预处理推荐进行信息提取"""
|
| 160 |
+
|
| 161 |
+
prompt = f"""
|
| 162 |
+
请阅读以下建模建议,并将其转化为对下一个 coding agent 的清晰建模任务指令。
|
| 163 |
+
|
| 164 |
+
=== 建模建议 ===
|
| 165 |
+
{self.suggestion}
|
| 166 |
+
|
| 167 |
+
=== 输出要求(必须严格遵守) ===
|
| 168 |
+
1. 输出为纯文本,不使用任何 Markdown、编号或符号;
|
| 169 |
+
2. 指令应简洁明确,便于 coding agent 直接理解并执行;
|
| 170 |
+
3. 内容应聚焦于模型构建、训练或评估的具体任务;
|
| 171 |
+
4. 避免解释性或分析性语言,仅描述“需要执行的操作”;
|
| 172 |
+
5. 输出应覆盖所有关键步骤,使 coding agent 能独立完成建模流程。
|
| 173 |
+
""".strip()
|
| 174 |
+
|
| 175 |
+
refined_suggestions = self.call(prompt)
|
| 176 |
+
self.refined_suggestions = refined_suggestions
|
| 177 |
+
|
| 178 |
+
print(refined_suggestions)
|
| 179 |
+
|
| 180 |
+
return refined_suggestions
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def code_generation(self, df_head: str, user_prompt: str) -> str:
|
| 184 |
+
"""生成 LLM prompt:要求 LLM 输出 result_dict(可 JSON 序列化)。"""
|
| 185 |
+
allowed = ", ".join(self.allowed_libs)
|
| 186 |
+
|
| 187 |
+
if self.refined_suggestions is None:
|
| 188 |
+
suggestion = user_prompt
|
| 189 |
+
else:
|
| 190 |
+
suggestion = self.refined_suggestions
|
| 191 |
+
|
| 192 |
+
prompt = (
|
| 193 |
+
f"""请**严格只输出纯 Python 代码**,**不要**输出任何解释性文字、注释、示例、markdown code fence(禁止出现 ``` 或 ```python 等)。运行环境已提供 pandas DataFrame 变量 `df`、numpy(np)、train_test_split、StandardScaler、以及用户在 Requirement 中可能提到的任意模型类(例如 RandomForestRegressor、GradientBoostingRegressor、LinearRegression、XGBRegressor、LogisticRegression、SVC 等)。
|
| 194 |
+
|
| 195 |
+
要求:
|
| 196 |
+
|
| 197 |
+
1) 使��� 80/20 切分(random_state=42),根据用户需求决定是否对数值特征标准化(StandardScaler),如果标准化,务必只应用于数值列并在训练/测试集上分别执行 fit_transform/transform。
|
| 198 |
+
2) **对 Requirement 中列出的所有模型都依次训练和评估**,不得只选随机森林;如果用户在 Requirement 中指定了多个模型名称,脚本必须循环遍历这些模型并分别训练、预测、计算指标。
|
| 199 |
+
3) 不要导入任何评价库(如 sklearn.metrics),如需评价请用 numpy 手写实现常见指标(回归:MAE、MSE、R2;分类:accuracy、precision、recall、f1)。
|
| 200 |
+
4) **脚本最后必须只输出并赋值一个变量 `result_dict`,且它是一个可以 JSON 序列化的 Python dict。**
|
| 201 |
+
推荐 schema(必须包含以下键):
|
| 202 |
+
{{
|
| 203 |
+
"dataset": "<可选描述字符串>",
|
| 204 |
+
"models": [
|
| 205 |
+
{{
|
| 206 |
+
"name": "<模型类名>",
|
| 207 |
+
"type": "<regression 或 classification>",
|
| 208 |
+
"metrics": {{ "<指标名>": <float>, ... }}
|
| 209 |
+
}},
|
| 210 |
+
...
|
| 211 |
+
],
|
| 212 |
+
"best_model": {{
|
| 213 |
+
"name": "<得分最优的模型类名>",
|
| 214 |
+
"score": <float>
|
| 215 |
+
}},
|
| 216 |
+
"artifacts": {{
|
| 217 |
+
"best_model_b64": "<base64 字符串>",
|
| 218 |
+
"best_model_format": "pickle+gzip"
|
| 219 |
+
}},
|
| 220 |
+
// 如模型过大,可选 "artifact_warning": <int 字节大小>
|
| 221 |
+
// 以及用户在 Requirement 中提出的其他字段
|
| 222 |
+
}}
|
| 223 |
+
5) 确保所有数值均为 Python 原生类型(float、int),字段名严格为 models、best_model、artifacts;如果用户有额外需求,如记录训练时间、特征重要性等,也请加入 result_dict。
|
| 224 |
+
6) 模型导出:训练完毕后,将选定的 best_model 用 pickle 序列化并 gzip 压缩,再 base64 编码;把编码字符串和格式信息填入 result_dict["artifacts"],并确保最终 result_dict 可 JSON 序列化。
|
| 225 |
+
7) 脚本末尾仅包含一行 `result_dict = {{...}}`,不要 print、不创建其他全局变量、不读写文件。
|
| 226 |
+
8) 如果模型序列化后的字节数超过合理大小,请在 result_dict 中添加 `"artifact_warning": <字节数>`。
|
| 227 |
+
9) 不要使用任何外部 IO 或文件操作。
|
| 228 |
+
10) 请准确实现Requirement中要求的模型,不许添加Requirement之外的模型,若先提供的库中无法直接调用对应模型,请手动实现!
|
| 229 |
+
|
| 230 |
+
示例数据头部:
|
| 231 |
+
{df_head}
|
| 232 |
+
|
| 233 |
+
Requirement(请根据以下建模任务指令,对所有列出的模型依次执行训练与评估。若某模型在当前环境不可用,请手动实现对应算法或类,使结果完整可复现):
|
| 234 |
+
{suggestion}
|
| 235 |
+
|
| 236 |
+
Allowed libraries: {allowed}。
|
| 237 |
+
|
| 238 |
+
返回:完整 Python 代码(纯代码块)。"""
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if self.error is not None:
|
| 242 |
+
if self.debug_num < 5 :
|
| 243 |
+
self.debug_num += 1
|
| 244 |
+
prompt += f"""
|
| 245 |
+
上次生成的代码运行失败。
|
| 246 |
+
【错误信息】:
|
| 247 |
+
{self.error}
|
| 248 |
+
|
| 249 |
+
【原始代码】:
|
| 250 |
+
{self.code}
|
| 251 |
+
|
| 252 |
+
请在不输出任何解释性文字的情况下,推理并理解导致错误的根本原因,
|
| 253 |
+
|
| 254 |
+
要求:
|
| 255 |
+
1. 不输出任何分析、解释或说明(包括文字、列表或注释段落);
|
| 256 |
+
2. 可在代码内部使用简短注释说明关键修改;
|
| 257 |
+
3. 若错误源于逻辑、数据结构或函数使用不当,请自行调整;
|
| 258 |
+
4. 若依赖库方法不适用,可自行实现替代函数;
|
| 259 |
+
5. 生成的代码必须可独立运行,无语法错误;
|
| 260 |
+
6. 保持整体逻辑与原代码意图一致,仅做必要修正。
|
| 261 |
+
"""
|
| 262 |
+
else:
|
| 263 |
+
self.debug_num = 0
|
| 264 |
+
|
| 265 |
+
raw = self.call(prompt)
|
| 266 |
+
|
| 267 |
+
return raw
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def result_format_prompt(self, result_json: str) -> str:
|
| 271 |
+
"""生成 LLM prompt:要求 LLM 输出 result_dict(可 JSON 序列化)。"""
|
| 272 |
+
|
| 273 |
+
prompt = f"""
|
| 274 |
+
下面给出一个 JSON 对象(包含模型评估结果结构)。请将其转换为一份对人类友好的 Markdown 报告,输出要求如下:
|
| 275 |
+
|
| 276 |
+
=== 输出要求 ===
|
| 277 |
+
1. 报告开头需有一行简短的“数据集说明”。
|
| 278 |
+
2. 对每个模型,展示以下内容:
|
| 279 |
+
- 模型名称;
|
| 280 |
+
- 模型类型(分类 / 回归);
|
| 281 |
+
- 主要性能指标(如准确率、R²、MAE、MSE 等),每个指标保留 4 位小数;
|
| 282 |
+
- 建议使用表格或分点列表清晰呈现。
|
| 283 |
+
3. 明确标出 **best_model**(以粗体高亮显示其名称和最优指标)。
|
| 284 |
+
4. 若 JSON 中包含特征工程相关信息,���在“特征工程说明”部分详细描述其具体方法与作用。
|
| 285 |
+
5. 输出格式:
|
| 286 |
+
- 只返回 Markdown 文本;
|
| 287 |
+
- 不得使用任何代码块标记(如 ```、```markdown 等);
|
| 288 |
+
- 不输出解释性文字,仅输出最终报告内容(便于直接渲染于 Streamlit)。
|
| 289 |
+
|
| 290 |
+
=== 输入 JSON ===
|
| 291 |
+
{result_json}
|
| 292 |
+
""".strip()
|
| 293 |
+
|
| 294 |
+
raw = self.call(prompt)
|
| 295 |
+
|
| 296 |
+
return raw
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def get_model_suggestion(
|
| 300 |
+
self,
|
| 301 |
+
user_input=None,
|
| 302 |
+
memory_limit: int = 6, # 控制引入的 memory 轮数
|
| 303 |
+
) -> str:
|
| 304 |
+
"""
|
| 305 |
+
根据数据集与历史上下文,生成建模阶段的智能建议。
|
| 306 |
+
自动整合 memory(最近几轮对话)作为辅助上下文。
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
# === 加载基础数据 ===
|
| 310 |
+
df = self.load_df()
|
| 311 |
+
df_head = df.head().to_string(index=False)
|
| 312 |
+
columns = df.columns.tolist()
|
| 313 |
+
data_info = f"数据列名: {columns}\n\n数据前5行:\n{df_head}"
|
| 314 |
+
|
| 315 |
+
# === 整理 memory 片段 ===
|
| 316 |
+
recent_memory = self.memory[-memory_limit:] if getattr(self, "memory", None) else []
|
| 317 |
+
if recent_memory:
|
| 318 |
+
formatted_memory = "\n".join(
|
| 319 |
+
f"{m['role']}: {m['content']}" for m in recent_memory
|
| 320 |
+
)
|
| 321 |
+
memory_block = f"\n=== 历史上下文(仅供参考) ===\n{formatted_memory}\n"
|
| 322 |
+
else:
|
| 323 |
+
memory_block = ""
|
| 324 |
+
|
| 325 |
+
# === 主 prompt 组装 ===
|
| 326 |
+
prompt = f"""
|
| 327 |
+
你是一位资深的机器学习建模专家,请基于以下信息进行分析与推理,输出针对性建模建议或改进方案。
|
| 328 |
+
|
| 329 |
+
=== 数据信息 ===
|
| 330 |
+
{data_info}
|
| 331 |
+
|
| 332 |
+
=== 历史上下文(仅供参考) ===
|
| 333 |
+
{memory_block}
|
| 334 |
+
""".strip()
|
| 335 |
+
|
| 336 |
+
# 若用户有明确建模目标
|
| 337 |
+
if getattr(self, "target", None):
|
| 338 |
+
prompt += f"""
|
| 339 |
+
|
| 340 |
+
=== 建模目标 ===
|
| 341 |
+
{self.target}
|
| 342 |
+
(请务必满足该目标,并在回答中明确复述建模意图。)
|
| 343 |
+
"""
|
| 344 |
+
|
| 345 |
+
# 若用户额外输入了需求
|
| 346 |
+
if user_input:
|
| 347 |
+
prompt += f"""
|
| 348 |
+
|
| 349 |
+
=== 用户当前需求 ===
|
| 350 |
+
{user_input}
|
| 351 |
+
(请严格满足该需求。若为局部修改,请保留原逻辑,仅更新指定部分。)
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
# 若有之前生成的训练代码
|
| 355 |
+
train_code = self.load_code()
|
| 356 |
+
if train_code:
|
| 357 |
+
prompt += f"""
|
| 358 |
+
|
| 359 |
+
=== 历史训练代码 ===
|
| 360 |
+
{train_code}
|
| 361 |
+
|
| 362 |
+
请在充分理解上述代码的基础上,提出 **1–2 条高质量的模型改进建议**。
|
| 363 |
+
可从以下角度思考,但不限于此:
|
| 364 |
+
- 模型结构优化(如增加层数、调整激活函数、替换模型类型等);
|
| 365 |
+
- 特征工程改进(如变量选择、特征交互、归一化策略等);
|
| 366 |
+
- 训练流程优化(如正则化、学习率调度、损失函数调整等);
|
| 367 |
+
- 超参数调整(如树深度、学习率、batch size 等)。
|
| 368 |
+
在给出建议时,请简要说明“为什么”与“预期改进效果”。
|
| 369 |
+
"""
|
| 370 |
+
else:
|
| 371 |
+
prompt += """
|
| 372 |
+
|
| 373 |
+
=== 建模建议任务 ===
|
| 374 |
+
请根据数据特征和上下文,推荐 2–3 个适合的模型方案。
|
| 375 |
+
要求:
|
| 376 |
+
1. 每个模型需包含模型名称、主要原理、适用场景;
|
| 377 |
+
2. 指出其在当前任务中的优势与潜在局限;
|
| 378 |
+
3. 保持语言专业、简洁,不输出代码。
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
# # === 主 prompt 组装 ===
|
| 383 |
+
# prompt = f"""
|
| 384 |
+
# 你是一位资深的机器学习建模专家。
|
| 385 |
+
|
| 386 |
+
# 以下是用户的数据信息:
|
| 387 |
+
# {data_info}
|
| 388 |
+
|
| 389 |
+
# {memory_block}
|
| 390 |
+
# """.strip()
|
| 391 |
+
|
| 392 |
+
# # 若用户有明确建模目标
|
| 393 |
+
# if getattr(self, "target", None):
|
| 394 |
+
# prompt += f"\n\n建模目标:{self.target}(务必满足,并请在回答中复述)"
|
| 395 |
+
|
| 396 |
+
# # 若用户额外输入了需求
|
| 397 |
+
# if user_input:
|
| 398 |
+
# prompt += f"""\n\n用户的当前需求:{user_input}(务必满足!)
|
| 399 |
+
# 若用户的要求是局部更新,则保留先前内容,仅修改特定部分。"""
|
| 400 |
+
|
| 401 |
+
# # 若有之前生成的训练代码
|
| 402 |
+
# train_code = self.load_code()
|
| 403 |
+
# if train_code:
|
| 404 |
+
# prompt += f"""
|
| 405 |
+
|
| 406 |
+
# 用户之前生成的训练代码:
|
| 407 |
+
# {train_code}
|
| 408 |
+
|
| 409 |
+
# 请在理解该代码的基础上,提供 **1–2 条模型改进建议**,
|
| 410 |
+
# 可涉及但不限于:
|
| 411 |
+
# - 模型结构调整
|
| 412 |
+
# - 特征工程优化
|
| 413 |
+
# - 模型替换(例如从树模型切换为深度学习模型)
|
| 414 |
+
# - 超参数调整或正则化策略优化
|
| 415 |
+
# """
|
| 416 |
+
# else:
|
| 417 |
+
# prompt += """
|
| 418 |
+
|
| 419 |
+
# 请基于数据特征,推荐 2–3 个合适的模型,
|
| 420 |
+
# 并说明每个模型的适用场景和优劣分析。
|
| 421 |
+
# """
|
| 422 |
+
|
| 423 |
+
# # 若存在以往建模结果
|
| 424 |
+
# modeling_result = self.load_modeling_result()
|
| 425 |
+
# if modeling_result:
|
| 426 |
+
# prompt += f"\n\n用户之前的模型运行结果:\n{modeling_result}"
|
| 427 |
+
|
| 428 |
+
# === 调用 LLM ===
|
| 429 |
+
raw = self.call(prompt)
|
| 430 |
+
return raw
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def summary_html(self) -> str:
|
| 435 |
+
|
| 436 |
+
if self.code is None:
|
| 437 |
+
|
| 438 |
+
summary = None
|
| 439 |
+
|
| 440 |
+
return summary
|
| 441 |
+
|
| 442 |
+
else:
|
| 443 |
+
|
| 444 |
+
prompt = f"""
|
| 445 |
+
你正在撰写数据分析报告的**第四章:数据建模**。
|
| 446 |
+
请根据以下输入内容,综合分析并生成完整的章节正文。
|
| 447 |
+
内容需逻辑严谨、表达自然,体现专业的分析与总结能力。
|
| 448 |
+
|
| 449 |
+
=== 输出结构 ===
|
| 450 |
+
请严格按照以下五个小节组织内容:
|
| 451 |
+
|
| 452 |
+
1. 概述
|
| 453 |
+
- 说明本次建模的目标、研究背景及数据来源的上下文。
|
| 454 |
+
|
| 455 |
+
2. 方法说明
|
| 456 |
+
- 介绍所采用的模型或算法的核心思想与实现流程;
|
| 457 |
+
- 若涉及特征工程、超参数选择或数据预处理,请一并说明;
|
| 458 |
+
- 可适当涉及模型的数学原理或优化机制,以体现技术深度。
|
| 459 |
+
|
| 460 |
+
3. 关键代码解读
|
| 461 |
+
- 聚焦核心函数与模块,说明其在建模流程中的作用;
|
| 462 |
+
- 可提及模型结构定义、训练循环、损失函数与评估逻辑;
|
| 463 |
+
- 语言应清晰简练,避免逐行解释。
|
| 464 |
+
|
| 465 |
+
4. 结果与评估
|
| 466 |
+
- 概述主要性能指标(如 Accuracy、AUC、MSE 等)及结果表现;
|
| 467 |
+
- 分析模型效果是否符合预期,并指出主要优劣与瓶颈。
|
| 468 |
+
|
| 469 |
+
5. 改进建议
|
| 470 |
+
- 针对模型性能与实验发现,提出具体可行的优化方向;
|
| 471 |
+
- 可从模型结构、特征选择、训练策略或正则化等角度给出建议。
|
| 472 |
+
|
| 473 |
+
=== 写作要求 ===
|
| 474 |
+
1. 使用自然流畅、正式的书面表达;
|
| 475 |
+
2. 避免使用模糊或主观词汇(如“可能”“似乎”“微妙”等);
|
| 476 |
+
3. 注重逻辑连贯与专业性;
|
| 477 |
+
4. 不输出标题、列表标记或额外说明,只生成正文内容。
|
| 478 |
+
""".strip()
|
| 479 |
+
|
| 480 |
+
if self.code is not None:
|
| 481 |
+
prompt += f"=== 数据建模代码 ===\n\n{self.code}"
|
| 482 |
+
if self.target is not None:
|
| 483 |
+
prompt += f"=== 用户建模目标 ===\n\n{self.target}"
|
| 484 |
+
if self.load_memory is not None:
|
| 485 |
+
prompt += f"=== 数据建模聊天对话 ===\n\n{self.load_memory}"
|
| 486 |
+
if self.result is not None:
|
| 487 |
+
prompt += f"=== 建模运行结果 ===\n\n{self.result}"
|
| 488 |
+
|
| 489 |
+
desc = self.call(prompt)
|
| 490 |
+
|
| 491 |
+
summary = {
|
| 492 |
+
"title": "建模分析",
|
| 493 |
+
"code": self.code,
|
| 494 |
+
"desc": desc,
|
| 495 |
+
"result": self.result,
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
return summary
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def summary_word(self) -> str:
|
| 502 |
+
|
| 503 |
+
return self.summary_html()
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def code_generation_for_inference(self, code, inference_df_head) -> str:
|
| 507 |
+
"""生成 LLM prompt:要求 LLM 输出推断分析代码。"""
|
| 508 |
+
|
| 509 |
+
prompt = (
|
| 510 |
+
f"""请生成完整的 Python 推断分析脚本(仅返回代码,不要任何解释文字)。运行环境已提供 pandas DataFrame 变量 `inference_df`、已经 train 好的模型 `model_obj`、numpy(np)、StandardScaler 库、align_features 辅助函数,其余未提及的库请手写实现。要求:
|
| 511 |
+
|
| 512 |
+
示例数据信息:
|
| 513 |
+
{code}, inference_df 前五行: {inference_df_head}(请勿引入不存在 inference_df 中的变量)
|
| 514 |
+
|
| 515 |
+
1) **可用变量说明:**
|
| 516 |
+
- `inference_df`:推断数据集(Pandas DataFrame)
|
| 517 |
+
- `model_obj`:已训练好的模型对象(从best_model.joblib加载)
|
| 518 |
+
- `np`:NumPy库
|
| 519 |
+
- `pd`:Pandas库
|
| 520 |
+
- `StandardScaler`:用于数据标准化的sklearn工具
|
| 521 |
+
|
| 522 |
+
2) **脚本必须实现的功能:**
|
| 523 |
+
a) 对推断数据进行与训练时完全一致的预处理(例如,缺失值处理、编码转换、标准化等)
|
| 524 |
+
b) **关键步骤:在预测前,必须使用align_features函数处理特征数据,确保特征数量和顺序与训练时一致**
|
| 525 |
+
c) 使用model_obj对预处理并对齐后的特征数据进行预测
|
| 526 |
+
d) 生成详细的推断报告,包含预处理步骤、预测结果分析等
|
| 527 |
+
|
| 528 |
+
3) **预测结果处理要求:**
|
| 529 |
+
- 将模型输出转换为人类可理解的形式(如概率值、类别标签、数值结果等)
|
| 530 |
+
- **必须生成带预测结果的DataFrame**:将��始或处理后的`inference_df`与预测结果合并,命名为`inference_df_with_predictions`
|
| 531 |
+
- 合并后的DataFrame必须包含原始特征列和一列名为`'prediction'`的预测结果列(模型输出多维时扩展为`prediction_0`, `prediction_1`, ...)
|
| 532 |
+
|
| 533 |
+
4) **序列化要求(用于前端下载):**
|
| 534 |
+
- 将`inference_df_with_predictions`转换为无索引的CSV格式
|
| 535 |
+
- 对CSV数据进行gzip压缩,然后编码为base64字符串
|
| 536 |
+
- 创建包含以下键的`result_dict['artifacts']`字典:
|
| 537 |
+
* `'predictions_df_b64'`:base64编码的压缩数据
|
| 538 |
+
* `'predictions_df_format'`:固定值'csv+gzip'
|
| 539 |
+
* `'predictions_df_size_bytes'`:压缩后的字节大小(整数)
|
| 540 |
+
- 在`result_dict`中添加`'predictions_df_records'`键,值为`inference_df_with_predictions.to_dict(orient='records')`
|
| 541 |
+
- 确保所有numpy/pandas类型转换为原生Python类型(int/float/str)以保证JSON可序列化
|
| 542 |
+
|
| 543 |
+
5) **代码结构与输出约束:**
|
| 544 |
+
- 脚本最后**仅**包含一行`result_dict = {...}`语句
|
| 545 |
+
- `result_dict`必须是完全JSON可序列化的Python字典
|
| 546 |
+
- 禁止任何外部IO操作(不读写文件)
|
| 547 |
+
- 禁止使用print语句或创建额外的全局变量
|
| 548 |
+
|
| 549 |
+
8) **生成代码质量要求:**
|
| 550 |
+
- 确保所有变量名称与上述规范严格一致
|
| 551 |
+
- 逻辑清晰,步骤完整,严格按照用户提供的数据和最佳模型文件生成代码
|
| 552 |
+
- 处理可能出现的各种异常情况,提高代码的稳定性和可靠性
|
| 553 |
+
|
| 554 |
+
返回:完整的Python代码(仅包含代码本身,不要任何解释性文字)。"""
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
raw = self.call(prompt)
|
| 558 |
+
|
| 559 |
+
return raw
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def check_abstract(self):
|
| 563 |
+
if self.abstract is None:
|
| 564 |
+
if self.code is None:
|
| 565 |
+
self.abstract = None
|
| 566 |
+
else:
|
| 567 |
+
prompt = f"""
|
| 568 |
+
这是数据分析流程中的“建模阶段”。
|
| 569 |
+
|
| 570 |
+
请基于以下信息,在保留所有关键信息的前提下,将内容整理成一段简洁、连贯的文字摘要,用于报告撰写中的建模小节预览。
|
| 571 |
+
|
| 572 |
+
=== 输入信息 ===
|
| 573 |
+
- 用户初始需求:{self.target}
|
| 574 |
+
- 建模代码:{self.code}
|
| 575 |
+
- 建模阶段的交互记录:{self.load_memory}
|
| 576 |
+
- 建模运行结果:{self.result}
|
| 577 |
+
|
| 578 |
+
=== 输出要求 ===
|
| 579 |
+
1. 以自然流畅的语言撰写一段总结,全面涵盖上述内容中的核心信息;
|
| 580 |
+
2. 重点说明建模目标、所用方法、主要实现逻辑与结果特征;
|
| 581 |
+
3. 避免逐行描述代码,仅提炼核心思路;
|
| 582 |
+
4. 语言应专业、客观,不使用“可能”“似乎”“也许”等模糊表达;
|
| 583 |
+
5. 输出仅为一段完整文字(不要标题、编号或列表);
|
| 584 |
+
6. 摘要应能让人据此判断该部分是否需要纳入最终报告。
|
| 585 |
+
""".strip()
|
| 586 |
+
|
| 587 |
+
desc = self.call(prompt)
|
| 588 |
+
self.abstract = desc
|
| 589 |
+
|
| 590 |
+
return self.abstract
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def check_full(self):
|
| 594 |
+
if self.full is None:
|
| 595 |
+
if self.code is None:
|
| 596 |
+
self.full = None
|
| 597 |
+
else:
|
| 598 |
+
self.full = f"""
|
| 599 |
+
【阶段说明】这是数据分析流程中的数据建模阶段。
|
| 600 |
+
【用户初始需求】{self.target}
|
| 601 |
+
【数据建模代码】{self.code}
|
| 602 |
+
【建模聊天对话】{self.load_memory}
|
| 603 |
+
【建模运行结果】{self.result}
|
| 604 |
+
""".strip()
|
| 605 |
+
|
| 606 |
+
return self.full
|
prompt_engineer/sec5_call_llm.py
ADDED
|
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import openai
|
| 3 |
+
import requests
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from config import MODEL_CONFIGS
|
| 10 |
+
from prompt_engineer.call_llm import LLMClient
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ReportAgent(LLMClient):
|
| 14 |
+
|
| 15 |
+
def __init__(self, *args, **kwargs):
|
| 16 |
+
|
| 17 |
+
super().__init__(*args, **kwargs)
|
| 18 |
+
self.template = None
|
| 19 |
+
self.name = None
|
| 20 |
+
self.date = None
|
| 21 |
+
self.report_format = None
|
| 22 |
+
self.html = None
|
| 23 |
+
self.word = None
|
| 24 |
+
self.markdown = None
|
| 25 |
+
self.user_input = None
|
| 26 |
+
self.outline = None
|
| 27 |
+
self.outline_length = None
|
| 28 |
+
self.report= None
|
| 29 |
+
self.finish_auto_task = False
|
| 30 |
+
self.gen_mode = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def save_gen_mode(self, gen_mode):
|
| 34 |
+
|
| 35 |
+
self.gen_mode = gen_mode
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_gen_mode(self):
|
| 39 |
+
|
| 40 |
+
return self.gen_mode
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def finish_auto(self):
|
| 44 |
+
|
| 45 |
+
self.finish_auto_task = True
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def save_user_input(self, user_input):
|
| 49 |
+
|
| 50 |
+
self.user_input = user_input
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_user_input(self):
|
| 54 |
+
|
| 55 |
+
return self.user_input
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def save_outline_length(self, outline_length):
|
| 59 |
+
|
| 60 |
+
self.outline_length = outline_length
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_outline_length(self):
|
| 64 |
+
|
| 65 |
+
return self.outline_length
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def save_outline(self, outline):
|
| 69 |
+
|
| 70 |
+
self.outline = outline
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_outline(self):
|
| 74 |
+
|
| 75 |
+
return self.outline
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def save_template(self, template):
|
| 79 |
+
|
| 80 |
+
self.template = template
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def load_template(self):
|
| 84 |
+
|
| 85 |
+
return self.template
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def save_word(self, word):
|
| 89 |
+
|
| 90 |
+
self.word = word
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def load_word(self):
|
| 94 |
+
|
| 95 |
+
return self.word
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def save_html(self, html):
|
| 99 |
+
|
| 100 |
+
self.html = html
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def load_html(self):
|
| 104 |
+
|
| 105 |
+
return self.html
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def save_markdown(self, markdown):
|
| 109 |
+
|
| 110 |
+
self.markdown = markdown
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def load_markdown(self):
|
| 114 |
+
|
| 115 |
+
return self.markdown
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def save_report(self, report):
|
| 119 |
+
|
| 120 |
+
self.report = report
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def load_report(self):
|
| 124 |
+
|
| 125 |
+
return self.report
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def save_report_format(self, report_format):
|
| 129 |
+
|
| 130 |
+
self.report_format = report_format
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def load_report_format(self):
|
| 134 |
+
|
| 135 |
+
return self.report_format
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def save_date(self, date):
|
| 139 |
+
|
| 140 |
+
self.date = date
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def load_date(self):
|
| 144 |
+
|
| 145 |
+
return self.date
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def save_name(self, name):
|
| 149 |
+
|
| 150 |
+
self.name = name
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def load_name(self):
|
| 154 |
+
|
| 155 |
+
return self.name
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def generate_template(self, user_input = None) -> str:
|
| 159 |
+
"""
|
| 160 |
+
调用 LLM 生成一个带有占位符的 HTML 报告模板,
|
| 161 |
+
包含标题、摘要、表格和图表区域等。
|
| 162 |
+
"""
|
| 163 |
+
prompt = (
|
| 164 |
+
"""
|
| 165 |
+
我希望你输出一个现代、简洁且美观的 HTML 章节模板,请满足以下要求:
|
| 166 |
+
|
| 167 |
+
1. 整体配色采用“蓝 – 白”主题:
|
| 168 |
+
- 背景为白色,标题与边框使用深蓝(#1E3A8A)和浅蓝(#3B82F6);
|
| 169 |
+
2. 最外层用 `<section class="chapter" id="chapter-{{ num }}">` 包裹;
|
| 170 |
+
3. 标题使用 `<h2>{{ title }}</h2>`:
|
| 171 |
+
- 文字颜色:#1E3A8A;
|
| 172 |
+
- 下方装饰性下划线:高度 3px,颜色 #3B82F6,宽度 30%;
|
| 173 |
+
4. 正文内容区 `<div class="content">{{ body }}</div>`,支持任意 HTML;
|
| 174 |
+
- **仅对“重点摘录”或“引用”段落加用圆角矩形**,其余普通段落保持标准 `<p>` 样式;
|
| 175 |
+
- 圆角矩形样式:背景 #EFF6FF,padding 12px,border-radius 8px,margin-bottom 16px;
|
| 176 |
+
5. 如果有图片列表 `images`:
|
| 177 |
+
- ≤3 张时水平并排;>3 张时自动换行,每行最多 3 张;
|
| 178 |
+
- `<img>` 带 6px 圆角、轻微阴影 `box-shadow:0 2px 6px rgba(0,0,0,0.1)`;
|
| 179 |
+
6. 在 `<style>` 中内联基础样式:
|
| 180 |
+
- `.chapter` 外层间距、内边距、最大宽度、白底阴影;
|
| 181 |
+
- `.chapter h2` 字体、颜色、下划线;
|
| 182 |
+
- `.content p` 和 `.content .highlight`(重点段落)样式区分;
|
| 183 |
+
- `.images` 的 flex 布局与 gap;
|
| 184 |
+
7. 使用 Jinja2 占位符:
|
| 185 |
+
- 普通段落:`{% for p in paragraphs %}<p>{{ p }}</p>{% endfor %}`;
|
| 186 |
+
- 重点段落数组 `highlights`:`{% for h in highlights %}<div class="highlight">{{ h }}</div>{% endfor %}`;
|
| 187 |
+
8. **只输出完整的 `<section>…</section>` 片段**,不要任何解释文字或其他标签。
|
| 188 |
+
9. 在模板的 .content 区域加入一个 DataFrame 占位并用 Jinja2 渲染变量 df_html({{ df_html | safe }}),要求输出为响应式 HTML 表格(显示表头、支持横向滚动并在窄屏下自动换行),以便在导出为 PDF 时正确排版。
|
| 189 |
+
|
| 190 |
+
请直接给出最终的 HTML 模板代码。
|
| 191 |
+
"""
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
if user_input is not None:
|
| 195 |
+
prompt += f"请根据用户需求进行调整{user_input}"
|
| 196 |
+
|
| 197 |
+
return self.call(prompt)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def fill_report(self, template: str, content: str) -> str:
|
| 201 |
+
"""
|
| 202 |
+
将 DataFrame 转为 HTML 表格,拼接进模板,
|
| 203 |
+
并让 LLM 对报告进行润色、补充解释文字。
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
prompt = (f"""
|
| 207 |
+
下面是章节结构模板:
|
| 208 |
+
{template}
|
| 209 |
+
请仅输出 `<section>` 里完整的 HTML(包括标题、正文、图片区块),请将重点内容用highlight凸显,
|
| 210 |
+
对于内容的分析具有一下要求:
|
| 211 |
+
1. 要用流畅的自然语言
|
| 212 |
+
2. 不要滥用形容词和副词,尽量用简单的动词和名词表达意思
|
| 213 |
+
3. 不用"可能""也许""似乎""微妙"等模糊表述
|
| 214 |
+
请根据一下提供的信息对文章进行深入分析:
|
| 215 |
+
""")
|
| 216 |
+
|
| 217 |
+
if content.get("title") is not None:
|
| 218 |
+
prompt += f"- title={content['title']}\n"
|
| 219 |
+
if content.get("fig_analysis") is not None:
|
| 220 |
+
prompt += f"- images及其分析(请将image也放入报告中):{content['fig_analysis']}\n"
|
| 221 |
+
if content.get("df") is not None:
|
| 222 |
+
prompt += f"- 表格预览(请将表格也放入报告中,输出美观完整):{content['df']}\n"
|
| 223 |
+
if content.get("code") is not None:
|
| 224 |
+
prompt += f"- 对应部分代码(请将代码中的重点公式与内容进行讲解与分析):{content['code']}\n"
|
| 225 |
+
if content.get("processed_df") is not None:
|
| 226 |
+
prompt += f"- 预处理后的数据预览:{content['processed_df']}\n"
|
| 227 |
+
if content.get("desc") is not None:
|
| 228 |
+
prompt += f"- 具体内容分析:{content['desc']}\n"
|
| 229 |
+
if content.get("header") is not None:
|
| 230 |
+
prompt = f"""
|
| 231 |
+
下面是章节结构模板:
|
| 232 |
+
{template}
|
| 233 |
+
要求:header单独占一页
|
| 234 |
+
- 请为我生成封面header:{content['header']}
|
| 235 |
+
"""
|
| 236 |
+
if content.get("footer") is not None:
|
| 237 |
+
prompt = f"""
|
| 238 |
+
下面是章节结构模板:
|
| 239 |
+
{template}
|
| 240 |
+
要求:footer单独占一页
|
| 241 |
+
- 请为我生成最后一页footer:{content['footer']}
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
prompt += "请仅返回提供html"
|
| 245 |
+
|
| 246 |
+
return self.call(prompt)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def fill_report_word(self, content: str) -> str:
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
prompt = (f"""
|
| 253 |
+
你是一个资深的数据分析专家,
|
| 254 |
+
请仅输出每一章节的完整的word内容(包括标题、正文、图片区块),
|
| 255 |
+
对于内容的分析具有一下要求:
|
| 256 |
+
1. 要用流畅的自然语言
|
| 257 |
+
2. 不要滥用形容词和副词,尽量用简单的动词和名词表达意思
|
| 258 |
+
3. 不用"可能""也许""似乎""微妙"等模糊表述
|
| 259 |
+
请根据一下提供的信息对文章进行深入分析:
|
| 260 |
+
""")
|
| 261 |
+
|
| 262 |
+
if content.get("title") is not None:
|
| 263 |
+
prompt += f"- title={content['title']}\n"
|
| 264 |
+
if content.get("fig_analysis") is not None:
|
| 265 |
+
prompt += f"- images及其分析(请将image也放入报告中):{content['fig_analysis']}\n"
|
| 266 |
+
if content.get("df") is not None:
|
| 267 |
+
prompt += f"- 表格预览(请将表格也放入报告中,输出美观完整):{content['df']}\n"
|
| 268 |
+
if content.get("code") is not None:
|
| 269 |
+
prompt += f"- 对应部分代码(请将代码中的重点公式与内容进行讲解与分析):{content['code']}\n"
|
| 270 |
+
if content.get("processed_df") is not None:
|
| 271 |
+
prompt += f"- 预处理后的数据预览:{content['processed_df']}\n"
|
| 272 |
+
if content.get("desc") is not None:
|
| 273 |
+
prompt += f"- 具体内容分析:{content['desc']}\n"
|
| 274 |
+
if content.get("header") is not None:
|
| 275 |
+
prompt = f"""
|
| 276 |
+
下面是章节结构模板:
|
| 277 |
+
{template}
|
| 278 |
+
要求:header单独占一页
|
| 279 |
+
- 请为我生成封面header:{content['header']}
|
| 280 |
+
"""
|
| 281 |
+
if content.get("footer") is not None:
|
| 282 |
+
prompt = f"""
|
| 283 |
+
下面是章节结构模板:
|
| 284 |
+
{template}
|
| 285 |
+
要求:footer单独占一页
|
| 286 |
+
- 请为我生成最后一页footer:{content['footer']}
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
prompt += "请仅返回提供html"
|
| 290 |
+
|
| 291 |
+
return self.call(prompt)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def get_content(self, agent):
|
| 295 |
+
|
| 296 |
+
content = agent.summary()
|
| 297 |
+
|
| 298 |
+
return content
|
| 299 |
+
|
| 300 |
+
def generate_toc_from_summary(self, full_summary) -> str:
|
| 301 |
+
"""
|
| 302 |
+
调用大模型,根据已有 summary 内容自动生成带有分级结构与内容大纲的目录(最多 2 级标题)
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
prompt = f"""
|
| 306 |
+
你是一位资深数据分析报告结构设计专家。
|
| 307 |
+
|
| 308 |
+
请你根据以下报告摘要内容,为该数据分析报告生成**层次清晰、内容具体、贴合数据本身**的目录结构。
|
| 309 |
+
|
| 310 |
+
【输出要求】
|
| 311 |
+
1. 格式:
|
| 312 |
+
- 纯文本输出(不得使用 Markdown、代码块、Python 列表或符号标记)
|
| 313 |
+
- 每行一个目录项,无缩进或前缀符号
|
| 314 |
+
- 示例格式:
|
| 315 |
+
1.概述(说明报告背景与目标)
|
| 316 |
+
2.数据导入(说明数据来源与结构)
|
| 317 |
+
2.1 数据概览(展示核心字段与样本规模)
|
| 318 |
+
2.1.1 租赁数量趋势(分析租赁随时间的变化)
|
| 319 |
+
2. 编号规则:
|
| 320 |
+
- 一级标题:1, 2, 3...
|
| 321 |
+
- 二级标题:2.1, 2.2...
|
| 322 |
+
- 三级标题:2.1.1, 2.1.2...
|
| 323 |
+
3. 内容说明:
|
| 324 |
+
- 所有标题与说明应以摘要为基础,可在保持主题一致的前提下,适度补充逻辑性或结构性内容。
|
| 325 |
+
- 每个标题后附一句说明,用于指导后续大模型撰写章节内容;
|
| 326 |
+
- 说明须以中文括号“( )”包裹;
|
| 327 |
+
- 每条说明需精准、具体,**明确指示该部分的写作任务、分析角度、数据焦点或方法方向**;
|
| 328 |
+
- 字数不超过 50 字;
|
| 329 |
+
- 上下级说明应保持语义连贯,避免重复;
|
| 330 |
+
- 说明可涉及:
|
| 331 |
+
- 要分析的变量或主题(如“气温”“租赁数量”“污染物浓度”);
|
| 332 |
+
- 要执行的任务(如“展示分布”“分析趋势”“比较模型性能”);
|
| 333 |
+
4. 禁止输出任何解释、前言、说明、提示、或多余空行,仅输出目录正文。
|
| 334 |
+
|
| 335 |
+
【生成逻辑】
|
| 336 |
+
1. 依据摘要内容中出现的主题(如数据特征、指标、变量名、任务目标)生成章节标题。
|
| 337 |
+
- 若摘要中提及 “租赁数量”“气温”“湿度”“时间”等,请将其体现在相关标题中。
|
| 338 |
+
- 避免使用模糊标题(如“数据分析”“关系探索”“模型评估”等)。
|
| 339 |
+
2. 报告可能包含模块:
|
| 340 |
+
“数据导入”、“数据预处理”、“数据可视化”、“建模分析”。
|
| 341 |
+
- 仅生成摘要中实际涉及的模块。
|
| 342 |
+
3. 确保章节间语义互斥(正交),避免内容重叠。
|
| 343 |
+
4. 根据详细程度动态调整层级:
|
| 344 |
+
- 简要:生成两级标题;
|
| 345 |
+
- 标准:生成三级标题;
|
| 346 |
+
- 详细:生成四级标题。
|
| 347 |
+
5. 若摘要涉及具体变量(如“Temperature”、“Rented Bike Count”),
|
| 348 |
+
请在目录中直接引用中文变量名(如“气温”、“租赁数量”),
|
| 349 |
+
以体现报告的“数据感知性”。
|
| 350 |
+
|
| 351 |
+
用户选择的目录详细程度为:{self.outline_length}
|
| 352 |
+
|
| 353 |
+
报告摘要如下:
|
| 354 |
+
{full_summary}
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
toc_response = self.call(prompt)
|
| 359 |
+
return toc_response.strip()
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def selected_photo_update_toc(self, toc, selected_full_contents_vis: str) -> list:
|
| 363 |
+
"""
|
| 364 |
+
根据完整报告内容 selected_full_contents_vis,更新 toc,在每个小节增加第四项:对应的图像编号列表。
|
| 365 |
+
"""
|
| 366 |
+
print(selected_full_contents_vis)
|
| 367 |
+
|
| 368 |
+
prompt = f"""
|
| 369 |
+
你是一位专业的数据分析报告结构与图文匹配专家。
|
| 370 |
+
|
| 371 |
+
任务:请你根据报告的目录结构和,正文内容和阶段说明,判断每个 [FIG:x] 图像最合适归属的章节。
|
| 372 |
+
|
| 373 |
+
【输入内容】
|
| 374 |
+
1. 目录结构(含标题、层级、内容大纲):
|
| 375 |
+
{toc}
|
| 376 |
+
|
| 377 |
+
2. 报告完整正文(带有 [FIG:x] 图片标记):
|
| 378 |
+
{selected_full_contents_vis}
|
| 379 |
+
|
| 380 |
+
【任务说明】
|
| 381 |
+
请你逐一分析每个 [FIG:x] 图像的出现上下文,并结合目录内容,判断该图应归属于哪个章节。
|
| 382 |
+
要求同时考虑:
|
| 383 |
+
- **语义匹配**:图像内容的主题(如污染物趋势、气象变化、时间分布、模型结果)与章节描述的一致性;
|
| 384 |
+
- **上下文位置**:图像在正文中出现时,其前后段落通常属于哪个章节;
|
| 385 |
+
- **粒度优先**:若图像语义符合多个章节(如“气象参数”与“气象参数图形分析”),优先归入更具体的章节(层级数字更大);
|
| 386 |
+
- **禁止误归**:禁止将图像分配到“概述”“结论”“摘要”等非分析或与图像不相关的章节!
|
| 387 |
+
- **全部使用**:所有 [FIG:x] 必须被使用一次,不得遗漏或重复。
|
| 388 |
+
|
| 389 |
+
【输出格式】
|
| 390 |
+
请以 Python 列表形式输出,每项为:
|
| 391 |
+
(标题, 层级, 内容大纲, 图编号列表)
|
| 392 |
+
要求:
|
| 393 |
+
- 图编号按出现顺序排列;
|
| 394 |
+
- 若无图片则为空列表 [];
|
| 395 |
+
- 层级仅用整数表示(1, 2, 3...);
|
| 396 |
+
- 不输出任何解释、注释、Markdown标记。
|
| 397 |
+
|
| 398 |
+
【示例格式】
|
| 399 |
+
[
|
| 400 |
+
('概述',1,'说明报告背景与目标',[]),
|
| 401 |
+
('数据导入',1,'说明数据来源与结构',[]),
|
| 402 |
+
('数据可视化',1,'展示变量特征与关系',[4,5]),
|
| 403 |
+
('气象参数分析',2,'研究温度与湿度对污染的影响',[2,3]),
|
| 404 |
+
('模型评估',2,'展示预测结果与误差',[6,7])
|
| 405 |
+
]
|
| 406 |
+
|
| 407 |
+
【提示与约束】
|
| 408 |
+
1. 若章节间存在嵌套关系,优先分配���最具体的子章节(如 3.1.2 比 3.1 更优)。
|
| 409 |
+
"""
|
| 410 |
+
|
| 411 |
+
toc_with_figs = self.call(prompt)
|
| 412 |
+
return toc_with_figs.strip()
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def summarize_all_sections(
|
| 416 |
+
self,
|
| 417 |
+
toc_md: str,
|
| 418 |
+
load_summary: str,
|
| 419 |
+
preproc_summary: str,
|
| 420 |
+
visual_summary: str,
|
| 421 |
+
coding_summary: str
|
| 422 |
+
) -> str:
|
| 423 |
+
"""
|
| 424 |
+
汇总所有 agent 的 summary,并根据 toc_md 结构进行文字性总结
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
# Step 1:拼接所有 agent 的摘要
|
| 428 |
+
section_summaries = {
|
| 429 |
+
"加载阶段": load_summary,
|
| 430 |
+
"预处理阶段": preproc_summary,
|
| 431 |
+
"可视化分析": visual_summary,
|
| 432 |
+
"模型建构": coding_summary,
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
# Step 2:构建大模型 prompt
|
| 436 |
+
prompt = f"""你现在是一个经验丰富的数据分析报告撰写助手。
|
| 437 |
+
|
| 438 |
+
我已经完成了一个数据分析项目的初稿,结构目录如下:
|
| 439 |
+
{toc_md}
|
| 440 |
+
|
| 441 |
+
现在我将为你提供各个章节的内容摘要,请你根据这些内容,用流畅的中文撰写一段总结性描述(可用于报告的导语或结语),要求包括但不限于:
|
| 442 |
+
|
| 443 |
+
1. 报告分析的主题方向
|
| 444 |
+
2. 各章节的核心处理逻辑和大致作用
|
| 445 |
+
3. 报告内容的整体风格与结构特性(例如是否包含图表、是否强调建模等)
|
| 446 |
+
4. 使用自然语言、风格正式,避免主观判断词汇(如“也许”、“不错”、“感觉”)
|
| 447 |
+
5. 最终输出 150~300 字中文总结段落,不需要标题
|
| 448 |
+
|
| 449 |
+
每个阶段摘要如下:\n\n"""
|
| 450 |
+
|
| 451 |
+
for title, content in section_summaries.items():
|
| 452 |
+
if content:
|
| 453 |
+
prompt += f"\n【{title}】\n{content}\n"
|
| 454 |
+
|
| 455 |
+
# 调用大模型总结
|
| 456 |
+
overall_summary = self.call(prompt)
|
| 457 |
+
|
| 458 |
+
return overall_summary
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def update_toc_with_relevant_sections(self, toc, agent_abstracts):
|
| 462 |
+
"""
|
| 463 |
+
根据 toc 和各模块摘要,为每个章节生成应参考的模块编号列表,
|
| 464 |
+
并将结果添加为第五项。
|
| 465 |
+
"""
|
| 466 |
+
prompt = f"""
|
| 467 |
+
你是一个专业的数据分析报告规划助手。
|
| 468 |
+
我将提供报告目录和各分析模块的摘要,请为每个章节确定应参考的模块编号列表。
|
| 469 |
+
|
| 470 |
+
报告目录(每个元素为四元组:标题、层级、内容大纲、图编号列表):
|
| 471 |
+
{toc}
|
| 472 |
+
|
| 473 |
+
各数据分析模块摘要如下:
|
| 474 |
+
{agent_abstracts}
|
| 475 |
+
请根据:
|
| 476 |
+
1. 各章节的标题、层级与内容大纲;
|
| 477 |
+
2. 各数据处理板块摘要;
|
| 478 |
+
3. 各章节的图编号分配情况(报告目录第四项);
|
| 479 |
+
|
| 480 |
+
合理判断各章节在生成报告时应参考哪些数据处理板块的信息。
|
| 481 |
+
输出要求:
|
| 482 |
+
|
| 483 |
+
- 对每个章节生成一个五元组 (标题, 层级, 内容大纲, 图编号列表, 模块编号列表)
|
| 484 |
+
- 标题, 层级, 内容大纲, 图编号列表一定不能改变,只在原有基础上添加第五项
|
| 485 |
+
- 模块编号列表为 Python list,例如 [0, 2]
|
| 486 |
+
- 若无需参考任何模块,返回 []
|
| 487 |
+
- 输出为 Python 列表,不含任何额外说明
|
| 488 |
+
示例:
|
| 489 |
+
输入:
|
| 490 |
+
[
|
| 491 |
+
('概述',1,'介绍报告背景与目标',[1]),
|
| 492 |
+
('数据可视化',1,'分析空气质量和相关环境变量的可视化图表',[2,3]),
|
| 493 |
+
('xxxx关联性分析',2,'分析相对湿度与其他污染物关系',[4,5])
|
| 494 |
+
]
|
| 495 |
+
输出:
|
| 496 |
+
[
|
| 497 |
+
('概述',1,'介绍报告背景与目标',[1],[1,2]),
|
| 498 |
+
('数据可视化',1,'分析空气质量和相关环境变量的可视化图表',[2,3],[0,1]),
|
| 499 |
+
('xxxx关联性分析',2,'分析相对湿度与其他污染物关系',[4,5],[2,3])
|
| 500 |
+
]
|
| 501 |
+
"""
|
| 502 |
+
toc_with_sections = self.call(prompt)
|
| 503 |
+
print(toc_with_sections)
|
| 504 |
+
return toc_with_sections.strip()
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def write_section_body(self, toc, t, selected_full_contents, history_content):
|
| 508 |
+
|
| 509 |
+
prompt = f"""
|
| 510 |
+
你是一个专业的数据分析报告撰写助手。你的任务是基于我提供的参考信息,生成逻辑清晰、结构严谨、内容专业的报告章节。
|
| 511 |
+
|
| 512 |
+
当前章节信息(四元组:标题、层级、内容大纲、图编号列表):
|
| 513 |
+
{t}
|
| 514 |
+
|
| 515 |
+
报告目录结构(包含所有章节的四元组信息):
|
| 516 |
+
{toc}
|
| 517 |
+
|
| 518 |
+
可参考的分析内容如下:
|
| 519 |
+
{selected_full_contents}
|
| 520 |
+
|
| 521 |
+
此前已生成的章节内容如下(用于保持整体风格一致、避免重复):
|
| 522 |
+
{history_content}
|
| 523 |
+
|
| 524 |
+
写作要求:
|
| 525 |
+
|
| 526 |
+
一、写作目标
|
| 527 |
+
1. 仅撰写当前章节“{t[0]}”的正文内容;
|
| 528 |
+
2. 内容必须以“参考信息”为核心依据,可在其逻辑框架内**进行适度拓展与归纳总结**;
|
| 529 |
+
3. 允许进行合理的专业性补充(如统计学解释、方法原理、结果含义),但**禁止编造具体数据、图表结果、实验场景或样本特征**;
|
| 530 |
+
4. 若参考信息不足,可补充一般性分析思路,但需保持内容通用、客观、抽象,不得具体化为假想数据。
|
| 531 |
+
|
| 532 |
+
二、语言与结构
|
| 533 |
+
1. 文风应正式、专业、学术化;
|
| 534 |
+
2. 论述应符合数据分析逻辑:先描述、后解释、再总结;
|
| 535 |
+
3. 每一自然段应围绕一个逻辑核心展开(如趋势、对比、相关性、分布特征等)。
|
| 536 |
+
|
| 537 |
+
三、图表使用规范
|
| 538 |
+
1. 正文中仅可使用本章节的图编号 {t[3]};
|
| 539 |
+
2. 使用占位符 [FIG:index] 标注图表位置;
|
| 540 |
+
3. 在每个占位符下方添加图片标题:
|
| 541 |
+
图:图片标题(简要说明图片内容及分析要点)
|
| 542 |
+
4. 图片位置与语义保持自然衔接:
|
| 543 |
+
- 若图片引出分析 → 放在段落开头;
|
| 544 |
+
- 若图片支撑论点 → 放在相关描述句之后;
|
| 545 |
+
- 若图片总结结果 → 放在段落结尾;
|
| 546 |
+
5. 不得增删或重排图片编号。
|
| 547 |
+
|
| 548 |
+
四、输出要求
|
| 549 |
+
- 仅输出正文内容;
|
| 550 |
+
- 不得输出标题、编号、解释文字、Markdown;
|
| 551 |
+
- 不使用加粗、斜体、符号修饰或非正文语句;
|
| 552 |
+
- 不得出现“我认为”、“请继续”、“综上可见”等主观表达。
|
| 553 |
+
|
| 554 |
+
五、写作模式
|
| 555 |
+
当前模式:{self.outline_length}
|
| 556 |
+
- 简要:仅写结论;
|
| 557 |
+
- 标准:含逻辑与结论;
|
| 558 |
+
- 详细:包含推理与方法,但仍应基于参考信息,不得自由创作。
|
| 559 |
+
|
| 560 |
+
请严格在以上范围内撰写本章节正文。
|
| 561 |
+
"""
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
# prompt = f"""
|
| 565 |
+
# 你是一个专业的数据分析报告撰写助手。你需要基于我提供的参考信息进行深入分析,生成结构严谨、逻辑清晰的专业报告章节。
|
| 566 |
+
|
| 567 |
+
# 当前需要撰写的章节信息(完整四元组,依次为标题、层级、内容大纲、图编号列表):
|
| 568 |
+
# {t}
|
| 569 |
+
|
| 570 |
+
# 报告目录结构(包含所有章节的四元组信息):
|
| 571 |
+
# {toc}
|
| 572 |
+
|
| 573 |
+
# 可参考的分析内容如下:
|
| 574 |
+
# {selected_full_contents}
|
| 575 |
+
|
| 576 |
+
# 此前已生成的章节内容如下(用于保持整体风格一致,并避免内容重复):
|
| 577 |
+
# {history_content}
|
| 578 |
+
|
| 579 |
+
# 请根据章节标题、层级、内容大纲、参考信息和图编号列表,生成该章节的完整正文。
|
| 580 |
+
|
| 581 |
+
# 正文详细程度有三种模式:
|
| 582 |
+
# - 简要:只包含核心结论与关键点,语言精炼;
|
| 583 |
+
# - 标准:包含主要分析逻辑、步骤与结果;
|
| 584 |
+
# - 详细:展开完整分析、方法论、推理过程与补充说明。
|
| 585 |
+
# 用户当前选择的模式是:{self.outline_length}
|
| 586 |
+
|
| 587 |
+
# 写作要求:
|
| 588 |
+
# 1. **核心任务**:仅撰写当前章节 **“{t[0]}”** 的正文内容,不得涉及其他章节。
|
| 589 |
+
# 2. **图表引用**:正文中引用的图表必须严格对应本章节的图编号(即 {t[3]}),**不得使用或编造其他编号**。
|
| 590 |
+
# 3. **语言规范**:
|
| 591 |
+
# - 语言应专业、准确、逻辑严谨;
|
| 592 |
+
# - 叙述风格应正式、学术化;
|
| 593 |
+
# - 禁止使用口语化或主观色彩表达。
|
| 594 |
+
# 4. **输出要求**:
|
| 595 |
+
# - 仅输出章节正文内容,不得输出 Markdown;
|
| 596 |
+
# - 不得输出任何标题,如:1,一,(1)等;
|
| 597 |
+
# - 禁止加粗、斜体、表情符号或其他符号修饰;
|
| 598 |
+
# - 不得出现非正文短语,如 “我认为”、“请继续”、“感谢阅读”、“---” 等;
|
| 599 |
+
# 5. **图片规范**:
|
| 600 |
+
# - 图片应独立成行,不得嵌入句子内部;
|
| 601 |
+
# - 图片可放置在段落的开头、结尾,或自然停顿处(如句号、分号后),以保持语义连贯;
|
| 602 |
+
# - 使用占位符格式 [FIG:index] 标记图片位置,其中 index 为对应图片的编号;
|
| 603 |
+
# - 在每个 [FIG:index] 占位符后,需紧跟一行图片标题,格式如下:
|
| 604 |
+
# 图:图片标题(简要说明图片内容及分析要点)
|
| 605 |
+
# - 图片插入位置应依据其语义和上下文逻辑确定:
|
| 606 |
+
# · 若图片用于引出分析,应放在段落开头;
|
| 607 |
+
# · 若用于支撑论述,应放在对应描述句之后;
|
| 608 |
+
# · 若总结结果或展示对比,应放在段落结尾;
|
| 609 |
+
# - 请务必确保图片位置与文字逻辑匹配,使图片与正文形成自然的论证衔接;
|
| 610 |
+
# - 请不要删除、合并或重排序图片编号,系统将在后续自动替换为真实图像。
|
| 611 |
+
|
| 612 |
+
# 请直接输出该章节的正文内容,不要有任何其他文字。
|
| 613 |
+
# """
|
| 614 |
+
|
| 615 |
+
content = self.call(prompt)
|
| 616 |
+
|
| 617 |
+
return content
|
utils/content.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class Content:
|
| 6 |
+
def __init__(self, text: str = None, fig = None):
|
| 7 |
+
self.text = text
|
| 8 |
+
self.fig = fig
|
| 9 |
+
|
| 10 |
+
def display(self):
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
|
utils/sanitize_code.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Any
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def sanitize_code(code: str) -> str:
|
| 7 |
+
"""清理可能包含的 Markdown 代码块标记。"""
|
| 8 |
+
if not isinstance(code, str):
|
| 9 |
+
return ""
|
| 10 |
+
code = code.strip()
|
| 11 |
+
if code.startswith("```") and code.endswith("```"):
|
| 12 |
+
lines = code.splitlines()
|
| 13 |
+
# 去掉首尾 ``` 或 ```python
|
| 14 |
+
if re.match(r"^```(?:python)?", lines[0].strip()):
|
| 15 |
+
lines = lines[1:]
|
| 16 |
+
if lines and lines[-1].strip() == "```":
|
| 17 |
+
lines = lines[:-1]
|
| 18 |
+
code = "\n".join(lines)
|
| 19 |
+
return code
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def to_json_serializable(obj: Any) -> Any:
|
| 23 |
+
"""将可能含 numpy 类型的对象转换为可 JSON 序列化类型(递归)。"""
|
| 24 |
+
if obj is None:
|
| 25 |
+
return None
|
| 26 |
+
if isinstance(obj, (str, bool, int)):
|
| 27 |
+
return obj
|
| 28 |
+
if isinstance(obj, float):
|
| 29 |
+
# 确保是内置 float(JSON 支持)
|
| 30 |
+
return float(obj)
|
| 31 |
+
if isinstance(obj, np.generic):
|
| 32 |
+
return obj.item()
|
| 33 |
+
if isinstance(obj, np.ndarray):
|
| 34 |
+
return obj.tolist()
|
| 35 |
+
if isinstance(obj, dict):
|
| 36 |
+
return {str(k): to_json_serializable(v) for k, v in obj.items()}
|
| 37 |
+
if isinstance(obj, (list, tuple)):
|
| 38 |
+
return [to_json_serializable(v) for v in obj]
|
| 39 |
+
# fallback: try to cast to float / str
|
| 40 |
+
try:
|
| 41 |
+
return float(obj)
|
| 42 |
+
except Exception:
|
| 43 |
+
try:
|
| 44 |
+
return str(obj)
|
| 45 |
+
except Exception:
|
| 46 |
+
return None
|
| 47 |
+
|
utils/save_secrets.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import toml
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
# 我们把 secrets 放在项目根目录下的 .streamlit 文件夹
|
| 5 |
+
BASE = Path(__file__).parent
|
| 6 |
+
SECRETS_DIR = BASE / ".streamlit"
|
| 7 |
+
SECRETS_FILE = SECRETS_DIR / "secrets.toml"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_local_api_keys() -> dict[str, str]:
|
| 11 |
+
"""
|
| 12 |
+
从项目目录的 .streamlit/secrets.toml 中读取 [api_keys] 部分。
|
| 13 |
+
如果文件或该节不存在,返回空字典。
|
| 14 |
+
"""
|
| 15 |
+
if not SECRETS_FILE.exists():
|
| 16 |
+
return {}
|
| 17 |
+
data = toml.load(SECRETS_FILE)
|
| 18 |
+
return data.get("api_keys", {})
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def update_local_api_key(model_name: str, api_key: str) -> None:
|
| 22 |
+
"""
|
| 23 |
+
将一对 model_name: api_key 写入 .streamlit/secrets.toml 的 [api_keys]。
|
| 24 |
+
如果文件或该节不存在,会自动创建;保留其它已有设置。
|
| 25 |
+
"""
|
| 26 |
+
SECRETS_DIR.mkdir(exist_ok=True)
|
| 27 |
+
if SECRETS_FILE.exists():
|
| 28 |
+
data = toml.load(SECRETS_FILE)
|
| 29 |
+
else:
|
| 30 |
+
data = {}
|
| 31 |
+
data.setdefault("api_keys", {})[model_name] = api_key
|
| 32 |
+
with SECRETS_FILE.open("w", encoding="utf-8") as f:
|
| 33 |
+
toml.dump(data, f)
|
utils/spinner_pool.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
def get_spinner_msg(stage="writing"):
|
| 4 |
+
msg_pool = {
|
| 5 |
+
"summarizing": [
|
| 6 |
+
"正在汇总各模块的分析结果...",
|
| 7 |
+
"稍等一下,正在总结前面几个 Agent 的内容...",
|
| 8 |
+
"AI 正在整理前面的分析,请稍候...",
|
| 9 |
+
"正在综合各分析步骤的结论..."
|
| 10 |
+
],
|
| 11 |
+
"writing": [
|
| 12 |
+
"正在生成各章节内容...",
|
| 13 |
+
"请稍候,系统正在详细撰写报告...",
|
| 14 |
+
"AI 正在逐步生成报告章节...",
|
| 15 |
+
"正在整理并撰写每一章节..."
|
| 16 |
+
],
|
| 17 |
+
"default": [
|
| 18 |
+
"正在处理数据,请稍候...",
|
| 19 |
+
"AI 正在努力生成结果...",
|
| 20 |
+
"请耐心等待,正在计算中..."
|
| 21 |
+
]
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
pool = msg_pool.get(stage, msg_pool["default"])
|
| 25 |
+
return random.choice(pool)
|
workflow/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
workflow/dataloading/dataloading_core.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
|
| 6 |
+
import chardet
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from scipy import sparse
|
| 10 |
+
from scipy.io import loadmat, arff
|
| 11 |
+
import streamlit as st
|
| 12 |
+
import streamlit_antd_components as sac
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def read_data_from_file(
|
| 16 |
+
uploaded_data_file,
|
| 17 |
+
col_names: Optional[List[str]] = None,
|
| 18 |
+
sep: Optional[str] = None,
|
| 19 |
+
na_values: List[str] = ['?'],
|
| 20 |
+
encoding: Optional[str] = None
|
| 21 |
+
) -> pd.DataFrame:
|
| 22 |
+
"""
|
| 23 |
+
从上传的数据文件读取 DataFrame。
|
| 24 |
+
- 支持 .csv/.data/.txt/.xlsx/.xls/.mat
|
| 25 |
+
- col_names=None 时使用 header=0(文件首行做列名)
|
| 26 |
+
- col_names 不为 None 时使用 header=None 并指定 names=col_names
|
| 27 |
+
- 文本文件:自动探测编码、嗅探分隔符,跳过坏行
|
| 28 |
+
- Excel 文件:直接使用 pandas.read_excel
|
| 29 |
+
- MAT 文件:使用 scipy.loadmat,提取第一个主要变量,转为 DataFrame,并保证一维列
|
| 30 |
+
"""
|
| 31 |
+
# 读取所有字节
|
| 32 |
+
data_bytes = uploaded_data_file.read()
|
| 33 |
+
# 重置流位置
|
| 34 |
+
try:
|
| 35 |
+
uploaded_data_file.seek(0)
|
| 36 |
+
except Exception:
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
name = uploaded_data_file.name
|
| 40 |
+
ext = os.path.splitext(name)[1].lower()
|
| 41 |
+
|
| 42 |
+
# Excel 文件处理
|
| 43 |
+
if ext in ('.xlsx', '.xls'):
|
| 44 |
+
excel_kwargs = {}
|
| 45 |
+
if col_names is None:
|
| 46 |
+
excel_kwargs['header'] = 0
|
| 47 |
+
else:
|
| 48 |
+
excel_kwargs['header'] = None
|
| 49 |
+
excel_kwargs['names'] = col_names
|
| 50 |
+
return pd.read_excel(io.BytesIO(data_bytes), **excel_kwargs)
|
| 51 |
+
|
| 52 |
+
# ARFF 文件特殊处理
|
| 53 |
+
if ext == '.arff':
|
| 54 |
+
text = data_bytes.decode(encoding or 'utf-8', errors='ignore')
|
| 55 |
+
raw_data, meta = arff.loadarff(io.StringIO(text))
|
| 56 |
+
df = pd.DataFrame(raw_data)
|
| 57 |
+
for col in df.select_dtypes([object]).columns:
|
| 58 |
+
if isinstance(df[col].iloc[0], bytes):
|
| 59 |
+
df[col] = df[col].str.decode('utf-8', errors='ignore')
|
| 60 |
+
if col_names is not None and df.shape[1] == len(col_names):
|
| 61 |
+
df.columns = col_names
|
| 62 |
+
return df
|
| 63 |
+
|
| 64 |
+
# —— MAT 文件特殊处理 —— #
|
| 65 |
+
if ext == '.mat':
|
| 66 |
+
mat = loadmat(io.BytesIO(data_bytes))
|
| 67 |
+
data_keys = [k for k in mat.keys() if not k.startswith('__')]
|
| 68 |
+
if not data_keys:
|
| 69 |
+
raise ValueError('MAT 文件中未发现有效数据变量')
|
| 70 |
+
arr = mat[data_keys[0]]
|
| 71 |
+
|
| 72 |
+
# —— 先处理稀疏矩阵 —— #
|
| 73 |
+
if sparse.issparse(arr):
|
| 74 |
+
arr = arr.toarray()
|
| 75 |
+
|
| 76 |
+
arr = np.array(arr)
|
| 77 |
+
if arr.ndim > 2:
|
| 78 |
+
arr = arr.reshape(arr.shape[0], -1)
|
| 79 |
+
|
| 80 |
+
df = pd.DataFrame(arr)
|
| 81 |
+
|
| 82 |
+
if col_names is not None and df.shape[1] == len(col_names):
|
| 83 |
+
df.columns = col_names
|
| 84 |
+
|
| 85 |
+
return df
|
| 86 |
+
|
| 87 |
+
if encoding is None:
|
| 88 |
+
detected = chardet.detect(data_bytes)
|
| 89 |
+
encoding = detected.get('encoding', 'utf-8')
|
| 90 |
+
sample = data_bytes[:10_000].decode(encoding, errors='ignore')
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
dialect = csv.Sniffer().sniff(sample, delimiters=[',',';','\t','|'])
|
| 94 |
+
detected_sep = dialect.delimiter
|
| 95 |
+
use_whitespace = False
|
| 96 |
+
except csv.Error:
|
| 97 |
+
detected_sep = None
|
| 98 |
+
use_whitespace = True
|
| 99 |
+
|
| 100 |
+
read_kwargs = {
|
| 101 |
+
'engine': 'python',
|
| 102 |
+
'encoding': encoding,
|
| 103 |
+
'na_values': na_values,
|
| 104 |
+
'comment': '|',
|
| 105 |
+
'skipinitialspace': True,
|
| 106 |
+
'on_bad_lines': 'skip',
|
| 107 |
+
}
|
| 108 |
+
if use_whitespace:
|
| 109 |
+
read_kwargs['delim_whitespace'] = True
|
| 110 |
+
else:
|
| 111 |
+
read_kwargs['sep'] = detected_sep
|
| 112 |
+
|
| 113 |
+
if col_names is None:
|
| 114 |
+
read_kwargs['header'] = 0
|
| 115 |
+
else:
|
| 116 |
+
read_kwargs['header'] = None
|
| 117 |
+
read_kwargs['names'] = col_names
|
| 118 |
+
|
| 119 |
+
return pd.read_csv(io.BytesIO(data_bytes), **read_kwargs)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def process_complex_data(uploaded_files, dataloadingagent):
|
| 123 |
+
"""
|
| 124 |
+
上传处理逻辑:
|
| 125 |
+
- 单文件:当作普通表格或 MAT 文件读(第一行当表头)
|
| 126 |
+
- 多文件:若有 .names/.arff 表头文件,则用其列名;否则推断列名
|
| 127 |
+
并在存在多个数据文件时,通过用户选择进行横向或纵向拼接
|
| 128 |
+
"""
|
| 129 |
+
if not uploaded_files:
|
| 130 |
+
st.error("请先上传文件")
|
| 131 |
+
return None, None
|
| 132 |
+
|
| 133 |
+
names_exts = ('.names', '.arff', '.doc')
|
| 134 |
+
data_exts = ('.data', '.csv', '.txt', '.xlsx', '.xls', '.mat', '.arff', '.tsv', '.dat', '.tst')
|
| 135 |
+
|
| 136 |
+
names_files = [f for f in uploaded_files
|
| 137 |
+
if os.path.splitext(f.name)[1].lower() in names_exts]
|
| 138 |
+
data_files = [f for f in uploaded_files
|
| 139 |
+
if os.path.splitext(f.name)[1].lower() in data_exts]
|
| 140 |
+
|
| 141 |
+
# 单文件直接读取
|
| 142 |
+
if len(uploaded_files) == 1 and uploaded_files[0] in data_files:
|
| 143 |
+
return read_data_from_file(uploaded_files[0], col_names=None), None
|
| 144 |
+
|
| 145 |
+
if not data_files:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
"未检测到任何数据文件,请上传支持的格式:.csv/.data/.txt/.xlsx/.xls/.mat/.arff/.tsv/.dat/.tst"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# 1) 如果存在表头文件 (.names/.arff),读取���名
|
| 151 |
+
if names_files:
|
| 152 |
+
header_file = names_files[0]
|
| 153 |
+
# 使用 read_data_from_file 读取 sample,以确保正确处理编码
|
| 154 |
+
sample_df = read_data_from_file(data_files[0], col_names=None)
|
| 155 |
+
col_names = dataloadingagent.read_names_from_file(header_file, sample_df.head())
|
| 156 |
+
else:
|
| 157 |
+
# 2) 否则从第一个数据文件推断列名,加入编码容错
|
| 158 |
+
sample = data_files[0]
|
| 159 |
+
ext0 = os.path.splitext(sample.name)[1].lower()
|
| 160 |
+
try:
|
| 161 |
+
if ext0 in ('.xlsx', '.xls'):
|
| 162 |
+
col_names = list(pd.read_excel(sample, nrows=0))
|
| 163 |
+
elif ext0 == '.mat':
|
| 164 |
+
df_sample = read_data_from_file(sample, col_names=None)
|
| 165 |
+
col_names = list(df_sample.columns)
|
| 166 |
+
else:
|
| 167 |
+
# 文本文件推断列名,带上 encoding 参数
|
| 168 |
+
# 先通过 chardet 检测,再尝试 utf-8,失败则 latin1
|
| 169 |
+
raw_bytes = sample.read()
|
| 170 |
+
detected = chardet.detect(raw_bytes)
|
| 171 |
+
enc = detected.get('encoding', 'utf-8')
|
| 172 |
+
try:
|
| 173 |
+
col_names = list(pd.read_csv(
|
| 174 |
+
io.BytesIO(raw_bytes),
|
| 175 |
+
nrows=0,
|
| 176 |
+
encoding=enc,
|
| 177 |
+
engine='python'
|
| 178 |
+
).columns)
|
| 179 |
+
except UnicodeDecodeError:
|
| 180 |
+
col_names = list(pd.read_csv(
|
| 181 |
+
io.BytesIO(raw_bytes),
|
| 182 |
+
nrows=0,
|
| 183 |
+
encoding='latin1',
|
| 184 |
+
engine='python'
|
| 185 |
+
).columns)
|
| 186 |
+
finally:
|
| 187 |
+
try: sample.seek(0)
|
| 188 |
+
except: pass
|
| 189 |
+
|
| 190 |
+
# 读取所有数据文件并统一列名
|
| 191 |
+
dfs = [read_data_from_file(f, col_names=col_names) for f in data_files]
|
| 192 |
+
|
| 193 |
+
# 若多个数据文件,弹出拼接模式选择
|
| 194 |
+
if len(data_files) >= 2:
|
| 195 |
+
|
| 196 |
+
big_df = pd.concat(dfs, axis=0, ignore_index=True)
|
| 197 |
+
|
| 198 |
+
else:
|
| 199 |
+
big_df = dfs[0]
|
| 200 |
+
|
| 201 |
+
return big_df, dfs
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def load_from_path(local_path):
|
| 205 |
+
|
| 206 |
+
ext = os.path.splitext(local_path)[1].lower()
|
| 207 |
+
if ext in (".csv", ".txt", ".data"):
|
| 208 |
+
df_local = pd.read_csv(local_path)
|
| 209 |
+
elif ext in (".xls", ".xlsx"):
|
| 210 |
+
df_local = pd.read_excel(local_path)
|
| 211 |
+
elif ext == ".json":
|
| 212 |
+
df_local = pd.read_json(local_path)
|
| 213 |
+
elif ext == ".jsonl":
|
| 214 |
+
df_local = pd.read_json(local_path, lines=True)
|
| 215 |
+
elif ext == ".parquet":
|
| 216 |
+
df_local = pd.read_parquet(local_path)
|
| 217 |
+
elif ext in (".pkl", ".pickle"):
|
| 218 |
+
df_local = pd.read_pickle(local_path)
|
| 219 |
+
elif ext == ".feather":
|
| 220 |
+
df_local = pd.read_feather(local_path)
|
| 221 |
+
elif ext == ".arff":
|
| 222 |
+
data, meta = arff.loadarff(local_path)
|
| 223 |
+
df_local = pd.DataFrame(data)
|
| 224 |
+
for col in df_local.select_dtypes([object]).columns:
|
| 225 |
+
if isinstance(df_local[col].iloc[0], bytes):
|
| 226 |
+
df_local[col] = df_local[col].str.decode('utf-8')
|
| 227 |
+
else:
|
| 228 |
+
st.error(f"不支持的文件类型:{ext}")
|
| 229 |
+
df_local = None
|
| 230 |
+
|
| 231 |
+
return df_local
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def load_concat_file(dfs, agent):
|
| 235 |
+
|
| 236 |
+
mode = sac.segmented(
|
| 237 |
+
items=[
|
| 238 |
+
sac.SegmentedItem(label='纵向拼接'),
|
| 239 |
+
sac.SegmentedItem(label='横向拼接'),
|
| 240 |
+
], label='检测到多个数据文件,请选择拼接方式', size='sm', radius='sm'
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
if mode.startswith("横向拼接"):
|
| 244 |
+
dfs_pos = [df.reset_index(drop=True) for df in dfs]
|
| 245 |
+
big_df = pd.concat(dfs_pos, axis=1)
|
| 246 |
+
|
| 247 |
+
cols = []
|
| 248 |
+
seen = {}
|
| 249 |
+
for c in big_df.columns:
|
| 250 |
+
if c in seen:
|
| 251 |
+
seen[c] += 1
|
| 252 |
+
cols.append(f"{c}_{seen[c]}")
|
| 253 |
+
else:
|
| 254 |
+
seen[c] = 0
|
| 255 |
+
cols.append(c)
|
| 256 |
+
big_df.columns = cols
|
| 257 |
+
agent.add_df(big_df)
|
| 258 |
+
else:
|
| 259 |
+
big_df = pd.concat(dfs, axis=0, ignore_index=True)
|
| 260 |
+
agent.add_df(big_df)
|
| 261 |
+
|
| 262 |
+
csv_bytes = big_df.to_csv(index=False).encode('utf-8')
|
| 263 |
+
st.download_button(
|
| 264 |
+
label="下载文件",
|
| 265 |
+
data=csv_bytes,
|
| 266 |
+
file_name="processed_data.csv",
|
| 267 |
+
mime="text/csv"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class PathFileWrapper:
|
| 272 |
+
"""A wrapper to treat a local file path as a Streamlit UploadedFile."""
|
| 273 |
+
def __init__(self, path):
|
| 274 |
+
self.path = path
|
| 275 |
+
self.name = os.path.basename(path)
|
| 276 |
+
self._file = None
|
| 277 |
+
|
| 278 |
+
def read(self, *args, **kwargs):
|
| 279 |
+
with open(self.path, 'rb') as f:
|
| 280 |
+
return f.read()
|
| 281 |
+
|
| 282 |
+
def seek(self, offset, whence=0):
|
| 283 |
+
|
| 284 |
+
pass
|
| 285 |
+
|
| 286 |
+
def __repr__(self):
|
| 287 |
+
return f"PathFileWrapper(path='{self.path}')"
|
workflow/dataloading/dataloading_render.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import streamlit as st
|
| 6 |
+
import streamlit_antd_components as sac
|
| 7 |
+
|
| 8 |
+
from workflow.dataloading.dataloading_core import process_complex_data, load_from_path, load_concat_file, PathFileWrapper
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def loading_data_file(agent):
|
| 12 |
+
|
| 13 |
+
st.info(
|
| 14 |
+
"💡 提示:\n"
|
| 15 |
+
"1. 支持一次上传多个数据文件\n"
|
| 16 |
+
"2. 自动使用大模型分析并处理数据\n"
|
| 17 |
+
"3. 支持多种格式的文件类型上传\n"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
selected_index = sac.tabs([
|
| 21 |
+
sac.TabsItem(label='本地上传'),
|
| 22 |
+
sac.TabsItem(label='路径导入'),
|
| 23 |
+
], color='#5980AE',)
|
| 24 |
+
|
| 25 |
+
if selected_index == "本地上传":
|
| 26 |
+
# 点击上传文件
|
| 27 |
+
uploaded_files = st.file_uploader(
|
| 28 |
+
"选择新文件",
|
| 29 |
+
accept_multiple_files=True,
|
| 30 |
+
help="拖拽或点击上传多个文件",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
if uploaded_files:
|
| 34 |
+
current_memory_file_name = agent.load_file_name()
|
| 35 |
+
new_files = [f for f in uploaded_files if f.name not in current_memory_file_name]
|
| 36 |
+
if new_files:
|
| 37 |
+
try:
|
| 38 |
+
with st.spinner("正在处理数据..."):
|
| 39 |
+
df, dfs = process_complex_data(new_files, agent)
|
| 40 |
+
if df is not None:
|
| 41 |
+
agent.add_df(df)
|
| 42 |
+
agent.save_dfs(dfs)
|
| 43 |
+
for f in new_files:
|
| 44 |
+
agent.save_file_name(f.name)
|
| 45 |
+
st.rerun()
|
| 46 |
+
except Exception as err:
|
| 47 |
+
st.error(f"导入失败:{err}")
|
| 48 |
+
|
| 49 |
+
elif selected_index == "路径导入":
|
| 50 |
+
# 路径上传文件
|
| 51 |
+
raw_paths = st.text_area(
|
| 52 |
+
"从路径导入数据 (每行一个文件路径)",
|
| 53 |
+
placeholder= "C:\\data\\iris.names\nC:\\data\\iris.data",
|
| 54 |
+
height=100
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if st.button("从路径加载文件", use_container_width=True):
|
| 58 |
+
if raw_paths:
|
| 59 |
+
|
| 60 |
+
path_list = [p.strip().strip("'\"") for p in raw_paths.strip().split('\n') if p.strip()]
|
| 61 |
+
|
| 62 |
+
valid_paths = [p for p in path_list if os.path.exists(p)]
|
| 63 |
+
invalid_paths = [p for p in path_list if not os.path.exists(p)]
|
| 64 |
+
|
| 65 |
+
if invalid_paths:
|
| 66 |
+
st.warning(f"路径不存在,已跳过:\n- " + "\n- ".join(invalid_paths))
|
| 67 |
+
|
| 68 |
+
if not valid_paths:
|
| 69 |
+
st.error("未找到任何有效的本地文件路径。")
|
| 70 |
+
else:
|
| 71 |
+
current_memory_file_name = agent.load_file_name()
|
| 72 |
+
new_paths = [p for p in valid_paths if p not in current_memory_file_name]
|
| 73 |
+
|
| 74 |
+
if not new_paths:
|
| 75 |
+
st.info("所有指定的路径文件均已加载。")
|
| 76 |
+
else:
|
| 77 |
+
files_to_process = [PathFileWrapper(p) for p in new_paths]
|
| 78 |
+
try:
|
| 79 |
+
with st.spinner("正在处理数据..."):
|
| 80 |
+
df, dfs = process_complex_data(files_to_process, agent)
|
| 81 |
+
if df is not None:
|
| 82 |
+
agent.add_df(df)
|
| 83 |
+
agent.save_dfs(dfs)
|
| 84 |
+
for p in new_paths:
|
| 85 |
+
agent.save_file_name(p)
|
| 86 |
+
st.rerun()
|
| 87 |
+
except Exception as err:
|
| 88 |
+
st.error(f"本地文件读取失败:{err}")
|
| 89 |
+
|
| 90 |
+
dfs = agent.load_dfs()
|
| 91 |
+
if dfs is not None and len(dfs) >= 2:
|
| 92 |
+
load_concat_file(dfs, agent)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def loading_basic_info(agent):
|
| 96 |
+
|
| 97 |
+
df = agent.load_df()
|
| 98 |
+
if df is not None:
|
| 99 |
+
r, c = df.shape
|
| 100 |
+
missing = int(df.isnull().sum().sum())
|
| 101 |
+
col1, col2, col3 = st.columns(3)
|
| 102 |
+
col1.metric("行数", r)
|
| 103 |
+
col2.metric("列数", c)
|
| 104 |
+
col3.metric("缺失值总数", missing)
|
| 105 |
+
|
| 106 |
+
dtype_info = pd.DataFrame({
|
| 107 |
+
"列名": df.columns,
|
| 108 |
+
"类型": df.dtypes.astype(str),
|
| 109 |
+
"非空": df.count().values,
|
| 110 |
+
"缺失%": (df.isnull().mean() * 100).round(2).values,
|
| 111 |
+
}).reset_index(drop=True)
|
| 112 |
+
|
| 113 |
+
selected_index = sac.tabs([
|
| 114 |
+
sac.TabsItem(label='数据类型概览'),
|
| 115 |
+
sac.TabsItem(label='数据预览'),
|
| 116 |
+
],color='#5980AE',)
|
| 117 |
+
|
| 118 |
+
if selected_index == "数据类型概览":
|
| 119 |
+
st.dataframe(dtype_info, use_container_width=True)
|
| 120 |
+
elif selected_index == "数据预览":
|
| 121 |
+
if st.button("🎲 随机抽样"):
|
| 122 |
+
display_df = df.sample(10)
|
| 123 |
+
st.dataframe(display_df, use_container_width=True)
|
| 124 |
+
else:
|
| 125 |
+
st.dataframe(df.head(10), use_container_width=True)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def loading_chat(agent, auto=False) -> None:
|
| 129 |
+
|
| 130 |
+
df = agent.load_df()
|
| 131 |
+
if df is None:
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
with st.chat_message("assistant"):
|
| 135 |
+
st.write(
|
| 136 |
+
"我是 Anystat 数据分析助手,很高兴为您服务!\n\n"
|
| 137 |
+
"请先上传您的数据文件,上传完成后,您可以在下方和我对话,也可以直接点击按钮解析数据含义。"
|
| 138 |
+
)
|
| 139 |
+
analyze_btn = st.button("🔍 解析含义")
|
| 140 |
+
result_placeholder = st.empty()
|
| 141 |
+
|
| 142 |
+
# 渲染历史对话
|
| 143 |
+
chat_history = agent.load_memory()
|
| 144 |
+
|
| 145 |
+
for idx, entry in enumerate(chat_history):
|
| 146 |
+
bubble = st.chat_message(entry["role"])
|
| 147 |
+
content = entry["content"]
|
| 148 |
+
if isinstance(content, str):
|
| 149 |
+
bubble.write(content)
|
| 150 |
+
|
| 151 |
+
already_generated = any(
|
| 152 |
+
entry["role"] == "assistant" and "含义" in str(entry["content"])
|
| 153 |
+
for entry in chat_history
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if analyze_btn or (auto and not already_generated):
|
| 157 |
+
st.chat_message("user").write("请帮我解析数据含义")
|
| 158 |
+
agent.add_memory({"role": "user", "content": "请帮我解析数据含义"})
|
| 159 |
+
with st.spinner("分析中..."):
|
| 160 |
+
desc = agent.do_data_description(df)
|
| 161 |
+
|
| 162 |
+
agent.finish_auto()
|
| 163 |
+
st.chat_message("assistant").write(desc)
|
| 164 |
+
agent.add_memory({"role": "assistant", "content": desc})
|
| 165 |
+
st.rerun()
|
| 166 |
+
|
| 167 |
+
# 用户自定义输入
|
| 168 |
+
user_input = st.chat_input("请输入需求,例如“帮我分析xx列”")
|
| 169 |
+
if user_input:
|
| 170 |
+
st.chat_message("user").write(user_input)
|
| 171 |
+
agent.add_memory({"role": "user", "content": user_input})
|
| 172 |
+
with st.spinner("处理中…"):
|
| 173 |
+
reply = agent.do_data_description(df, user_input)
|
| 174 |
+
|
| 175 |
+
st.chat_message("assistant").write(reply)
|
| 176 |
+
agent.add_memory({"role": "assistant", "content": reply})
|
| 177 |
+
st.rerun()
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
if __name__ == "__main__":
|
| 181 |
+
|
| 182 |
+
agent = st.session_state.data_loading_agent
|
| 183 |
+
planner = st.session_state.planner_agent
|
| 184 |
+
auto = planner.loading_auto
|
| 185 |
+
|
| 186 |
+
if st.session_state.auto_mode == True:
|
| 187 |
+
if (agent.finish_auto_task == True and planner.switched_prep == False) or planner.prep_auto == False:
|
| 188 |
+
planner.finish_loading_auto()
|
| 189 |
+
st.switch_page("workflow/preprocessing/preprocessing_render.py")
|
| 190 |
+
|
| 191 |
+
c1,c2 = st.columns(2)
|
| 192 |
+
with c1:
|
| 193 |
+
st.title("数据导入")
|
| 194 |
+
with c2:
|
| 195 |
+
st.write("")
|
| 196 |
+
st.write("")
|
| 197 |
+
sac.buttons([
|
| 198 |
+
sac.ButtonsItem(label='Github', icon='github', href='https://github.com/ElvisWang1111/AAAAAnystat'),
|
| 199 |
+
sac.ButtonsItem(label='Doc', icon=sac.BsIcon(name='bi bi-file-earmark-post-fill', size=16), href='https://elviswang1111.github.io/anystatweb.github.io/index.html'),
|
| 200 |
+
], align='end', color='dark', variant='filled', index=None)
|
| 201 |
+
st.markdown("---")
|
| 202 |
+
|
| 203 |
+
c = st.columns(2)
|
| 204 |
+
with c[0].expander('数据上传', True):
|
| 205 |
+
loading_data_file(agent)
|
| 206 |
+
with c[1].expander('数据建议', True):
|
| 207 |
+
loading_chat(agent, auto)
|
| 208 |
+
with c[0].expander('数据展示', True):
|
| 209 |
+
loading_basic_info(agent)
|
| 210 |
+
|
workflow/modeling/model_inference.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import gzip
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
import traceback
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from sklearn.preprocessing import StandardScaler
|
| 10 |
+
import streamlit as st
|
| 11 |
+
|
| 12 |
+
from workflow.dataloading.dataloading_core import process_complex_data
|
| 13 |
+
from utils.sanitize_code import sanitize_code, to_json_serializable
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def infer_load_data(agent) -> None:
|
| 17 |
+
|
| 18 |
+
uploaded_files = st.file_uploader(
|
| 19 |
+
"选择推理数据集",
|
| 20 |
+
accept_multiple_files=True,
|
| 21 |
+
help="拖拽或点击上传多个文件",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
if uploaded_files:
|
| 25 |
+
try:
|
| 26 |
+
with st.spinner("正在处理数据..."):
|
| 27 |
+
big_df, dfs = process_complex_data(uploaded_files, agent)
|
| 28 |
+
if big_df is not None:
|
| 29 |
+
agent.save_inference_data(big_df)
|
| 30 |
+
st.success("导入并处理完成!")
|
| 31 |
+
except Exception as err:
|
| 32 |
+
st.error(f"导入失败:{err}")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def infer_execution(agent):
|
| 36 |
+
|
| 37 |
+
inference_df = agent.load_inference_processed_df()
|
| 38 |
+
edited_code = agent.load_inference_code()
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
model_obj = agent.load_best_model()
|
| 42 |
+
|
| 43 |
+
exec_ns = {
|
| 44 |
+
"inference_df": inference_df,
|
| 45 |
+
'model_obj': model_obj,
|
| 46 |
+
"np": np,
|
| 47 |
+
"pd": pd,
|
| 48 |
+
"StandardScaler": StandardScaler
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
with st.spinner("正在进行推断分析..."):
|
| 52 |
+
exec(edited_code, exec_ns)
|
| 53 |
+
|
| 54 |
+
result_dict = exec_ns.get("result_dict")
|
| 55 |
+
if result_dict is None:
|
| 56 |
+
st.error("脚本未写入 `result_dict`。请确保编辑后的脚本在末尾赋值 result_dict。")
|
| 57 |
+
else:
|
| 58 |
+
art = result_dict.get('artifacts', {})
|
| 59 |
+
b64 = art.pop('predictions_df_b64', None)
|
| 60 |
+
if not art:
|
| 61 |
+
result_dict.pop('artifacts', None)
|
| 62 |
+
|
| 63 |
+
serializable = to_json_serializable(result_dict)
|
| 64 |
+
try:
|
| 65 |
+
result_json = json.dumps(serializable, ensure_ascii=False)
|
| 66 |
+
except Exception:
|
| 67 |
+
result_json = json.dumps(serializable, default=str, ensure_ascii=False)
|
| 68 |
+
|
| 69 |
+
with st.expander("推理结果", True):
|
| 70 |
+
if b64:
|
| 71 |
+
try:
|
| 72 |
+
gz_bytes = base64.b64decode(b64)
|
| 73 |
+
csv_bytes = gzip.decompress(gz_bytes)
|
| 74 |
+
|
| 75 |
+
df_pred = pd.read_csv(io.BytesIO(csv_bytes))
|
| 76 |
+
st.success("已加载带预测结果的 DataFrame")
|
| 77 |
+
st.dataframe(df_pred)
|
| 78 |
+
|
| 79 |
+
st.download_button(
|
| 80 |
+
label="下载带预测结果(predictions.csv)",
|
| 81 |
+
data=csv_bytes,
|
| 82 |
+
file_name="predictions.csv",
|
| 83 |
+
mime="text/csv"
|
| 84 |
+
)
|
| 85 |
+
except Exception as e:
|
| 86 |
+
st.error(f"解码 predictions_df 失败: {e}")
|
| 87 |
+
# 兜底:尝试从 records 字段恢复
|
| 88 |
+
records = result_dict.get('predictions_df_records')
|
| 89 |
+
if records:
|
| 90 |
+
try:
|
| 91 |
+
df_pred = pd.DataFrame(records)
|
| 92 |
+
st.dataframe(df_pred)
|
| 93 |
+
except Exception as e2:
|
| 94 |
+
st.error(f"从 records 恢复表格失败: {e2}")
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
st.error(f"推断失败:{e}")
|
| 98 |
+
st.text(traceback.format_exc())
|
| 99 |
+
agent.save_inference_error(traceback.format_exc())
|
| 100 |
+
raw = agent.code_generation_for_inference(agent.load_code(), inference_data.head(), auto=True)
|
| 101 |
+
code = sanitize_code(raw)
|
| 102 |
+
agent.save_inference_code(code)
|
workflow/modeling/model_training.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import json
|
| 3 |
+
import traceback
|
| 4 |
+
import base64
|
| 5 |
+
import gzip
|
| 6 |
+
import pickle
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import streamlit as st
|
| 12 |
+
import streamlit_antd_components as sac
|
| 13 |
+
from streamlit_ace import st_ace
|
| 14 |
+
import torch
|
| 15 |
+
import torchvision
|
| 16 |
+
import xgboost
|
| 17 |
+
import lightgbm
|
| 18 |
+
from sklearn.ensemble import GradientBoostingRegressor, RandomForestClassifier, RandomForestRegressor
|
| 19 |
+
from sklearn.linear_model import LinearRegression
|
| 20 |
+
from sklearn.model_selection import train_test_split
|
| 21 |
+
from sklearn.preprocessing import StandardScaler
|
| 22 |
+
|
| 23 |
+
from utils.sanitize_code import sanitize_code, to_json_serializable
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def train_execution(agent):
|
| 27 |
+
|
| 28 |
+
code = agent.load_code()
|
| 29 |
+
df = agent.load_df()
|
| 30 |
+
|
| 31 |
+
torch = importlib.import_module("torch")
|
| 32 |
+
torchvision = importlib.import_module("torchvision")
|
| 33 |
+
|
| 34 |
+
exec_ns = {
|
| 35 |
+
"df": df,
|
| 36 |
+
"np": np,
|
| 37 |
+
"pd": pd,
|
| 38 |
+
"torch": torch,
|
| 39 |
+
"torchvision": torchvision,
|
| 40 |
+
"train_test_split": train_test_split,
|
| 41 |
+
"StandardScaler": StandardScaler,
|
| 42 |
+
"LinearRegression": LinearRegression,
|
| 43 |
+
"RandomForestRegressor": RandomForestRegressor,
|
| 44 |
+
"GradientBoostingRegressor": GradientBoostingRegressor,
|
| 45 |
+
"RandomForestClassifier": RandomForestClassifier,
|
| 46 |
+
"xgboost": xgboost,
|
| 47 |
+
"lightgbm": lightgbm,
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
with st.spinner("正在运行程序..."):
|
| 52 |
+
exec(code, exec_ns)
|
| 53 |
+
except Exception as exc:
|
| 54 |
+
st.error(f"已保存报错,请重新调用llm生成代码debug")
|
| 55 |
+
# st.error(f"脚本执行失败:{exc}")
|
| 56 |
+
st.text(traceback.format_exc())
|
| 57 |
+
agent.save_error(traceback.format_exc())
|
| 58 |
+
modeling_code_gen(agent, debug=True)
|
| 59 |
+
else:
|
| 60 |
+
result_dict = exec_ns.get("result_dict")
|
| 61 |
+
if result_dict is None:
|
| 62 |
+
st.error(
|
| 63 |
+
"脚本未写入 `result_dict`。请确保编辑后的脚本在末尾赋值 result_dict。"
|
| 64 |
+
)
|
| 65 |
+
else:
|
| 66 |
+
art = result_dict.get('artifacts', {})
|
| 67 |
+
b64 = art.pop('best_model_b64', None)
|
| 68 |
+
artifact_warning = result_dict.pop('artifact_warning', None)
|
| 69 |
+
|
| 70 |
+
if not art:
|
| 71 |
+
result_dict.pop('artifacts', None)
|
| 72 |
+
|
| 73 |
+
serializable = to_json_serializable(result_dict)
|
| 74 |
+
try:
|
| 75 |
+
result_json = json.dumps(serializable, ensure_ascii=False)
|
| 76 |
+
except Exception:
|
| 77 |
+
result_json = json.dumps(serializable, default=str, ensure_ascii=False)
|
| 78 |
+
|
| 79 |
+
with st.spinner("请求 LLM 格式化结果为 Markdown..."):
|
| 80 |
+
formatted = agent.result_format_prompt(result_json)
|
| 81 |
+
agent.save_modeling_result(formatted)
|
| 82 |
+
|
| 83 |
+
if b64:
|
| 84 |
+
|
| 85 |
+
gz_bytes = base64.b64decode(b64)
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
agent.save_best_model_gz_bytes(gz_bytes)
|
| 89 |
+
model_obj = pickle.loads(gzip.decompress(gz_bytes))
|
| 90 |
+
st.success("最佳模型已加载到内存,可用于即时推理(示例)。")
|
| 91 |
+
agent.save_best_model(model_obj)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
st.error(f"加载模型失败:{e}")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def modeling_code_gen(agent, debug = False, auto = False, ) -> None:
|
| 97 |
+
|
| 98 |
+
df = agent.load_df()
|
| 99 |
+
suggest = agent.load_suggestion()
|
| 100 |
+
print(suggest)
|
| 101 |
+
chat_history = agent.load_memory()
|
| 102 |
+
already_generated = any(
|
| 103 |
+
entry["role"] == "assistant" and "训练脚本已更新!请重新运行代码!" in str(entry["content"])
|
| 104 |
+
for entry in chat_history
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if suggest is not None:
|
| 108 |
+
if debug == True or (auto and not already_generated):
|
| 109 |
+
with st.spinner("建模 Agent 正在生成训练脚本..."):
|
| 110 |
+
raw = agent.code_generation(
|
| 111 |
+
df.head().to_string(),
|
| 112 |
+
suggest,
|
| 113 |
+
)
|
| 114 |
+
code = sanitize_code(raw)
|
| 115 |
+
agent.save_code(code)
|
| 116 |
+
st.chat_message("assistant").write("训练脚本已更新!请重新运行代码!")
|
| 117 |
+
agent.add_memory({"role": "assistant", "content": "训练脚本已更新!请重新运行代码!"})
|
| 118 |
+
st.rerun()
|
| 119 |
+
|
| 120 |
+
analyze_btn = st.button("🔧 生成建模代码", key='modeling_code')
|
| 121 |
+
if analyze_btn:
|
| 122 |
+
with st.spinner("建模 Agent 正在生成训练脚本..."):
|
| 123 |
+
raw = agent.code_generation(
|
| 124 |
+
df.head().to_string(),
|
| 125 |
+
suggest,
|
| 126 |
+
)
|
| 127 |
+
code = sanitize_code(raw)
|
| 128 |
+
agent.save_code(code)
|
| 129 |
+
st.chat_message("assistant").write("训练脚本已更新!请重新运行代码!")
|
| 130 |
+
agent.add_memory({"role": "assistant", "content": "训练脚本已更新!请重新运行代码!"})
|
| 131 |
+
st.rerun()
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def train_download_model(agent):
|
| 135 |
+
|
| 136 |
+
model = agent.load_best_model_gz_bytes()
|
| 137 |
+
if model is not None:
|
| 138 |
+
st.download_button(
|
| 139 |
+
label="⬇�� 下载最佳模型",
|
| 140 |
+
data=model,
|
| 141 |
+
file_name="best_model.pkl.gz",
|
| 142 |
+
mime="application/gzip"
|
| 143 |
+
)
|
workflow/modeling/modeling_render.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import streamlit_antd_components as sac
|
| 3 |
+
from streamlit_ace import st_ace
|
| 4 |
+
|
| 5 |
+
from utils.sanitize_code import sanitize_code
|
| 6 |
+
from workflow.modeling.model_training import train_execution, modeling_code_gen, train_download_model
|
| 7 |
+
from workflow.modeling.model_inference import infer_load_data, infer_execution
|
| 8 |
+
from workflow.preprocessing.preprocessing_core import prep_meta_execution
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def modeling_quick_actions(agent):
|
| 12 |
+
|
| 13 |
+
st.write("选择一个或多个model:")
|
| 14 |
+
selected_models = sac.chip(
|
| 15 |
+
items=[
|
| 16 |
+
sac.ChipItem(label='线性回归'),
|
| 17 |
+
sac.ChipItem(label='XGBoost'),
|
| 18 |
+
sac.ChipItem(label='随机森林'),
|
| 19 |
+
sac.ChipItem(label='神经网络'),
|
| 20 |
+
], index=[0, 2], align='center', radius='md', color='#44658C', multiple=True
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
df = agent.load_df()
|
| 24 |
+
|
| 25 |
+
if st.button("🖋️ 快速建模"):
|
| 26 |
+
if not selected_models:
|
| 27 |
+
st.error("请先选择训练model。")
|
| 28 |
+
else:
|
| 29 |
+
with st.spinner("建模 Agent 正在生成训练脚本..."):
|
| 30 |
+
raw = agent.code_generation(df.head().to_string(), selected_models)
|
| 31 |
+
code = sanitize_code(raw)
|
| 32 |
+
agent.save_code(code)
|
| 33 |
+
agent.save_suggestion(selected_models)
|
| 34 |
+
agent.save_user_selection(selected_models)
|
| 35 |
+
st.success("训练脚本已生成并保存。")
|
| 36 |
+
st.rerun()
|
| 37 |
+
|
| 38 |
+
return selected_models
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def modeling_execution(agent, auto = False) -> None:
|
| 42 |
+
|
| 43 |
+
code = agent.load_code()
|
| 44 |
+
|
| 45 |
+
edited = st_ace(
|
| 46 |
+
value=code,
|
| 47 |
+
height=450,
|
| 48 |
+
theme="tomorrow_night",
|
| 49 |
+
language="python",
|
| 50 |
+
auto_update=True
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
not_executed = agent.load_modeling_result() == None
|
| 54 |
+
|
| 55 |
+
if edited is not None:
|
| 56 |
+
if st.button("▶️ 执行建模", key="modeling_run_code") or (auto and not_executed):
|
| 57 |
+
code = sanitize_code(edited)
|
| 58 |
+
agent.save_code(code)
|
| 59 |
+
train_execution(agent)
|
| 60 |
+
agent.finish_auto()
|
| 61 |
+
st.rerun()
|
| 62 |
+
|
| 63 |
+
modeling_result = agent.load_modeling_result()
|
| 64 |
+
if modeling_result is None:
|
| 65 |
+
result_expand = False
|
| 66 |
+
else:
|
| 67 |
+
result_expand = True
|
| 68 |
+
train_download_model(agent)
|
| 69 |
+
with st.expander("训练结果", result_expand):
|
| 70 |
+
if modeling_result:
|
| 71 |
+
st.subheader("训练结果")
|
| 72 |
+
try:
|
| 73 |
+
st.markdown(modeling_result)
|
| 74 |
+
except Exception:
|
| 75 |
+
st.write(modeling_result)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def modeling_inference(agent, preproc_agent):
|
| 79 |
+
|
| 80 |
+
infer_load_data(agent)
|
| 81 |
+
inference_processed_data = agent.load_inference_processed_df()
|
| 82 |
+
inference_data = agent.load_inference_data()
|
| 83 |
+
|
| 84 |
+
code = agent.load_inference_code()
|
| 85 |
+
|
| 86 |
+
if st.button("▶️ 执行推断"):
|
| 87 |
+
|
| 88 |
+
with st.spinner("正在对推理数据进行预处理..."):
|
| 89 |
+
|
| 90 |
+
inference_data = agent.load_inference_data()
|
| 91 |
+
if preproc_agent.code is not None:
|
| 92 |
+
inference_processed_df = prep_meta_execution(preproc_agent, preproc_agent.code, inference_data)
|
| 93 |
+
inference_data = inference_processed_df
|
| 94 |
+
agent.save_inference_processed_df(inference_data)
|
| 95 |
+
st.write("推断数据预览:")
|
| 96 |
+
st.dataframe(inference_data.head())
|
| 97 |
+
|
| 98 |
+
with st.spinner("建模 Agent 正在生成推理脚本..."):
|
| 99 |
+
|
| 100 |
+
raw = agent.code_generation_for_inference(agent.load_code(), inference_data.head())
|
| 101 |
+
code = sanitize_code(raw)
|
| 102 |
+
agent.save_inference_code(code)
|
| 103 |
+
|
| 104 |
+
if code is not None:
|
| 105 |
+
edited_code = st_ace(
|
| 106 |
+
value=code,
|
| 107 |
+
height=450,
|
| 108 |
+
theme="tomorrow_night",
|
| 109 |
+
language="python",
|
| 110 |
+
auto_update=True
|
| 111 |
+
)
|
| 112 |
+
agent.save_inference_code(code)
|
| 113 |
+
if st.button("▶️ 执行建模"):
|
| 114 |
+
infer_execution(agent)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def modeling_chat(agent, auto) -> None:
|
| 118 |
+
|
| 119 |
+
user_input = st.text_input("建模目标", "默认")
|
| 120 |
+
agent.save_target(user_input)
|
| 121 |
+
|
| 122 |
+
with st.chat_message("assistant"):
|
| 123 |
+
st.write(
|
| 124 |
+
"我是 Anystat 数据分析助手,很高兴为您服务!\n\n"
|
| 125 |
+
"您可以在下方输入建模相关问题,或直接点击按钮获取建模建议。"
|
| 126 |
+
)
|
| 127 |
+
analyze_btn = st.button("🔍 建模推荐", key='modeling_suggest')
|
| 128 |
+
result_placeholder = st.empty()
|
| 129 |
+
|
| 130 |
+
chat_history = agent.load_memory()
|
| 131 |
+
|
| 132 |
+
for idx, entry in enumerate(chat_history):
|
| 133 |
+
bubble = st.chat_message(entry["role"])
|
| 134 |
+
content = entry["content"]
|
| 135 |
+
if isinstance(content, str):
|
| 136 |
+
bubble.write(content)
|
| 137 |
+
|
| 138 |
+
already_generated = any(
|
| 139 |
+
entry["role"] == "assistant" and "模" in str(entry["content"])
|
| 140 |
+
for entry in chat_history
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if analyze_btn or (auto and not already_generated):
|
| 144 |
+
st.chat_message("user").write("请帮我获取��模建议")
|
| 145 |
+
agent.add_memory({"role": "user", "content": "请帮我获取建模建议"})
|
| 146 |
+
with st.spinner("分析中..."):
|
| 147 |
+
suggestion = agent.get_model_suggestion()
|
| 148 |
+
agent.save_suggestion(suggestion)
|
| 149 |
+
agent.refine_suggestions()
|
| 150 |
+
st.chat_message("assistant").write(suggestion)
|
| 151 |
+
agent.add_memory({"role": "assistant", "content": suggestion})
|
| 152 |
+
st.chat_message("assistant").write("需要进一步优化?再次点击按钮获取下一条建议")
|
| 153 |
+
agent.add_memory({"role": "assistant", "content": "需要进一步优化?再次点击按钮获取下一条建议"})
|
| 154 |
+
|
| 155 |
+
user_input = st.chat_input("请输入您的问题,例如“如何优化这个模型”")
|
| 156 |
+
if user_input:
|
| 157 |
+
st.chat_message("user").write(user_input)
|
| 158 |
+
agent.add_memory({"role": "user", "content": user_input})
|
| 159 |
+
with st.spinner("处理中…"):
|
| 160 |
+
reply = agent.get_model_suggestion(user_input)
|
| 161 |
+
agent.save_suggestion(reply)
|
| 162 |
+
agent.refine_suggestions()
|
| 163 |
+
st.chat_message("assistant").write(reply)
|
| 164 |
+
agent.add_memory({"role": "assistant", "content": reply})
|
| 165 |
+
st.chat_message("assistant").write("需要进一步优化?再次点击按钮获取下一条建议")
|
| 166 |
+
agent.add_memory({"role": "assistant", "content": "需要进一步优化?再次点击按钮获取下一条建议"})
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
|
| 171 |
+
st.title("数据建模")
|
| 172 |
+
st.markdown("---")
|
| 173 |
+
|
| 174 |
+
preproc_agent = st.session_state.data_preprocess_agent
|
| 175 |
+
load_agent = st.session_state.data_loading_agent
|
| 176 |
+
|
| 177 |
+
processed_df = preproc_agent.load_processed_df()
|
| 178 |
+
if processed_df is None:
|
| 179 |
+
df = load_agent.load_df()
|
| 180 |
+
else:
|
| 181 |
+
df = processed_df
|
| 182 |
+
|
| 183 |
+
if df is None:
|
| 184 |
+
st.warning("⚠️ 请先在数据导入页面加载数据")
|
| 185 |
+
st.stop()
|
| 186 |
+
|
| 187 |
+
agent = st.session_state.modeling_coding_agent
|
| 188 |
+
agent.add_df(df)
|
| 189 |
+
planner = st.session_state.planner_agent
|
| 190 |
+
auto = planner.modeling_auto
|
| 191 |
+
|
| 192 |
+
if st.session_state.auto_mode == True:
|
| 193 |
+
if (agent.finish_auto_task == True and planner.switched_modeling == False) or planner.modeling_auto == False:
|
| 194 |
+
planner.finish_modeling_auto()
|
| 195 |
+
st.switch_page("workflow/report/report_render.py")
|
| 196 |
+
|
| 197 |
+
code = agent.load_code()
|
| 198 |
+
if code is None:
|
| 199 |
+
expand = False
|
| 200 |
+
else:
|
| 201 |
+
expand = True
|
| 202 |
+
|
| 203 |
+
inference_model = agent.load_best_model()
|
| 204 |
+
if inference_model is None:
|
| 205 |
+
inference_expand = False
|
| 206 |
+
else:
|
| 207 |
+
inference_expand = True
|
| 208 |
+
|
| 209 |
+
c = st.columns(2)
|
| 210 |
+
with c[0].expander('快速建模', True):
|
| 211 |
+
modeling_quick_actions(agent)
|
| 212 |
+
with c[1].expander('建模建议', True):
|
| 213 |
+
modeling_chat(agent, auto)
|
| 214 |
+
modeling_code_gen(agent, auto=auto)
|
| 215 |
+
with c[0].expander('建模执行', expand):
|
| 216 |
+
modeling_execution(agent, auto)
|
| 217 |
+
# with c[0].expander('推断分析', inference_expand):
|
| 218 |
+
# modeling_inference(agent, preproc_agent)
|
workflow/preprocessing/preprocessing_core.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import traceback
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import streamlit as st
|
| 7 |
+
from streamlit_ace import st_ace
|
| 8 |
+
from sklearn.compose import ColumnTransformer
|
| 9 |
+
from sklearn.impute import SimpleImputer
|
| 10 |
+
from sklearn.pipeline import Pipeline
|
| 11 |
+
from sklearn.preprocessing import FunctionTransformer
|
| 12 |
+
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder, OrdinalEncoder, RobustScaler, StandardScaler
|
| 13 |
+
|
| 14 |
+
from utils.sanitize_code import sanitize_code
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def prep_meta_execution(agent, code, df, auto=False):
|
| 18 |
+
|
| 19 |
+
edited = st_ace(
|
| 20 |
+
value=code,
|
| 21 |
+
height=400,
|
| 22 |
+
theme="tomorrow_night",
|
| 23 |
+
language="python",
|
| 24 |
+
auto_update=True
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
not_generated = agent.load_processed_df() is None
|
| 28 |
+
|
| 29 |
+
if code is not None:
|
| 30 |
+
if st.button("▶️ 执行预处理") or (auto and not_generated):
|
| 31 |
+
code = sanitize_code(edited)
|
| 32 |
+
agent.save_code(code)
|
| 33 |
+
|
| 34 |
+
exec_ns = {
|
| 35 |
+
"df": df,
|
| 36 |
+
"np": np,
|
| 37 |
+
"pd": pd,
|
| 38 |
+
"st": st,
|
| 39 |
+
"SimpleImputer": SimpleImputer,
|
| 40 |
+
"FunctionTransformer": FunctionTransformer,
|
| 41 |
+
"StandardScaler": StandardScaler,
|
| 42 |
+
"MinMaxScaler": MinMaxScaler,
|
| 43 |
+
"RobustScaler": RobustScaler,
|
| 44 |
+
"OneHotEncoder": OneHotEncoder,
|
| 45 |
+
"OrdinalEncoder": OrdinalEncoder,
|
| 46 |
+
"LabelEncoder": LabelEncoder,
|
| 47 |
+
"ColumnTransformer": ColumnTransformer,
|
| 48 |
+
"Pipeline": Pipeline,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
with st.spinner("正在运行程序..."):
|
| 53 |
+
exec(code, exec_ns)
|
| 54 |
+
except Exception as exc:
|
| 55 |
+
st.error(f"已保存报错,请重新调用llm生成代码debug")
|
| 56 |
+
st.text(traceback.format_exc())
|
| 57 |
+
agent.save_error(traceback.format_exc())
|
| 58 |
+
prep_code_gen(agent, debug=True)
|
| 59 |
+
else:
|
| 60 |
+
process_df = exec_ns.get("process_df")
|
| 61 |
+
if process_df is None:
|
| 62 |
+
st.error(
|
| 63 |
+
"脚本未写入 `process_df`。请确保编辑后的脚本在末尾赋值 process_df"
|
| 64 |
+
)
|
| 65 |
+
else:
|
| 66 |
+
agent.save_processed_df(process_df)
|
| 67 |
+
agent.finish_auto()
|
| 68 |
+
st.rerun()
|
| 69 |
+
return process_df
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def prep_code_gen(agent, auto = False, debug = False):
|
| 73 |
+
|
| 74 |
+
suggest = agent.load_preprocessing_suggestions()
|
| 75 |
+
df = agent.load_df()
|
| 76 |
+
|
| 77 |
+
chat_history = agent.load_memory()
|
| 78 |
+
already_generated = any(
|
| 79 |
+
entry["role"] == "assistant" and "预处理脚本已更新!请重新运行代码!" in str(entry["content"])
|
| 80 |
+
for entry in chat_history
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if suggest is not None:
|
| 84 |
+
|
| 85 |
+
if debug == True or (auto and not already_generated):
|
| 86 |
+
with st.spinner("预处理 Agent 正在编写脚本..."):
|
| 87 |
+
raw = agent.code_generation(
|
| 88 |
+
df.head(10).to_string(),
|
| 89 |
+
suggest,
|
| 90 |
+
)
|
| 91 |
+
code = sanitize_code(raw)
|
| 92 |
+
agent.save_code(code)
|
| 93 |
+
|
| 94 |
+
st.chat_message("assistant").write("预处理脚本已更新!请重新运行代码!")
|
| 95 |
+
agent.add_memory({"role": "assistant", "content": "预处理脚本已更新!请重新运行代码!"})
|
| 96 |
+
|
| 97 |
+
st.rerun()
|
| 98 |
+
|
| 99 |
+
analyze_btn = st.button("🔧 生成预处理代码", key='prep_code')
|
| 100 |
+
if analyze_btn:
|
| 101 |
+
with st.spinner("向 LLM 请求生成预处理脚本..."):
|
| 102 |
+
raw = agent.code_generation(
|
| 103 |
+
df.head(10).to_string(),
|
| 104 |
+
suggest,
|
| 105 |
+
)
|
| 106 |
+
code = sanitize_code(raw)
|
| 107 |
+
agent.save_code(code)
|
| 108 |
+
|
| 109 |
+
st.chat_message("assistant").write("预处理脚本已更新!请重新运行代码!")
|
| 110 |
+
agent.add_memory({"role": "assistant", "content": "预处理脚本已更新!请重新运行代码!"})
|
| 111 |
+
|
| 112 |
+
st.rerun()
|
workflow/preprocessing/preprocessing_render.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import traceback
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import streamlit as st
|
| 7 |
+
from streamlit_ace import st_ace
|
| 8 |
+
from sklearn.compose import ColumnTransformer
|
| 9 |
+
from sklearn.impute import SimpleImputer
|
| 10 |
+
from sklearn.pipeline import Pipeline
|
| 11 |
+
from sklearn.preprocessing import FunctionTransformer
|
| 12 |
+
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder, OrdinalEncoder, RobustScaler, StandardScaler
|
| 13 |
+
|
| 14 |
+
from utils.sanitize_code import sanitize_code
|
| 15 |
+
from workflow.preprocessing.preprocessing_core import prep_meta_execution, prep_code_gen
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def prep_basic_info(agent):
|
| 19 |
+
|
| 20 |
+
df = agent.load_df()
|
| 21 |
+
|
| 22 |
+
# 展示基本统计
|
| 23 |
+
r, c = df.shape
|
| 24 |
+
missing = int(df.isnull().sum().sum())
|
| 25 |
+
col1, col2, col3 = st.columns(3)
|
| 26 |
+
col1.metric("行数", r)
|
| 27 |
+
col2.metric("列数", c)
|
| 28 |
+
col3.metric("缺失值总数", missing)
|
| 29 |
+
|
| 30 |
+
dtype_info = pd.DataFrame({
|
| 31 |
+
'列名': df.columns,
|
| 32 |
+
'类型': df.dtypes.astype(str),
|
| 33 |
+
'非空值数量': df.count().values,
|
| 34 |
+
'缺失值比例(%)': (df.isnull().mean() * 100).round(2).values,
|
| 35 |
+
})
|
| 36 |
+
dtype_info = dtype_info.reset_index(drop=True)
|
| 37 |
+
st.dataframe(dtype_info, use_container_width=True)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def prep_execution(agent, auto=False):
|
| 41 |
+
'''
|
| 42 |
+
training data进行预处理
|
| 43 |
+
'''
|
| 44 |
+
|
| 45 |
+
code = agent.load_code()
|
| 46 |
+
df = agent.load_df()
|
| 47 |
+
|
| 48 |
+
process_df = prep_meta_execution(agent, code, df, auto=auto)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def prep_result(agent):
|
| 52 |
+
|
| 53 |
+
process_df = agent.load_processed_df()
|
| 54 |
+
df = agent.load_df()
|
| 55 |
+
|
| 56 |
+
if process_df is not None:
|
| 57 |
+
st.write("处理前数据预览:", df.head(10))
|
| 58 |
+
st.write("处理后数据预览:", process_df.head(10))
|
| 59 |
+
|
| 60 |
+
csv_buffer = io.StringIO()
|
| 61 |
+
process_df.to_csv(csv_buffer, index=False)
|
| 62 |
+
csv_bytes = csv_buffer.getvalue().encode('utf-8')
|
| 63 |
+
|
| 64 |
+
st.download_button(
|
| 65 |
+
label="⬇️ 下载处理后数据",
|
| 66 |
+
data=csv_bytes,
|
| 67 |
+
file_name="processed_data.csv",
|
| 68 |
+
mime="text/csv",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def prep_chat(agent, auto=False):
|
| 73 |
+
"""渲染对话式建议区"""
|
| 74 |
+
|
| 75 |
+
with st.chat_message("assistant"):
|
| 76 |
+
st.write("我是 Anystat 数据分析助手,很高兴为您服务!\n\n"
|
| 77 |
+
"您可以在下方输入预处理需求,或直接点击按钮获取预处理建议。")
|
| 78 |
+
analyze_btn = st.button("🔍 预处理推荐", key='prep_suggest')
|
| 79 |
+
|
| 80 |
+
# 对话历史渲染
|
| 81 |
+
chat_history = agent.load_memory()
|
| 82 |
+
|
| 83 |
+
for idx, entry in enumerate(chat_history):
|
| 84 |
+
bubble = st.chat_message(entry["role"])
|
| 85 |
+
content = entry["content"]
|
| 86 |
+
if isinstance(content, str):
|
| 87 |
+
bubble.write(content)
|
| 88 |
+
|
| 89 |
+
already_generated = any(
|
| 90 |
+
entry["role"] == "assistant" and "预处理" in str(entry["content"])
|
| 91 |
+
for entry in chat_history
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# 自动/手动触发
|
| 95 |
+
if analyze_btn or (auto and not already_generated):
|
| 96 |
+
|
| 97 |
+
st.chat_message("user").write("请给我预处理建议")
|
| 98 |
+
agent.add_memory({'role': 'user', 'content': "请给我预处理建议"})
|
| 99 |
+
|
| 100 |
+
with st.spinner("生成建议中…"):
|
| 101 |
+
text = agent.get_preprocessing_suggestions()
|
| 102 |
+
agent.save_preprocessing_suggestions(text)
|
| 103 |
+
agent.refine_suggestions(df.head(10).to_string())
|
| 104 |
+
st.chat_message("assistant").write(text)
|
| 105 |
+
agent.add_memory({'role': 'assistant', 'content': text})
|
| 106 |
+
|
| 107 |
+
# 用户自然语言交互
|
| 108 |
+
user_input = st.chat_input("请输入您的问题")
|
| 109 |
+
if user_input:
|
| 110 |
+
st.chat_message("user").write(user_input)
|
| 111 |
+
agent.add_memory({'role': 'user', 'content': user_input})
|
| 112 |
+
agent.save_user_input(user_input)
|
| 113 |
+
with st.spinner("处理中…"):
|
| 114 |
+
reply = agent.get_preprocessing_suggestions(user_input)
|
| 115 |
+
agent.save_preprocessing_suggestions(reply)
|
| 116 |
+
agent.refine_suggestions(df.head(10).to_string())
|
| 117 |
+
st.chat_message('assistant').write(reply)
|
| 118 |
+
agent.add_memory({'role': 'assistant', 'content': reply})
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == '__main__':
|
| 122 |
+
|
| 123 |
+
st.title("数据预处理与标准化")
|
| 124 |
+
|
| 125 |
+
st.markdown("---")
|
| 126 |
+
|
| 127 |
+
data_loading_agent = st.session_state.data_loading_agent
|
| 128 |
+
df = data_loading_agent.load_df()
|
| 129 |
+
planner = st.session_state.planner_agent
|
| 130 |
+
auto = planner.prep_auto
|
| 131 |
+
|
| 132 |
+
if df is None:
|
| 133 |
+
st.warning("⚠️ 请先在数据导入页面加载数据")
|
| 134 |
+
st.stop()
|
| 135 |
+
|
| 136 |
+
agent = st.session_state.data_preprocess_agent
|
| 137 |
+
agent.add_df(df)
|
| 138 |
+
|
| 139 |
+
if st.session_state.auto_mode == True:
|
| 140 |
+
if (agent.finish_auto_task == True and planner.switched_prep == False) or planner.prep_auto == False:
|
| 141 |
+
planner.finish_prep_auto()
|
| 142 |
+
st.switch_page("workflow/visualization/viz_render.py")
|
| 143 |
+
|
| 144 |
+
code = agent.load_code()
|
| 145 |
+
if code is None:
|
| 146 |
+
code_expand = False
|
| 147 |
+
else:
|
| 148 |
+
code_expand = True
|
| 149 |
+
|
| 150 |
+
c = st.columns(2)
|
| 151 |
+
with c[0].expander('预处理展示', True):
|
| 152 |
+
prep_basic_info(agent)
|
| 153 |
+
with c[1].expander('预处理建议', True):
|
| 154 |
+
prep_chat(agent, auto)
|
| 155 |
+
prep_code_gen(agent, auto=auto)
|
| 156 |
+
with c[0].expander('预处理执行', code_expand):
|
| 157 |
+
prep_execution(agent, auto)
|
| 158 |
+
with c[0].expander('预处理结果', code_expand):
|
| 159 |
+
prep_result(agent)
|
workflow/report/report_core.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ReportNode:
|
| 2 |
+
"""文档节点:可以是 heading 或 paragraph"""
|
| 3 |
+
def __init__(self, node_type, text, level=0):
|
| 4 |
+
self.type = node_type # "heading" 或 "paragraph"
|
| 5 |
+
self.text = text
|
| 6 |
+
self.level = level
|
| 7 |
+
self.children = [] # 子节点(用于分层)
|
| 8 |
+
|
| 9 |
+
def to_dict(self):
|
| 10 |
+
return {
|
| 11 |
+
"type": self.type,
|
| 12 |
+
"text": self.text,
|
| 13 |
+
"level": self.level,
|
| 14 |
+
"children": [c.to_dict() for c in self.children]
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
# 现在只适合于顺序添加
|
| 18 |
+
class Reportcore:
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.root = ReportNode("root", "", level=-1) # 虚拟根节点
|
| 21 |
+
self.current_stack = [self.root] # 用栈管理当前层级
|
| 22 |
+
|
| 23 |
+
def add_heading(self, text, level=0):# 从0开始
|
| 24 |
+
"""
|
| 25 |
+
添加标题,根据 level 自动挂载到合适的父节点
|
| 26 |
+
"""
|
| 27 |
+
new_node = ReportNode("heading", text, level)
|
| 28 |
+
|
| 29 |
+
# 回溯到合适的父节点
|
| 30 |
+
while self.current_stack and self.current_stack[-1].level >= level:
|
| 31 |
+
self.current_stack.pop()
|
| 32 |
+
|
| 33 |
+
parent = self.current_stack[-1]
|
| 34 |
+
parent.children.append(new_node)
|
| 35 |
+
self.current_stack.append(new_node)
|
| 36 |
+
|
| 37 |
+
def add_paragraph(self, text):
|
| 38 |
+
"""
|
| 39 |
+
添加段落,挂在当前最后一个 heading 下
|
| 40 |
+
"""
|
| 41 |
+
parent = self.current_stack[-1]
|
| 42 |
+
|
| 43 |
+
parent.children.append(ReportNode("paragraph", text, level=parent.level + 1))
|
| 44 |
+
|
| 45 |
+
def to_dict(self):
|
| 46 |
+
return self.root.to_dict()
|
workflow/report/report_html.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import re
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import plotly.io as pio
|
| 5 |
+
from utils.sanitize_code import sanitize_code
|
| 6 |
+
import base64
|
| 7 |
+
|
| 8 |
+
def write_html(agents):
|
| 9 |
+
report_agent = agents[-1]
|
| 10 |
+
report_obj = report_agent.load_report() # Reportcore
|
| 11 |
+
|
| 12 |
+
# 图像分析列表
|
| 13 |
+
analysis_list = agents[2].summary_fig_analysis_list()
|
| 14 |
+
|
| 15 |
+
# 给 heading 加唯一 id
|
| 16 |
+
heading_counter = {"count": 0}
|
| 17 |
+
def _gen_id(text):
|
| 18 |
+
heading_counter["count"] += 1
|
| 19 |
+
return f"sec-{heading_counter['count']}"
|
| 20 |
+
|
| 21 |
+
# 遍历树 → 正文 & TOC
|
| 22 |
+
toc_items, content_items = [], []
|
| 23 |
+
|
| 24 |
+
def _process_node(node):
|
| 25 |
+
if node.type == "heading":
|
| 26 |
+
sec_id = _gen_id(node.text)
|
| 27 |
+
toc_items.append((sec_id, node.text, node.level))
|
| 28 |
+
content_items.append(
|
| 29 |
+
f"<h{node.level} id='{sec_id}' class='font-bold text-gray-800 mt-8 mb-4 text-{max(6-node.level,1)}xl'>{node.text}</h{node.level}>"
|
| 30 |
+
)
|
| 31 |
+
for ch in node.children:
|
| 32 |
+
_process_node(ch)
|
| 33 |
+
|
| 34 |
+
elif node.type == "paragraph":
|
| 35 |
+
parts = re.split(r'(\[FIG:\d+\])', node.text)
|
| 36 |
+
html_parts = []
|
| 37 |
+
for part in parts:
|
| 38 |
+
part = part.strip()
|
| 39 |
+
if not part:
|
| 40 |
+
continue
|
| 41 |
+
if part.startswith("[FIG:") and part.endswith("]"):
|
| 42 |
+
idx = int(part[5:-1])
|
| 43 |
+
fig_html = ""
|
| 44 |
+
if 0 <= idx < len(analysis_list):
|
| 45 |
+
fig_obj = analysis_list[idx].get("figure")
|
| 46 |
+
try:
|
| 47 |
+
buf = io.BytesIO()
|
| 48 |
+
pio.write_image(fig_obj, buf, format="png")
|
| 49 |
+
data = buf.getvalue()
|
| 50 |
+
b64 = base64.b64encode(data).decode("utf-8")
|
| 51 |
+
fig_html = f"<div class='flex justify-center my-6'><img src='data:image/png;base64,{b64}' class='rounded-xl shadow-md max-w-3xl w-full'/></div>"
|
| 52 |
+
except Exception as e:
|
| 53 |
+
fig_html = f"<p class='text-red-500'>[图像插入失败: {e}]</p>"
|
| 54 |
+
html_parts.append(fig_html)
|
| 55 |
+
else:
|
| 56 |
+
html_parts.append(f"<p class='text-gray-700 leading-relaxed mb-4'>{part}</p>")
|
| 57 |
+
content_items.append("".join(html_parts))
|
| 58 |
+
|
| 59 |
+
else: # root
|
| 60 |
+
for ch in node.children:
|
| 61 |
+
_process_node(ch)
|
| 62 |
+
|
| 63 |
+
_process_node(report_obj.root)
|
| 64 |
+
|
| 65 |
+
# TOC HTML
|
| 66 |
+
toc_html = ["<nav class='space-y-2'>"]
|
| 67 |
+
prev_level = -1
|
| 68 |
+
for sec_id, text, level in toc_items:
|
| 69 |
+
indent = "ml-" + str(level * 4)
|
| 70 |
+
toc_html.append(f"<a href='#{sec_id}' class='block {indent} text-gray-600 hover:text-blue-600 transition-colors'>{text}</a>")
|
| 71 |
+
toc_html.append("</nav>")
|
| 72 |
+
|
| 73 |
+
# 拼接完整 HTML
|
| 74 |
+
html_content = f"""
|
| 75 |
+
<html>
|
| 76 |
+
<head>
|
| 77 |
+
<meta charset="utf-8">
|
| 78 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
| 79 |
+
<script>
|
| 80 |
+
document.addEventListener("DOMContentLoaded", function() {{
|
| 81 |
+
const sections = document.querySelectorAll("h1, h2, h3, h4, h5, h6");
|
| 82 |
+
const navLinks = document.querySelectorAll("nav a");
|
| 83 |
+
|
| 84 |
+
function onScroll() {{
|
| 85 |
+
let scrollPos = document.documentElement.scrollTop || document.body.scrollTop;
|
| 86 |
+
let currentId = "";
|
| 87 |
+
sections.forEach(sec => {{
|
| 88 |
+
if (sec.offsetTop - 80 <= scrollPos) {{
|
| 89 |
+
currentId = sec.id;
|
| 90 |
+
}}
|
| 91 |
+
}});
|
| 92 |
+
navLinks.forEach(link => {{
|
| 93 |
+
link.classList.remove("font-bold", "text-blue-600");
|
| 94 |
+
if (link.getAttribute("href") === "#" + currentId) {{
|
| 95 |
+
link.classList.add("font-bold", "text-blue-600");
|
| 96 |
+
}}
|
| 97 |
+
}});
|
| 98 |
+
}}
|
| 99 |
+
window.addEventListener("scroll", onScroll);
|
| 100 |
+
onScroll();
|
| 101 |
+
}});
|
| 102 |
+
</script>
|
| 103 |
+
</head>
|
| 104 |
+
<body class="flex font-sans">
|
| 105 |
+
<aside class="fixed top-0 left-0 h-screen w-64 bg-gray-100 border-r border-gray-300 p-6 overflow-y-auto">
|
| 106 |
+
<h2 class="text-xl font-bold mb-4">目录</h2>
|
| 107 |
+
{''.join(toc_html)}
|
| 108 |
+
</aside>
|
| 109 |
+
<main class="ml-64 p-10 w-full max-w-5xl">
|
| 110 |
+
{''.join(content_items)}
|
| 111 |
+
</main>
|
| 112 |
+
</body>
|
| 113 |
+
</html>
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
report_agent.save_html(html_content)
|
| 117 |
+
st.success("HTML 报告 (Tailwind 风格) 生成成功 ✅")
|
workflow/report/report_markdown.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import re
|
| 3 |
+
import base64
|
| 4 |
+
import plotly.io as pio
|
| 5 |
+
import streamlit as st
|
| 6 |
+
def write_markdown(agents):
|
| 7 |
+
report_agent = agents[-1]
|
| 8 |
+
report_obj = report_agent.load_report() # Reportcore
|
| 9 |
+
|
| 10 |
+
# 图像分析列表
|
| 11 |
+
analysis_list = agents[2].summary_fig_analysis_list()
|
| 12 |
+
|
| 13 |
+
md_parts = []
|
| 14 |
+
|
| 15 |
+
def _process_node(node):
|
| 16 |
+
if node.type == "heading":
|
| 17 |
+
prefix = "#" * (node.level if node.level > 0 else 1)
|
| 18 |
+
md_parts.append(f"{prefix} {node.text}\n")
|
| 19 |
+
for ch in node.children:
|
| 20 |
+
_process_node(ch)
|
| 21 |
+
|
| 22 |
+
elif node.type == "paragraph":
|
| 23 |
+
parts = re.split(r'(\[FIG:\d+\])', node.text)
|
| 24 |
+
for part in parts:
|
| 25 |
+
part = part.strip()
|
| 26 |
+
if not part:
|
| 27 |
+
continue
|
| 28 |
+
|
| 29 |
+
if part.startswith("[FIG:") and part.endswith("]"):
|
| 30 |
+
idx = int(part[5:-1])
|
| 31 |
+
if 0 <= idx < len(analysis_list):
|
| 32 |
+
fig_obj = analysis_list[idx].get("figure")
|
| 33 |
+
try:
|
| 34 |
+
buf = io.BytesIO()
|
| 35 |
+
pio.write_image(fig_obj, buf, format="png")
|
| 36 |
+
data = buf.getvalue()
|
| 37 |
+
b64 = base64.b64encode(data).decode("utf-8")
|
| 38 |
+
# 🔹 直接内嵌 base64
|
| 39 |
+
md_parts.append(
|
| 40 |
+
f"\n"
|
| 41 |
+
)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
md_parts.append(f"> **图像插入失败**: {e}\n")
|
| 44 |
+
else:
|
| 45 |
+
md_parts.append(f"{part}\n\n")
|
| 46 |
+
|
| 47 |
+
else: # root
|
| 48 |
+
for ch in node.children:
|
| 49 |
+
_process_node(ch)
|
| 50 |
+
|
| 51 |
+
_process_node(report_obj.root)
|
| 52 |
+
|
| 53 |
+
md_content = "\n".join(md_parts)
|
| 54 |
+
report_agent.save_markdown(md_content)
|
| 55 |
+
st.success("Markdown 报告生成成功 ✅")
|
workflow/report/report_prepare_er.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import io
|
| 3 |
+
import re
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
|
| 6 |
+
import streamlit as st
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from stqdm import stqdm
|
| 9 |
+
from docx import Document
|
| 10 |
+
from docx.oxml.ns import qn
|
| 11 |
+
from docx.shared import Inches
|
| 12 |
+
from docx.enum.table import WD_TABLE_ALIGNMENT
|
| 13 |
+
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
| 14 |
+
import plotly.express as px
|
| 15 |
+
import plotly.io as pio
|
| 16 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 17 |
+
|
| 18 |
+
from utils.sanitize_code import sanitize_code
|
| 19 |
+
from workflow.report.report_core import *
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def report_prepare(agents, parallel=True, max_workers=4):
|
| 23 |
+
report_agent = agents[-1]
|
| 24 |
+
toc = report_agent.load_outline()
|
| 25 |
+
if toc is None:
|
| 26 |
+
st.error("请先生成目录")
|
| 27 |
+
return
|
| 28 |
+
|
| 29 |
+
toc = sanitize_code(toc)
|
| 30 |
+
|
| 31 |
+
# === 汇总各分析模块的摘要 ===
|
| 32 |
+
agent_abstracts = {}
|
| 33 |
+
with st.spinner("正在汇总各分析模块的结果..."):
|
| 34 |
+
for i in stqdm(range(len(agents) - 1)):
|
| 35 |
+
agent_abstracts[i] = agents[i].check_abstract()
|
| 36 |
+
|
| 37 |
+
# === 更新 toc 的 FIG 列表 ===
|
| 38 |
+
selected_full_contents_vis = agents[2].check_full()
|
| 39 |
+
toc = report_agent.selected_photo_update_toc(toc, selected_full_contents_vis)
|
| 40 |
+
toc = sanitize_code(toc)
|
| 41 |
+
print(toc)
|
| 42 |
+
try:
|
| 43 |
+
toc = ast.literal_eval(toc)
|
| 44 |
+
except Exception:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
# === 更新 toc 的 模块选择 列表 ===
|
| 48 |
+
with st.spinner("正在匹配各章节所需的分析模块..."):
|
| 49 |
+
toc_with_choice = report_agent.update_toc_with_relevant_sections(toc, agent_abstracts)
|
| 50 |
+
toc_with_choice = sanitize_code(toc_with_choice)
|
| 51 |
+
try:
|
| 52 |
+
toc_with_choice = ast.literal_eval(toc_with_choice)
|
| 53 |
+
except Exception:
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
# === 初始化报告结构 ===
|
| 57 |
+
doc = Reportcore()
|
| 58 |
+
doc.add_heading('数据分析报告', 0)
|
| 59 |
+
selected_model = st.session_state.selected_model
|
| 60 |
+
|
| 61 |
+
def process_section(idx, t,t_w_c, history_content=""):
|
| 62 |
+
st.session_state.selected_model = selected_model
|
| 63 |
+
# t: ('标题', 层级, 内容大纲, [figs], [modules])
|
| 64 |
+
_, _, _, _, choice_list = t_w_c
|
| 65 |
+
selected_full_contents = {i: agents[i].check_full() for i in choice_list if i < len(agents) - 1}
|
| 66 |
+
content = report_agent.write_section_body(toc, t, selected_full_contents, history_content)
|
| 67 |
+
print(idx)
|
| 68 |
+
return (idx, t, content)
|
| 69 |
+
|
| 70 |
+
results = []
|
| 71 |
+
|
| 72 |
+
# 串行或并行
|
| 73 |
+
if not parallel:
|
| 74 |
+
with st.spinner("正在串行生成各章节内容(带上下文)..."):
|
| 75 |
+
history_content = ""
|
| 76 |
+
for idx, t in stqdm(enumerate(toc)):
|
| 77 |
+
t_w_c= toc_with_choice[idx]
|
| 78 |
+
_, _, content = process_section(idx, t,t_w_c, history_content)
|
| 79 |
+
results.append((idx, t, content))
|
| 80 |
+
history_content += f"\n\n{t[0]}\n{content}"
|
| 81 |
+
else:
|
| 82 |
+
with st.spinner(f"正在并行生成各章节内容({max_workers}线程)..."):
|
| 83 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 84 |
+
print(toc_with_choice)
|
| 85 |
+
# print(f"idx={idx}, len={len(toc_with_choice)}")
|
| 86 |
+
futures = {
|
| 87 |
+
executor.submit(process_section, idx, t, toc_with_choice[idx], ""): idx
|
| 88 |
+
for idx, t in enumerate(toc)
|
| 89 |
+
}
|
| 90 |
+
for future in stqdm(as_completed(futures), total=len(futures)):
|
| 91 |
+
try:
|
| 92 |
+
results.append(future.result())
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f"章节生成失败: {e}")
|
| 95 |
+
|
| 96 |
+
# 排序 & 写入报告
|
| 97 |
+
results.sort(key=lambda x: x[0])
|
| 98 |
+
for _, t, content in results:
|
| 99 |
+
doc.add_heading(t[0], level=t[1])
|
| 100 |
+
doc.add_paragraph(content)
|
| 101 |
+
|
| 102 |
+
report_agent.save_report(doc)
|
workflow/report/report_render.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import io
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
|
| 5 |
+
from stqdm import stqdm
|
| 6 |
+
import mammoth
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import streamlit as st
|
| 10 |
+
import streamlit_antd_components as sac
|
| 11 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 12 |
+
|
| 13 |
+
from prompt_engineer.sec5_call_llm import *
|
| 14 |
+
from workflow.report.report_utils import html_dowmload
|
| 15 |
+
from workflow.report.report_html import write_html
|
| 16 |
+
from workflow.report.report_word import write_word
|
| 17 |
+
from workflow.report.report_markdown import write_markdown
|
| 18 |
+
from workflow.report.report_prepare_er import report_prepare
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def report_save(agents, auto):
|
| 22 |
+
|
| 23 |
+
report_agent = agents[-1]
|
| 24 |
+
action = report_agent.load_report_format()
|
| 25 |
+
|
| 26 |
+
if report_agent.load_report_format() == 'HTML':
|
| 27 |
+
not_generate = report_agent.html == None
|
| 28 |
+
if report_agent.load_report_format() == 'Word':
|
| 29 |
+
not_generate = report_agent.word == None
|
| 30 |
+
if report_agent.load_report_format() == 'Markdown':
|
| 31 |
+
not_generate = report_agent.markdown == None
|
| 32 |
+
|
| 33 |
+
mode = report_agent.load_gen_mode()
|
| 34 |
+
parallel = (mode == "并行")
|
| 35 |
+
|
| 36 |
+
if st.button(f"📝 生成 {action} 报告") or (auto and not_generate):
|
| 37 |
+
with st.spinner(f"正在生成 {action} 报告..."):
|
| 38 |
+
|
| 39 |
+
report_prepare(agents, parallel=parallel)
|
| 40 |
+
|
| 41 |
+
if report_agent.load_report_format() == 'Word':
|
| 42 |
+
write_word(agents)
|
| 43 |
+
elif report_agent.load_report_format() == 'HTML':
|
| 44 |
+
write_html(agents)
|
| 45 |
+
elif report_agent.load_report_format() == 'Markdown':
|
| 46 |
+
write_markdown(agents)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def report_basic_info(agent, auto) -> None:
|
| 50 |
+
outline_length = sac.segmented(
|
| 51 |
+
items=[
|
| 52 |
+
sac.SegmentedItem(label='简要'),
|
| 53 |
+
sac.SegmentedItem(label='标准'),
|
| 54 |
+
sac.SegmentedItem(label='详细'),
|
| 55 |
+
],
|
| 56 |
+
label='详细程度', index=1, align='center',
|
| 57 |
+
size='sm', radius='sm', use_container_width=True
|
| 58 |
+
)
|
| 59 |
+
agent.save_outline_length(outline_length)
|
| 60 |
+
|
| 61 |
+
c1, c2 = st.columns(2)
|
| 62 |
+
with c1:
|
| 63 |
+
date = st.date_input("报告日期", datetime.date(2025, 10, 1))
|
| 64 |
+
agent.save_date(date)
|
| 65 |
+
with c2:
|
| 66 |
+
name = st.text_input("报告作者", "Anystat")
|
| 67 |
+
agent.save_name(name)
|
| 68 |
+
|
| 69 |
+
c1, c2 = st.columns([3, 1])
|
| 70 |
+
with c1:
|
| 71 |
+
report_format = sac.chip(
|
| 72 |
+
items=[
|
| 73 |
+
sac.ChipItem(label='Word', icon=sac.BsIcon(name='file-earmark-word', size=16)),
|
| 74 |
+
sac.ChipItem(label='HTML', icon=sac.BsIcon(name='filetype-html', size=16)),
|
| 75 |
+
sac.ChipItem(label='Markdown', icon=sac.BsIcon(name='file-earmark-code', size=16)),
|
| 76 |
+
],
|
| 77 |
+
label='选择报告生成格式', index=[0, 2],
|
| 78 |
+
align='start', radius='md', multiple=False,
|
| 79 |
+
)
|
| 80 |
+
agent.save_report_format(report_format)
|
| 81 |
+
|
| 82 |
+
with c2:
|
| 83 |
+
mode = sac.segmented(
|
| 84 |
+
items=[
|
| 85 |
+
sac.SegmentedItem(label='并行'),
|
| 86 |
+
sac.SegmentedItem(label='串行'),
|
| 87 |
+
],
|
| 88 |
+
label='生成模式', align='end', size='sm',
|
| 89 |
+
use_container_width=True, radius='md'
|
| 90 |
+
)
|
| 91 |
+
agent.save_gen_mode(mode)
|
| 92 |
+
|
| 93 |
+
user_input = st.text_input("报告生成要求", "默认")
|
| 94 |
+
agent.save_user_input(user_input)
|
| 95 |
+
|
| 96 |
+
not_generated = report_agent.load_outline() is None
|
| 97 |
+
|
| 98 |
+
# === 并行生成目录 ===
|
| 99 |
+
if st.button("🗒️ 生成目录") or (auto and not_generated):
|
| 100 |
+
with st.spinner("⏳ 正在自动生成目录结构..."):
|
| 101 |
+
summaries = []
|
| 102 |
+
|
| 103 |
+
# === 保存当前 Streamlit 状态副本 ===
|
| 104 |
+
session_snapshot = dict(st.session_state)
|
| 105 |
+
|
| 106 |
+
def process_summary(idx, sub_agent, session_snapshot):
|
| 107 |
+
"""并行执行 summary_html/summary_word(带状态复制)"""
|
| 108 |
+
# 恢复 session_state
|
| 109 |
+
for k, v in session_snapshot.items():
|
| 110 |
+
st.session_state[k] = v
|
| 111 |
+
|
| 112 |
+
# 实际生成逻辑
|
| 113 |
+
if hasattr(sub_agent, "summary_html"):
|
| 114 |
+
summary = sub_agent.summary_html()
|
| 115 |
+
else:
|
| 116 |
+
summary = sub_agent.summary_word()
|
| 117 |
+
|
| 118 |
+
return idx, summary
|
| 119 |
+
|
| 120 |
+
max_workers = min(6, len(agents) - 1)
|
| 121 |
+
results = []
|
| 122 |
+
|
| 123 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 124 |
+
futures = {
|
| 125 |
+
executor.submit(process_summary, i, sub_agent, session_snapshot): i
|
| 126 |
+
for i, sub_agent in enumerate(agents[:-1])
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
for future in stqdm(as_completed(futures), total=len(futures)):
|
| 130 |
+
try:
|
| 131 |
+
idx, summary = future.result()
|
| 132 |
+
if summary:
|
| 133 |
+
results.append((idx, summary))
|
| 134 |
+
except Exception as e:
|
| 135 |
+
print(f"子模块摘要生成失败: {e}")
|
| 136 |
+
|
| 137 |
+
# === 恢复章节原顺序 ===
|
| 138 |
+
results.sort(key=lambda x: x[0])
|
| 139 |
+
summaries = [summary for _, summary in results if summary]
|
| 140 |
+
|
| 141 |
+
# === 生成目录 ===
|
| 142 |
+
default_toc = report_agent.generate_toc_from_summary(summaries)
|
| 143 |
+
report_agent.save_outline(default_toc)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def report_outline(agents):
|
| 147 |
+
|
| 148 |
+
st.subheader("目录结构预览与编辑")
|
| 149 |
+
load_agent, preproc_agent, visual_agent, coding_agent, report_agent = agents[0], agents[1], agents[2], agents[3], agents[4]
|
| 150 |
+
|
| 151 |
+
default_toc = report_agent.load_outline()
|
| 152 |
+
|
| 153 |
+
toc_md = st.text_area(
|
| 154 |
+
"您可以在此处编辑目录结构",
|
| 155 |
+
value=default_toc,
|
| 156 |
+
height=260
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
report_agent.save_outline(toc_md)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def report_execution(report_agent):
|
| 163 |
+
|
| 164 |
+
if report_agent.load_report_format() == 'Word':
|
| 165 |
+
|
| 166 |
+
full_report = report_agent.load_word()
|
| 167 |
+
if full_report is not None:
|
| 168 |
+
st.download_button(
|
| 169 |
+
label="⬇️ 下载 Word 报告",
|
| 170 |
+
data=full_report,
|
| 171 |
+
file_name="report.docx",
|
| 172 |
+
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
elif report_agent.load_report_format() == 'HTML':
|
| 176 |
+
|
| 177 |
+
full_report = report_agent.load_html()
|
| 178 |
+
|
| 179 |
+
if full_report is not None:
|
| 180 |
+
st.download_button(
|
| 181 |
+
label="⬇️ 下载 HTML 报告",
|
| 182 |
+
data=full_report.encode("utf-8"),
|
| 183 |
+
file_name="report.html",
|
| 184 |
+
mime="text/html",
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if st.button("⬇️ 下载 PDF 报告"):
|
| 188 |
+
html_dowmload(full_report)
|
| 189 |
+
elif report_agent.load_report_format() == 'Markdown':
|
| 190 |
+
|
| 191 |
+
full_report = report_agent.load_markdown()
|
| 192 |
+
if full_report is not None:
|
| 193 |
+
|
| 194 |
+
# 提供下载按钮
|
| 195 |
+
st.download_button(
|
| 196 |
+
label="⬇️ 下载 Markdown 报告",
|
| 197 |
+
data=full_report,
|
| 198 |
+
file_name="report.md",
|
| 199 |
+
mime="text/markdown"
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
|
| 205 |
+
st.title("报告生成")
|
| 206 |
+
|
| 207 |
+
st.markdown("---")
|
| 208 |
+
|
| 209 |
+
load_agent = st.session_state.data_loading_agent
|
| 210 |
+
preproc_agent = st.session_state.data_preprocess_agent
|
| 211 |
+
visual_agent = st.session_state.visualization_agent
|
| 212 |
+
coding_agent = st.session_state.modeling_coding_agent
|
| 213 |
+
planner = st.session_state.planner_agent
|
| 214 |
+
auto = planner.report_auto
|
| 215 |
+
|
| 216 |
+
processed_df = preproc_agent.load_processed_df()
|
| 217 |
+
if processed_df is None:
|
| 218 |
+
df = load_agent.load_df()
|
| 219 |
+
else:
|
| 220 |
+
df = processed_df
|
| 221 |
+
|
| 222 |
+
if df is None:
|
| 223 |
+
st.warning("⚠️ 请先在数据导入页面加载数据")
|
| 224 |
+
st.stop()
|
| 225 |
+
|
| 226 |
+
if isinstance(df, np.ndarray):
|
| 227 |
+
df = pd.DataFrame(df)
|
| 228 |
+
|
| 229 |
+
df_shuffled = df.sample(frac=1, random_state=42).reset_index(drop=True)
|
| 230 |
+
|
| 231 |
+
report_agent = st.session_state.report_agent
|
| 232 |
+
report_agent.add_df(df_shuffled)
|
| 233 |
+
|
| 234 |
+
agents = [load_agent, preproc_agent, visual_agent, coding_agent, report_agent]
|
| 235 |
+
|
| 236 |
+
c = st.columns(2)
|
| 237 |
+
with c[0].expander('报告设置', True):
|
| 238 |
+
report_basic_info(report_agent, auto)
|
| 239 |
+
|
| 240 |
+
with c[1].expander('报告大纲', True):
|
| 241 |
+
report_outline(agents)
|
| 242 |
+
report_save(agents, auto)
|
| 243 |
+
report_execution(report_agent)
|
workflow/report/report_utils.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import base64
|
| 3 |
+
|
| 4 |
+
import streamlit as st
|
| 5 |
+
from playwright.sync_api import sync_playwright
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def html_to_pdf_bytes_playwright(html: str) -> bytes:
|
| 9 |
+
with sync_playwright() as p:
|
| 10 |
+
browser = p.chromium.launch()
|
| 11 |
+
page = browser.new_page()
|
| 12 |
+
page.set_content(html, wait_until="load")
|
| 13 |
+
pdf_bytes = page.pdf(format="A4", print_background=True)
|
| 14 |
+
browser.close()
|
| 15 |
+
return pdf_bytes
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def html_dowmload(full_report):
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
pdf_bytes = html_to_pdf_bytes_playwright(full_report)
|
| 22 |
+
except Exception as e:
|
| 23 |
+
st.error(f"生成 PDF 出错:{e}")
|
| 24 |
+
else:
|
| 25 |
+
b64 = base64.b64encode(pdf_bytes).decode("utf-8")
|
| 26 |
+
|
| 27 |
+
auto_download_html = f"""
|
| 28 |
+
<html>
|
| 29 |
+
<body>
|
| 30 |
+
<a id="dl_link"
|
| 31 |
+
href="data:application/pdf;base64,{b64}"
|
| 32 |
+
download="report.pdf"
|
| 33 |
+
style="display:none">download</a>
|
| 34 |
+
<script>
|
| 35 |
+
(function() {{
|
| 36 |
+
const a = document.getElementById('dl_link');
|
| 37 |
+
try {{
|
| 38 |
+
a.click();
|
| 39 |
+
}} catch (err) {{
|
| 40 |
+
// 如果自动点击被阻止,替换页面内容并露出手动链接
|
| 41 |
+
document.body.innerHTML =
|
| 42 |
+
'<p>自动下载被浏览器阻止,请点击下面链接手动下载:</p>' + a.outerHTML;
|
| 43 |
+
}}
|
| 44 |
+
}})();
|
| 45 |
+
</script>
|
| 46 |
+
</body>
|
| 47 |
+
</html>
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
st.components.v1.html(auto_download_html, height=120)
|
| 51 |
+
|
| 52 |
+
st.download_button(
|
| 53 |
+
label="⬇️ 手动下载 PDF(回退)",
|
| 54 |
+
data=pdf_bytes,
|
| 55 |
+
file_name="report.pdf",
|
| 56 |
+
mime="application/pdf",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
st.success("PDF 已生成(如未自动下载,请使用上方手动下载按钮)。")
|
workflow/report/report_word.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import io
|
| 3 |
+
import re
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
|
| 6 |
+
import streamlit as st
|
| 7 |
+
from stqdm import stqdm
|
| 8 |
+
from docx import Document
|
| 9 |
+
from docx.oxml.ns import qn
|
| 10 |
+
from docx.shared import Inches
|
| 11 |
+
from docx.enum.table import WD_TABLE_ALIGNMENT
|
| 12 |
+
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
| 13 |
+
import plotly.express as px
|
| 14 |
+
import plotly.io as pio
|
| 15 |
+
|
| 16 |
+
from utils.sanitize_code import sanitize_code
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def write_word(agents):
|
| 21 |
+
'''
|
| 22 |
+
choice:是否要搜索
|
| 23 |
+
True:根据目录搜索相关章节
|
| 24 |
+
False:全部章节
|
| 25 |
+
'''
|
| 26 |
+
# 拿图
|
| 27 |
+
analysis_list = agents[2].summary_fig_analysis_list()
|
| 28 |
+
|
| 29 |
+
report_agent = agents[-1]
|
| 30 |
+
report_obj = report_agent.load_report() # Reportcore
|
| 31 |
+
|
| 32 |
+
doc = Document()
|
| 33 |
+
|
| 34 |
+
style = doc.styles['Normal']
|
| 35 |
+
|
| 36 |
+
style.font.name = 'Times New Roman'
|
| 37 |
+
style._element.rPr.rFonts.set(qn('w:eastAsia'), '微软雅黑')
|
| 38 |
+
|
| 39 |
+
def _insert_figure(fig_obj):
|
| 40 |
+
if fig_obj is None:
|
| 41 |
+
return
|
| 42 |
+
try:
|
| 43 |
+
img_bytes = io.BytesIO()
|
| 44 |
+
img_bytes = io.BytesIO(fig_obj.to_image(format="png"))
|
| 45 |
+
# pio.write_image(fig_obj, img_bytes, format='png')
|
| 46 |
+
img_bytes.seek(0)
|
| 47 |
+
|
| 48 |
+
paragraph = doc.add_paragraph()
|
| 49 |
+
run = paragraph.add_run()
|
| 50 |
+
run.add_picture(img_bytes, width=Inches(4))
|
| 51 |
+
paragraph.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
| 52 |
+
except Exception as e:
|
| 53 |
+
doc.add_paragraph(f"[图像插入失败: {e}]")
|
| 54 |
+
|
| 55 |
+
def _process_node(node):
|
| 56 |
+
if node.type == "heading":
|
| 57 |
+
doc.add_heading(node.text, level=node.level)
|
| 58 |
+
for ch in node.children:
|
| 59 |
+
_process_node(ch)
|
| 60 |
+
|
| 61 |
+
elif node.type == "paragraph":
|
| 62 |
+
parts = re.split(r'(\[FIG:\d+\])', node.text)
|
| 63 |
+
|
| 64 |
+
for part in parts:
|
| 65 |
+
part = part.strip()
|
| 66 |
+
if not part:
|
| 67 |
+
continue
|
| 68 |
+
if part.startswith("[FIG:") and part.endswith("]"):
|
| 69 |
+
idx = int(part[5:-1])
|
| 70 |
+
fig_obj = None
|
| 71 |
+
if 0 <= idx < len(analysis_list):
|
| 72 |
+
entry = analysis_list[idx]
|
| 73 |
+
fig_obj = entry.get("figure")
|
| 74 |
+
_insert_figure(fig_obj)
|
| 75 |
+
else:
|
| 76 |
+
doc.add_paragraph(part)
|
| 77 |
+
|
| 78 |
+
else: # root
|
| 79 |
+
for ch in node.children:
|
| 80 |
+
_process_node(ch)
|
| 81 |
+
|
| 82 |
+
# 从 root.children 开始写
|
| 83 |
+
_process_node(report_obj.root)
|
| 84 |
+
|
| 85 |
+
buf = io.BytesIO()
|
| 86 |
+
doc.save(buf)
|
| 87 |
+
buf.seek(0)
|
| 88 |
+
report_agent.save_word(buf.getvalue())
|
| 89 |
+
st.success("Word 报告生成成功")
|
workflow/visualization/viz_coding.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import traceback
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import plotly.express as px
|
| 7 |
+
import plotly.graph_objs as go
|
| 8 |
+
import streamlit as st
|
| 9 |
+
from stqdm import stqdm
|
| 10 |
+
from streamlit_ace import st_ace
|
| 11 |
+
import streamlit_antd_components as sac
|
| 12 |
+
|
| 13 |
+
from utils.sanitize_code import sanitize_code
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def vis_code_gen(agent, debug = False, auto = False) -> None:
|
| 17 |
+
|
| 18 |
+
df = agent.load_df()
|
| 19 |
+
suggest = agent.load_suggestion()
|
| 20 |
+
user_input = agent.load_user_input()
|
| 21 |
+
|
| 22 |
+
chat_history = agent.load_memory()
|
| 23 |
+
already_generated = any(
|
| 24 |
+
entry["role"] == "assistant" and "训练脚本已更新!请重新运行代码!" in str(entry["content"])
|
| 25 |
+
for entry in chat_history
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
if suggest is not None:
|
| 29 |
+
if debug == True or (auto and not already_generated):
|
| 30 |
+
with st.spinner("可视化 Agent 正在编写脚本..."):
|
| 31 |
+
raw = agent.code_generation(
|
| 32 |
+
df.head().to_string(),
|
| 33 |
+
suggest,
|
| 34 |
+
)
|
| 35 |
+
code = sanitize_code(raw)
|
| 36 |
+
agent.save_code(code)
|
| 37 |
+
st.chat_message("assistant").write("训练脚本已更新!请重新运行代码!")
|
| 38 |
+
agent.add_memory({"role": "assistant", "content": "训练脚本已更新!请重新运行代码!"})
|
| 39 |
+
st.rerun()
|
| 40 |
+
|
| 41 |
+
analyze_btn = st.button("🔧 生成可视化代码", key="viz_code")
|
| 42 |
+
if analyze_btn:
|
| 43 |
+
with st.spinner("可视化 Agent 正在编写脚本..."):
|
| 44 |
+
raw = agent.code_generation(
|
| 45 |
+
df.head().to_string(),
|
| 46 |
+
suggest,
|
| 47 |
+
)
|
| 48 |
+
code = sanitize_code(raw)
|
| 49 |
+
agent.save_code(code)
|
| 50 |
+
st.chat_message("assistant").write("训练脚本已更新!请重新运行代码!")
|
| 51 |
+
agent.add_memory({"role": "assistant", "content": "训练脚本已更新!请重新运行代码!"})
|
| 52 |
+
st.rerun()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def vis_execution(agent, auto = False):
|
| 56 |
+
|
| 57 |
+
df = agent.load_df()
|
| 58 |
+
|
| 59 |
+
exec_ns = {
|
| 60 |
+
"df": df,
|
| 61 |
+
"np": np,
|
| 62 |
+
"pd": pd,
|
| 63 |
+
"px": px,
|
| 64 |
+
"go": go,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
code = agent.load_code()
|
| 68 |
+
edited = st_ace(
|
| 69 |
+
value=code,
|
| 70 |
+
height=450,
|
| 71 |
+
theme="tomorrow_night",
|
| 72 |
+
language="python",
|
| 73 |
+
auto_update=True
|
| 74 |
+
)
|
| 75 |
+
desc_switch = sac.switch(label='附加分析', value=False, off_label='Off')
|
| 76 |
+
if code is not None:
|
| 77 |
+
not_executed = agent.load_fig() == []
|
| 78 |
+
# 当点击按钮,或者 auto=True 且尚未执行过时才执行
|
| 79 |
+
if st.button("▶️ 执行可视化") or (auto and not_executed):
|
| 80 |
+
code = sanitize_code(edited)
|
| 81 |
+
agent.save_code(code)
|
| 82 |
+
try:
|
| 83 |
+
with st.spinner("正在运行可视化脚本..."):
|
| 84 |
+
exec(code, exec_ns)
|
| 85 |
+
except Exception as exc:
|
| 86 |
+
st.error(f"已记录报错内容,正在为您debug...")
|
| 87 |
+
st.text(traceback.format_exc())
|
| 88 |
+
agent.save_error(traceback.format_exc())
|
| 89 |
+
vis_code_gen(agent, debug=True)
|
| 90 |
+
else:
|
| 91 |
+
fig_dict = exec_ns.get("fig_dict")
|
| 92 |
+
if not fig_dict or not isinstance(fig_dict, dict):
|
| 93 |
+
st.error(
|
| 94 |
+
"脚本未写入 `fig_dict` 或格式不正确。请确保末尾赋值 `fig_dict`,且它是一个 {列名: 图表} 的 dict。"
|
| 95 |
+
)
|
| 96 |
+
agent.save_error(traceback.format_exc())
|
| 97 |
+
vis_code_gen(agent, debug=True)
|
| 98 |
+
else:
|
| 99 |
+
with st.spinner("正在处理可视化结果..."):
|
| 100 |
+
for col_name, fig in stqdm(fig_dict.items()):
|
| 101 |
+
dtype_info = ", ".join(
|
| 102 |
+
f"{c}: {df[c].dtype}" for c in df.columns
|
| 103 |
+
)
|
| 104 |
+
if desc_switch == True:
|
| 105 |
+
desc = agent.desc_fig(fig, dtype_info)
|
| 106 |
+
else:
|
| 107 |
+
desc = None
|
| 108 |
+
agent.add_fig(fig, desc)
|
| 109 |
+
agent.finish_auto()
|
| 110 |
+
st.rerun()
|
workflow/visualization/viz_color.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
PALETTES = {
|
| 4 |
+
"Classic": [
|
| 5 |
+
"#2B5C8A", "#4F81AF", "#77ACD3", "#D9D5C9", "#F69035"
|
| 6 |
+
],
|
| 7 |
+
"Ocean Breeze": [
|
| 8 |
+
"#03045E", "#0077B6", "#00B4D8", "#90E0EF", "#CAF0F8"
|
| 9 |
+
],
|
| 10 |
+
"Olive Garden Feast": [
|
| 11 |
+
"#606C38", "#283618", "#FEFAE0", "#DDA15E", "#BC6C25"
|
| 12 |
+
],
|
| 13 |
+
"Fiery Ocean": [
|
| 14 |
+
"#780000", "#C1121F", "#FDF0D5", "#003049", "#669BBC"
|
| 15 |
+
],
|
| 16 |
+
"Refreshing Summer Fun": [
|
| 17 |
+
"#8ECAE6", "#219EBC", "#023047", "#FFB703", "#FB8500"
|
| 18 |
+
],
|
| 19 |
+
"Golden Summer Fields": [
|
| 20 |
+
"#CCD5AE", "#E9EDC9", "#FEFAE0", "#FAEDCD", "#D4A373"
|
| 21 |
+
],
|
| 22 |
+
"Deep Sea": [
|
| 23 |
+
"#0D1B2A", "#1B263B", "#415A77", "#778DA9", "#E0E1DD"
|
| 24 |
+
],
|
| 25 |
+
"Bold Berry": [
|
| 26 |
+
"#F9DBBD", "#FFA5AB", "#DA627D", "#A53860", "#450920"
|
| 27 |
+
],
|
| 28 |
+
"Fresh Greens": [
|
| 29 |
+
"#D8F3DC", "#95D5B2", "#52B788", "#2D6A4F", "#1B4332"
|
| 30 |
+
],
|
| 31 |
+
"Deep Sea": [
|
| 32 |
+
"#EDEDE9", "#D6CCC2", "#F5EBE0", "#E3D5CA", "#D5BDAF"
|
| 33 |
+
],
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def vis_palette(agent):
|
| 37 |
+
|
| 38 |
+
choice = st.selectbox("请选择配色方案", list(PALETTES.keys()))
|
| 39 |
+
colors = PALETTES[choice]
|
| 40 |
+
|
| 41 |
+
cols = st.columns(len(colors))
|
| 42 |
+
for col, code in zip(cols, colors):
|
| 43 |
+
col.markdown(
|
| 44 |
+
f"""
|
| 45 |
+
<div style="
|
| 46 |
+
background-color: {code};
|
| 47 |
+
height: 30px;
|
| 48 |
+
border-radius: 4px;
|
| 49 |
+
margin-bottom: 2px;
|
| 50 |
+
"></div>
|
| 51 |
+
<div style="text-align: center; font-size: 10px;">{code}</div>
|
| 52 |
+
""",
|
| 53 |
+
unsafe_allow_html=True
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
agent.save_color(colors)
|
| 57 |
+
|
| 58 |
+
return colors
|
workflow/visualization/viz_quick_action.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import plotly.express as px
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def plot_for_option(df, option: str, column: str):
|
| 6 |
+
|
| 7 |
+
series = df[column]
|
| 8 |
+
|
| 9 |
+
if option == "直方图":
|
| 10 |
+
fig = px.histogram(df, x=column, title=f"{column} 的直方图")
|
| 11 |
+
elif option == "饼图":
|
| 12 |
+
counts = series.value_counts().reset_index()
|
| 13 |
+
counts.columns = [column, 'count']
|
| 14 |
+
fig = px.pie(counts, names=column, values='count', title=f"{column} 的饼图")
|
| 15 |
+
elif option == "折线图":
|
| 16 |
+
fig = px.line(df, y=column, title=f"{column} 的折线图")
|
| 17 |
+
elif option == "箱线图":
|
| 18 |
+
fig = px.box(df, y=column, title=f"{column} 的箱线图")
|
| 19 |
+
else:
|
| 20 |
+
st.error("未知的图表类型")
|
| 21 |
+
return
|
| 22 |
+
|
| 23 |
+
return fig
|
workflow/visualization/viz_render.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import plotly.graph_objs as go
|
| 7 |
+
import streamlit as st
|
| 8 |
+
import streamlit_antd_components as sac
|
| 9 |
+
|
| 10 |
+
from utils.sanitize_code import sanitize_code
|
| 11 |
+
from workflow.visualization.viz_suggestion import vis_button_suggest, vis_talk_suggest
|
| 12 |
+
from workflow.visualization.viz_coding import vis_execution, vis_code_gen
|
| 13 |
+
from workflow.visualization.viz_quick_action import plot_for_option
|
| 14 |
+
from workflow.visualization.viz_color import vis_palette
|
| 15 |
+
|
| 16 |
+
def vis_quick_actions(agent):
|
| 17 |
+
|
| 18 |
+
cols_list = agent.load_df().columns.tolist()
|
| 19 |
+
options = ["直方图", "饼图", "箱线图", "折线图"]
|
| 20 |
+
|
| 21 |
+
selected_col = st.selectbox("选择一个列:", cols_list)
|
| 22 |
+
|
| 23 |
+
logo_dir = r"logo\sec3"
|
| 24 |
+
logo_paths = {opt: os.path.join(logo_dir, f"{opt}.png") for opt in options}
|
| 25 |
+
|
| 26 |
+
cols = st.columns(len(options))
|
| 27 |
+
|
| 28 |
+
fig_placeholder = st.empty()
|
| 29 |
+
|
| 30 |
+
for idx, opt in enumerate(options):
|
| 31 |
+
with cols[idx]:
|
| 32 |
+
left, center, right = st.columns([1, 8, 1])
|
| 33 |
+
with center:
|
| 34 |
+
st.write(opt)
|
| 35 |
+
path = logo_paths.get(opt)
|
| 36 |
+
if path and os.path.exists(path):
|
| 37 |
+
st.image(Image.open(path), width=80)
|
| 38 |
+
else:
|
| 39 |
+
st.text("Logo 文件未找到")
|
| 40 |
+
|
| 41 |
+
if st.button("Try me", key=f"try_{idx}"):
|
| 42 |
+
fig = plot_for_option(agent.load_df(), opt, selected_col)
|
| 43 |
+
fig_placeholder.plotly_chart(fig, use_container_width=True)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def vis_result(agent) -> None:
|
| 47 |
+
|
| 48 |
+
fig_desc_list = agent.load_fig()
|
| 49 |
+
total = len(fig_desc_list)
|
| 50 |
+
PAGE_SIZE = 5
|
| 51 |
+
|
| 52 |
+
current_page = sac.pagination(
|
| 53 |
+
total=total,
|
| 54 |
+
page_size=PAGE_SIZE,
|
| 55 |
+
align='center',
|
| 56 |
+
jump=False,
|
| 57 |
+
show_total=True,
|
| 58 |
+
variant='filled',
|
| 59 |
+
color='#44658C'
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
start_idx = (current_page - 1) * PAGE_SIZE
|
| 63 |
+
end_idx = min(start_idx + PAGE_SIZE, total)
|
| 64 |
+
page_items = fig_desc_list[start_idx:end_idx]
|
| 65 |
+
|
| 66 |
+
for offset, item in enumerate(page_items):
|
| 67 |
+
|
| 68 |
+
idx = start_idx + offset
|
| 69 |
+
fig = item["fig"]
|
| 70 |
+
desc = item["desc"]
|
| 71 |
+
|
| 72 |
+
st.plotly_chart(
|
| 73 |
+
fig,
|
| 74 |
+
use_container_width=True,
|
| 75 |
+
key=f"fig_{idx}"
|
| 76 |
+
)
|
| 77 |
+
if desc is not None:
|
| 78 |
+
st.write(desc)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def vis_chat(agent, auto = False):
|
| 82 |
+
|
| 83 |
+
msg = st.chat_message("assistant")
|
| 84 |
+
msg.write(
|
| 85 |
+
"我是 Anystat 数据分析助手,很高兴为您服务!\n\n"
|
| 86 |
+
"您可以在下方对话框输入具体可视化需求,"
|
| 87 |
+
"也可以点击下面的按钮,一键获取可视化建议并绘图。"
|
| 88 |
+
)
|
| 89 |
+
analyze_clicked = msg.button("🔍 可视化推荐", key="viz_suggest")
|
| 90 |
+
reply_placeholder = msg.empty()
|
| 91 |
+
|
| 92 |
+
chat_history = agent.load_memory()
|
| 93 |
+
|
| 94 |
+
for idx, entry in enumerate(chat_history):
|
| 95 |
+
bubble = st.chat_message(entry["role"])
|
| 96 |
+
content = entry["content"]
|
| 97 |
+
if isinstance(content, str):
|
| 98 |
+
bubble.write(content)
|
| 99 |
+
elif isinstance(content, go.Figure):
|
| 100 |
+
bubble.plotly_chart(
|
| 101 |
+
content,
|
| 102 |
+
use_container_width=True,
|
| 103 |
+
key=f"chart-{idx}"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
already_generated = any(
|
| 107 |
+
entry["role"] == "assistant" and "图" in str(entry["content"])
|
| 108 |
+
for entry in chat_history
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# 按钮路径
|
| 112 |
+
if analyze_clicked or (auto and not already_generated):
|
| 113 |
+
st.chat_message("user").write("请帮我做可视化分析")
|
| 114 |
+
agent.add_memory({'role': 'user', 'content': "请帮我做可视化分析"})
|
| 115 |
+
with st.spinner("正在处理您的请求..."):
|
| 116 |
+
rec = vis_button_suggest(agent)
|
| 117 |
+
agent.save_suggestion(rec)
|
| 118 |
+
st.chat_message("assistant").write(rec)
|
| 119 |
+
agent.add_memory({"role": "assistant", "content": str(rec)})
|
| 120 |
+
|
| 121 |
+
# 对话路径
|
| 122 |
+
reply = None
|
| 123 |
+
user_input = None
|
| 124 |
+
user_input = st.chat_input("请输入需求,例如'请给我一些可视化建议'")
|
| 125 |
+
if user_input is not None:
|
| 126 |
+
st.chat_message("user").write(user_input)
|
| 127 |
+
with st.spinner("正在处理您的请求..."):
|
| 128 |
+
agent.save_user_input(user_input)
|
| 129 |
+
agent.add_memory({"role": "user", "content": user_input})
|
| 130 |
+
rec = vis_talk_suggest(agent, user_input)
|
| 131 |
+
agent.save_suggestion(rec)
|
| 132 |
+
st.chat_message("assistant").write(rec)
|
| 133 |
+
agent.add_memory({"role": "assistant", "content": str(rec)})
|
| 134 |
+
st.rerun()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
|
| 139 |
+
st.title("统计可视化分析")
|
| 140 |
+
st.markdown("---")
|
| 141 |
+
|
| 142 |
+
preproc_agent = st.session_state.data_preprocess_agent
|
| 143 |
+
load_agent = st.session_state.data_loading_agent
|
| 144 |
+
planner = st.session_state.planner_agent
|
| 145 |
+
auto = planner.vis_auto
|
| 146 |
+
|
| 147 |
+
processed_df = preproc_agent.load_processed_df()
|
| 148 |
+
if processed_df is None:
|
| 149 |
+
df = load_agent.load_df()
|
| 150 |
+
else:
|
| 151 |
+
df = processed_df
|
| 152 |
+
|
| 153 |
+
if df is None:
|
| 154 |
+
st.warning("⚠️ ���先在数据导入页面加载数据")
|
| 155 |
+
st.stop()
|
| 156 |
+
|
| 157 |
+
if isinstance(df, np.ndarray):
|
| 158 |
+
df = pd.DataFrame(df)
|
| 159 |
+
|
| 160 |
+
df_shuffled = df.sample(frac=1, random_state=42).reset_index(drop=True)
|
| 161 |
+
agent = st.session_state.visualization_agent
|
| 162 |
+
agent.add_df(df_shuffled)
|
| 163 |
+
|
| 164 |
+
if st.session_state.auto_mode == True:
|
| 165 |
+
if (agent.finish_auto_task == True and planner.switched_vis == False) or planner.vis_auto == False:
|
| 166 |
+
planner.finish_vis_auto()
|
| 167 |
+
st.switch_page("workflow/modeling/modeling_render.py")
|
| 168 |
+
|
| 169 |
+
code = agent.load_code()
|
| 170 |
+
if code is None:
|
| 171 |
+
code_expand = False
|
| 172 |
+
else:
|
| 173 |
+
code_expand = True
|
| 174 |
+
|
| 175 |
+
fig = agent.load_fig()
|
| 176 |
+
if fig is None:
|
| 177 |
+
fig_expand = False
|
| 178 |
+
else:
|
| 179 |
+
fig_expand = True
|
| 180 |
+
|
| 181 |
+
c = st.columns(2)
|
| 182 |
+
# with c[1].expander('快速可视化', False):
|
| 183 |
+
# vis_quick_actions(agent)
|
| 184 |
+
with c[0].expander('配色选择', True):
|
| 185 |
+
vis_palette(agent)
|
| 186 |
+
with c[1].expander('可视化建议', True):
|
| 187 |
+
vis_chat(agent, auto)
|
| 188 |
+
vis_code_gen(agent, auto = auto)
|
| 189 |
+
with c[0].expander('可视化执行', code_expand):
|
| 190 |
+
vis_execution(agent, auto = auto)
|
| 191 |
+
with c[0].expander('可视化结果', fig_expand):
|
| 192 |
+
vis_result(agent)
|
workflow/visualization/viz_suggestion.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def vis_button_suggest(agent):
|
| 5 |
+
"""
|
| 6 |
+
按钮路径:调用 LLM 获取结构化的可视化推荐(JSON)。
|
| 7 |
+
"""
|
| 8 |
+
df = agent.load_df()
|
| 9 |
+
cols_wo_id = agent.load_cols_wo_id()
|
| 10 |
+
|
| 11 |
+
if cols_wo_id is None:
|
| 12 |
+
cols_wo_id = [str(c) for c in df.columns if not str(c).lower().startswith(('id', 'idx', 'index'))]
|
| 13 |
+
agent.save_cols_wo_id(cols_wo_id)
|
| 14 |
+
|
| 15 |
+
rec = agent.get_visualization_recommendations(cols_wo_id)
|
| 16 |
+
|
| 17 |
+
agent.save_recommendations(rec)
|
| 18 |
+
agent.refine_suggestions(rec)
|
| 19 |
+
|
| 20 |
+
return rec
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def vis_talk_suggest(agent, user_input):
|
| 24 |
+
"""
|
| 25 |
+
对话路径:根据对话获取建议
|
| 26 |
+
"""
|
| 27 |
+
df = agent.load_df()
|
| 28 |
+
cols_wo_id = agent.load_cols_wo_id()
|
| 29 |
+
|
| 30 |
+
if cols_wo_id is None:
|
| 31 |
+
cols_wo_id = [c for c in df.columns if not c.lower().startswith(('id', '编号', '序号', 'index'))]
|
| 32 |
+
agent.save_cols_wo_id(cols_wo_id)
|
| 33 |
+
|
| 34 |
+
rec = agent.get_visualization_recommendations(cols_wo_id, user_input)
|
| 35 |
+
agent.save_recommendations(rec)
|
| 36 |
+
agent.refine_suggestions(rec)
|
| 37 |
+
|
| 38 |
+
return rec
|