jacket0603 commited on
Commit
bae046a
·
1 Parent(s): f58dd42
Files changed (3) hide show
  1. Lagent +1 -0
  2. app.py +198 -0
  3. requirements.txt +164 -0
Lagent ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 723870607138e9c973688036a0b7b7dea4e313bb
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import json
4
+ import re
5
+ import requests
6
+ import streamlit as st
7
+
8
+ from lagent.agents import Agent
9
+ from lagent.prompts.parsers import PluginParser
10
+ from lagent.agents.stream import PLUGIN_CN, get_plugin_prompt
11
+ from lagent.schema import AgentMessage
12
+ from lagent.actions import ArxivSearch
13
+ from lagent.hooks import Hook
14
+ from lagent.llms import GPTAPI
15
+
16
+ YOUR_TOKEN_HERE = os.getenv("token")
17
+ if not YOUR_TOKEN_HERE:
18
+ raise EnvironmentError("未找到环境变量 'token',请设置后再运行程序。")
19
+
20
+ # Hook类,用于对消息添加前缀
21
+ class PrefixedMessageHook(Hook):
22
+ def __init__(self, prefix, senders=None):
23
+ """
24
+ 初始化Hook
25
+ :param prefix: 消息前缀
26
+ :param senders: 指定发送者列表
27
+ """
28
+ self.prefix = prefix
29
+ self.senders = senders or []
30
+
31
+ def before_agent(self, agent, messages, session_id):
32
+ """
33
+ 在代理处理消息前修改消息内容
34
+ :param agent: 当前代理
35
+ :param messages: 消息列表
36
+ :param session_id: 会话ID
37
+ """
38
+ for message in messages:
39
+ if message.sender in self.senders:
40
+ message.content = self.prefix + message.content
41
+
42
+ class AsyncBlogger:
43
+ """博客生成类,整合写作者和批评者。"""
44
+
45
+ def __init__(self, model_type, api_base, writer_prompt, critic_prompt, critic_prefix='', max_turn=2):
46
+ """
47
+ 初始化博客生成器
48
+ :param model_type: 模型类型
49
+ :param api_base: API 基地址
50
+ :param writer_prompt: 写作者提示词
51
+ :param critic_prompt: 批评者提示词
52
+ :param critic_prefix: 批评消息前缀
53
+ :param max_turn: 最大轮次
54
+ """
55
+ self.model_type = model_type
56
+ self.api_base = api_base
57
+ self.llm = GPTAPI(
58
+ model_type=model_type,
59
+ api_base=api_base,
60
+ key=YOUR_TOKEN_HERE,
61
+ max_new_tokens=4096,
62
+ )
63
+ self.plugins = [dict(type='lagent.actions.ArxivSearch')]
64
+ self.writer = Agent(
65
+ self.llm,
66
+ writer_prompt,
67
+ name='写作者',
68
+ output_format=dict(
69
+ type=PluginParser,
70
+ template=PLUGIN_CN,
71
+ prompt=get_plugin_prompt(self.plugins)
72
+ )
73
+ )
74
+ self.critic = Agent(
75
+ self.llm,
76
+ critic_prompt,
77
+ name='批评者',
78
+ hooks=[PrefixedMessageHook(critic_prefix, ['写作者'])]
79
+ )
80
+ self.max_turn = max_turn
81
+
82
+ async def forward(self, message: AgentMessage, update_placeholder):
83
+ """
84
+ 执行多阶段博客生成流程
85
+ :param message: 初始消息
86
+ :param update_placeholder: Streamlit占位符
87
+ :return: 最终优化的博客内容
88
+ """
89
+ step1_placeholder = update_placeholder.container()
90
+ step2_placeholder = update_placeholder.container()
91
+ step3_placeholder = update_placeholder.container()
92
+
93
+ # 第一步:生成初始内容
94
+ step1_placeholder.markdown("**Step 1: 生成初始内容...**")
95
+ message = self.writer(message)
96
+ if message.content:
97
+ step1_placeholder.markdown(f"**生成的初始内容**:\n\n{message.content}")
98
+ else:
99
+ step1_placeholder.markdown("**生成的初始内容为空,请检查生成逻辑。**")
100
+
101
+ # 第二步:批评者提供反馈
102
+ step2_placeholder.markdown("**Step 2: 批评者正在提供反馈和文献推荐...**")
103
+ message = self.critic(message)
104
+ if message.content:
105
+ # 解析批评者反馈
106
+ suggestions = re.search(r"1\. 批评建议:\n(.*?)2\. 推荐的关键词:", message.content, re.S)
107
+ keywords = re.search(r"2\. 推荐的关键词:\n- (.*)", message.content)
108
+ feedback = suggestions.group(1).strip() if suggestions else "未提供批评建议"
109
+ keywords = keywords.group(1).strip() if keywords else "未提供关键词"
110
+
111
+ # Arxiv 文献查询
112
+ arxiv_search = ArxivSearch()
113
+ arxiv_results = arxiv_search.get_arxiv_article_information(keywords)
114
+
115
+ # 显示批评内容和文献推荐
116
+ message.content = f"**批评建议**:\n{feedback}\n\n**推荐的文献**:\n{arxiv_results}"
117
+ step2_placeholder.markdown(f"**批评和文献推荐**:\n\n{message.content}")
118
+ else:
119
+ step2_placeholder.markdown("**批评内容为空,请检查批评逻辑。**")
120
+
121
+ # 第三步:写作者根据反馈优化内容
122
+ step3_placeholder.markdown("**Step 3: 根据反馈改进内容...**")
123
+ improvement_prompt = AgentMessage(
124
+ sender="critic",
125
+ content=(
126
+ f"根据以下批评建议和推荐文献对内容进行改进:\n\n"
127
+ f"批评建议:\n{feedback}\n\n"
128
+ f"推荐文献:\n{arxiv_results}\n\n"
129
+ f"请优化��始内容,使其更加清晰、丰富,并符合专业水准。"
130
+ ),
131
+ )
132
+ message = self.writer(improvement_prompt)
133
+ if message.content:
134
+ step3_placeholder.markdown(f"**最终优化的博客内容**:\n\n{message.content}")
135
+ else:
136
+ step3_placeholder.markdown("**最终优化的博客内容为空,请检查生成逻辑。**")
137
+
138
+ return message
139
+
140
+ def setup_sidebar():
141
+ """设置侧边栏,选择模型。"""
142
+ model_name = st.sidebar.text_input('模型名称:', value='internlm2.5-latest')
143
+ api_base = st.sidebar.text_input(
144
+ 'API Base 地址:', value='https://internlm-chat.intern-ai.org.cn/puyu/api/v1/chat/completions'
145
+ )
146
+
147
+ return model_name, api_base
148
+
149
+ def main():
150
+ """
151
+ 主函数:构建Streamlit界面并处理用户交互
152
+ """
153
+ st.set_page_config(layout='wide', page_title='Lagent Web Demo', page_icon='🤖')
154
+ st.title("多代理博客优化助手")
155
+
156
+ model_type, api_base = setup_sidebar()
157
+ topic = st.text_input('输入一个话题:', 'Self-Supervised Learning')
158
+ generate_button = st.button('生成博客内容')
159
+
160
+ if (
161
+ 'blogger' not in st.session_state or
162
+ st.session_state['model_type'] != model_type or
163
+ st.session_state['api_base'] != api_base
164
+ ):
165
+ st.session_state['blogger'] = AsyncBlogger(
166
+ model_type=model_type,
167
+ api_base=api_base,
168
+ writer_prompt="你是一位优秀的AI内容写作者,请撰写一篇有吸引力且信息丰富的博客内容。",
169
+ critic_prompt="""
170
+ 作为一位严谨的批评者,请给出建设性的批评和改进建议,并基于相关主题使用已有的工具推荐一些参考文献,推荐的关键词应该是英语形式,简洁且切题。
171
+ 请按照以下格式提供反馈:
172
+ 1. 批评建议:
173
+ - (具体建议)
174
+ 2. 推荐的关键词:
175
+ - (关键词1, 关键词2, ...)
176
+ """,
177
+ critic_prefix="请批评以下内容,并提供改进建议:\n\n"
178
+ )
179
+ st.session_state['model_type'] = model_type
180
+ st.session_state['api_base'] = api_base
181
+
182
+ if generate_button:
183
+ update_placeholder = st.empty()
184
+
185
+ async def run_async_blogger():
186
+ message = AgentMessage(
187
+ sender='user',
188
+ content=f"请撰写一篇关于{topic}的博客文章,要求表达专业,生动有趣,并且易于理解。"
189
+ )
190
+ result = await st.session_state['blogger'].forward(message, update_placeholder)
191
+ return result
192
+
193
+ loop = asyncio.new_event_loop()
194
+ asyncio.set_event_loop(loop)
195
+ loop.run_until_complete(run_async_blogger())
196
+
197
+ if __name__ == '__main__':
198
+ main()
requirements.txt CHANGED
@@ -4,3 +4,167 @@ pandas==2.2.2
4
  torch==2.3.0
