ElvisWang111 commited on
Commit
a0357aa
·
verified ·
1 Parent(s): d8f115d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -46
app.py CHANGED
@@ -2,7 +2,7 @@ import sys, os
2
  import tempfile
3
  import streamlit as st
4
 
5
- from config import MODEL_CONFIGS
6
  from utils.save_secrets import *
7
  from prompt_engineer.sec1_call_llm import DataLoadingAgent
8
  from prompt_engineer.sec2_call_llm import DataPreprocessAgent
@@ -20,12 +20,7 @@ np.set_printoptions(edgeitems=250, threshold=501)
20
 
21
  sys.path.append(os.path.dirname(__file__))
22
 
23
- CACHE_FILE = os.path.join(tempfile.gettempdir(), "anystat_cache.pkl")
24
- CACHE_DIR = './cache'
25
- SECRETS_PATH = Path(".streamlit") / "secrets.toml"
26
 
27
-
28
- # 设置页面配置
29
  st.set_page_config(
30
  page_title="Autostat",
31
  page_icon="🤖",
@@ -37,21 +32,36 @@ def init_session_state():
37
 
38
  if 'selected_model' not in st.session_state:
39
  st.session_state.selected_model = "DeepSeek"
40
- if "api_keys" not in st.session_state:
41
- st.session_state.api_keys = load_local_api_keys()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  if 'auto_mode' not in st.session_state:
43
  st.session_state.auto_mode = False
44
 
45
- if 'loading_start_time' not in st.session_state:
46
- st.session_state.loading_start_time = None
47
- if 'prep_start_time' not in st.session_state:
48
- st.session_state.prep_start_time = None
49
- if 'vis_start_time' not in st.session_state:
50
- st.session_state.vis_start_time = None
51
- if 'modeling_start_time' not in st.session_state:
52
- st.session_state.modeling_start_time = None
53
- if 'report_start_time' not in st.session_state:
54
- st.session_state.report_start_time = None
55
  if 'preference_select' not in st.session_state:
56
  st.session_state.preference_select = None
57
  if 'additional_preference' not in st.session_state:
@@ -62,37 +72,37 @@ def init_session_state():
62
  if 'data_loading_agent' not in st.session_state:
63
  st.session_state.data_loading_agent = DataLoadingAgent(
64
  api_keys=st.session_state.api_keys,
65
- model_configs=MODEL_CONFIGS,
66
  model=st.session_state.selected_model
67
  )
68
  if 'data_preprocess_agent' not in st.session_state:
69
  st.session_state.data_preprocess_agent = DataPreprocessAgent(
70
  api_keys=st.session_state.api_keys,
71
- model_configs=MODEL_CONFIGS,
72
  model=st.session_state.selected_model
73
  )
74
  if 'visualization_agent' not in st.session_state:
75
  st.session_state.visualization_agent = VisualizationAgent(
76
  api_keys=st.session_state.api_keys,
77
- model_configs=MODEL_CONFIGS,
78
  model=st.session_state.selected_model
79
  )
80
  if 'modeling_coding_agent' not in st.session_state:
81
  st.session_state.modeling_coding_agent = ModelingCodingAgent(
82
  api_keys=st.session_state.api_keys,
83
- model_configs=MODEL_CONFIGS,
84
  model=st.session_state.selected_model
85
  )
86
  if 'report_agent' not in st.session_state:
87
  st.session_state.report_agent = ReportAgent(
88
  api_keys=st.session_state.api_keys,
89
- model_configs=MODEL_CONFIGS,
90
  model=st.session_state.selected_model
91
  )
92
  if 'planner_agent' not in st.session_state:
93
  st.session_state.planner_agent = PlannerAgent(
94
  api_keys=st.session_state.api_keys,
95
- model_configs=MODEL_CONFIGS,
96
  model=st.session_state.selected_model
97
  )
98
 
@@ -111,11 +121,21 @@ def run_app():
111
  init_session_state()
112
  with st.sidebar:
113
  st.subheader("选择大模型")
114
- models = list(MODEL_CONFIGS.keys())
 
 
 
 
 
 
 
 
 
 
115
  st.selectbox(
116
  "选择要使用的大模型",
117
  models,
118
- index=models.index(st.session_state.selected_model),
119
  key="model_selector",
120
  on_change=on_model_selector_change,
121
  )
@@ -123,50 +143,132 @@ def run_app():
123
  st.subheader("API 密钥设置")
124
  selected = st.session_state.selected_model
125
 
126
- api_key_input = st.text_input(
127
- f"{selected} API 密钥",
128
- value=st.session_state.api_keys.get(selected, ""),
129
- type="password",
130
- key="api_key_input",
131
- )
132
- st.session_state.api_keys[selected] = api_key_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- if st.button("💾 保存密钥", use_container_width=True, key="save_key"):
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- st.session_state.api_keys[selected] = api_key_input
137
- st.success("已保存")
138
- st.rerun()
 
 
139
 
140
  if st.button("🧹 清空数据", use_container_width=True, key="clear_data"):
141
 
142
  st.session_state.data_loading_agent = DataLoadingAgent(
143
  api_keys=st.session_state.api_keys,
144
- model_configs=MODEL_CONFIGS,
145
  model=st.session_state.selected_model
146
  )
147
  st.session_state.data_preprocess_agent = DataPreprocessAgent(
148
  api_keys=st.session_state.api_keys,
149
- model_configs=MODEL_CONFIGS,
150
  model=st.session_state.selected_model
151
  )
152
  st.session_state.visualization_agent = VisualizationAgent(
153
  api_keys=st.session_state.api_keys,
154
- model_configs=MODEL_CONFIGS,
155
  model=st.session_state.selected_model
156
  )
157
  st.session_state.modeling_coding_agent = ModelingCodingAgent(
158
  api_keys=st.session_state.api_keys,
159
- model_configs=MODEL_CONFIGS,
160
  model=st.session_state.selected_model
161
  )
162
  st.session_state.report_agent = ReportAgent(
163
  api_keys=st.session_state.api_keys,
164
- model_configs=MODEL_CONFIGS,
165
  model=st.session_state.selected_model
166
  )
167
  st.session_state.planner_agent = PlannerAgent(
168
  api_keys=st.session_state.api_keys,
169
- model_configs=MODEL_CONFIGS,
170
  model=st.session_state.selected_model
171
  )
172
  st.session_state.auto_mode = False
@@ -186,7 +288,7 @@ def run_app():
186
  st.session_state.auto_mode = False
187
  st.session_state.planner_agent = PlannerAgent(
188
  api_keys=st.session_state.api_keys,
189
- model_configs=MODEL_CONFIGS,
190
  model=st.session_state.selected_model
191
  )
192
  st.rerun()
@@ -231,5 +333,4 @@ def run_app():
231
  pg.run()
232
 
233
  if __name__ == "__main__":
234
- run_app()
235
-
 
2
  import tempfile
3
  import streamlit as st
4
 
5
+ from config import MODEL_CONFIGS, CUSTOM_MODEL_KEY
6
  from utils.save_secrets import *
7
  from prompt_engineer.sec1_call_llm import DataLoadingAgent
8
  from prompt_engineer.sec2_call_llm import DataPreprocessAgent
 
20
 
21
  sys.path.append(os.path.dirname(__file__))
22
 
 
 
 
23
 
 
 
24
  st.set_page_config(
25
  page_title="Autostat",
26
  page_icon="🤖",
 
32
 
33
  if 'selected_model' not in st.session_state:
34
  st.session_state.selected_model = "DeepSeek"
35
+
36
+ if 'model_configs_runtime' not in st.session_state:
37
+ # 运行时模型配置,包含预设和自定义模型
38
+ st.session_state.model_configs_runtime = MODEL_CONFIGS.copy()
39
+ # 加载用户配置(包括 API 密钥和自定义模型)
40
+ user_configs = load_local_model_configs()
41
+ for model_name, config in user_configs.items():
42
+ if model_name in MODEL_CONFIGS:
43
+ # 预设模型:只更新 API 密钥
44
+ st.session_state.model_configs_runtime[model_name]["api_key"] = config.get("api_key", "")
45
+ else:
46
+ # 自定义模型:添加完整配置
47
+ st.session_state.model_configs_runtime[model_name] = {
48
+ "api_base": config.get("api_base", ""),
49
+ "model_name": config.get("model_name", ""),
50
+ "api_key": config.get("api_key", ""),
51
+ "api_type": "openai",
52
+ "is_preset": False,
53
+ }
54
+
55
+ # 从 model_configs_runtime 提取 api_keys(用于传递给 Agent)
56
+ if 'api_keys' not in st.session_state:
57
+ st.session_state.api_keys = {
58
+ name: config.get("api_key", "")
59
+ for name, config in st.session_state.model_configs_runtime.items()
60
+ }
61
+
62
  if 'auto_mode' not in st.session_state:
63
  st.session_state.auto_mode = False
64
 
 
 
 
 
 
 
 
 
 
 
65
  if 'preference_select' not in st.session_state:
66
  st.session_state.preference_select = None
67
  if 'additional_preference' not in st.session_state:
 
72
  if 'data_loading_agent' not in st.session_state:
73
  st.session_state.data_loading_agent = DataLoadingAgent(
74
  api_keys=st.session_state.api_keys,
75
+ model_configs=st.session_state.model_configs_runtime,
76
  model=st.session_state.selected_model
77
  )
78
  if 'data_preprocess_agent' not in st.session_state:
79
  st.session_state.data_preprocess_agent = DataPreprocessAgent(
80
  api_keys=st.session_state.api_keys,
81
+ model_configs=st.session_state.model_configs_runtime,
82
  model=st.session_state.selected_model
83
  )
84
  if 'visualization_agent' not in st.session_state:
85
  st.session_state.visualization_agent = VisualizationAgent(
86
  api_keys=st.session_state.api_keys,
87
+ model_configs=st.session_state.model_configs_runtime,
88
  model=st.session_state.selected_model
89
  )
90
  if 'modeling_coding_agent' not in st.session_state:
91
  st.session_state.modeling_coding_agent = ModelingCodingAgent(
92
  api_keys=st.session_state.api_keys,
93
+ model_configs=st.session_state.model_configs_runtime,
94
  model=st.session_state.selected_model
95
  )
96
  if 'report_agent' not in st.session_state:
97
  st.session_state.report_agent = ReportAgent(
98
  api_keys=st.session_state.api_keys,
99
+ model_configs=st.session_state.model_configs_runtime,
100
  model=st.session_state.selected_model
101
  )
102
  if 'planner_agent' not in st.session_state:
103
  st.session_state.planner_agent = PlannerAgent(
104
  api_keys=st.session_state.api_keys,
105
+ model_configs=st.session_state.model_configs_runtime,
106
  model=st.session_state.selected_model
107
  )
108
 
 
121
  init_session_state()
122
  with st.sidebar:
123
  st.subheader("选择大模型")
124
+
125
+ # 获取所有可用的模型(预设模型 + OpenAI API 兼容模型)
126
+ models = list(MODEL_CONFIGS.keys()) + [CUSTOM_MODEL_KEY]
127
+
128
+ # 确保选择的索引有效
129
+ try:
130
+ current_index = models.index(st.session_state.selected_model)
131
+ except ValueError:
132
+ current_index = 0
133
+ st.session_state.selected_model = models[0]
134
+
135
  st.selectbox(
136
  "选择要使用的大模型",
137
  models,
138
+ index=current_index,
139
  key="model_selector",
140
  on_change=on_model_selector_change,
141
  )
 
143
  st.subheader("API 密钥设置")
144
  selected = st.session_state.selected_model
145
 
146
+ # 判断是否为 OpenAI API 兼容模型
147
+ is_custom_model = (selected == CUSTOM_MODEL_KEY)
148
+
149
+ if is_custom_model:
150
+ # 显示 OpenAI API 兼容模型的配置界面
151
+ existing_config = st.session_state.model_configs_runtime.get(CUSTOM_MODEL_KEY, {})
152
+
153
+ base_url_input = st.text_input(
154
+ "Base URL",
155
+ value=existing_config.get("api_base", ""),
156
+ key="base_url_input",
157
+ placeholder="例如: https://api.siliconflow.cn/v1"
158
+ )
159
+
160
+ model_name_input = st.text_input(
161
+ "模型 ID",
162
+ value=existing_config.get("model_name", ""),
163
+ key="model_name_input",
164
+ placeholder="例如: Qwen/Qwen3-8B"
165
+ )
166
+
167
+ api_key_input = st.text_input(
168
+ "API 密钥",
169
+ value=st.session_state.api_keys.get(CUSTOM_MODEL_KEY, ""),
170
+ type="password",
171
+ key="api_key_input",
172
+ )
173
+
174
+ if existing_config and existing_config.get("api_base"):
175
+ st.info(f"当前配置: {existing_config.get('model_name', 'N/A')}")
176
+ else:
177
+ st.info("配置 OpenAI API 兼容模型")
178
+
179
+ if st.button("💾 保存配置", use_container_width=True, key="save_key"):
180
+ if not base_url_input or not model_name_input or not api_key_input:
181
+ st.error("请填写所有必需字段")
182
+ else:
183
+ # 保存到配置文件
184
+ update_local_model_config(
185
+ display_name=CUSTOM_MODEL_KEY,
186
+ api_key=api_key_input,
187
+ base_url=base_url_input,
188
+ model_name=model_name_input
189
+ )
190
+
191
+ # 更新运行时配置
192
+ st.session_state.model_configs_runtime[CUSTOM_MODEL_KEY] = {
193
+ "api_base": base_url_input,
194
+ "model_name": model_name_input,
195
+ "api_key": api_key_input, # 也保存 api_key
196
+ "api_type": "openai",
197
+ "is_preset": False,
198
+ }
199
+ # 同步到 api_keys
200
+ st.session_state.api_keys[CUSTOM_MODEL_KEY] = api_key_input
201
+ st.session_state.selected_model = CUSTOM_MODEL_KEY
202
+
203
+ st.success("已保存配置")
204
+ st.rerun()
205
+ else:
206
+ # 预设模型或已保存的自定义模型
207
+ api_key_input = st.text_input(
208
+ f"{selected} API 密钥",
209
+ value=st.session_state.api_keys.get(selected, ""),
210
+ type="password",
211
+ key="api_key_input",
212
+ )
213
+
214
+ # 如果是自定义模型,显示其配置信息
215
+ if selected in st.session_state.model_configs_runtime:
216
+ config = st.session_state.model_configs_runtime[selected]
217
+ if not config.get("is_preset", False):
218
+ st.caption(f"Base URL: {config.get('api_base', 'N/A')}")
219
+ st.caption(f"Model: {config.get('model_name', 'N/A')}")
220
 
221
+ if st.button("💾 保存密钥", use_container_width=True, key="save_key"):
222
+ # 保存到配置文件
223
+ config = st.session_state.model_configs_runtime.get(selected, {})
224
+ if config.get("is_preset", False):
225
+ # 预设模型,只保存 API key
226
+ update_local_model_config(display_name=selected, api_key=api_key_input)
227
+ else:
228
+ # 自定义模型,保存完整配置
229
+ update_local_model_config(
230
+ display_name=selected,
231
+ api_key=api_key_input,
232
+ base_url=config.get("api_base"),
233
+ model_name=config.get("model_name")
234
+ )
235
 
236
+ # 同步更新运行时配置和 api_keys
237
+ st.session_state.model_configs_runtime[selected]["api_key"] = api_key_input
238
+ st.session_state.api_keys[selected] = api_key_input
239
+ st.success("已保存")
240
+ st.rerun()
241
 
242
  if st.button("🧹 清空数据", use_container_width=True, key="clear_data"):
243
 
244
  st.session_state.data_loading_agent = DataLoadingAgent(
245
  api_keys=st.session_state.api_keys,
246
+ model_configs=st.session_state.model_configs_runtime,
247
  model=st.session_state.selected_model
248
  )
249
  st.session_state.data_preprocess_agent = DataPreprocessAgent(
250
  api_keys=st.session_state.api_keys,
251
+ model_configs=st.session_state.model_configs_runtime,
252
  model=st.session_state.selected_model
253
  )
254
  st.session_state.visualization_agent = VisualizationAgent(
255
  api_keys=st.session_state.api_keys,
256
+ model_configs=st.session_state.model_configs_runtime,
257
  model=st.session_state.selected_model
258
  )
259
  st.session_state.modeling_coding_agent = ModelingCodingAgent(
260
  api_keys=st.session_state.api_keys,
261
+ model_configs=st.session_state.model_configs_runtime,
262
  model=st.session_state.selected_model
263
  )
264
  st.session_state.report_agent = ReportAgent(
265
  api_keys=st.session_state.api_keys,
266
+ model_configs=st.session_state.model_configs_runtime,
267
  model=st.session_state.selected_model
268
  )
269
  st.session_state.planner_agent = PlannerAgent(
270
  api_keys=st.session_state.api_keys,
271
+ model_configs=st.session_state.model_configs_runtime,
272
  model=st.session_state.selected_model
273
  )
274
  st.session_state.auto_mode = False
 
288
  st.session_state.auto_mode = False
289
  st.session_state.planner_agent = PlannerAgent(
290
  api_keys=st.session_state.api_keys,
291
+ model_configs=st.session_state.model_configs_runtime,
292
  model=st.session_state.selected_model
293
  )
294
  st.rerun()
 
333
  pg.run()
334
 
335
  if __name__ == "__main__":
336
+ run_app()