lt676767 commited on
Commit
080c12b
·
verified ·
1 Parent(s): 2d6bf3d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +198 -0
  2. requirements.txt +163 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import json
4
+ import re
5
+ import requests
6
+ import streamlit as st
7
+
8
+ from lagent.agents import Agent
9
+ from lagent.prompts.parsers import PluginParser
10
+ from lagent.agents.stream import PLUGIN_CN, get_plugin_prompt
11
+ from lagent.schema import AgentMessage
12
+ from lagent.actions import ArxivSearch
13
+ from lagent.hooks import Hook
14
+ from lagent.llms import GPTAPI
15
+
16
+ YOUR_TOKEN_HERE = os.getenv("token")
17
+ if not YOUR_TOKEN_HERE:
18
+ raise EnvironmentError("未找到环境变量 'token',请设置后再运行程序。")
19
+
20
+ # Hook类,用于对消息添加前缀
21
+ class PrefixedMessageHook(Hook):
22
+ def __init__(self, prefix, senders=None):
23
+ """
24
+ 初始化Hook
25
+ :param prefix: 消息前缀
26
+ :param senders: 指定发送者列表
27
+ """
28
+ self.prefix = prefix
29
+ self.senders = senders or []
30
+
31
+ def before_agent(self, agent, messages, session_id):
32
+ """
33
+ 在代理处理消息前修改消息内容
34
+ :param agent: 当前代理
35
+ :param messages: 消息列表
36
+ :param session_id: 会话ID
37
+ """
38
+ for message in messages:
39
+ if message.sender in self.senders:
40
+ message.content = self.prefix + message.content
41
+
42
+ class AsyncBlogger:
43
+ """博客生成类,整合写作者和批评者。"""
44
+
45
+ def __init__(self, model_type, api_base, writer_prompt, critic_prompt, critic_prefix='', max_turn=2):
46
+ """
47
+ 初始化博客生成器
48
+ :param model_type: 模型类型
49
+ :param api_base: API 基地址
50
+ :param writer_prompt: 写作者提示词
51
+ :param critic_prompt: 批评者提示词
52
+ :param critic_prefix: 批评消息前缀
53
+ :param max_turn: 最大轮次
54
+ """
55
+ self.model_type = model_type
56
+ self.api_base = api_base
57
+ self.llm = GPTAPI(
58
+ model_type=model_type,
59
+ api_base=api_base,
60
+ key=YOUR_TOKEN_HERE,
61
+ max_new_tokens=4096,
62
+ )
63
+ self.plugins = [dict(type='lagent.actions.ArxivSearch')]
64
+ self.writer = Agent(
65
+ self.llm,
66
+ writer_prompt,
67
+ name='写作者',
68
+ output_format=dict(
69
+ type=PluginParser,
70
+ template=PLUGIN_CN,
71
+ prompt=get_plugin_prompt(self.plugins)
72
+ )
73
+ )
74
+ self.critic = Agent(
75
+ self.llm,
76
+ critic_prompt,
77
+ name='批评者',
78
+ hooks=[PrefixedMessageHook(critic_prefix, ['写作者'])]
79
+ )
80
+ self.max_turn = max_turn
81
+
82
+ async def forward(self, message: AgentMessage, update_placeholder):
83
+ """
84
+ 执行多阶段博客生成流程
85
+ :param message: 初始消息
86
+ :param update_placeholder: Streamlit占位符
87
+ :return: 最终优化的博客内容
88
+ """
89
+ step1_placeholder = update_placeholder.container()
90
+ step2_placeholder = update_placeholder.container()
91
+ step3_placeholder = update_placeholder.container()
92
+
93
+ # 第一步:生成初始内容
94
+ step1_placeholder.markdown("**Step 1: 生成初始内容...**")
95
+ message = self.writer(message)
96
+ if message.content:
97
+ step1_placeholder.markdown(f"**生成的初始内容**:\n\n{message.content}")
98
+ else:
99
+ step1_placeholder.markdown("**生成的初始内容为空,请检查生成逻辑。**")
100
+
101
+ # 第二步:批评者提供反馈
102
+ step2_placeholder.markdown("**Step 2: 批评者正在提供反馈和文献推荐...**")
103
+ message = self.critic(message)
104
+ if message.content:
105
+ # 解析批评者反馈
106
+ suggestions = re.search(r"1\. 批评建议:\n(.*?)2\. 推荐的关键词:", message.content, re.S)
107
+ keywords = re.search(r"2\. 推荐的关键词:\n- (.*)", message.content)
108
+ feedback = suggestions.group(1).strip() if suggestions else "未提供批评建议"
109
+ keywords = keywords.group(1).strip() if keywords else "未提供关键词"
110
+
111
+ # Arxiv 文献查询
112
+ arxiv_search = ArxivSearch()
113
+ arxiv_results = arxiv_search.get_arxiv_article_information(keywords)
114
+
115
+ # 显示批评内容和文献推荐
116
+ message.content = f"**批评建议**:\n{feedback}\n\n**推荐的文献**:\n{arxiv_results}"
117
+ step2_placeholder.markdown(f"**批评和文献推荐**:\n\n{message.content}")
118
+ else:
119
+ step2_placeholder.markdown("**批评内容为空,请检查批评逻辑。**")
120
+
121
+ # 第三步:写作者根据反馈优化内容
122
+ step3_placeholder.markdown("**Step 3: 根据反馈改进内容...**")
123
+ improvement_prompt = AgentMessage(
124
+ sender="critic",
125
+ content=(
126
+ f"根据以下批评建议和推荐文献对内容进行改进:\n\n"
127
+ f"批评建议:\n{feedback}\n\n"
128
+ f"推荐文献:\n{arxiv_results}\n\n"
129
+ f"请优化��始内容,使其更加清晰、丰富,并符合专业水准。"
130
+ ),
131
+ )
132
+ message = self.writer(improvement_prompt)
133
+ if message.content:
134
+ step3_placeholder.markdown(f"**最终优化的博客内容**:\n\n{message.content}")
135
+ else:
136
+ step3_placeholder.markdown("**最终优化的博客内容为空,请检查生成逻辑。**")
137
+
138
+ return message
139
+
140
+ def setup_sidebar():
141
+ """设置侧边栏,选择模型。"""
142
+ model_name = st.sidebar.text_input('模型名称:', value='internlm2.5-latest')
143
+ api_base = st.sidebar.text_input(
144
+ 'API Base 地址:', value='https://internlm-chat.intern-ai.org.cn/puyu/api/v1/chat/completions'
145
+ )
146
+
147
+ return model_name, api_base
148
+
149
+ def main():
150
+ """
151
+ 主函数:构建Streamlit界面并处理用户交互
152
+ """
153
+ st.set_page_config(layout='wide', page_title='Lagent Web Demo', page_icon='🤖')
154
+ st.title("多代理博客优化助手")
155
+
156
+ model_type, api_base = setup_sidebar()
157
+ topic = st.text_input('输入一个话题:', 'Self-Supervised Learning')
158
+ generate_button = st.button('生成博客内容')
159
+
160
+ if (
161
+ 'blogger' not in st.session_state or
162
+ st.session_state['model_type'] != model_type or
163
+ st.session_state['api_base'] != api_base
164
+ ):
165
+ st.session_state['blogger'] = AsyncBlogger(
166
+ model_type=model_type,
167
+ api_base=api_base,
168
+ writer_prompt="你是一位优秀的AI内容写作者,请撰写一篇有吸引力且信息丰富的博客内容。",
169
+ critic_prompt="""
170
+ 作为一位严谨的批评者,请给出建设性的批评和改进建议,并基于相关主题使用已有的工具推荐一些参考文献,推荐的关键词应该是英语形式,简洁且切题。
171
+ 请按照以下格式提供反馈:
172
+ 1. 批评建议:
173
+ - (具体建议)
174
+ 2. 推荐的关键词:
175
+ - (关键词1, 关键词2, ...)
176
+ """,
177
+ critic_prefix="请批评以下内容,并提供改进建议:\n\n"
178
+ )
179
+ st.session_state['model_type'] = model_type
180
+ st.session_state['api_base'] = api_base
181
+
182
+ if generate_button:
183
+ update_placeholder = st.empty()
184
+
185
+ async def run_async_blogger():
186
+ message = AgentMessage(
187
+ sender='user',
188
+ content=f"请撰写一篇关于{topic}的博客文章,要求表达专业,生动有趣,并且易于理解。"
189
+ )
190
+ result = await st.session_state['blogger'].forward(message, update_placeholder)
191
+ return result
192
+
193
+ loop = asyncio.new_event_loop()
194
+ asyncio.set_event_loop(loop)
195
+ loop.run_until_complete(run_async_blogger())
196
+
197
+ if __name__ == '__main__':
198
+ main()
requirements.txt ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.4.4
2
+ aiohttp==3.11.10
3
+ aiosignal==1.3.2
4
+ altair==5.5.0
5
+ annotated-types==0.7.0
6
+ anyio==4.7.0
7
+ argon2-cffi==23.1.0
8
+ argon2-cffi-bindings==21.2.0
9
+ arrow==1.3.0
10
+ arxiv==2.1.3
11
+ asttokens==3.0.0
12
+ async-lru==2.0.4
13
+ async-timeout==5.0.1
14
+ asyncache==0.3.1
15
+ asyncer==0.0.8
16
+ attrs==24.3.0
17
+ babel==2.16.0
18
+ backports.strenum==1.3.1
19
+ beautifulsoup4==4.12.3
20
+ bleach==6.2.0
21
+ blinker==1.9.0
22
+ Brotli==1.1.0
23
+ cachetools==5.5.0
24
+ certifi==2024.12.14
25
+ cffi==1.17.1
26
+ charset-normalizer==3.4.0
27
+ class-registry==2.1.2
28
+ click==8.1.7
29
+ colorama==0.4.6
30
+ comm==0.2.2
31
+ datasets==3.1.0
32
+ debugpy==1.8.11
33
+ decorator==5.1.1
34
+ defusedxml==0.7.1
35
+ dill==0.3.8
36
+ distro==1.9.0
37
+ duckduckgo_search==5.3.1b1
38
+ exceptiongroup==1.2.2
39
+ executing==2.1.0
40
+ fastjsonschema==2.21.1
41
+ feedparser==6.0.11
42
+ filelock==3.16.1
43
+ fqdn==1.5.1
44
+ frozenlist==1.5.0
45
+ fsspec==2024.9.0
46
+ func_timeout==4.3.5
47
+ gitdb==4.0.11
48
+ GitPython==3.1.43
49
+ griffe==0.48.0
50
+ h11==0.14.0
51
+ h2==4.1.0
52
+ hpack==4.0.0
53
+ httpcore==1.0.7
54
+ httpx==0.28.1
55
+ huggingface-hub==0.27.0
56
+ hyperframe==6.0.1
57
+ idna==3.10
58
+ ipykernel==6.29.5
59
+ ipython==8.30.0
60
+ ipywidgets==8.1.5
61
+ isoduration==20.11.0
62
+ jedi==0.19.2
63
+ Jinja2==3.1.4
64
+ json5==0.10.0
65
+ jsonpointer==3.0.0
66
+ jsonschema==4.23.0
67
+ jsonschema-specifications==2024.10.1
68
+ jupyter==1.0.0
69
+ jupyter-console==6.6.3
70
+ jupyter-events==0.10.0
71
+ jupyter-lsp==2.2.5
72
+ jupyter_client==8.6.2
73
+ jupyter_core==5.7.2
74
+ jupyter_server==2.14.2
75
+ jupyter_server_terminals==0.5.3
76
+ jupyterlab==4.3.3
77
+ jupyterlab_pygments==0.3.0
78
+ jupyterlab_server==2.27.3
79
+ jupyterlab_widgets==3.0.13
80
+ -e git+https://github.com/InternLM/lagent.git@e304e5d323cdbb631257fac9187d16b99476bc2f#egg=lagent
81
+ markdown-it-py==3.0.0
82
+ MarkupSafe==3.0.2
83
+ matplotlib-inline==0.1.7
84
+ mdurl==0.1.2
85
+ mistune==3.0.2
86
+ multidict==6.1.0
87
+ multiprocess==0.70.16
88
+ narwhals==1.18.4
89
+ nbclient==0.10.1
90
+ nbconvert==7.16.4
91
+ nbformat==5.10.4
92
+ nest-asyncio==1.6.0
93
+ notebook==7.3.1
94
+ notebook_shim==0.2.4
95
+ numpy==2.2.0
96
+ overrides==7.7.0
97
+ packaging==24.2
98
+ pandas==2.2.3
99
+ pandocfilters==1.5.1
100
+ parso==0.8.4
101
+ pexpect==4.9.0
102
+ pillow==10.4.0
103
+ platformdirs==4.3.6
104
+ prometheus_client==0.21.1
105
+ prompt_toolkit==3.0.48
106
+ propcache==0.2.1
107
+ protobuf==5.29.1
108
+ psutil==6.1.0
109
+ ptyprocess==0.7.0
110
+ pure_eval==0.2.3
111
+ pyarrow==18.1.0
112
+ pycparser==2.22
113
+ pydantic==2.6.4
114
+ pydantic_core==2.16.3
115
+ pydeck==0.9.1
116
+ Pygments==2.18.0
117
+ python-dateutil==2.9.0.post0
118
+ python-json-logger==3.2.1
119
+ pytz==2024.2
120
+ PyYAML==6.0.2
121
+ pyzmq==26.2.0
122
+ qtconsole==5.6.1
123
+ QtPy==2.4.2
124
+ referencing==0.35.1
125
+ regex==2024.11.6
126
+ requests==2.32.3
127
+ rfc3339-validator==0.1.4
128
+ rfc3986-validator==0.1.1
129
+ rich==13.9.4
130
+ rpds-py==0.22.3
131
+ Send2Trash==1.8.3
132
+ sgmllib3k==1.0.0
133
+ six==1.17.0
134
+ smmap==5.0.1
135
+ sniffio==1.3.1
136
+ socksio==1.0.0
137
+ soupsieve==2.6
138
+ stack-data==0.6.3
139
+ streamlit==1.39.0
140
+ tenacity==9.0.0
141
+ termcolor==2.4.0
142
+ terminado==0.18.1
143
+ tiktoken==0.8.0
144
+ timeout-decorator==0.5.0
145
+ tinycss2==1.4.0
146
+ toml==0.10.2
147
+ tomli==2.2.1
148
+ tornado==6.4.2
149
+ tqdm==4.67.1
150
+ traitlets==5.14.3
151
+ types-python-dateutil==2.9.0.20241206
152
+ typing_extensions==4.12.2
153
+ tzdata==2024.2
154
+ uri-template==1.3.0
155
+ urllib3==2.2.3
156
+ watchdog==5.0.3
157
+ wcwidth==0.2.13
158
+ webcolors==24.11.1
159
+ webencodings==0.5.1
160
+ websocket-client==1.8.0
161
+ widgetsnbextension==4.0.13
162
+ xxhash==3.5.0
163
+ yarl==1.18.3