diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..833cdec748eafe9807e9301755d636768993fecc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +logo/logo_16_9.png filter=lfs diff=lfs merge=lfs -text +logo/logo_big.png filter=lfs diff=lfs merge=lfs -text +logo/logo_blue_wide.png filter=lfs diff=lfs merge=lfs -text +logo/logo_wide.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ca4928df04532cbbc7b2b2e8661be837567950d0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +utils/.streamlit/ +__pycache__/ +*.py[cod] +*.pyo \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..48e0ab957af70d43817963026d9f69c05ad9777c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,38 @@ + +# FROM python:3.12-slim + +# # ========= 设置工作目录 ========= +# WORKDIR /app + +# # ========= 安装依赖 ========= +# COPY ./requirements.txt ./requirements.txt +# RUN pip install --no-cache-dir --upgrade pip setuptools wheel \ +# && pip install --no-cache-dir -r requirements.txt + +# # ========= 拷贝项目文件 ========= +# COPY . /app + +# # ========= 暴露 Streamlit 端口 ========= +# EXPOSE 8501 + +# # ========= 健康检查 ========= +# HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health || exit 1 + +# # ========= 启动命令 ========= +# CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"] +# ========= 基础镜像 ========= +FROM python:3.12-slim + +# ========= 改变工作目录 ========= +WORKDIR /tmp + +# ========= 安装依赖 ========= +COPY ./requirements.txt ./requirements.txt +RUN pip install --no-cache-dir -r requirements.txt + +# ========= 拷贝项目文件 ========= +COPY . /tmp + +# ========= 暴露端口并启动 ========= +EXPOSE 8501 +CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"] diff --git a/README.md b/README.md index 3cb8dcc142a42daf657d3e1c534bebd18d8e9660..f163b832dc74cbcd6978a9e05d5e2107e1830743 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,126 @@ --- -title: Autostat -emoji: 🏢 +title: Anystat +emoji: 🚀 colorFrom: red -colorTo: gray -sdk: gradio -sdk_version: 5.49.1 -app_file: app.py +colorTo: red +sdk: docker +app_port: 8501 +tags: +- streamlit pinned: false +short_description: testtesttest --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + +
+ + Anystat: Statistical Analysis, Instantly. + + +
+ + + + +Anystat,致力于成为用户数据分析的 copilot 。 + +我们正在寻求一个入门友好、覆盖数据分析端到端流程、可通过与用户多轮交互持续优化效果,并具备承载未来五年 LLM 技术迭代能力的数据分析 Agent 框架,助你高效推进每一步分析任务。 + +## News + +- (待添加) + +## 功能特点 + +- **全流程覆盖,模块化重构数据分析。** Anystat 覆盖导入、预处理、可视化、建模与报告生成五个流程。针对每一流程内任务的采用模块化设计,专职 Agent 负责,将 Agent 的能力无缝融入数据分析。 +- **编写代码,释放数据分析潜能。** Coding 兼容工具调用与自主开发。 Agent 不仅能精准理解用户需求,灵活调用现有工具,还可根据需求自主编写新工具,兼顾稳定性与灵活性,承载未来模型能力的溢出。 +- **自动模式,让AI主导数据分析。** 面向小白用户,简单上手操作。只需上传数据,剩下交给 Agent 负责。内置 Planning Agent 自动分解任务、智能分工。一键实现高质量数据分析报告。 +- **专业报告,一键生成完整分析。** 多智能体协作自动生成初步目录,用户可灵活调整。 Report Agent 基于最终目录,从概要到细节一键输出图文并茂的专业级数据分析报告。 + +## 快速开始 + +### 从Github开始 + + > 请确保您的计算机上已安装了 Python3.9 及以上的版本,推荐 Python 版本在 3.11 及以上以获得更好体验。 + > 支持 Windows/MacOS/Linux 环境。 + + 1. **克隆项目到本地** + + ```bash + gh repo clone ElvisWang1111/AAAAAnystat + cd (to working directory of Anystat) + ``` + + 2. **环境配置** + + ```bash + conda create --name anystat + conda activate anystat + ``` + + 3. **安装依赖**: + + ```bash + pip install -r requirements.txt + pip install playwright + playwright install + ``` + + 4. **启动应用** + + ```bash + streamlit run app.py + ``` + +### 通过发行包安装(Windows) + + 程序链接:待加 + +### 通过脚本安装(Mac) + +#### 预先准备 + + 请先下载 *Anaconda* 和 *Miniconda*,用于创建独立的 Python 环境。下载请访问[Anaconda官网](https://www.anaconda.com/download)。 + +#### 一键配置 + + 打开 Anystat 程序所在目录,在该目录下打开终端(命令行),执行以下命令: + + ```bash + bash setup.sh + ``` + + 完成后将输出启动提示,执行 + + ```bash + conda activate anystat_env + streamlit run app.py + ``` + + 即可访问 Anystat Agent。 + +### 直接访问 Web 端服务器资源 + + 点击[Anystat Web](https://modelscope.cn/studios/boyuanwang/teststat/summary)以直接使用 Anystat。 + +> 更详细的教程详见[**Anystat Doc**](https://elviswang1111.github.io/anystatweb.github.io/index.html)。 + +## 相关链接 + +1. [Anystat Doc](https://elviswang1111.github.io/anystatweb.github.io/index.html) + +2. API key 获取网址: + - [Deepseek](https://platform.deepseek.com/api_keys) + - [ChatGPT](https://platform.openai.com/docs/overview) + - [通义千问](https://bailian.console.aliyun.com/?spm=5176.29597918.J_SEsSjsNv72yRuRFS2VknO.2.54d87b08CphuY5&tab=api#/api) + - [智谱 AI](https://docs.bigmodel.cn/cn/guide/develop/http/introduction) + +## 许可 + +本项目基于 MIT 许可证开源,详见 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..ef24edc636aaf6361f4faa27e6e1393171bd4778 --- /dev/null +++ b/app.py @@ -0,0 +1,215 @@ +import sys, os +import tempfile +import streamlit as st + +from config import MODEL_CONFIGS +from utils.save_secrets import * +from prompt_engineer.sec1_call_llm import DataLoadingAgent +from prompt_engineer.sec2_call_llm import DataPreprocessAgent +from prompt_engineer.sec3_call_llm import VisualizationAgent +from prompt_engineer.sec4_call_llm import ModelingCodingAgent +from prompt_engineer.sec5_call_llm import ReportAgent +from prompt_engineer.planner import PlannerAgent + +import warnings +warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", message="missing ScriptRunContext") + +import numpy as np +np.set_printoptions(edgeitems=250, threshold=501) + +sys.path.append(os.path.dirname(__file__)) + +CACHE_FILE = os.path.join(tempfile.gettempdir(), "anystat_cache.pkl") +CACHE_DIR = './cache' +SECRETS_PATH = Path(".streamlit") / "secrets.toml" + + +# 设置页面配置 +st.set_page_config( + page_title="AnyStat", + page_icon="🤖", + layout="wide" +) + + +def init_session_state(): + + if 'selected_model' not in st.session_state: + st.session_state.selected_model = "DeepSeek" + if "api_keys" not in st.session_state: + st.session_state.api_keys = load_local_api_keys() + if 'auto_mode' not in st.session_state: + st.session_state.auto_mode = False + + if 'loading_start_time' not in st.session_state: + st.session_state.loading_start_time = None + if 'prep_start_time' not in st.session_state: + st.session_state.prep_start_time = None + if 'vis_start_time' not in st.session_state: + st.session_state.vis_start_time = None + if 'modeling_start_time' not in st.session_state: + st.session_state.modeling_start_time = None + if 'report_start_time' not in st.session_state: + st.session_state.report_start_time = None + + if 'data_loading_agent' not in st.session_state: + st.session_state.data_loading_agent = DataLoadingAgent( + api_keys=st.session_state.api_keys, + model_configs=MODEL_CONFIGS, + model=st.session_state.selected_model + ) + if 'data_preprocess_agent' not in st.session_state: + st.session_state.data_preprocess_agent = DataPreprocessAgent( + api_keys=st.session_state.api_keys, + model_configs=MODEL_CONFIGS, + model=st.session_state.selected_model + ) + if 'visualization_agent' not in st.session_state: + st.session_state.visualization_agent = VisualizationAgent( + api_keys=st.session_state.api_keys, + model_configs=MODEL_CONFIGS, + model=st.session_state.selected_model + ) + if 'modeling_coding_agent' not in st.session_state: + st.session_state.modeling_coding_agent = ModelingCodingAgent( + api_keys=st.session_state.api_keys, + model_configs=MODEL_CONFIGS, + model=st.session_state.selected_model + ) + if 'report_agent' not in st.session_state: + st.session_state.report_agent = ReportAgent( + api_keys=st.session_state.api_keys, + model_configs=MODEL_CONFIGS, + model=st.session_state.selected_model + ) + if 'planner_agent' not in st.session_state: + st.session_state.planner_agent = PlannerAgent( + api_keys=st.session_state.api_keys, + model_configs=MODEL_CONFIGS, + model=st.session_state.selected_model + ) + + +def on_model_selector_change(): + """ + Callback when the model selector in the sidebar changes. + """ + st.session_state.selected_model = st.session_state.model_selector + + +def run_app(): + """ + Main entry point to render the Streamlit app. + """ + init_session_state() + with st.sidebar: + st.subheader("选择大模型") + models = list(MODEL_CONFIGS.keys()) + st.selectbox( + "选择要使用的大模型", + models, + index=models.index(st.session_state.selected_model), + key="model_selector", + on_change=on_model_selector_change, + ) + + st.subheader("API 密钥设置") + selected = st.session_state.selected_model + + api_key_input = st.text_input( + f"{selected} API 密钥", + value=st.session_state.api_keys.get(selected, ""), + type="password", + key="api_key_input", + ) + + + if st.button("💾 保存密钥", use_container_width=True, key="save_key"): + # 保存在 utils/.streamlit/secrets.toml + update_local_api_key(selected, api_key_input) + + st.session_state.api_keys[selected] = api_key_input + st.success("已保存") + st.rerun() + + if st.button("🧹 清空数据", use_container_width=True, key="clear_data"): + + st.session_state.data_loading_agent = DataLoadingAgent( + api_keys=st.session_state.api_keys, + model_configs=MODEL_CONFIGS, + model=st.session_state.selected_model + ) + st.session_state.data_preprocess_agent = DataPreprocessAgent( + api_keys=st.session_state.api_keys, + model_configs=MODEL_CONFIGS, + model=st.session_state.selected_model + ) + st.session_state.visualization_agent = VisualizationAgent( + api_keys=st.session_state.api_keys, + model_configs=MODEL_CONFIGS, + model=st.session_state.selected_model + ) + st.session_state.modeling_coding_agent = ModelingCodingAgent( + api_keys=st.session_state.api_keys, + model_configs=MODEL_CONFIGS, + model=st.session_state.selected_model + ) + st.session_state.report_agent = ReportAgent( + api_keys=st.session_state.api_keys, + model_configs=MODEL_CONFIGS, + model=st.session_state.selected_model + ) + st.session_state.planner_agent = PlannerAgent( + api_keys=st.session_state.api_keys, + model_configs=MODEL_CONFIGS, + model=st.session_state.selected_model + ) + st.session_state.auto_mode = False + st.rerun() + + if st.session_state.data_loading_agent.load_df() is not None: + planner = st.session_state.planner_agent + if st.button("🚗 自动模式", use_container_width=True, key="self_driving"): + planner.self_driving(st.session_state.data_loading_agent.load_df()) + st.session_state.auto_mode = True + st.rerun() + + st.image( + "logo/logo_big.png", + use_container_width=True + ) + + # Define pages + data_loading = st.Page( + "workflow/dataloading/dataloading_render.py", + title="📥 数据导入", + ) + preprocessing = st.Page( + "workflow/preprocessing/preprocessing_render.py", + title="⚙️ 数据预处理", + ) + visualization = st.Page( + "workflow/visualization/viz_render.py", + title="📊 数据可视化", + ) + report = st.Page( + "workflow/report/report_render.py", + title="📝 报告生成", + ) + coding_modeling = st.Page( + "workflow/modeling/modeling_render.py", + title="🧠 建模分析", + ) + # Navigation + pg = st.navigation( + { + "设置": [data_loading, preprocessing], + "功能": [visualization, coding_modeling, report], + } + ) + pg.run() + +if __name__ == "__main__": + run_app() + diff --git a/best_model.joblib b/best_model.joblib new file mode 100644 index 0000000000000000000000000000000000000000..17c1ac058fa94cb968a79f67470afc74febe0ad3 --- /dev/null +++ b/best_model.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da59e1539546964f9e34dbf3b5bc61a14f2ba5c4eceb24a08758ab26fdb4cf37 +size 30632 diff --git a/components/__init__.py b/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..999f2a2f48450d8216ff4c01b52eb585de7d30e7 --- /dev/null +++ b/config.py @@ -0,0 +1,35 @@ +# 大模型配置 +MODEL_CONFIGS = { + + ######################################################################################### + #注意:ChatGPT和Claude的模型请求url非官方,后续需将API-key和"api_base"替换为正版key和官方地址# + ######################################################################################### + "GPT-4o": { + "api_base": "https://turingai.plus/v1", + "model_name": "gpt-4o", + }, + "GPT-5": { + "api_base": "https://turingai.plus/v1", + "model_name": "gpt-5", + }, + "Claude": { + "api_base": "https://turingai.plus/v1", + "model_name": "claude-sonnet-4-5-20250929", + }, + "通义千问": { + "api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "model_name": "qwen-max", + }, + "DeepSeek": { + "api_base": "https://api.deepseek.com/v1", + "model_name": "deepseek-chat", + }, + "智谱AI": { + "api_base": "https://open.bigmodel.cn/api/paas/v4/chat/completions", + "model_name": "glm-4v-plus-0111", + }, + "豆包": { + "api_base": "https://ark.cn-beijing.volces.com/api/v3/", + "model_name": "doubao-seed-1-6-251015", + } +} \ No newline at end of file diff --git a/doc.zip b/doc.zip new file mode 100644 index 0000000000000000000000000000000000000000..d446061eff6d7359d7f4daea0c0ec740c438aa44 --- /dev/null +++ b/doc.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac7c135f6e1234166738832a3a6b0286721cf296083ce9c3397869289dd41713 +size 1851148 diff --git a/logo/logo_16_9.png b/logo/logo_16_9.png new file mode 100644 index 0000000000000000000000000000000000000000..b8fcb6a2b8f7bc7a6ddfe6be80af43c25b9d4900 --- /dev/null +++ b/logo/logo_16_9.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9d657c0e416eed4a69ac6da7b7a271239c26e185a0cd778f2072b24db594cf6 +size 327466 diff --git a/logo/logo_big.png b/logo/logo_big.png new file mode 100644 index 0000000000000000000000000000000000000000..6a758c9ffc9dbfb267d967fdf74ee99d4625b9f8 --- /dev/null +++ b/logo/logo_big.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:389ffb2d5eec47539b6aee2ef89d4949d3bfa2d94d16c2d7198bd7ef394beb59 +size 325971 diff --git a/logo/logo_blue_wide.png b/logo/logo_blue_wide.png new file mode 100644 index 0000000000000000000000000000000000000000..c2021b22e9a08a2813c1662c4146bc9555eba82b --- /dev/null +++ b/logo/logo_blue_wide.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b28da164d5e95630bb79aea3e35fc71bfd9f556520c067194d0935bc918f036 +size 783752 diff --git a/logo/logo_wide.png b/logo/logo_wide.png new file mode 100644 index 0000000000000000000000000000000000000000..d7ea45169d24501776108be018bd36edbd92864e --- /dev/null +++ b/logo/logo_wide.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af377c54885f9d9fddaf3632f9cdca087e931a9da509a7dc9908eed858057f0a +size 325496 diff --git "a/logo/sec3/\346\212\230\347\272\277\345\233\276.png" "b/logo/sec3/\346\212\230\347\272\277\345\233\276.png" new file mode 100644 index 0000000000000000000000000000000000000000..b90402f59b2d78126098b5e4b01da17d73fa9a41 Binary files /dev/null and "b/logo/sec3/\346\212\230\347\272\277\345\233\276.png" differ diff --git "a/logo/sec3/\347\233\264\346\226\271\345\233\276.png" "b/logo/sec3/\347\233\264\346\226\271\345\233\276.png" new file mode 100644 index 0000000000000000000000000000000000000000..80cdfa1e189683d06ef11048c0bd50f9473a924c Binary files /dev/null and "b/logo/sec3/\347\233\264\346\226\271\345\233\276.png" differ diff --git "a/logo/sec3/\347\256\261\347\272\277\345\233\276.png" "b/logo/sec3/\347\256\261\347\272\277\345\233\276.png" new file mode 100644 index 0000000000000000000000000000000000000000..4c535726863b8b2a59d101acabf4d6467601061d Binary files /dev/null and "b/logo/sec3/\347\256\261\347\272\277\345\233\276.png" differ diff --git "a/logo/sec3/\351\245\274\345\233\276.png" "b/logo/sec3/\351\245\274\345\233\276.png" new file mode 100644 index 0000000000000000000000000000000000000000..747db7cdf81fff69e10a30dd0ff9c94b1c849570 Binary files /dev/null and "b/logo/sec3/\351\245\274\345\233\276.png" differ diff --git a/model.pkl.gz b/model.pkl.gz new file mode 100644 index 0000000000000000000000000000000000000000..a6a0995537cd040554d8efb179a0de92d7abd354 --- /dev/null +++ b/model.pkl.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee8eeed067caf9c95c4a171fc261a6b094aff415a9414df922d461f1923e8c04 +size 1409879 diff --git a/prompt_engineer/.DS_Store b/prompt_engineer/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/prompt_engineer/.DS_Store differ diff --git a/prompt_engineer/call_llm.py b/prompt_engineer/call_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..12436501fe75cc9d33e7eb530202ae90d0f56980 --- /dev/null +++ b/prompt_engineer/call_llm.py @@ -0,0 +1,144 @@ +import re +from openai import OpenAI, OpenAIError +from anthropic import Anthropic, AnthropicError +import requests +import json + +import streamlit as st +import pandas as pd +import numpy as np +from config import MODEL_CONFIGS +from typing import IO, List, Dict +from zai import ZhipuAiClient + +class LLMClient: + def __init__(self, model_configs: dict, api_keys: dict, model: str): + + self.model = model + self.model_configs = model_configs + self.api_keys = api_keys + self.memory = [] + self.df = None + + def call(self, prompt) -> str: + + model_name = st.session_state.selected_model + config = self.model_configs.get(model_name, {}) + api_key = self.api_keys.get(model_name) + + if not api_key: + return "请先在设置中配置 API 密钥" + + system_msg = ( + "你是一个专业的数据分析助手。" + ) + + try: + if model_name == "GPT-4o" or model_name == "GPT-5" or model_name == "DeepSeek" or model_name == "通义千问" or model_name == "Claude" or model_name == "豆包": + try: + client = OpenAI( + api_key=api_key, + base_url=config["api_base"] + ) + + # 使用新的 API 调用方式 + resp = client.chat.completions.create( + model=config["model_name"], + messages=[ + {"role": "system", "content": system_msg}, + {"role": "user", "content": prompt}, + ], + stream = False + ) + return resp.choices[0].message.content + + except OpenAIError as e: + # 这里可以捕获所有OpenAI SDK定义的错误 + st.error(f"API调用失败: {str(e)}") + # 记录日志或提示用户 + return "调用失败,请检查密钥或网络" + except Exception as e: + # 捕获其他非预期的异常,如网络问题 + st.error(f"发生未知错误: {str(e)}") + return "发生未知错误" + + elif model_name == "智谱AI": + client = ZhipuAiClient(api_key=api_key) + response = client.chat.completions.create( + model=config["model_name"], + messages=[{"role": "system", "content": "你是一个专业的数据分析助手。"}, + {"role": "user", "content": prompt}], + thinking={ + "type":"enabled" + } + ) + if response: + print(response.choices[0].message) + desc = response.choices[0].message.content if hasattr(response.choices[0].message, "content") else str(response.choices[0].message) + return desc.replace("<|begin_of_box|>", "").replace("<|end_of_box|>", "").strip() + + st.error(f"智谱调用失败:{response.text}") + return "调用失败,请检查密钥或网络" + + + # elif model_name == "DeepSeek": + # client = OpenAI( + # api_key=api_key, + # base_url=config["api_base"]) + + # resp = client.chat.completions.create( + # model=config["model_name"], + # messages=[ + # {"role": "system", "content": system_msg}, + # {"role": "user", "content": prompt}, + # ], + # stream=False + # ) + # if resp: + # return resp.choices[0].message.content + # st.error(f"DeepSeek调用失败:{resp.text}") + # return "调用失败,请检查密钥或网络" + + else: + return f"暂不支持模型:{model_name}" + + except Exception as e: + st.error(f"{model_name} 调用异常:{e}") + return "大模型调用失败,请检查 API 密钥或网络连接" + + + def add_memory(self, entry: Dict[str, str]) -> None: + + self.memory.append(entry) + + + def load_memory(self) -> List[Dict[str, str]]: + + return self.memory + + + def clear_memory(self) -> None: + + self.memory.clear() + + + def add_df(self, input_df) -> None: + + + + self.df = input_df + + + def load_df(self) -> pd.DataFrame: + + return self.df + + + def clear_df(self) -> None: + + self.df = None + + + def has_df(self) -> bool: + + return self.df == None \ No newline at end of file diff --git a/prompt_engineer/planner.py b/prompt_engineer/planner.py new file mode 100644 index 0000000000000000000000000000000000000000..e24fad37caead35d7ef738c3a820dfd93431680d --- /dev/null +++ b/prompt_engineer/planner.py @@ -0,0 +1,177 @@ +import re +import json + +import streamlit as st +from typing import IO, List + +from prompt_engineer.call_llm import LLMClient + + +class PlannerAgent(LLMClient): + + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) + self.loading_auto = False + self.prep_auto = False + self.vis_auto = False + self.modeling_auto = False + self.report_auto = False + + self.switched_loading = False + self.switched_prep = False + self.switched_vis = False + self.switched_modeling = False + self.switched_report = False + + def self_driving(self, df, user_input=None) -> str: + + prompt = ( + f"下面是一个数据集的基本信息,请你根据它和用户的需求,判断需要开启哪些分析步骤:\n\n" + f"- 数据维度:{df.shape[0]} 行 × {df.shape[1]} 列\n" + f"- 列名和数据类型:{dict(zip(df.columns.tolist(), df.dtypes.astype(str).tolist()))}\n" + f"- 前 5 行样本:\n{df.head().to_dict(orient='list')}\n\n" + ) + + if user_input: + prompt += f"用户的具体需求是:“{user_input}”。\n\n" + + prompt += """ + 你需要在以下 5 个步骤中,对每个步骤分别判断是否应该开启(True / False): + 1. loading_auto —— 是否需要对数据列名进行初步分析? + 2. prep_auto —— 是否需要做数据预处理或清洗? + 3. vis_auto —— 是否需要做数据可视化? + 4. modeling_auto —— 是否需要建模或统计分析? + 5. report_auto —— 是否需要生成分析报告? + + 必须以 **JSON 格式** 输出你的判断结果,如: + { + "loading_auto": true, + "prep_auto": false, + "vis_auto": true, + "modeling_auto": true, + "report_auto": true + } + + 不要输出其他内容。 + """ + + plan_text = self.call(prompt) + print(plan_text) + try: + plan_dict = json.loads(plan_text) + except json.JSONDecodeError: + plan_text_fixed = plan_text.strip().strip('```json').strip('```') + plan_dict = json.loads(plan_text_fixed) + + print(plan_dict) + self.loading_auto = bool(plan_dict.get("loading_auto", False)) + self.prep_auto = bool(plan_dict.get("prep_auto", False)) + self.vis_auto = bool(plan_dict.get("vis_auto", False)) + self.modeling_auto = bool(plan_dict.get("modeling_auto", False)) + # self.modeling_auto = False + self.report_auto = bool(plan_dict.get("report_auto", False)) + + + def finish_loading_auto(self) -> str: + + self.switched_loading = True + + + def finish_prep_auto(self) -> str: + + self.switched_prep = True + + + def finish_vis_auto(self) -> str: + + self.switched_vis = True + + + def finish_modeling_auto(self) -> str: + + self.switched_modeling = True + + + def finish_report_auto(self) -> str: + + self.switched_report = True + + +import json +import ast +import re +import traceback + +def _extract_first_json(text: str): + """从 text 中提取第一个顶层花括号 JSON 子串(用配对计数法),找不到则返回 None。""" + if not text: + return None + start = text.find('{') + if start == -1: + return None + depth = 0 + for i in range(start, len(text)): + ch = text[i] + if ch == '{': + depth += 1 + elif ch == '}': + depth -= 1 + if depth == 0: + return text[start:i+1] + return None + +def _safe_parse_json(text: str): + """ + 尝试多种策略解析 LLM 输出为 dict: + 1) 直接 json.loads + 2) 去除 Markdown code fence 后再 loads + 3) 提取第一个完整花括号块后 loads + 4) ast.literal_eval 作为最后手段(接受 Python dict 风格) + 返回 (dict_or_None, used_text, error_message_or_None) + """ + if not text or not text.strip(): + return None, text, "empty" + # 1) 直接尝试 + try: + return json.loads(text), text, None + except Exception as e1: + pass + + # 2) 去掉 ```json / ``` fence + try: + cleaned = re.sub(r'```json\s*', '', text, flags=re.IGNORECASE) + cleaned = re.sub(r'```', '', cleaned) + cleaned = cleaned.strip() + return json.loads(cleaned), cleaned, None + except Exception: + pass + + # 3) 提取首个匹配的 { ... } 顶层块 + try: + sub = _extract_first_json(text) + if sub: + return json.loads(sub), sub, None + except Exception: + pass + + # 4) ast.literal_eval 兼容 Python 字典格式(单引号等) + try: + literal = ast.literal_eval(text) + if isinstance(literal, dict): + return literal, text, None + except Exception: + pass + + # 5) 再次尝试在提取的子串上用 literal_eval(防止单引号) + try: + sub = _extract_first_json(text) + if sub: + literal = ast.literal_eval(sub) + if isinstance(literal, dict): + return literal, sub, None + except Exception: + pass + + # 最后,返回 None 并带上错误信息 + return None, text, "unable_to_parse" \ No newline at end of file diff --git a/prompt_engineer/sec1_call_llm.py b/prompt_engineer/sec1_call_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a145c33f647f114a86fbd4cef24054e2cbb0ef --- /dev/null +++ b/prompt_engineer/sec1_call_llm.py @@ -0,0 +1,248 @@ +import re + +import streamlit as st +from typing import IO, List + +from prompt_engineer.call_llm import LLMClient + + +class DataLoadingAgent(LLMClient): + + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) + self.file_name = [] + self.user_input = None + self.par_content = "" + self.dfs = None + self.abstract=None + self.full = None + self.finish_auto_task = False + + + def finish_auto(self): + + self.finish_auto_task = True + + + def save_file_name(self, file_name): + + self.file_name.append(file_name) + + + def load_file_name(self): + + return self.file_name + + + def save_dfs(self, dfs): + + self.dfs = (dfs) + + + def load_dfs(self): + + return self.dfs + + + def clear_file_name(self): + + self.file_name = [] + + + def read_names_from_file(self, uploaded_names_file, df_head): + """ + 从上传的 .names/.arff 文件中提取属性名。 + 优先使用 LLM 识别 @attribute 行中的属性名;如果 LLM 调用失败,退回到正则解析。 + """ + + raw = uploaded_names_file.read().decode('utf-8', errors='ignore') + try: + uploaded_names_file.seek(0) + except Exception: + pass + + prompt = ( + "下面是上传的 names 和 df_head 文件内容,请仅以 Python 列表格式返回与df_head一一对应的所有属性(attribute)名称," + "并保持顺序,不要添加多余文字,请注意,你只需要返回一个列表,不要出现任何markdown语法:\n```\n" + f"name文件:{raw}\n```" + f"df_head:{df_head}\n```" + ) + try: + response = self.call(prompt) + names_list = eval(response.strip()) + if isinstance(names_list, list) and all(isinstance(n, str) for n in names_list): + col_names = names_list + else: + raise ValueError("LLM 输出格式不正确") + except Exception: + + col_names = [] + attr_re = re.compile( + r"""^@attribute\s+ + ['"]?([^'"\s]+)['"]? + \s+.+ + """, + re.IGNORECASE | re.VERBOSE + ) + for line in raw.splitlines(): + line = line.strip() + if not line: + continue + if line.lower().startswith('@data'): + break + m = attr_re.match(line) + if m: + col_names.append(m.group(1)) + + counts: dict[str, int] = {} + unique_names: List[str] = [] + for name in col_names: + if name in counts: + counts[name] += 1 + unique_names.append(f"{name}_{counts[name]}") + else: + counts[name] = 0 + unique_names.append(name) + + return unique_names + + + def do_data_description(self, df, user_input=None, memory_limit=6): + + recent_memory = self.memory[-memory_limit:] if self.memory else [] + if recent_memory: + formatted_memory = "\n".join( + f"{m['role']}: {m['content']}" for m in recent_memory + ) + memory_block = f"{formatted_memory}" + else: + memory_block = "" + + prompt = ( + "你是一名专业的数据分析助手,负责解释数据结构与业务含义。\n" + f"- 数据维度:{df.shape[0]} 行 × {df.shape[1]} 列\n" + f"- 列名和数据类型:{dict(zip(df.columns.tolist(), df.dtypes.astype(str).tolist()))}\n" + f"- 前 5 行样本:\n{df.head().to_dict(orient='list')}\n\n" + f"""- 数据解释聊天对话: + --- 开始聊天记录 --- + {memory_block} + --- 结束聊天记录 ---""" + ) + + if user_input is not None: + prompt += f""" + 请严格依据用户需求“{user_input}”,对当前数据进行深入、系统的分析。 + 要求: + 1. 分析内容必须与该需求完全对应,不能添加无关推断。 + 2. 结论要具体、清晰,可直接支持后续报告撰写或建模步骤。 + 3. 分析语言应专业、简洁,不使用模糊或情绪化表述。 + """ + else: + prompt += """ + 以下是一个数据集的基本概览。请帮助我分析它的性质和结构,并回答以下问题: + + 1. 该数据集可能来源于什么业务或研究场景? + 2. 各主要字段分别代表什么含义?若能判断,请说明其单位或数值含义。 + 3. 数据中是否存在明显异常、异常分布或需要注意的特征? + + 输出要求: + - 使用自然、流畅的中文描述; + - 采用清晰的分条结构(1、2、3); + - 语言客观简洁,不使用“可能”“也许”“似乎”等模糊词; + - 重点突出数据结构、含义与潜在问题。 + """ + + desc = self.call(prompt) + + return desc + + + def summary_html(self): + + df = self.load_df() + df_head = df.head() + dtype_info = df.dtypes.astype(str) + + prompt = f""" + 你正在撰写一份数据分析报告的第一章——《数据概览与数据含义分析》。 + 请根据以下输入内容,整理关键信息并进行分析说明: + 数据格式: + {dtype_info} + + 前五行数据: + {df_head} + + 数据解释聊天对话: + --- 开始聊天记录 --- + {self.memory} + --- 结束聊天记录 --- + + 额外要求: + 1. 要用流畅的自然语言 + 2. 不要滥用形容词和副词,尽量用简单的动词和名词表达意思 + 3. 不用"可能""也许""似乎""微妙"等模糊表述 + """.strip() + + desc = self.call(prompt) + + summary = { + "title": "数据导入", + "df": df_head, + "desc": desc, + } + + return summary + + + def summary_word(self): + + return self.summary_html() + + + def check_abstract(self): + + if self.abstract is None: + df = self.load_df() + df_head = df.head() + dtype_info = df.dtypes.astype(str) + + prompt = f""" + 这是数据分析的数据导入阶段 + 数据格式: + {dtype_info} + + 前五行数据: + {df_head} + + 数据解释聊天对话: + --- 开始聊天记录 --- + {self.memory} + --- 结束聊天记录 --- + + 要求: + 请基于上述数据与对话内容,生成一段简洁、准确的综合摘要。 + 摘要需完整呈现核心信息,便于后续自动判断该内容在报告撰写中是否需要被引用。 + """.strip() + + desc = self.call(prompt) + self.abstract = desc + + return self.abstract + + + def check_full(self): + + if self.full is None: + df = self.load_df() + df_head = df.head() + dtype_info = df.dtypes.astype(str) + + self.full = ( + f"【阶段说明】这是数据分析流程中的数据导入阶段。\n" + f"【数据格式】{dtype_info}\n" + f"【样本预览】\n{df_head}\n" + f"【分析对话记录】\n{self.memory}" + ) + + return self.full diff --git a/prompt_engineer/sec2_call_llm.py b/prompt_engineer/sec2_call_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..862e6caec24c7dd9c4b0d6b0e78d4b1287ee0e0c --- /dev/null +++ b/prompt_engineer/sec2_call_llm.py @@ -0,0 +1,374 @@ +import numpy as np +import pandas as pd + +from prompt_engineer.call_llm import LLMClient + +class DataPreprocessAgent(LLMClient): + + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) + self.processed_df = None + self.code = None + self.preprocessing_suggestions = None + self.allowed_libs = [ + "numpy", + "pandas", + "sklearn.impute", + "sklearn.preprocessing", + "sklearn.compose", + "sklearn.pipeline" + ] + self.par_content = "" + self.error = None + self.user_input = None + self.refined_suggestions = "" + self.abstract=None + self.full = None + self.finish_auto_task = False + self.debug_num = 0 + + + def finish_auto(self): + + self.finish_auto_task = True + + + def save_code(self, code): + + self.code = code + + + def load_code(self): + + return self.code + + + def save_user_input(self, user_input): + + self.user_input = user_input + + + def load_user_input(self): + + return self.user_input + + + def save_error(self, error): + + self.error = error + + + def load_error(self): + + return self.error + + + def save_preprocessing_suggestions(self, suggestions): + + self.preprocessing_suggestions = suggestions + + + def load_preprocessing_suggestions(self): + + return self.preprocessing_suggestions + + + def save_processed_df(self, processed_df): + + if not isinstance(processed_df, pd.DataFrame): + if isinstance(processed_df, np.ndarray): + processed_df = pd.DataFrame(processed_df) + else: + raise TypeError(f"期望 pandas.DataFrame 或 numpy.ndarray,收到 {type(processed_df)}") + + self.processed_df = processed_df + + + def load_processed_df(self): + + return self.processed_df + + + def load_refined_suggestions(self): + return self.refined_suggestions + + + def save_refined_suggestions(self, refined_suggestions): + self.refined_suggestions = refined_suggestions + + + def refine_suggestions(self, df_head): + """将 LLM 返回的预处理推荐进行信息提取""" + + suggestion = self.load_preprocessing_suggestions() + + prompt = f""" + 请根据以下预处理建议,概括数据集中每一列的推荐预处理方法。 + + 数据示例: + {df_head} + + 详细预处理建议: + {suggestion} + + 输出要求(必须严格遵守): + 1. 输出格式:列名:推荐预处理方法;每条独立换行。 + 2. 每列最多给出三个推荐方法,多个方法用逗号分隔。 + 3. 输出必须为纯文本,不使用任何 Markdown 标记。 + 4. 每个方法的长度不得超过20个汉字,若包含英文则不超过10个单词。""" + + refined_suggestions = self.call(prompt) + self.refined_suggestions = refined_suggestions + + return refined_suggestions + + + def get_preprocessing_suggestions( + self, + user_input=None, + memory_limit=6, + ): + + df = self.load_df() + + # 基本统计 + n_rows, n_cols = df.shape + dtype_counts = df.dtypes.value_counts().to_dict() + missing_total = int(df.isnull().sum().sum()) + missing_by_col = df.isnull().mean().mul(100).round(2).to_dict() + num_cols = df.select_dtypes(include=[np.number]).columns.tolist() + + # 整理 memory 片段 + recent_memory = self.memory[-memory_limit:] if self.memory else [] + if recent_memory: + formatted_memory = "\n".join( + f"{m['role']}: {m['content']}" for m in recent_memory + ) + memory_block = f"{formatted_memory}" + else: + memory_block = "" + + prompt = f""" + 你是一名资深的数据预处理专家,负责为数据分析报告提供高质量的预处理建议。 + + === 数据概览 === + - 数据规模:{n_rows} 行 × {n_cols} 列 + - 数据类型分布:{dtype_counts} + - 缺失值总数:{missing_total} + - 各列缺失率:{missing_by_col} + - 数值型列:{num_cols} + - 历史上下文(仅供参考):{memory_block} + """ + + if user_input is None: + prompt += """ + === 请对每一列进行逐项分析(注意,是逐列分析) === + 请针对每一列依次说明以下四个方面: + + 1. **数据类型**:明确该列的数据类型,若存在混合类型或异常值类型,请指出。 + 2. **缺失值处理建议**:说明该列的缺失值处理策略;若建议调整,请指明具体“缺失值处理 策略”操作。 + 3. **异常值处理建议**:说明该列的异常检测与处理方案;若需调整,请说明“异常值处理 策略或阈值”操作。 + 4. **标准化建议**:说明是否建议标准化或缩放,并在需要时指出“标准化处理 策略”操作。 + + 输出格式要求: + - 按“列名 + 分点说明(1–4)”的形式分段输出; + - 每一列独立成段,并以换行分隔; + - 使用清晰、简洁的专业语言。 + """ + else: + prompt += f""" + === 用户新需求 === + {user_input} + + 请结合以上数据概览与历史上下文,针对该需求,给出下一步操作。 + 可考虑的操作包括:缺失值处理、异常值检测与修正、标准化或归一化、特征类型调整等。 + 输出应保持结构化与连贯性,避免重复说明。 + """ + + suggestions = self.call(prompt) + + return suggestions + + + def code_generation(self, df_head, user_prompt): + """生成 LLM prompt:要求 LLM 输出 process_df(pandas DataFrame)。""" + allowed = ", ".join(self.allowed_libs) + + prompt = f""" + 请**严格只输出纯 Python 代码**,不得包含以下内容: + - 解释性文字、注释、示例; + - Markdown 代码块标记(禁止出现 ``` 或 ```python 等); + - 任何多余输出(如 print、全局变量赋值等)。 + + === 运行环境说明 === + 运行环境中已提供以下对象与库: + - pandas DataFrame 变量:`df` + - 库:numpy (np)、SimpleImputer、StandardScaler、MinMaxScaler、RobustScaler、 + OneHotEncoder、OrdinalEncoder、LabelEncoder、FunctionTransformer、 + ColumnTransformer、Pipeline。 + 若所需功能在这些库中不存在,请自行写 Python code 实现。 + + === 生成要求 === + 1. 若有用户需求,请优先满足用户需求(优先级高于 LLM 返回的通用建议)。 + 2. 若建议指出某列“无需处理”,则对该列不进行任何操作。 + 3. 禁止导入其他库、禁止文件读写。 + 4. 所有括号(圆括号、方括号、大括号)必须成对闭合,不得错位或遗漏。 + 5. 对类别特征,可使用 OneHotEncoder 或 OrdinalEncoder; + 若为单列字符串/类别列,请使用 LabelEncoder 或 OrdinalEncoder,不得 passthrough。 + 6. 在构建 ColumnTransformer 前,需检测并处理“混合型列” + —— 即同时包含数值和字符串的列, + 使用 `FunctionTransformer(lambda x: x.astype(str))` 将其统一为字符串类型。 + 7. ColumnTransformer 的 transformers 中仅包含经过上述处理的列。 + 8. 使用 OneHotEncoder 时,若输出稀疏矩阵,请确保所有输入特征均为数值类型。 + 9. 若 df 中存在重复表头(如第 0 行与 header 相同),需自动检测并删除重复表头行。 + 10. 确保预处理后的 DataFrame 中每一列均有明确列名。 + 11. 脚本最后仅保留一行结果: + `process_df = ...` + 不允许出现 print、显示语句或其他多余输出。 + + === 输入数据示例 === + {df_head} + + === 用户指定需求 === + {user_prompt} + + 请严格依据以上要求,输出完整且可直接执行的 Python 代码(纯代码块,无额外说明)。 + """.strip() + + if self.error is not None: + if self.debug_num < 5 : + self.debug_num += 1 + + prompt += f""" + 上次生成的代码运行失败。 + 【错误信息】: + {self.error} + + 【原始代码】: + {self.code} + + 请在不输出任何解释性文字的情况下,推理并理解导致错误的根本原因, + + 要求: + 1. 不输出任何分析、解释或说明(包括文字、列表或注释段落); + 2. 可在代码内部使用简短注释说明关键修改; + 3. 若错误源于逻辑、数据结构或函数使用不当,请自行调整; + 4. 若依赖库方法不适用,可自行实现替代函数; + 5. 生成的代码必须可独立运行,无语法错误; + 6. 保持整体逻辑与原代码意图一致,仅做必要修正。 + """ + + else: + self.debug_num = 0 + + if self.user_input is not None: + prompt += f"用户需求:{self.user_input}。\n请严格遵循并优先执行该需求,其优先级高于所有其他建议或规则。\n" + + if self.refined_suggestions is not None: + prompt += f"LLM返回的预处理建议:{self.refined_suggestions}" + + raw = self.call(prompt) + return raw + + + def summary_html(self): + + if self.code is None: + summary = None + return summary + + else: + processed_df = self.load_processed_df() + prompt = f""" + 你正在撰写数据分析报告的第二章——《数据预处理与标准化》。 + 请根据以下输入内容,提炼关键信息并撰写相应分析段落。 + + - 预处理代码: + {self.code} + + - 预处理结果(数据示例): + {processed_df.head()} + + {f"- 预处理建议对话记录:{self.load_memory}" if self.load_memory else ""} + + 撰写要求: + 1. 使用流畅、自然的中文表达; + 2. 语言应简洁、准确,避免过多形容词或副词; + 3. 不使用“可能”“也许”“似乎”“微妙”等模糊表述; + 4. 不添加大标题,可使用自然段进行叙述; + 5. 内容需逻辑清晰,体现代码与结果之间的分析关联。 + + """.strip() + + desc = self.call(prompt) + + summary = { + "title": "数据预处理", + "desc": desc, + "processed_df": self.processed_df.head(), + "code": self.code, + } + + return summary + + + def summary_word(self): + + return self.summary_html() + + + def check_abstract(self): + + if self.abstract is None: + + processed_df = self.load_processed_df() + + if self.code is None: + self.abstract = None + + else: + + memory = f"【预处理建议对话记录】\n{self.load_memory}\n" if self.load_memory else "" + + prompt = f""" + 这是数据分析流程中的“数据预处理与标准化”阶段。 + + 【预处理代码】 + {self.code} + + 【预处理结果(前五行)】 + {processed_df.head()} + + {memory} + 请在确保信息准确完整的前提下,将上述内容概括为一段简洁的文字摘要。 + 要求: + 1. 语言自然流畅,保持客观和专业; + 2. 内容应涵盖关键点(包括主要预处理步骤与结果特征); + 3. 重点在于“说明核心信息”,而非逐行描述; + 4. 生成的摘要应可用于报告编写时判断该部分是否需要引用。 + """.strip() + + desc = self.call(prompt) + self.abstract = desc + + return self.abstract + + + def check_full(self): + if self.full is None: + processed_df = self.load_processed_df() + if self.code is None: + self.full = None + else: + content = f""" + 【阶段说明】这是数据分析流程中的数据预处理阶段。 + 【预处理代码】{self.code} + 【预处理结果前五行】{processed_df.head()} + """.strip() + if self.load_memory is not None: + content += f"\n【预处理建议聊天对话】{self.load_memory}" + + self.full = content + + return self.full diff --git a/prompt_engineer/sec3_call_llm.py b/prompt_engineer/sec3_call_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..16f355f3236689f832cc386a7a1a809fcbdf1721 --- /dev/null +++ b/prompt_engineer/sec3_call_llm.py @@ -0,0 +1,691 @@ +import streamlit as st +import base64 +import plotly.graph_objs as go +from concurrent.futures import ThreadPoolExecutor, as_completed + +from prompt_engineer.call_llm import LLMClient + +import numpy as np +np.set_printoptions(edgeitems=250, threshold=501) + +class VisualizationAgent(LLMClient): + + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) + self.cols_wo_id = None + self.recommendations = None + self.analysis = [] + self.quick_action = None + self.data_meaning = "" + self.allowed_libs = [ + "numpy", "plotly", "plotly.express", "plotly.graph_objects" + ] + self.code = None + self.result = None + self.suggestion = None + self.user_input = None + self.fig = [] + self.par_content = "" + self.error = None + self.abstract=None + self.full = None + self.color = None + self.finish_auto_task = False + self.debug_num = 0 + self.refined_suggestions = None + + + def finish_auto(self): + + self.finish_auto_task = True + + + def save_user_input(self, user_input): + + self.user_input = user_input + + + def load_user_input(self): + + return self.user_input + + + def save_color(self, color): + + self.color = color + + + def load_color(self): + + return self.color + + + def add_fig(self, fig, desc): + + entry = {"fig": fig, "desc": desc} + self.fig.append(entry) + + + def load_fig(self): + + return self.fig + + + def save_cols_wo_id(self, col): + + self.cols_wo_id = col + + + def load_cols_wo_id(self): + + return self.cols_wo_id + + + def save_code(self, code): + + self.code = code + + + def load_code(self): + + return self.code + + + def save_recommendations(self, recommendations): + + self.recommendations = recommendations + + + def load_recommendations(self): + + return self.recommendations + + + def save_suggestion(self, suggestion): + + self.suggestion = suggestion + + + def load_suggestion(self): + + return self.suggestion + + + def load_data_meaning(self): + + return self.data_meaning + + + def save_error(self, error): + + self.error = error + + + def load_error(self): + + return self.error + + + def refine_suggestions(self, rec): + + prompt = f""" + 请根据以下详细的可视化建议,提取每一列与每个变量组的推荐可视化方法。 + + 详细可视化建议: + {rec} + + 输出要求(必须严格遵守): + 1. 输出为纯文本,每条独立换行,且不得有多余说明。 + 2. 单变量格式:列名:图表1, 图表2。 + 3. 多变量格式:关系组:列A,列B:图表1, 图表2。 + 4. 总体变量格式:总体:图表1, 图表2。 + 5. 严格不要添加标题、编号、示例或额外解释。 + 6. 提取可视化方法精准。 + """ + + refined_suggestions = self.call(prompt) + self.refined_suggestions = refined_suggestions + + return refined_suggestions + + + def get_visualization_recommendations( + self, + cols, + user_input=None, + memory_limit: int = 6, + ) -> str: + + dim_info = f"{self.df.shape[0]} 行 x {self.df.shape[1]} 列" + + recent_memory = self.memory[-memory_limit:] if getattr(self, "memory", None) else [] + if recent_memory: + formatted_memory = "\n".join( + f"{m['role']}: {m['content']}" for m in recent_memory + ) + memory_block = f"{formatted_memory}" + else: + memory_block = "" + + if user_input is None: + prompt = f""" + 你是一位资深数据可视化专家,请根据以下信息,为数据分析报告的“可视化设计”章节提供系统、专业的建议。 + + 【数据集信息】 + - 数值型变量:{cols} + - 数据维度:{dim_info} + - 历史上下文(仅供参考):{memory_block} + + 【输出格式】 + 请严格按照以下结构输出(保持标题和层级一致,不得增减): + + 一、单变量可视化(Univariate) + 1. 针对每个数值型变量,推荐 1–2 种最合适的可视化方法,并简要说明理由。 + 例如: + - `列1`:推荐“直方图(Histogram)”和“盒须图(Box Plot)”,理由:…… + + 二、多变量关系可视化(Multivariate) + 1. 从上述变量中选择 1–3 组值得重点分析的变量组合(每组包含 2–3 个变量),并说明选择理由。 + 例如: + - 关系组 1:`[列1, 列2]`,理由:…… + 2. 对每一组变量,推荐最合适的可视化方法,并简要说明。 + 例如: + - 关系组 1:散点图(Scatter Plot)+ 回归线(Regression Line),理由:…… + + 三、整体分布可视化(Distribution Overview) + 1. 针对全数据的总体分布特征,推荐 1–2 种全局可视化方法,并说明用途。 + 例如: + - 推荐“小提琴图矩阵(Violin Plot Matrix)”,用途:…… + - 推荐“热力图(Heatmap)”,用途:…… + + 【执行要求】 + 1. 若列名无实际意义(如索引、冗余 ID),应自动过滤; + 2. 输出内容需保持条理清晰、语言简洁、专业。 + """.strip() + + else: + prompt = f""" + 你是一位资深数据可视化专家,请根据以下信息,请回应用户需求,实现用户需求: + + 【用户需求】 + {user_input} + + 【数据集信息】 + - 数值型变量:{cols} + - 数据维度:{dim_info} + - 数据概览(前几行): + {self.df.head().to_string(index=False)} + - 历史上下文(仅供参考):{memory_block} + + 【执行要求】 + 1. 若用户明确指定可视化列,仅针对这些列给出建议; + 2. 若用户提出特定要求(如图形大小、坐标轴 log 缩放等),必须在输出中体现; + 3. 仅响应用户需求,不输出无关内容; + 4. 若用户要求对先前内容进行局部修改,应保留未更动部分,仅更新相关建议; + 5. 输出内容应结构清晰、逻辑连贯、语言简洁。 + 6. 禁止输出代码。 + """.strip() + + recommendations = self.call(prompt) + return recommendations + + + def desc_fig(self, fig, dtype_info): + + selected = st.session_state.selected_model + + if selected == "智谱AI" or selected == "通义千问" or selected == "GPT-4o" or selected == "GPT-5" or selected == "豆包" or selected == "Claude": + img_bytes = fig.to_image(format="jpg") + fig_info = extract_plotly_info(fig) + base64_bytes = base64.b64encode(img_bytes) + base64_string = base64_bytes.decode('utf-8') + + prompt_payload = [ + { + "type": "image_url", + "image_url": {"url": f"data:image/jpg;base64,{base64_string}"} + }, + { + "type": "text", + "text": f""" + 请综合下方可视化图与变量信息,进行**简洁但深入的分析**。 + 从分布形态、趋势特征、变量间关系、潜在异常现象、现实含义五个角度,提炼关键洞察。 + 输出一段不超过 120 字的自然语言分析结论(非摘要)。 + + 【变量信息】 + {dtype_info} + + 【图表结构信息】 + {fig_info} + + 写作要求: + 1. 分析需包含对数据异常的识别与说明: + - 若存在明显异常点、异常段或突变趋势,请指出其特征与潜在影响; + - 若未发现异常,也需明确说明整体分布稳定或无显著异常; + 2. 内容需体现推理与解释性思考,而非表面描述; + 3. 使用逻辑清晰、客观专业的语言; + 4. 使用动词驱动句式(如“呈现出”“反映出”“揭示出”“说明了”等); + 5. 不使用模糊词(如“可能”“似乎”“微妙”等); + 6. 不使用标题、列表或格式符号; + 7. 若变量含义中存在噪声或重复信息,请自动忽略; + 8. 保持语气简洁有力,强调数据特征与分析结论。 + """.strip() + } + ] + + desc_fig = self.call(prompt_payload) + + else: + prompt = f""" + 请综合下方可视化图与变量信息,从数据分布、趋势特征及潜在关系等角度进行分析。 + 以不超过 100 字的自然语言总结关键发现,突出该变量在整体数据结构中的意义或异常现象。 + + 【变量信息】 + {dtype_info} + + 【图表信息】 + {fig.to_dict()} + + 写作要求: + 1. 语言应流畅自然,保持客观、专业; + 2. 使用简洁的动词和名词,不滥用形容词或副词; + 3. 避免“可能”“也许”“似乎”“微妙”等模糊词; + 4. 不添加标题或列表结构; + 5. 结合数据含义和图表特征,给出具有洞察力的简要结论; + 6. 若变量含义中存在杂乱或重复信息,请自动忽略。 + """.strip() + + desc_fig = self.call(prompt) + + return desc_fig + + + def summary_html(self) -> str: + + analysis = self.summary_fig_analysis_list() + + if analysis is None: + + return None + + else: + analysis = {i: item for i, item in enumerate(analysis)} + + summary = { + "title": "数据可视化", + "fig_analysis": analysis, + } + + return summary + + + def summary_word(self) -> str: + + analysis = self.summary_fig_analysis_list() + + if analysis is None: + + return None + + else: + + summary = { + "title": "数据可视化", + "fig_analysis": analysis, + } + + return summary + + + def summary_fig_analysis_list(self) -> str: + + if not self.code: + return self.analysis + + if self.analysis: + return self.analysis + + # state_copy = dict(st.session_state) + selected = st.session_state.get("selected_model", "default") + # selected = state_copy.get("selected_model", "default") + + # --- 定义单个任务 --- + def analyze_one(item, offset): + fig = item["fig"] + desc = item["desc"] + + # 恢复状态(如果需要访问 st.session_state) + # st.session_state.update(state_copy) + selected = st.session_state.get("selected_model", "default") + if isinstance(fig, go.Figure): + if selected == "智谱AI" or selected == "通义千问" or selected == "GPT-4o" or selected == "GPT-5" or selected == "豆包" or selected == "Claude": + img_bytes = fig.to_image(format="jpg") + base64_string = base64.b64encode(img_bytes).decode("utf-8") + + fig_info = extract_plotly_info(fig) + + prompt_payload = [ + { + "type": "image_url", + "image_url": {"url": f"data:image/jpg;base64,{base64_string}"} + }, + { + "type": "text", + "text": f""" + 你正在撰写数据分析报告的第三章——《数据可视化》。 + 请针对下方变量,结合其**业务含义、统计特征**与**可视化图表现**,撰写一段专业、逻辑严谨、可直接用于报告正文的分析内容。 + + 【变量信息】 + {self.cols_wo_id} + + 【Plotly 图表结构】 + {fig_info} + + 【基础统计概览】 + {desc} + + 【分析任务】 + 请在脑中先完成以下推理步骤,然后输出结构化正文: + 1. 从图表识别核心模式:整体趋势、峰值、分布形态、异常点或聚集区; + 2. 思考该模式与变量业务含义的关系; + 3. 判断是否存在异常现象(单点异常、阶段性异常或结构性突变),并说明其潜在影响; + 4. 若图中包含其他变量,请分析它们之间的统计或逻辑关联; + 5. 将上述洞察整合成逻辑完整、语言自然的段落。 + + 【输出格式(严格遵守)】 + 输出为纯文本,依次包含以下三部分(不使用 Markdown 或符号): + + 1. 概述 + - 简述变量的定义、业务角色及数据表现的总体趋势; + - 提出该变量在整体数据结构中可能的重要性。 + + 2. 分布与特征分析 + - 从统计与图形角度分析其分布特征(集中趋势、离散程度、偏态、峰度、周期性等); + - 若发现异常或突变,请具体说明其表现形式与潜在机制; + - 若与其他变量有关联趋势,指出方向与强度。 + + 3. 实际含义与推论 + - 结合业务或研究背景,解释观察到的现象; + - 分析其可能揭示的现实规律、风险或优化方向; + - 若合适,可提出合理推测或后续分析建议(保持客观与逻辑自洽)。 + + 【写作要求】 + 1. 保持语言正式、专业、逻辑紧密; + 2. 句式多样、表达自然,避免模板化表述; + 3. 禁用模糊词汇(如“可能”“似乎”“大概”等); + 4. 不使用任何标题符号(如 #、** 等); + 5. 不输出“AI”“模型”“助手”等字样; + 6. 输出为连续正文,不包含解释性语句或附加说明。 + """.strip() + } + ] + + analysis_text = self.call(prompt_payload) + + else: + + prompt = f""" + 你正在撰写数据分析报告的第三章——《数据可视化》。 + 请针对下方变量,结合其业务含义与对应的可视化图,撰写一段结构化、专业的分析文字。 + + 【变量信息】 + {self.cols_wo_id} + + 【Plotly 图表信息】 + {fig.to_dict()} + + 【基础统计概览】 + {desc} + + 请严格按照以下格式撰写内容(使用纯文本,不使用 Markdown 语法或符号): + + 1. 概述 + - 说明该变量的含义及其在数据或业务中的作用; + - 简要描述整体分布特征或变量间的主要关联趋势。 + + 2. 分布 / 关联特征 + - 从统计角度说明变量的分布特征或相关关系; + - 可引用关键统计量(均值、中位数、四分位数、相关系数等)支持分析。 + + 3. 现实含义 + - 结合变量在实际情境中的意义,解释所观察到的分布或关系; + - 指出这些模式可能反映的现实现象或潜在影响(例如:某变量偏高代表风险上升或群体特征差异)。 + + 【写作要求】 + 1. 使用流畅、自然且正式的中文表达; + 2. 语言应客观、简洁,避免冗余修辞; + 3. 禁止使用“可能”“也许”“似乎”“微妙”等模糊词; + 4. 不使用标题符号(#、** 等); + 5. 保持逻辑连贯,分析层次清晰。 + """.strip() + + analysis_text = self.call(prompt) + print(prompt) + return offset, {"figure": fig, "analysis": analysis_text} + + # --- 并行执行 --- + results = [] + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(analyze_one, item, i) for i, item in enumerate(self.fig)] + for f in as_completed(futures): + result = f.result() + if result: + results.append(result) + + # --- 按原顺序排序 --- + results.sort(key=lambda x: x[0]) + self.analysis = [r[1] for r in results] + + return self.analysis + + + def code_generation(self, df_head: str, user_prompt: str) -> str: + """生成 LLM prompt:要求 LLM 输出 result_dict(可 JSON 序列化)。""" + allowed = ", ".join(self.allowed_libs) + + prompt = ( + "请**严格只输出纯 Python 代码**,**不要**输出任何解释性文字、注释、示例、markdown code fence(禁止出现 ``` 或 ```python 等)" + "运行环境已提供 pandas DataFrame 变量 `df`、numpy(np)、" + "plotly.express(px)、plotly.graph_objects(go)。\n\n" + "##严格要求##:\n" + "1) **严格执行用户需求**:若用户指定了要可视化的列,可能是精确列名,也可能是模糊输入" + "(如输入 “ordera” 但实际列名为 “ordertypea”),不要凭空产生虚假列名!!!" + f"请在脚本开头使用 LLM 理解将用户输入映射到 {df_head} 中最合适的真正列名,或采用更保守的索引(如第0列,第1列 推荐!),再仅对这些列绘制图表;\n" + """2) **统计并重命名**:所有类别分布图请按下面模板写,**绝不直接用** `index` 作为列名—— + # === 模板:统计并绘制 Bar Chart === + for col in categorical_cols: + df_counts = df[col] \\ + .value_counts() \\ + .rename_axis(col) \\ + .reset_index(name='count') + fig = px.bar( + df_counts, + x=col, + y='count', + title=f'Bar Chart of {col}', + labels={col: col, 'count': 'Count'} + ) + fig_dict[f'{col}_bar'] = fig + + 3) 智能选图:根据数据类型(数值/类别)自动选择合适的图表。 + 4) 自动检测是否需要按分类列着色,并做两种处理:若存在指定的分类列且想连续映射,先编码为数值 codes;如要离散映射,使用 parallel_categories + 5) 如 Plotly Express 中无合适图表,使用 `go.Figure` 自定义。 + 6) 脚本末尾仅包含 `fig_dict = {...}`,不要 `print`、不要额外全局变量。 + 7) 任何情况下不得“造”列名或直接写 `'index'`;若要使用索引,必须显式使用 `df.index`。 + 8) 不要使用文件读写或其他外部 IO。 + 9) 请只给我python代码,不要给我任何'''python等非代码内容的标识符。""" + f"示例数据头部:\n{df_head}\n\n" + f"每一张图的颜色必须从{self.color}中,选择\n\n" + f"画图建议: {self.refined_suggestions}\n\n" + "返回:完整 Python 代码(纯代码块)。" + ) + + if self.error is not None: + if self.debug_num < 5 : + self.debug_num += 1 + prompt += f""" + 上次生成的代码运行失败。 + 【错误信息】: + {self.error} + + 【原始代码】: + {self.code} + + 请在不输出任何解释性文字的情况下,推理并理解导致错误的根本原因, + + 要求: + 1. 不输出任何分析、解释或说明(包括文字、列表或注释段落); + 2. 可在代码内部使用简短注释说明关键修改; + 3. 若错误源于逻辑、数据结构或函数使用不当,请自行调整; + 4. 若依赖库方法不适用,可自行实现替代函数; + 5. 生成的代码必须可独立运行,无语法错误; + 6. 保持整体逻辑与原代码意图一致,仅做必要修正。 + """ + else: + self.debug_num = 0 + + raw = self.call(prompt) + + return raw + + + def check_abstract(self): + if self.abstract is None: + # 获取所有分析内容 + analysis_list = self.summary_fig_analysis_list() + + if not analysis_list : + self.abstract = "暂无可视化分析内容。" + return self.abstract + + # 合并所有分析内容为一个整体文本 + all_analyses = "\n\n".join([ + f"【变量分析 {i+1}】\n{item['analysis']}" + for i, item in enumerate(analysis_list) + ]) + + prompt = f""" + 请阅读并综合以下多个变量的分析内容: + {all_analyses} + + 任务: + 将这些分析整合为一段结构化、信息充分的**综合语义总结**,供后续大模型自动生成报告目录使用。 + + 目标: + - 输出内容应帮助后续模型理解分析中包含的主题、变量、维度、关系与逻辑顺序; + - 它将作为“目录生成模型”的输入,因此必须让模型能看出报告中应有哪些章节与子章节。 + + 写作要求: + 1. **信息保留**: + - 保留每个变量的关键结论、趋势、特征、显著差异; + - 明确变量间的联系、对比或影响; + - 不得省略任何对分析主题有价值的事实。 + + 2. **结构导向**: + - 按逻辑顺序组织:总体特征 → 各变量分析 → 变量间关系 → 潜在规律; + - 若存在不同主题(如气象因素、污染物指标、模型结果),应自然体现层次; + - 语义中隐含章节边界信号(如“首先…其次…最后…”、“在气象变量方面…”、“在建模部分…”等)。 + + 3. **语言风格**: + - 专业、清晰、客观; + - 使用完整句表达,不使用列表或编号; + - 可以稍微详细,不追求简短。 + + 4. **输出格式**: + - 输出仅为一段完整文字; + - 不得加入标题、注释、JSON、代码块; + - 该文字将被直接送入目录生成模型,不对人类展示。 + + 请生成符合上述要求的综合语义总结。 + """.strip() + + self.abstract = self.call(prompt) + + return self.abstract + + + def check_full(self): + """ + 返回结构化的内容,遵守图片插入协议: + - 每个分析内容前标注索引 + - 图片插入位置用 [FIG:index] 表示 + - 后续处理时可根据此协议替换为实际图像 + """ + if self.full is None: + analysis_list = self.summary_fig_analysis_list() + + if not analysis_list : + self.full = "暂无可视化分析内容。" + return self.full + + # 构造结构化文本:带图片插入标记 + full_parts = ["""【阶段说明】这是数据分析流程中的数据可视化阶段。"""] + for i, item in enumerate(analysis_list): + desc = item["analysis"] + part = f""" + 【对图 {i}的分析】 + {desc} + [FIG:{i}] # 图片插入位置标记 + """.strip() + full_parts.append(part) + + self.full = "\n\n".join(full_parts) + + # 添加协议说明 + protocol_note = """ + --- + # 图片插入处理协议说明: + # [FIG:index] 表示图片插入位置 + # index 对应分析内容中的索引 + # 你在需要放图的地方用 [FIG:index] 代替即可 + """.strip() + + self.full = f"{self.full}\n\n{protocol_note}" + + return self.full + + +def extract_plotly_info(fig): + """ + 从 Plotly Figure(对象 / dict / 字符串)中提取关键信息: + - 图标题 + - X/Y 轴标题 + - 图类型 + - 颜色信息 + - trace 数量 + """ + import ast + import plotly.graph_objects as go + + if isinstance(fig, go.Figure): + fig = fig.to_dict() + elif isinstance(fig, dict): + pass + elif isinstance(fig, str): + clean_str = fig.strip() + if clean_str.startswith("Figure("): + clean_str = clean_str[len("Figure("):-1] + try: + fig = ast.literal_eval(clean_str) + except Exception as e: + raise ValueError(f"无法解析字符串形式的 Figure: {e}") + else: + raise TypeError(f"不支持的 fig 类型: {type(fig)}") + + layout = fig.get("layout", {}) + title = layout.get("title", {}).get("text", "") + xaxis_title = layout.get("xaxis", {}).get("title", {}).get("text", "") + yaxis_title = layout.get("yaxis", {}).get("title", {}).get("text", "") + + data_list = fig.get("data", []) + types = list({d.get("type", "") for d in data_list}) + + + return { + "title": title or "(无标题)", + "xaxis": xaxis_title or "(无X轴标题)", + "yaxis": yaxis_title or "(无Y轴标题)", + "types": types, + + } diff --git a/prompt_engineer/sec4_call_llm.py b/prompt_engineer/sec4_call_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..9fd22a5777633129b70b7339959b5256d0335f0c --- /dev/null +++ b/prompt_engineer/sec4_call_llm.py @@ -0,0 +1,606 @@ +import streamlit as st + +from prompt_engineer.call_llm import LLMClient + + +class ModelingCodingAgent(LLMClient): + + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) + self.allowed_libs = [ + "numpy", "sklearn.model_selection", "sklearn.preprocessing", "sklearn.ensemble", 'torch', 'torchvision', 'torchaudio', 'xgboost', 'lightgbm' + ] + self.code = None + self.result = None + self.suggestion = None + self.user_selection = None + self.par_content = "" + self.inference_code = None + self.best_model = None + self.inference_data = None + self.inference_processed_df = None + self.abstract=None + self.full = None + self.error = None + self.inference_error = None + self.target = None + self.finish_auto_task = False + self.best_model_gz_bytes = None + self.debug_num = 0 + self.refined_suggestions = None + + def finish_auto(self): + + self.finish_auto_task = True + + + def save_best_model_gz_bytes(self, best_model_gz_bytes): + + self.best_model_gz_bytes = best_model_gz_bytes + + + def load_best_model_gz_bytes(self): + + return self.best_model_gz_bytes + + + def save_target(self, target): + + self.target = target + + + def load_target(self): + + return self.target + + + def save_error(self, error): + + self.error = error + + + def load_error(self): + + return self.error + + + def save_inference_error(self, inference_error): + + self.inference_error = inference_error + + + def load_inference_error(self): + + return self.inference_error + + + def save_inference_data(self, inference_data): + + self.inference_data = inference_data + + + def load_inference_data(self): + + return self.inference_data + + + def save_inference_processed_df(self, inference_processed_df): + + self.inference_processed_df = inference_processed_df + + + def load_inference_processed_df(self): + + return self.inference_processed_df + + + def save_inference_code(self, code): + + self.inference_code = code + + + def load_inference_code(self): + + return self.inference_code + + + def save_best_model(self, best_model): + + self.best_model = best_model + + + def load_best_model(self): + + return self.best_model + + + def save_code(self, code): + + self.code = code + + + def load_code(self): + + return self.code + + + def save_suggestion(self, suggestion): + + self.suggestion = suggestion + + + def load_suggestion(self): + + return self.suggestion + + + def save_modeling_result(self, result): + + self.result = result + + + def load_modeling_result(self): + + return self.result + + + def save_user_selection(self, user_selection): + + self.user_selection = user_selection + + + def load_user_selection(self): + + return self.user_selection + + + def refine_suggestions(self): + """将 LLM 返回的预处理推荐进行信息提取""" + + prompt = f""" + 请阅读以下建模建议,并将其转化为对下一个 coding agent 的清晰建模任务指令。 + + === 建模建议 === + {self.suggestion} + + === 输出要求(必须严格遵守) === + 1. 输出为纯文本,不使用任何 Markdown、编号或符号; + 2. 指令应简洁明确,便于 coding agent 直接理解并执行; + 3. 内容应聚焦于模型构建、训练或评估的具体任务; + 4. 避免解释性或分析性语言,仅描述“需要执行的操作”; + 5. 输出应覆盖所有关键步骤,使 coding agent 能独立完成建模流程。 + """.strip() + + refined_suggestions = self.call(prompt) + self.refined_suggestions = refined_suggestions + + print(refined_suggestions) + + return refined_suggestions + + + def code_generation(self, df_head: str, user_prompt: str) -> str: + """生成 LLM prompt:要求 LLM 输出 result_dict(可 JSON 序列化)。""" + allowed = ", ".join(self.allowed_libs) + + if self.refined_suggestions is None: + suggestion = user_prompt + else: + suggestion = self.refined_suggestions + + prompt = ( + f"""请**严格只输出纯 Python 代码**,**不要**输出任何解释性文字、注释、示例、markdown code fence(禁止出现 ``` 或 ```python 等)。运行环境已提供 pandas DataFrame 变量 `df`、numpy(np)、train_test_split、StandardScaler、以及用户在 Requirement 中可能提到的任意模型类(例如 RandomForestRegressor、GradientBoostingRegressor、LinearRegression、XGBRegressor、LogisticRegression、SVC 等)。 + + 要求: + + 1) 使用 80/20 切分(random_state=42),根据用户需求决定是否对数值特征标准化(StandardScaler),如果标准化,务必只应用于数值列并在训练/测试集上分别执行 fit_transform/transform。 + 2) **对 Requirement 中列出的所有模型都依次训练和评估**,不得只选随机森林;如果用户在 Requirement 中指定了多个模型名称,脚本必须循环遍历这些模型并分别训练、预测、计算指标。 + 3) 不要导入任何评价库(如 sklearn.metrics),如需评价请用 numpy 手写实现常见指标(回归:MAE、MSE、R2;分类:accuracy、precision、recall、f1)。 + 4) **脚本最后必须只输出并赋值一个变量 `result_dict`,且它是一个可以 JSON 序列化的 Python dict。** + 推荐 schema(必须包含以下键): + {{ + "dataset": "<可选描述字符串>", + "models": [ + {{ + "name": "<模型类名>", + "type": "` 样式;
+ - 圆角矩形样式:背景 #EFF6FF,padding 12px,border-radius 8px,margin-bottom 16px;
+ 5. 如果有图片列表 `images`:
+ - ≤3 张时水平并排;>3 张时自动换行,每行最多 3 张;
+ - `` 带 6px 圆角、轻微阴影 `box-shadow:0 2px 6px rgba(0,0,0,0.1)`;
+ 6. 在 `