5
  torchvision==0.18.0
6
  tqdm==4.66.4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  torch==2.3.0
5
  torchvision==0.18.0
6
  tqdm==4.66.4
7
+ anthropic
8
+ aiohappyeyeballs==2.4.4
9
+ aiohttp==3.11.10
10
+ aiosignal==1.3.2
11
+ altair==5.5.0
12
+ annotated-types==0.7.0
13
+ anyio==4.7.0
14
+ argon2-cffi==23.1.0
15
+ argon2-cffi-bindings==21.2.0
16
+ arrow==1.3.0
17
+ arxiv==2.1.3
18
+ asttokens==3.0.0
19
+ async-lru==2.0.4
20
+ async-timeout==5.0.1
21
+ asyncache==0.3.1
22
+ asyncer==0.0.8
23
+ attrs==24.3.0
24
+ babel==2.16.0
25
+ backports.strenum==1.3.1
26
+ beautifulsoup4==4.12.3
27
+ bleach==6.2.0
28
+ blinker==1.9.0
29
+ Brotli==1.1.0
30
+ cachetools==5.5.0
31
+ certifi==2024.12.14
32
+ cffi==1.17.1
33
+ charset-normalizer==3.4.0
34
+ class-registry==2.1.2
35
+ click==8.1.7
36
+ colorama==0.4.6
37
+ comm==0.2.2
38
+ datasets==3.1.0
39
+ debugpy==1.8.11
40
+ decorator==5.1.1
41
+ defusedxml==0.7.1
42
+ dill==0.3.8
43
+ distro==1.9.0
44
+ duckduckgo_search==5.3.1b1
45
+ exceptiongroup==1.2.2
46
+ executing==2.1.0
47
+ fastjsonschema==2.21.1
48
+ feedparser==6.0.11
49
+ filelock==3.16.1
50
+ fqdn==1.5.1
51
+ frozenlist==1.5.0
52
+ fsspec==2024.9.0
53
+ func_timeout==4.3.5
54
+ gitdb==4.0.11
55
+ GitPython==3.1.43
56
+ griffe==0.48.0
57
+ h11==0.14.0
58
+ h2==4.1.0
59
+ hpack==4.0.0
60
+ httpcore==1.0.7
61
+ httpx==0.28.1
62
+ huggingface-hub==0.27.0
63
+ hyperframe==6.0.1
64
+ idna==3.10
65
+ ipykernel==6.29.5
66
+ ipython==8.30.0
67
+ ipywidgets==8.1.5
68
+ isoduration==20.11.0
69
+ jedi==0.19.2
70
+ Jinja2==3.1.4
71
+ json5==0.10.0
72
+ jsonpointer==3.0.0
73
+ jsonschema==4.23.0
74
+ jsonschema-specifications==2024.10.1
75
+ jupyter==1.0.0
76
+ jupyter-console==6.6.3
77
+ jupyter-events==0.10.0
78
+ jupyter-lsp==2.2.5
79
+ jupyter_client==8.6.2
80
+ jupyter_core==5.7.2
81
+ jupyter_server==2.14.2
82
+ jupyter_server_terminals==0.5.3
83
+ jupyterlab==4.3.3
84
+ jupyterlab_pygments==0.3.0
85
+ jupyterlab_server==2.27.3
86
+ jupyterlab_widgets==3.0.13
87
+ -e git+https://github.com/InternLM/lagent.git@e304e5d323cdbb631257fac9187d16b99476bc2f#egg=lagent
88
+ markdown-it-py==3.0.0
89
+ MarkupSafe==3.0.2
90
+ matplotlib-inline==0.1.7
91
+ mdurl==0.1.2
92
+ mistune==3.0.2
93
+ multidict==6.1.0
94
+ multiprocess==0.70.16
95
+ narwhals==1.18.4
96
+ nbclient==0.10.1
97
+ nbconvert==7.16.4
98
+ nbformat==5.10.4
99
+ nest-asyncio==1.6.0
100
+ notebook==7.3.1
101
+ notebook_shim==0.2.4
102
+ numpy==2.2.0
103
+ overrides==7.7.0
104
+ packaging==24.2
105
+ pandas==2.2.3
106
+ pandocfilters==1.5.1
107
+ parso==0.8.4
108
+ pexpect==4.9.0
109
+ pillow==10.4.0
110
+ platformdirs==4.3.6
111
+ prometheus_client==0.21.1
112
+ prompt_toolkit==3.0.48
113
+ propcache==0.2.1
114
+ protobuf==5.29.1
115
+ psutil==6.1.0
116
+ ptyprocess==0.7.0
117
+ pure_eval==0.2.3
118
+ pyarrow==18.1.0
119
+ pycparser==2.22
120
+ pydantic==2.6.4
121
+ pydantic_core==2.16.3
122
+ pydeck==0.9.1
123
+ Pygments==2.18.0
124
+ python-dateutil==2.9.0.post0
125
+ python-json-logger==3.2.1
126
+ pytz==2024.2
127
+ PyYAML==6.0.2
128
+ pyzmq==26.2.0
129
+ qtconsole==5.6.1
130
+ QtPy==2.4.2
131
+ referencing==0.35.1
132
+ regex==2024.11.6
133
+ requests==2.32.3
134
+ rfc3339-validator==0.1.4
135
+ rfc3986-validator==0.1.1
136
+ rich==13.9.4
137
+ rpds-py==0.22.3
138
+ Send2Trash==1.8.3
139
+ sgmllib3k==1.0.0
140
+ six==1.17.0
141
+ smmap==5.0.1
142
+ sniffio==1.3.1
143
+ socksio==1.0.0
144
+ soupsieve==2.6
145
+ stack-data==0.6.3
146
+ streamlit==1.39.0
147
+ tenacity==9.0.0
148
+ termcolor==2.4.0
149
+ terminado==0.18.1
150
+ tiktoken==0.8.0
151
+ timeout-decorator==0.5.0
152
+ tinycss2==1.4.0
153
+ toml==0.10.2
154
+ tomli==2.2.1
155
+ tornado==6.4.2
156
+ tqdm==4.67.1
157
+ traitlets==5.14.3
158
+ types-python-dateutil==2.9.0.20241206
159
+ typing_extensions==4.12.2
160
+ tzdata==2024.2
161
+ uri-template==1.3.0
162
+ urllib3==2.2.3
163
+ watchdog==5.0.3
164
+ wcwidth==0.2.13
165
+ webcolors==24.11.1
166
+ webencodings==0.5.1
167
+ websocket-client==1.8.0
168
+ widgetsnbextension==4.0.13
169
+ xxhash==3.5.0
170
+ yarl==1.18.3