wangrongsheng commited on
Commit
8c6097b
·
verified ·
1 Parent(s): f74a3e0

Upload folder using huggingface_hub

Browse files
Files changed (43) hide show
  1. .gitattributes +5 -0
  2. LICENSE +34 -0
  3. Open Source Software Notice +218 -0
  4. README.md +554 -6
  5. README_EN.md +554 -0
  6. checklist.chk +14 -0
  7. config.json +31 -0
  8. configuration_openpangu_dense.py +56 -0
  9. deepdiver_v2/cli/README.md +238 -0
  10. deepdiver_v2/cli/demo.py +668 -0
  11. deepdiver_v2/cli/run_demo.sh +171 -0
  12. deepdiver_v2/config/config.py +239 -0
  13. deepdiver_v2/env.template +44 -0
  14. deepdiver_v2/requirements.txt +8 -0
  15. deepdiver_v2/src/__init__.py +11 -0
  16. deepdiver_v2/src/agents/__init__.py +62 -0
  17. deepdiver_v2/src/agents/base_agent.py +692 -0
  18. deepdiver_v2/src/agents/objective_information_seeker.py +428 -0
  19. deepdiver_v2/src/agents/planner_agent.py +1203 -0
  20. deepdiver_v2/src/agents/subjective_information_seeker.py +417 -0
  21. deepdiver_v2/src/agents/writer_agent.py +477 -0
  22. deepdiver_v2/src/tools/__init__.py +36 -0
  23. deepdiver_v2/src/tools/mcp_client.py +814 -0
  24. deepdiver_v2/src/tools/mcp_server_standard.py +1751 -0
  25. deepdiver_v2/src/tools/mcp_tools.py +0 -0
  26. deepdiver_v2/src/tools/server_config.yaml +73 -0
  27. deepdiver_v2/src/utils/__init__.py +8 -0
  28. deepdiver_v2/src/utils/status_codes.py +12 -0
  29. deepdiver_v2/src/workspace/__init__.py +26 -0
  30. deepdiver_v2/src/workspace/local_workspace_manager.py +420 -0
  31. docs/openpangu-deepdiver-v2-tech-report.pdf +3 -0
  32. generation_config.json +11 -0
  33. model-00001-of-00004.safetensors +3 -0
  34. model-00002-of-00004.safetensors +3 -0
  35. model-00003-of-00004.safetensors +3 -0
  36. model-00004-of-00004.safetensors +3 -0
  37. model.safetensors.index.json +486 -0
  38. modeling_openpangu_dense.py +585 -0
  39. modular_openpangu_dense.py +149 -0
  40. special_tokens_map.json +30 -0
  41. tokenization_openpangu.py +273 -0
  42. tokenizer.model +3 -0
  43. tokenizer_config.json +1 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.pdf filter=lfs diff=lfs merge=lfs -text
37
+ model-00004-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text
38
+ model-00001-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text
39
+ model-00002-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text
40
+ model-00003-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0
2
+
3
+ This OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0 (the "Agreement") is a legal agreement between You and Huawei Technologies Co., Ltd. ("Huawei", "We" or "Us"), and it governs Your reproducing, use, modification, and distribution of openPangu as made available by Huawei under this Agreement.
4
+
5
+ By using, reproducing, modifying, distributing, performing or displaying any portion or element of openPangu, or otherwise accepting the terms of this Agreement, You agree to be bound by this Agreement.
6
+
7
+ 1. Definitions.
8
+ 1.1. “openPangu” or “Model” means openPangu large language models and software, including trained model weights, parameters (including optimizer states), accompanying source code and scripts released under this Agreement.
9
+ 1.2. “Derivative Model” means all (1) modifications to the Model, (2) works based on the Model, and (3) any other derivative works of the Model. For clarity, information or content results from operating or otherwise using the Model is not a Derivative Model.
10
+ 1.3. “You” or “Your” means an individual or Legal Entity exercising permissions granted by this Agreement and/or using the Model for any purpose.
11
+ 1.4. “Third Party” or “Third Parties” means individuals or legal entities that are not under common control with Us or You.
12
+
13
+ 2. License Grant. Subject to Your full compliance with the terms and conditions of this Agreement, We hereby grant to You a perpetual, worldwide, non-exclusive, non-transferable, no-charge, royalty-free license (except as stated in Section 3) to use, reproduce, modify, and distribute the Model.
14
+
15
+ 3. Conditions for License Grant. You represent and warrant that You will not, access, download, install, run, deploy, integrate, modify, or otherwise use the Model, directly or indirectly, within the European Union.
16
+
17
+
18
+ 4. Redistribution.
19
+ 4.1. If You distribute the Model or Derivative Model, You shall retain in Your distribution (1) a copy of this agreement, and (2) all copyright notices and other notices of origin included in the Model that are applicable to Your distribution.
20
+ 4.2. Further, if You distribute or make available to Third Parties a product or service (including another AI model) based on the Model, You are required to (1) display the acknowledgement “Powered by openPangu” and (2) include a trademark notice “openPangu is a trademark of Huawei Technologies Co., Ltd.” on related webpages, user manuals, product documentations or other advertising materials mentioning features of the Model.
21
+ 4.3. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for Derivative Model made by You as a whole, provided Your use, reproduction, and distribution of the Model otherwise complies with the terms and conditions of this Agreement.
22
+
23
+ 5. Ownership. We do not claim ownership to any information or content generated using the Model or Derivative Model that are made by You. You are solely responsible for evaluating the accuracy and appropriateness of such information or content for Your use case.
24
+
25
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of Huawei, except as required for complying with Section 4.2.
26
+
27
+ 7. Indemnity. You will indemnify and hold harmless Huawei from and against any claim by any third party arising out of or related to Your use or distribution of the Model or Derivative Model made by You (e.g. a violation against Section 3). For avoidance of doubt, “third party” in this clause include supervisory authorities.
28
+
29
+ 8. THE MODEL IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, NONINFRINGEMENT, ACCURACY, OR THE ABSENCE OF LATENT OR OTHER DEFECTS OR ERRORS, WHETHER OR NOT DISCOVERABLE, ALL TO THE GREATEST EXTENT PERMISSIBLE UNDER APPLICABLE LAW.
30
+
31
+ 9. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MODEL, IN WHOLE OR IN PART, NO MATTER HOW IT’S CAUSED OR THE LEGAL THEORY IT IS BASED ON, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
32
+
33
+
34
+ END OF THE TERMS AND CONDITIONS
Open Source Software Notice ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OPEN SOURCE SOFTWARE NOTICE
2
+
3
+ Please note we provide an open source software notice along with this product and/or this product firmware (in the following just “this product”). The open source software licenses are granted by the respective right holders. And the open source licenses prevail all other license information with regard to the respective open source software contained in the product, including but not limited to End User Software Licensing Agreement. This notice is provided on behalf of Huawei Technologies Co. Ltd. and any of its local subsidiaries which may have provided this product to you in your local country.
4
+
5
+ Warranty Disclaimer
6
+ THE OPEN SOURCE SOFTWARE IN THIS PRODUCT IS DISTRIBUTED IN THE HOPE THAT IT WILL BE USEFUL, BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. SEE THE APPLICABLE LICENSES FOR MORE DETAILS.
7
+
8
+ Copyright Notice and License Texts
9
+
10
+ Software: transformers 4.53.2
11
+ Copyright notice:
12
+ Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
13
+
14
+ License Text:
15
+ ----------------------------------------
16
+
17
+ Apache License
18
+ Version 2.0, January 2004
19
+ http://www.apache.org/licenses/
20
+
21
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
22
+
23
+ 1. Definitions.
24
+
25
+ "License" shall mean the terms and conditions for use, reproduction,
26
+ and distribution as defined by Sections 1 through 9 of this document.
27
+
28
+ "Licensor" shall mean the copyright owner or entity authorized by
29
+ the copyright owner that is granting the License.
30
+
31
+ "Legal Entity" shall mean the union of the acting entity and all
32
+ other entities that control, are controlled by, or are under common
33
+ control with that entity. For the purposes of this definition,
34
+ "control" means (i) the power, direct or indirect, to cause the
35
+ direction or management of such entity, whether by contract or
36
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
37
+ outstanding shares, or (iii) beneficial ownership of such entity.
38
+
39
+ "You" (or "Your") shall mean an individual or Legal Entity
40
+ exercising permissions granted by this License.
41
+
42
+ "Source" form shall mean the preferred form for making modifications,
43
+ including but not limited to software source code, documentation
44
+ source, and configuration files.
45
+
46
+ "Object" form shall mean any form resulting from mechanical
47
+ transformation or translation of a Source form, including but
48
+ not limited to compiled object code, generated documentation,
49
+ and conversions to other media types.
50
+
51
+ "Work" shall mean the work of authorship, whether in Source or
52
+ Object form, made available under the License, as indicated by a
53
+ copyright notice that is included in or attached to the work
54
+ (an example is provided in the Appendix below).
55
+
56
+ "Derivative Works" shall mean any work, whether in Source or Object
57
+ form, that is based on (or derived from) the Work and for which the
58
+ editorial revisions, annotations, elaborations, or other modifications
59
+ represent, as a whole, an original work of authorship. For the purposes
60
+ of this License, Derivative Works shall not include works that remain
61
+ separable from, or merely link (or bind by name) to the interfaces of,
62
+ the Work and Derivative Works thereof.
63
+
64
+ "Contribution" shall mean any work of authorship, including
65
+ the original version of the Work and any modifications or additions
66
+ to that Work or Derivative Works thereof, that is intentionally
67
+ submitted to Licensor for inclusion in the Work by the copyright owner
68
+ or by an individual or Legal Entity authorized to submit on behalf of
69
+ the copyright owner. For the purposes of this definition, "submitted"
70
+ means any form of electronic, verbal, or written communication sent
71
+ to the Licensor or its representatives, including but not limited to
72
+ communication on electronic mailing lists, source code control systems,
73
+ and issue tracking systems that are managed by, or on behalf of, the
74
+ Licensor for the purpose of discussing and improving the Work, but
75
+ excluding communication that is conspicuously marked or otherwise
76
+ designated in writing by the copyright owner as "Not a Contribution."
77
+
78
+ "Contributor" shall mean Licensor and any individual or Legal Entity
79
+ on behalf of whom a Contribution has been received by Licensor and
80
+ subsequently incorporated within the Work.
81
+
82
+ 2. Grant of Copyright License. Subject to the terms and conditions of
83
+ this License, each Contributor hereby grants to You a perpetual,
84
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
85
+ copyright license to reproduce, prepare Derivative Works of,
86
+ publicly display, publicly perform, sublicense, and distribute the
87
+ Work and such Derivative Works in Source or Object form.
88
+
89
+ 3. Grant of Patent License. Subject to the terms and conditions of
90
+ this License, each Contributor hereby grants to You a perpetual,
91
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
92
+ (except as stated in this section) patent license to make, have made,
93
+ use, offer to sell, sell, import, and otherwise transfer the Work,
94
+ where such license applies only to those patent claims licensable
95
+ by such Contributor that are necessarily infringed by their
96
+ Contribution(s) alone or by combination of their Contribution(s)
97
+ with the Work to which such Contribution(s) was submitted. If You
98
+ institute patent litigation against any entity (including a
99
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
100
+ or a Contribution incorporated within the Work constitutes direct
101
+ or contributory patent infringement, then any patent licenses
102
+ granted to You under this License for that Work shall terminate
103
+ as of the date such litigation is filed.
104
+
105
+ 4. Redistribution. You may reproduce and distribute copies of the
106
+ Work or Derivative Works thereof in any medium, with or without
107
+ modifications, and in Source or Object form, provided that You
108
+ meet the following conditions:
109
+
110
+ (a) You must give any other recipients of the Work or
111
+ Derivative Works a copy of this License; and
112
+
113
+ (b) You must cause any modified files to carry prominent notices
114
+ stating that You changed the files; and
115
+
116
+ (c) You must retain, in the Source form of any Derivative Works
117
+ that You distribute, all copyright, patent, trademark, and
118
+ attribution notices from the Source form of the Work,
119
+ excluding those notices that do not pertain to any part of
120
+ the Derivative Works; and
121
+
122
+ (d) If the Work includes a "NOTICE" text file as part of its
123
+ distribution, then any Derivative Works that You distribute must
124
+ include a readable copy of the attribution notices contained
125
+ within such NOTICE file, excluding those notices that do not
126
+ pertain to any part of the Derivative Works, in at least one
127
+ of the following places: within a NOTICE text file distributed
128
+ as part of the Derivative Works; within the Source form or
129
+ documentation, if provided along with the Derivative Works; or,
130
+ within a display generated by the Derivative Works, if and
131
+ wherever such third-party notices normally appear. The contents
132
+ of the NOTICE file are for informational purposes only and
133
+ do not modify the License. You may add Your own attribution
134
+ notices within Derivative Works that You distribute, alongside
135
+ or as an addendum to the NOTICE text from the Work, provided
136
+ that such additional attribution notices cannot be construed
137
+ as modifying the License.
138
+
139
+ You may add Your own copyright statement to Your modifications and
140
+ may provide additional or different license terms and conditions
141
+ for use, reproduction, or distribution of Your modifications, or
142
+ for any such Derivative Works as a whole, provided Your use,
143
+ reproduction, and distribution of the Work otherwise complies with
144
+ the conditions stated in this License.
145
+
146
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
147
+ any Contribution intentionally submitted for inclusion in the Work
148
+ by You to the Licensor shall be under the terms and conditions of
149
+ this License, without any additional terms or conditions.
150
+ Notwithstanding the above, nothing herein shall supersede or modify
151
+ the terms of any separate license agreement you may have executed
152
+ with Licensor regarding such Contributions.
153
+
154
+ 6. Trademarks. This License does not grant permission to use the trade
155
+ names, trademarks, service marks, or product names of the Licensor,
156
+ except as required for reasonable and customary use in describing the
157
+ origin of the Work and reproducing the content of the NOTICE file.
158
+
159
+ 7. Disclaimer of Warranty. Unless required by applicable law or
160
+ agreed to in writing, Licensor provides the Work (and each
161
+ Contributor provides its Contributions) on an "AS IS" BASIS,
162
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
163
+ implied, including, without limitation, any warranties or conditions
164
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
165
+ PARTICULAR PURPOSE. You are solely responsible for determining the
166
+ appropriateness of using or redistributing the Work and assume any
167
+ risks associated with Your exercise of permissions under this License.
168
+
169
+ 8. Limitation of Liability. In no event and under no legal theory,
170
+ whether in tort (including negligence), contract, or otherwise,
171
+ unless required by applicable law (such as deliberate and grossly
172
+ negligent acts) or agreed to in writing, shall any Contributor be
173
+ liable to You for damages, including any direct, indirect, special,
174
+ incidental, or consequential damages of any character arising as a
175
+ result of this License or out of the use or inability to use the
176
+ Work (including but not limited to damages for loss of goodwill,
177
+ work stoppage, computer failure or malfunction, or any and all
178
+ other commercial damages or losses), even if such Contributor
179
+ has been advised of the possibility of such damages.
180
+
181
+ 9. Accepting Warranty or Additional Liability. While redistributing
182
+ the Work or Derivative Works thereof, You may choose to offer,
183
+ and charge a fee for, acceptance of support, warranty, indemnity,
184
+ or other liability obligations and/or rights consistent with this
185
+ License. However, in accepting such obligations, You may act only
186
+ on Your own behalf and on Your sole responsibility, not on behalf
187
+ of any other Contributor, and only if You agree to indemnify,
188
+ defend, and hold each Contributor harmless for any liability
189
+ incurred by, or claims asserted against, such Contributor by reason
190
+ of your accepting any such warranty or additional liability.
191
+
192
+ END OF TERMS AND CONDITIONS
193
+
194
+ APPENDIX: How to apply the Apache License to your work.
195
+
196
+ To apply the Apache License to your work, attach the following
197
+ boilerplate notice, with the fields enclosed by brackets "[]"
198
+ replaced with your own identifying information. (Don't include
199
+ the brackets!) The text should be enclosed in the appropriate
200
+ comment syntax for the file format. We also recommend that a
201
+ file or class name and description of purpose be included on the
202
+ same "printed page" as the copyright notice for easier
203
+ identification within third-party archives.
204
+
205
+ Copyright [yyyy] [name of copyright owner]
206
+
207
+ Licensed under the Apache License, Version 2.0 (the "License");
208
+ you may not use this file except in compliance with the License.
209
+ You may obtain a copy of the License at
210
+
211
+ http://www.apache.org/licenses/LICENSE-2.0
212
+
213
+ Unless required by applicable law or agreed to in writing, software
214
+ distributed under the License is distributed on an "AS IS" BASIS,
215
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
216
+ See the License for the specific language governing permissions and
217
+ limitations under the License.
218
+
README.md CHANGED
@@ -1,6 +1,554 @@
1
- ---
2
- license: other
3
- license_name: openpangu-model-license-agreement-version-1.0
4
- license_link: >-
5
- https://ai.gitcode.com/ascend-tribe/openPangu-Embedded-7B-DeepDiver/blob/main/LICENSE
6
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 开源盘古 Embedded-7B-DeepDiver
2
+ 中文 | [English](README_EN.md)
3
+ 📑[技术报告](https://ai.gitcode.com/ascend-tribe/openPangu-Embedded-7B-DeepDiver/blob/main/docs/openpangu-deepdiver-v2-tech-report.pdf)
4
+
5
+
6
+ ## 1. 简介
7
+ DeepDiver是openPangu系列中定位深度信息获取与处理的Agent,支持原生 Multi-Agent System(MAS),用于复杂知识问答与长文调研报告写作。
8
+
9
+ ### 特性
10
+ - 🔍 支持QA模式:回答100步+复杂知识性问题
11
+ - ✍️ 支持长文写作模式:撰写3w+字文章与报告
12
+ - 🔄 支持自适应模式:根据用户问题自动选择知识问答模式或长文写作模式
13
+
14
+ ## 2. 评测结果
15
+
16
+ | 测评集 | 测评指标 | openPangu-7B-DeepDiver|
17
+ | :------------: | :-----------------: | :--------: |
18
+ | **BrowseComp-zh** | Acc | 18.3 |
19
+ | **BrowseComp-en** | Acc | 8.3 |
20
+ |**XBench-DeepSearch** | Acc | 39.0 |
21
+
22
+ 注:上表仅展示复杂问答的结果,长文调研的评测结果请参考[技术报告](https://ai.gitcode.com/ascend-tribe/openPangu-Embedded-7B-DeepDiver/blob/main/docs/openpangu-deepdiver-v2-tech-report.pdf)
23
+
24
+ ## 3. 快速部署
25
+
26
+ ### 3.1 环境准备
27
+
28
+ ```bash
29
+ # 克隆并安装
30
+ git clone <repository-url>
31
+ cd deepdiver_v2
32
+ pip install -r requirements.txt
33
+ ```
34
+
35
+ ### 3.2 部署推理服务
36
+
37
+ #### 拉取镜像
38
+
39
+ ```
40
+ docker pull quay.io/ascend/vllm-ascend:v0.9.2rc1
41
+ ```
42
+
43
+ 或按照[官方文档](https://vllm-ascend.readthedocs.io/en/stable/installation.html)手动构建 docker 容器。
44
+
45
+ #### 运行容器
46
+
47
+ ```
48
+ docker run -itd --name vllm-deepdiver \
49
+ --network host \
50
+ --device /dev/davinci0 \
51
+ --device /dev/davinci1 \
52
+ --device /dev/davinci2 \
53
+ --device /dev/davinci3 \
54
+ --device /dev/davinci4 \
55
+ --device /dev/davinci5 \
56
+ --device /dev/davinci6 \
57
+ --device /dev/davinci7 \
58
+ -u root \
59
+ --device /dev/davinci_manager \
60
+ --device /dev/devmm_svm \
61
+ --device /dev/hisi_hdc \
62
+ -v /usr/local/dcmi:/usr/local/dcmi:ro \
63
+ -v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool:ro \
64
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi:ro \
65
+ -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/:ro \
66
+ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info:ro \
67
+ -v /etc/ascend_install.info:/etc/ascend_install.info:ro \
68
+ -v /usr/local/Ascend/firmware:/usr/local/Ascend/firmware:ro \
69
+ -v /data:/data:ro \
70
+ -v /home/work:/home/work \ # 配置一个可读写的工作目录
71
+ quay.io/ascend/vllm-ascend:v0.9.2rc1
72
+ ```
73
+
74
+ #### 进入容器
75
+
76
+
77
+ ```
78
+ docker exec -itu root vllm-deepdiver bash
79
+ ```
80
+
81
+ 注意:必须使用 `-itu root`。
82
+
83
+ #### 复制 Pangu 的 modeling 文件
84
+
85
+ `open_pangu.py` 和 `__init__.py` 可以在[这里](https://ai.gitcode.com/ascend-tribe/openpangu-embedded-7b-model/tree/main/inference/vllm_ascend/models)找到。
86
+
87
+ ```
88
+ cp ./vllm_ascend/open_pangu.py /vllm-workspace/vllm-ascend/vllm_ascend/models/
89
+ cp ./vllm_ascend/__init__.py /vllm-workspace/vllm-ascend/vllm_ascend/models/
90
+ ```
91
+
92
+ #### 启动部署
93
+
94
+ ```
95
+ PRECHECKPOINT_PATH="path/to/deepdiver_model"
96
+
97
+ export VLLM_USE_V1=1
98
+
99
+ export VLLM_WORKER_MULTIPROC_METHOD=fork
100
+ # export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
101
+
102
+ vllm serve $PRECHECKPOINT_PATH \
103
+ --served-model-name ${SERVED_MODEL_NAME:=pangu_auto} \
104
+ --tensor-parallel-size ${tensor_parallel_size:=8} \
105
+ --trust-remote-code \
106
+ --host 127.0.0.1 \
107
+ --port 8888 \
108
+ --max-num-seqs 256 \
109
+ --max-model-len ${MAX_MODEL_LEN:=131072} \
110
+ --max-num-batched-tokens ${MAX_NUM_BATCHED_TOKENS:=4096} \
111
+ --tokenizer-mode "slow" \
112
+ --dtype bfloat16 \
113
+ --distributed-executor-backend mp \
114
+ --gpu-memory-utilization 0.93 \
115
+ ```
116
+
117
+ #### 测试部署
118
+
119
+ ```
120
+ curl -X POST http://127.0.0.1:8888/v1/completions -H "Content-Type: application/json" -d '{
121
+ "model": "pangu_auto",
122
+ "prompt": ["Tell me who you are?"],
123
+ "max_tokens": 50
124
+ }'
125
+ ```
126
+
127
+ ### 3.3 实现所需工具
128
+
129
+ 在启动服务器前,你需要为 web search 与 URL 抓取工具实现自定义逻辑。
130
+
131
+ #### Web Search(`_generic_search`)
132
+
133
+ 位置:`src/tools/mcp_tools.py` - `_generic_search` 方法
134
+
135
+ 将 `NotImplementedError` 替换为你的搜索工具实现:
136
+
137
+ ```python
138
+ def _generic_search(self, query: str, max_results: int, config: Dict[str, Any]) -> MCPToolResult:
139
+ """Your custom search implementation - based on the commented code example"""
140
+ try:
141
+ # Example implementation for search API:
142
+ url = config.get('base_url', 'https://api.search-provider.com/search')
143
+ payload = json.dumps({"q": query, "num": max_results})
144
+ api_keys = config.get('api_keys', [])
145
+ headers = {
146
+ 'X-API-KEY': random.choice(api_keys),
147
+ 'Content-Type': 'application/json'
148
+ }
149
+
150
+ response = requests.post(url, data=payload, headers=headers)
151
+ response.raise_for_status()
152
+
153
+ # Transform your API response to required format
154
+ search_results = {
155
+ "organic": [
156
+ {
157
+ "title": result["title"],
158
+ "link": result["link"],
159
+ "snippet": result["snippet"],
160
+ "date": result.get("date", "unknown")
161
+ }
162
+ for result in response.json().get("organic", [])
163
+ ]
164
+ }
165
+
166
+ return MCPToolResult(success=True, data=search_results)
167
+
168
+ except Exception as e:
169
+ return MCPToolResult(success=False, error=f"Generic search failed: {e}")
170
+ ```
171
+
172
+ #### URL Crawler(`url_crawler` 与 `_content_extractor`)
173
+
174
+ 位置:`src/tools/mcp_tools.py` - `_content_extractor`
175
+
176
+ 将 `NotImplementedError` 部分替换为你的网页抓取工具实现:
177
+
178
+ ```python
179
+ # Example implementation for content extractor:
180
+ crawler_url = f"{crawler_config.get('base_url', 'https://api.content-extractor.com')}/{url}"
181
+ response = requests.get(crawler_url, headers=headers, timeout=crawler_config.get('timeout', 30))
182
+ response.raise_for_status()
183
+
184
+ content = response.text
185
+
186
+ # Truncate if needed
187
+ if max_tokens and len(content.split()) > max_tokens:
188
+ words = content.split()[:max_tokens]
189
+ content = ' '.join(words) + '...'
190
+
191
+ return MCPToolResult(success=True, data=content)
192
+ ```
193
+
194
+ #### ⚠️ 第三方服务提示
195
+
196
+ 重要:搜索与抓取工具使用外部 API 由用户自行选择和实现。我们不对以下情况负责:
197
+ - 与第三方服务相关的隐私/安全问题
198
+ - 搜索/抓取活动的合规性
199
+ - 内容准确性或版权问题
200
+ - API 停机或变更
201
+
202
+ 使用这些服务需自担风险。请查看其条款与隐私政策。
203
+
204
+ ### 3.4 必要配置
205
+
206
+ #### 配置 .env 文件
207
+ 复制 `env.template` 到 `config/.env` 并配置如下选项:
208
+
209
+ ```bash
210
+ # LLM Service
211
+ MODEL_REQUEST_URL=http://localhost:8888/v1/chat/completions # 你的 LLM endpoint
212
+
213
+ # Agent 限制
214
+ PLANNER_MODE=auto # 在 auto、writing 或 qa 模式间切换
215
+
216
+ # 外部 API(先实现函数)
217
+ SEARCH_ENGINE_BASE_URL= # 搜索 API endpoint
218
+ SEARCH_ENGINE_API_KEYS= # 搜索 API keys
219
+ URL_CRAWLER_BASE_URL= # URL Crawler API endpoint
220
+ URL_CRAWLER_API_KEYS= # URL Crawler API keys
221
+ ```
222
+
223
+ ⚠️ 注意:
224
+ - 请将上一步部署的推理服务 URL 配置到 `MODEL_REQUEST_URL`
225
+ - 在 `PLANNER_MODE` 中指定模式。`auto` 会自动决策回答复杂问题或生成长文;若希望优先长文写作,可设置为 `writing`;若希望专注解决高难度问题,可设置为 `qa`
226
+
227
+ ### 3.5 启动工具服务
228
+
229
+ ```bash
230
+ python src/tools/mcp_server_standard.py
231
+ ```
232
+
233
+ ### 3.6 运行Demo
234
+
235
+ ```bash
236
+ # 交互模式
237
+ python cli/demo.py
238
+
239
+ # 单次查询
240
+ python cli/demo.py -q "$your_query"
241
+ ```
242
+
243
+ 基于上述步骤可以快速运行DeepDiver,如果需要二次开发,可以参考[章节4](#4-自定义工具开发指南)和[5](#5-个性化配置)
244
+
245
+ ## 4. 自定义工具开发指南
246
+
247
+ 当前工具主要分为内置工具和外部MCP工具,内部工具主要包括分发任务,思考/反思等,外部MCP工具则是一些延伸LLM能力的工具,如搜索互联网,爬取链接,下载和读写文件等。
248
+
249
+ ### 4.1 已实现的工具类别
250
+
251
+ #### A. 外部MCP工具
252
+ Web Search 与数据采集:
253
+ - `batch_web_search`:多查询 web 搜索
254
+ - `url_crawler`:从 URL 抽取内容
255
+ - `download_files`:从 URL 下载文件
256
+
257
+ 文件操作:
258
+ - `file_read`、`file_write`:基础文件 I/O
259
+ - `list_workspace`:目录列表
260
+
261
+ 文档处理与内容创作:
262
+ - `document_qa`:针对特定文档问答
263
+ - `document_extract`:多格式文本抽取
264
+ - `section_writer`:结构化内容生成
265
+
266
+ #### B. 内置工具
267
+ - `think`、`reflect`:推理与规划
268
+ - `task_done`:任务完成汇报
269
+ - `assign_task_xxx`: 分发任务并创建子智能体
270
+
271
+ ### 4.2 开发并集成新的外部MCP工具
272
+
273
+ #### A. 实现新的MCP工具
274
+ 位置:`src/tools/mcp_tools.py` - 在 `MCPTools` 类中添加方法
275
+
276
+ ```python
277
+ def your_new_tool(self, param1: str, param2: int) -> MCPToolResult:
278
+ """
279
+ Description of what your tool does.
280
+
281
+ Args:
282
+ param1: Description of parameter 1
283
+ param2: Description of parameter 2
284
+
285
+ Returns:
286
+ MCPToolResult: Standardized result format
287
+ """
288
+ try:
289
+ # Your tool implementation here
290
+ result_data = {
291
+ "output": "Tool result",
292
+ "processed_items": param2
293
+ }
294
+
295
+ return MCPToolResult(
296
+ success=True,
297
+ data=result_data,
298
+ metadata={"tool_name": "your_new_tool"}
299
+ )
300
+
301
+ except Exception as e:
302
+ logger.error(f"Tool execution failed: {e}")
303
+ return MCPToolResult(
304
+ success=False,
305
+ error=f"Tool failed: {str(e)}"
306
+ )
307
+ ```
308
+
309
+ #### B. 在服务器中注册工具
310
+
311
+ ##### 添加工具 Schema
312
+ 位置:`src/tools/mcp_tools.py` - 添加到 `MCP_TOOL_SCHEMAS` 字典
313
+
314
+ ```python
315
+ MCP_TOOL_SCHEMAS = {
316
+ # ... existing tools ...
317
+
318
+ "your_new_tool": {
319
+ "name": "your_new_tool",
320
+ "description": "Brief description of what your tool does",
321
+ "inputSchema": {
322
+ "type": "object",
323
+ "properties": {
324
+ "param1": {
325
+ "type": "string",
326
+ "description": "Description of parameter 1"
327
+ },
328
+ "param2": {
329
+ "type": "integer",
330
+ "default": 10,
331
+ "description": "Description of parameter 2"
332
+ }
333
+ },
334
+ "required": ["param1"]
335
+ }
336
+ }
337
+ }
338
+ ```
339
+
340
+ ##### 注册工具函数
341
+ 位置:`src/tools/mcp_server_standard.py` - 添加到 `get_tool_function()`
342
+
343
+ ```python
344
+ def get_tool_function(tool_name: str):
345
+ """Get the actual function for a tool"""
346
+ tool_map = {
347
+ # ... existing tools ...
348
+ "your_new_tool": lambda tools, **kwargs: tools.your_new_tool(**kwargs),
349
+ }
350
+ return tool_map.get(tool_name)
351
+ ```
352
+
353
+ #### C. 让特定智能体可使用工具
354
+ 工具对各智能体的可见性由 MCP client 中的预定义工具集控制。
355
+
356
+ 位置:`src/tools/mcp_client.py` - 修改各智能体的工具集
357
+
358
+ ```python
359
+ # Define which MCP server tools each agent can access
360
+ PLANNER_AGENT_TOOLS = [
361
+ "download_files",
362
+ "document_qa",
363
+ "file_read",
364
+ "file_write",
365
+ "str_replace_based_edit_tool",
366
+ "list_workspace",
367
+ "file_find_by_name",
368
+ "your_new_tool", # Add your new tool here
369
+ ]
370
+
371
+ INFORMATION_SEEKER_TOOLS = [
372
+ "batch_web_search",
373
+ "url_crawler",
374
+ "document_extract",
375
+ "document_qa",
376
+ "download_files",
377
+ "file_read",
378
+ "file_write",
379
+ "str_replace_based_edit_tool",
380
+ "list_workspace",
381
+ "file_find_by_name",
382
+ "your_new_tool", # Add your new tool here if needed
383
+ ]
384
+
385
+ WRITER_AGENT_TOOLS = [
386
+ "file_read",
387
+ "list_workspace",
388
+ "file_find_by_name",
389
+ "search_result_classifier",
390
+ "section_writer",
391
+ "concat_section_files",
392
+ # Add your tool if the writer agent needs it
393
+ ]
394
+ ```
395
+
396
+ ### 4.3 添加内置智能体工具/函数
397
+
398
+ #### A. 带有真实返回的工具/函数
399
+ DeepDiver中的agent,如planner,集成了`assign_subjective_task_to_writer`, `assign_multi_objective_tasks_to_info_seeker` 等内置函数作为工具, 这类函数除了具体实现之外,还需要使用`_build_agent_specific_tool_schemas()` 添加专属的tool schema。
400
+
401
+ 位置:`src/agents/your_agent.py`
402
+
403
+ ```python
404
+ def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
405
+ """Add built-in agent functions (not MCP server tools)"""
406
+
407
+ # Get base schemas from MCP server via client
408
+ schemas = super()._build_agent_specific_tool_schemas()
409
+
410
+ # Add agent-specific built-in functions like task assignment, completion reporting
411
+ builtin_functions = [
412
+ {
413
+ "type": "function",
414
+ "function": {
415
+ "name": "agent_specific_task_done",
416
+ "description": "Report task completion for this agent",
417
+ "parameters": {
418
+ "type": "object",
419
+ "properties": {
420
+ "result": {"type": "string", "description": "Task result"},
421
+ "status": {"type": "string", "description": "Completion status"}
422
+ },
423
+ "required": ["result", "status"]
424
+ }
425
+ }
426
+ }
427
+ ]
428
+
429
+ schemas.extend(builtin_functions)
430
+ return schemas
431
+ ```
432
+
433
+ #### B. 带有伪返回的内置工具
434
+ DeepDiver中的cognitive tools,比如think和reflect等,这些工具实际没有具体实现,agent在调用这些工具时通过生成工具入参,就已经完成了工具的调用。可以直接在模型生成完入参后,使用类似以下方法进行返回,继续让模型完成后续工作 (参考`planner_agent.py` 中`_execute_react_loop()`的实现):
435
+
436
+ ```python
437
+ if tool_call["name"] in ["think", "reflect"]:
438
+ tool_result = {"tool_results": "You can proceed to invoke other tools if needed. "}
439
+ ```
440
+
441
+ 同理,这种内置工具也需要使用`_build_agent_specific_tool_schemas()` 添加专属的tool schema。
442
+
443
+
444
+ ## 5. 个性化配置
445
+
446
+ ### 5.1 Client 配置
447
+
448
+ 复制 `env.template` 到 `config/.env` 并配置如下选项:
449
+
450
+ ```bash
451
+ # LLM Service
452
+ MODEL_REQUEST_URL=http://localhost:8000 # 你的 LLM endpoint
453
+ MODEL_REQUEST_TOKEN=your-token # LLM auth token
454
+ MODEL_NAME=pangu_auto # 模型名
455
+ MODEL_TEMPERATURE=0.3 # 随机度(0.0-1.0)
456
+ MODEL_MAX_TOKENS=8192 # 最大回复长度
457
+ MODEL_REQUEST_TIMEOUT=60 # 请求超时(秒)
458
+
459
+ # Agent 限制
460
+ PLANNER_MAX_ITERATION=40 # Planner 最大 ReAct 步数
461
+ INFORMATION_SEEKER_MAX_ITERATION=30 # 信息搜集最大 ReAct 步数
462
+ WRITER_MAX_ITERATION=40 # Writer 最大 ReAct 步数
463
+ PLANNER_MODE=auto # auto / 长文优先 / qa 优先
464
+
465
+ # MCP Server
466
+ MCP_SERVER_URL=http://localhost:6274/mcp # MCP server endpoint
467
+ MCP_USE_STDIO=false # 使用 stdio 或 HTTP
468
+
469
+ # 外部 API(先实现函数)
470
+ SEARCH_ENGINE_BASE_URL= # 搜索 API endpoint
471
+ SEARCH_ENGINE_API_KEYS= # 搜索 API keys
472
+ URL_CRAWLER_BASE_URL= # URL Crawler API endpoint
473
+ URL_CRAWLER_API_KEYS= # URL Crawler API keys
474
+ URL_CRAWLER_MAX_TOKENS=100000 # URL Crawler 内容最大长度
475
+
476
+ # 存储路径
477
+ TRAJECTORY_STORAGE_PATH=./workspace # Agent工作目录
478
+ REPORT_OUTPUT_PATH=./report # 报告输出目录
479
+ DOCUMENT_ANALYSIS_PATH=./doc_analysis # 文档分析目录
480
+
481
+ # 系统
482
+ DEBUG_MODE=false # 是否开启调试日志
483
+ MAX_RETRIES=3 # API 重试次数
484
+ TIMEOUT=30 # 通用超时(秒)
485
+ ```
486
+
487
+ ### 5.2 Server 配置(server_config.yaml)
488
+
489
+ `server_config.yaml` 控制服务器行为、工具限流与运行设置:
490
+
491
+ #### 核心服务器设置
492
+
493
+ ```yaml
494
+ server:
495
+ host: "127.0.0.1" # 服务器绑定地址
496
+ port: 6274 # 端口
497
+ debug_mode: false # 调试日志
498
+ session_ttl_seconds: 21600 # 会话过期(6小时)
499
+ max_sessions: 1000 # 并发会话上限
500
+ ```
501
+
502
+ #### 工具限流
503
+
504
+ 对所有会话的外部 API 使用进行控制:
505
+
506
+ ```yaml
507
+ tool_rate_limits:
508
+ batch_web_search:
509
+ requests_per_minute: 9000 # 每分钟限制
510
+ burst_limit: 35 # 短时突发
511
+
512
+ url_crawler:
513
+ requests_per_minute: 9000
514
+ burst_limit: 60
515
+ ```
516
+
517
+ #### 会话管理
518
+
519
+ ```yaml
520
+ server:
521
+ cleanup_interval_seconds: 600 # 清理过期会话(5分钟)
522
+ enable_session_keepalive: true # 长时操作期间保活
523
+ keepalive_touch_interval: 300 # 保活触发间隔(秒)
524
+ ```
525
+
526
+ #### 安全与性能
527
+
528
+ ```yaml
529
+ server:
530
+ request_timeout_seconds: 1800 # 请求超时
531
+ max_request_size_mb: 1000 # 最大请求体
532
+ rate_limit_requests_per_minute: 300000 # 每 IP 限流
533
+ ```
534
+
535
+ 配置文件包含对每项设置的详细注释。请根据你的部署需求与外部 API 限额进行调整。
536
+
537
+ ## 6. 模型许可证
538
+
539
+ 除文件中对开源许可证另有约定外,openPangu-Embedded-7B-DeepDiver 模型根据 OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0 授权,旨在允许使用并促进人工智能技术的进一步发展。有关详细信息,请参阅模型存储库根目录中的 [LICENSE](LICENSE) 文件。
540
+
541
+ ## 7. 安全提示与免责声明
542
+ 由于 openPangu-Embedded-7B-DeepDiver 模型和框架所依赖的技术固有的技术限制,以及人工智能生成的内容是由盘古自动生成的,华为无法对以下事项做出任何保证:
543
+
544
+ - 尽管该模型的输出由 AI 算法生成,但不能排除某些信息可能存在缺陷、不合理或引起不适的可能性,生成的内容不代表华为的态度或立场;
545
+ - 无法保证该模型 100% 准确、可靠、功能齐全、及时、安全、无错误、不间断、持续稳定或无任何故障;
546
+ - 该模型的输出内容不构成任何建议或决策,也不保证生成的内容的真实性、完整性、准确性、及时性、合法性、功能性或实用性。生成的内容不能替代医疗、法律等领域的专业人士回答您的问题。生成的内容仅供参考,不代表华为的任何态度、立场或观点。您需要根据实际情况做出独立判断,华为不承担任何责任;
547
+ - DeepDiver MAS系统的组件间通信不包含内置的数据加密或认证(如 tokens、签名)。你需要自行评估安全需求并实施相应防护(例如运行在加密网络中、加入 SSL/TLS、强制组件身份校验);
548
+ - 由于缺乏加密/认证导致的任何安全事件(数据泄露、未授权访问、业务损失)由使用方自行承担。项目开发者不承担责任。
549
+
550
+ ## 8. 反馈
551
+
552
+ 如果有任何意见和建议,请提交issue或联系 openPangu@huawei.com。
553
+
554
+ ---
README_EN.md ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # openPangu-Embedded-7B-DeepDiver
2
+ [中文](README.md) | English
3
+ 📑[Technical Report](https://ai.gitcode.com/ascend-tribe/openPangu-Embedded-7B-DeepDiver/blob/main/docs/openpangu-deepdiver-v2-tech-report.pdf)
4
+
5
+ ## 1. Introduction
6
+ DeepDiver is an agentic solution within openPangu series aimed at deep information seeking and processing, which natively supports the Multi-Agent System (MAS) and is designed for complex question answering and long-form report writing.
7
+
8
+ ### Features
9
+ - 🔍 Supports QA Mode: Capable of answering 100+ steps of complex knowledge-based questions.
10
+ - ✍️ Supports Long-form Writing Mode: Enables the creation of articles and reports with over 3w+ words.
11
+ - 🔄 Supports Adaptive Mode: Automatically selects between QA Mode and Long-form Writing Mode based on user queries.
12
+
13
+ ## 2. Results
14
+
15
+ | Benchmark | Metric | openPangu-7B-DeepDiver|
16
+ | :------------: | :-----------------: | :--------: |
17
+ | **BrowseComp-zh** | Acc | 18.3 |
18
+ | **BrowseComp-en** | Acc | 8.3 |
19
+ | **XBench-DeepSearch** | Acc | 39.0 |
20
+
21
+ Note: The table above only displays the results of complex QA. For the evaluation results of long-form report writing, please refer to the [technical report](https://ai.gitcode.com/ascend-tribe/openPangu-Embedded-7B-DeepDiver/blob/main/docs/openpangu-deepdiver-v2-tech-report.pdf)
22
+
23
+ ## 3. Quick Start
24
+
25
+ ### 3.1 Setup
26
+
27
+ ```bash
28
+ # Clone and install
29
+ git clone <repository-url>
30
+ cd deepdiver_v2
31
+ pip install -r requirements.txt
32
+ ```
33
+
34
+ ### 3.2 Deployment of the Inference Service
35
+
36
+ #### Pull Images
37
+
38
+ ```
39
+ docker pull quay.io/ascend/vllm-ascend:v0.9.2rc1
40
+ ```
41
+
42
+ Or follow the [official documentation](https://vllm-ascend.readthedocs.io/en/stable/installation.html) to build the docker container manually.
43
+
44
+ #### Run Docker Container
45
+
46
+ ```
47
+ docker run -itd --name vllm-deepdiver \
48
+ --network host \
49
+ --device /dev/davinci0 \
50
+ --device /dev/davinci1 \
51
+ --device /dev/davinci2 \
52
+ --device /dev/davinci3 \
53
+ --device /dev/davinci4 \
54
+ --device /dev/davinci5 \
55
+ --device /dev/davinci6 \
56
+ --device /dev/davinci7 \
57
+ -u root \
58
+ --device /dev/davinci_manager \
59
+ --device /dev/devmm_svm \
60
+ --device /dev/hisi_hdc \
61
+ -v /usr/local/dcmi:/usr/local/dcmi:ro \
62
+ -v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool:ro \
63
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi:ro \
64
+ -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/:ro \
65
+ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info:ro \
66
+ -v /etc/ascend_install.info:/etc/ascend_install.info:ro \
67
+ -v /usr/local/Ascend/firmware:/usr/local/Ascend/firmware:ro \
68
+ -v /data:/data:ro \
69
+ -v /home/work:/home/work \ # set a working dir
70
+ quay.io/ascend/vllm-ascend:v0.9.2rc1
71
+ ```
72
+
73
+ #### Enter the Container
74
+
75
+ ```
76
+ docker exec -itu root vllm-deepdiver bash
77
+ ```
78
+ Note that `-itu root` is necessary.
79
+
80
+ #### Copy Pangu's Modeling Files
81
+
82
+ `open_pangu.py` and `__init__.py` can be found at [here](https://ai.gitcode.com/ascend-tribe/openpangu-embedded-7b-model/tree/main/inference/vllm_ascend/models)
83
+
84
+ ```
85
+ cp ./vllm_ascend/open_pangu.py /vllm-workspace/vllm-ascend/vllm_ascend/models/
86
+ cp ./vllm_ascend/__init__.py /vllm-workspace/vllm-ascend/vllm_ascend/models/
87
+ ```
88
+
89
+ #### Start Deployment
90
+
91
+ ```
92
+ PRECHECKPOINT_PATH="path/to/deepdiver_model"
93
+
94
+ export VLLM_USE_V1=1
95
+
96
+ export VLLM_WORKER_MULTIPROC_METHOD=fork
97
+ # export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
98
+
99
+ vllm serve $PRECHECKPOINT_PATH \
100
+ --served-model-name ${SERVED_MODEL_NAME:=pangu_auto} \
101
+ --tensor-parallel-size ${tensor_parallel_size:=8} \
102
+ --trust-remote-code \
103
+ --host 127.0.0.1 \
104
+ --port 8888 \
105
+ --max-num-seqs 256 \
106
+ --max-model-len ${MAX_MODEL_LEN:=131072} \
107
+ --max-num-batched-tokens ${MAX_NUM_BATCHED_TOKENS:=4096} \
108
+ --tokenizer-mode "slow" \
109
+ --dtype bfloat16 \
110
+ --distributed-executor-backend mp \
111
+ --gpu-memory-utilization 0.93 \
112
+ ```
113
+
114
+ #### Test Deployment
115
+
116
+ ```
117
+ curl -X POST http://127.0.0.1:8888/v1/completions -H "Content-Type: application/json" -d '{
118
+ "model": "pangu_auto",
119
+ "prompt": ["Tell me who you are?"],
120
+ "max_tokens": 50
121
+ }'
122
+ ```
123
+
124
+ ### 3.3 Implement Required Tools
125
+
126
+ Before starting the server, you must implement custom logic for web search and URL crawling tools.
127
+
128
+ #### Web Search (`_generic_search`)
129
+
130
+ **Location**: `src/tools/mcp_tools.py` - `_generic_search` method
131
+
132
+ Replace the `NotImplementedError` with your search API integration:
133
+
134
+ ```python
135
+ def _generic_search(self, query: str, max_results: int, config: Dict[str, Any]) -> MCPToolResult:
136
+ """Your custom search implementation - based on the commented code example"""
137
+ try:
138
+ # Example implementation for search API:
139
+ url = config.get('base_url', 'https://api.search-provider.com/search')
140
+ payload = json.dumps({"q": query, "num": max_results})
141
+ api_keys = config.get('api_keys', [])
142
+ headers = {
143
+ 'X-API-KEY': random.choice(api_keys),
144
+ 'Content-Type': 'application/json'
145
+ }
146
+
147
+ response = requests.post(url, data=payload, headers=headers)
148
+ response.raise_for_status()
149
+
150
+ # Transform your API response to required format
151
+ search_results = {
152
+ "organic": [
153
+ {
154
+ "title": result["title"],
155
+ "link": result["link"],
156
+ "snippet": result["snippet"],
157
+ "date": result.get("date", "unknown")
158
+ }
159
+ for result in response.json().get("organic", [])
160
+ ]
161
+ }
162
+
163
+ return MCPToolResult(success=True, data=search_results)
164
+
165
+ except Exception as e:
166
+ return MCPToolResult(success=False, error=f"Generic search failed: {e}")
167
+ ```
168
+
169
+ #### URL Crawler (`url_crawler` and `_content_extractor`)
170
+
171
+ **Location**: `src/tools/mcp_tools.py` - `_content_extractor`
172
+
173
+ Replace the `NotImplementedError` section with your crawler API integration:
174
+
175
+ ```python
176
+ # Example implementation for content extractor:
177
+ crawler_url = f"{crawler_config.get('base_url', 'https://api.content-extractor.com')}/{url}"
178
+ response = requests.get(crawler_url, headers=headers, timeout=crawler_config.get('timeout', 30))
179
+ response.raise_for_status()
180
+
181
+ content = response.text
182
+
183
+ # Truncate if needed
184
+ if max_tokens and len(content.split()) > max_tokens:
185
+ words = content.split()[:max_tokens]
186
+ content = ' '.join(words) + '...'
187
+
188
+ return MCPToolResult(success=True, data=content)
189
+ ```
190
+
191
+ #### ⚠️ **Third-Party Service Notice**
192
+
193
+ **Important**: Search and crawler tools use external APIs (your choice). We're not responsible for:
194
+ - Privacy/security issues with third-party services
195
+ - Legal compliance with search/crawling activities
196
+ - Content accuracy or copyright issues
197
+ - API downtime or changes
198
+
199
+ Use these services at your own risk. Check their terms and privacy policies.
200
+
201
+ ### 3.4 Mandatory Configuration
202
+
203
+ #### Configure the.env file
204
+ Copy `env.template` to `config/.env` and configure these options:
205
+
206
+ ```bash
207
+ # LLM Service
208
+ MODEL_REQUEST_URL=http://localhost:8888/v1/chat/completions # Your LLM endpoint
209
+
210
+ # Agent Limits
211
+ PLANNER_MODE=auto # Switching between the auto mode, writing mode, or qa mode.
212
+
213
+ # External APIs (implement functions first)
214
+ SEARCH_ENGINE_BASE_URL= # Search API endpoint
215
+ SEARCH_ENGINE_API_KEYS= # Search API keys
216
+ URL_CRAWLER_BASE_URL= # Crawler API endpoint
217
+ URL_CRAWLER_API_KEYS= # Crawler API keys
218
+ ```
219
+
220
+ **⚠️ Important:**
221
+ - Please configure the URL for deploying the inference service from the previous step in `MODEL_REQUEST_URL`
222
+ - Specify the mode in `PLANNER_MODE`. The `auto` mode is designed to automatically determine whether to answer complex questions or generate long-form reports. However, if you wish to prioritize long-form writing, you can set the PLANNER_MODE to ```writing```. Alternatively, if you want to focus solely on solving highly complex problems, configure the mode as ```qa```
223
+
224
+ ### 3.5 Start the Tool Server
225
+
226
+ ```bash
227
+ python src/tools/mcp_server_standard.py
228
+ ```
229
+
230
+ ### 3.6 Run the Demo
231
+
232
+ ```bash
233
+ # Interactive mode
234
+ python cli/demo.py
235
+
236
+ # Single query
237
+ python cli/demo.py -q "$your_query"
238
+ ```
239
+
240
+ Based on the above steps, DeepDiver can be quickly executed. If further development is required, you can refer to [Section 4](#4-customized-tool-development-guide) and [5](#5-customized-configuration).
241
+
242
+
243
+ ## 4. Customized Tool Development Guide
244
+ Currently, tools are mainly categorized into Built-in Tools and External MCP Tools. Built-in Tools primarily include task assignment, think/reflect, etc. External MCP Tools are extensions that enhance LLM capabilities, such as web search, url crawl, file download, read, and write.
245
+
246
+ ### 4.1 Implemented Tool Categories
247
+
248
+ #### A. External MCP Tools
249
+ Web Search and Data Collection:
250
+ - `batch_web_search`: Multi-query web search
251
+ - `url_crawler`: Extract content from URLs
252
+ - `download_files`: Download files from URLs
253
+
254
+ File Operations:
255
+ - `file_read`, `file_write`: Basic file I/O
256
+ - `list_workspace`: Directory listing
257
+
258
+ Document Processing and Content Creation:
259
+ - `document_qa`: Question-answering on documents
260
+ - `document_extract`: Extract text from various formats
261
+ - `section_writer`: Structured content generation
262
+
263
+
264
+ #### B. Built-in Tools
265
+ - `think`, `reflect`: Reasoning and planning
266
+ - `task_done`: Task completion reporting
267
+ - `assign_task_xxx`: Assign tasks and create sub-agents
268
+
269
+
270
+ ### 4.2 Develop and Integrate New External MCP Tools
271
+
272
+ #### A. Implementing a New MCP Tool
273
+ Location: `src/tools/mcp_tools.py` - Add a method to the `MCPTools` class
274
+
275
+ ```python
276
+ def your_new_tool(self, param1: str, param2: int) -> MCPToolResult:
277
+ """
278
+ Description of what your tool does.
279
+
280
+ Args:
281
+ param1: Description of parameter 1
282
+ param2: Description of parameter 2
283
+
284
+ Returns:
285
+ MCPToolResult: Standardized result format
286
+ """
287
+ try:
288
+ # Your tool implementation here
289
+ result_data = {
290
+ "output": "Tool result",
291
+ "processed_items": param2
292
+ }
293
+
294
+ return MCPToolResult(
295
+ success=True,
296
+ data=result_data,
297
+ metadata={"tool_name": "your_new_tool"}
298
+ )
299
+
300
+ except Exception as e:
301
+ logger.error(f"Tool execution failed: {e}")
302
+ return MCPToolResult(
303
+ success=False,
304
+ error=f"Tool failed: {str(e)}"
305
+ )
306
+ ```
307
+
308
+ #### B. Registering the Tool on the Server
309
+
310
+ ##### Adding Tool Schema
311
+ Location: `src/tools/mcp_tools.py` - Add to the `MCP_TOOL_SCHEMAS` dictionary
312
+
313
+ ```python
314
+ MCP_TOOL_SCHEMAS = {
315
+ # ... existing tools ...
316
+
317
+ "your_new_tool": {
318
+ "name": "your_new_tool",
319
+ "description": "Brief description of what your tool does",
320
+ "inputSchema": {
321
+ "type": "object",
322
+ "properties": {
323
+ "param1": {
324
+ "type": "string",
325
+ "description": "Description of parameter 1"
326
+ },
327
+ "param2": {
328
+ "type": "integer",
329
+ "default": 10,
330
+ "description": "Description of parameter 2"
331
+ }
332
+ },
333
+ "required": ["param1"]
334
+ }
335
+ }
336
+ }
337
+ ```
338
+
339
+ ##### Registering the Tool Function
340
+ Location: `src/tools/mcp_server_standard.py` - Add to `get_tool_function()`
341
+
342
+ ```python
343
+ def get_tool_function(tool_name: str):
344
+ """Get the actual function for a tool"""
345
+ tool_map = {
346
+ # ... existing tools ...
347
+ "your_new_tool": lambda tools, **kwargs: tools.your_new_tool(**kwargs),
348
+ }
349
+ return tool_map.get(tool_name)
350
+ ```
351
+
352
+ #### C. Making the Tool Accessible to Specific Agents
353
+ The visibility of tools to each agent is controlled by the predefined tool sets in the MCP client.
354
+
355
+ Location: `src/tools/mcp_client.py` - Modify the tool sets for each agent
356
+
357
+ ```python
358
+ # Define which MCP server tools each agent can access
359
+ PLANNER_AGENT_TOOLS = [
360
+ "download_files",
361
+ "document_qa",
362
+ "file_read",
363
+ "file_write",
364
+ "str_replace_based_edit_tool",
365
+ "list_workspace",
366
+ "file_find_by_name",
367
+ "your_new_tool", # Add your new tool here
368
+ ]
369
+
370
+ INFORMATION_SEEKER_TOOLS = [
371
+ "batch_web_search",
372
+ "url_crawler",
373
+ "document_extract",
374
+ "document_qa",
375
+ "download_files",
376
+ "file_read",
377
+ "file_write",
378
+ "str_replace_based_edit_tool",
379
+ "list_workspace",
380
+ "file_find_by_name",
381
+ "your_new_tool", # Add your new tool here if needed
382
+ ]
383
+
384
+ WRITER_AGENT_TOOLS = [
385
+ "file_read",
386
+ "list_workspace",
387
+ "file_find_by_name",
388
+ "search_result_classifier",
389
+ "section_writer",
390
+ "concat_section_files",
391
+ # Add your tool if the writer agent needs it
392
+ ]
393
+ ```
394
+
395
+
396
+ ### 4.3 Adding Built-in Agent Tools/Functions
397
+
398
+ #### A. Tools/Functions with Actual Return Values
399
+ Agents in DeepDiver (e.g., the Planner) integrate built-in functions as tools, such as `assign_subjective_task_to_writer` and `assign_multi_objective_tasks_to_info_seeker`. In addition to their specific implementations, these functions require adding **agent-specific tool schemas** using `_build_agent_specific_tool_schemas()`.
400
+
401
+ Location: `src/agents/your_agent.py`
402
+
403
+ ```python
404
+ def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
405
+ """Add built-in agent functions (not MCP server tools)"""
406
+
407
+ # Get base schemas from MCP server via client
408
+ schemas = super()._build_agent_specific_tool_schemas()
409
+
410
+ # Add agent-specific built-in functions like task assignment, completion reporting
411
+ builtin_functions = [
412
+ {
413
+ "type": "function",
414
+ "function": {
415
+ "name": "agent_specific_task_done",
416
+ "description": "Report task completion for this agent",
417
+ "parameters": {
418
+ "type": "object",
419
+ "properties": {
420
+ "result": {"type": "string", "description": "Task result"},
421
+ "status": {"type": "string", "description": "Completion status"}
422
+ },
423
+ "required": ["result", "status"]
424
+ }
425
+ }
426
+ }
427
+ ]
428
+
429
+ schemas.extend(builtin_functions)
430
+ return schemas
431
+ ```
432
+
433
+
434
+ #### B. Built-in Tools with Pseudo Return Values
435
+ Cognitive tools in DeepDiver (e.g., `think` and `reflect`) have no specific implementation. When an agent calls these tools, the tool invocation is considered complete once the agent generates the tool's input parameters. You can directly return a result after the model generates the input parameters, allowing the model to continue with subsequent tasks (refer to the implementation of `_execute_react_loop()` in `planner_agent.py`):
436
+
437
+ ```python
438
+ if tool_call["name"] in ["think", "reflect"]:
439
+ tool_result = {"tool_results": "You can proceed to invoke other tools if needed. "}
440
+ ```
441
+
442
+ Similarly, such built-in tools also require adding their exclusive tool schemas using `_build_agent_specific_tool_schemas()`.
443
+
444
+ ## 5. Customized Configuration
445
+
446
+ ### 5.1 Client Configuration
447
+
448
+ Copy `env.template` to `config/.env` and configure these options:
449
+
450
+ ```bash
451
+ # LLM Service
452
+ MODEL_REQUEST_URL=http://localhost:8000 # Your LLM endpoint
453
+ MODEL_REQUEST_TOKEN=your-token # Auth token
454
+ MODEL_NAME=pangu_auto # Model name
455
+ MODEL_TEMPERATURE=0.3 # Response randomness (0.0-1.0)
456
+ MODEL_MAX_TOKENS=8192 # Max response length
457
+ MODEL_REQUEST_TIMEOUT=60 # Request timeout (seconds)
458
+
459
+ # Agent Limits
460
+ PLANNER_MAX_ITERATION=40 # Planner maximum ReAct steps
461
+ INFORMATION_SEEKER_MAX_ITERATION=30 # Info seeker maximum ReAct steps
462
+ WRITER_MAX_ITERATION=40 # Writer maximum ReAct steps
463
+ PLANNER_MODE=auto # Switching between the auto mode, long-form writing - priority mode, or the qa - priority mode.
464
+
465
+ # MCP Server
466
+ MCP_SERVER_URL=http://localhost:6274/mcp # MCP server endpoint
467
+ MCP_USE_STDIO=false # Use stdio vs HTTP
468
+
469
+ # External APIs (implement functions first)
470
+ SEARCH_ENGINE_BASE_URL= # Search API endpoint
471
+ SEARCH_ENGINE_API_KEYS= # Search API keys
472
+ URL_CRAWLER_BASE_URL= # Crawler API endpoint
473
+ URL_CRAWLER_API_KEYS= # Crawler API keys
474
+ URL_CRAWLER_MAX_TOKENS=100000 # Max crawled content length
475
+
476
+ # Storage Paths
477
+ TRAJECTORY_STORAGE_PATH=./workspace # Agent work directory
478
+ REPORT_OUTPUT_PATH=./report # Report output directory
479
+ DOCUMENT_ANALYSIS_PATH=./doc_analysis # Document analysis directory
480
+
481
+ # System
482
+ DEBUG_MODE=false # Enable debug logging
483
+ MAX_RETRIES=3 # API retry attempts
484
+ TIMEOUT=30 # General timeout (seconds)
485
+ ```
486
+
487
+ ### 5.2 Server Configuration (server_config.yaml)
488
+
489
+ The `server_config.yaml` file controls server behavior, tool rate limiting, and operational settings:
490
+
491
+ #### Core Server Settings
492
+
493
+ ```yaml
494
+ server:
495
+ host: "127.0.0.1" # Server bind address
496
+ port: 6274 # Server port
497
+ debug_mode: false # Enable debug logging
498
+ session_ttl_seconds: 21600 # Session timeout (6 hours)
499
+ max_sessions: 1000 # Max concurrent sessions
500
+ ```
501
+
502
+ #### Tool Rate Limiting
503
+
504
+ Controls external API usage across all sessions:
505
+
506
+ ```yaml
507
+ tool_rate_limits:
508
+ batch_web_search:
509
+ requests_per_minute: 9000 # Per-minute limit
510
+ burst_limit: 35 # Short-term burst allowance
511
+
512
+ url_crawler:
513
+ requests_per_minute: 9000
514
+ burst_limit: 60
515
+ ```
516
+
517
+ #### Session Management
518
+
519
+ ```yaml
520
+ server:
521
+ cleanup_interval_seconds: 600 # Clean expired sessions (5 min)
522
+ enable_session_keepalive: true # Keep sessions alive during long operations
523
+ keepalive_touch_interval: 300 # Touch session every N seconds
524
+ ```
525
+
526
+ #### Security & Performance
527
+
528
+ ```yaml
529
+ server:
530
+ request_timeout_seconds: 1800 # Request timeout
531
+ max_request_size_mb: 1000 # Maximum request size
532
+ rate_limit_requests_per_minute: 300000 # Requests per IP
533
+ ```
534
+
535
+ The configuration file includes detailed comments explaining each setting. Modify values based on your deployment requirements and external API limits.
536
+
537
+ ## 6. Model License
538
+
539
+ Unless otherwise noted, openPangu-Embedded-7B-DeepDiver model is licensed under the terms and conditions of OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0, which is intended to be used permissively and enable the further development of artificial intelligence technologies. Please refer to the [LICENSE](LICENSE) file located in the root directory of the model repository for details.
540
+
541
+ ## 7. Security Notice and Disclaimer
542
+
543
+ Due to the inherent technical limitations of the technologies relied upon by the openPangu-Embedded-7B-DeepDiver model and its framework, as well as the fact that AI-generated content is automatically produced by Pangu, Huawei cannot make any warranties regarding the following matters:
544
+
545
+ - The output of this Model is automatically generated via AI algorithms, it does not rule out the possibility that some of the information may be flawed, unreasonable, or cause discomfort, and the generated content does not represent Huawei's attitude or standpoint;
546
+ - There is no guarantee that this Model is 100% accurate, reliable, functional, timely, secure and safety, error-free, uninterrupted, continuously stable, or free of any faults;
547
+ - The output of this Model does not constitute any advices or decisions for you, and it does not guarantee the authenticity, completeness, accuracy, timeliness, legality, functionality, or practicality of the generated content. The generated content cannot replace professionals in medical, legal, and other fields in answering your questions. The generated content is for your reference only and does not represent any attitude, standpoint, or position of Huawei. You need to make independent judgments based on your actual situation, and Huawei does not assume any responsibilities;
548
+ - The inter-component communication of the DeepDiver MAS system does not include built-in data encryption or authentication mechanisms (e.g., tokens, signatures). You shall independently assess your security requirements and implement corresponding protective measures (such as deploying the system in an encrypted network, integrating SSL/TLS protocols, and enforcing component identity verification);
549
+ - Any security incidents (including but not limited to data leakage, unauthorized access, and business losses) arising from the lack of encryption/authentication mechanisms shall be borne by the user of the system. Huawei shall bear no responsibility therefor.
550
+
551
+ ## 8. Contact Us
552
+ If you have any comments or suggestions, please submit an issue or contact openPangu@huawei.com.
553
+
554
+ ---
checklist.chk ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 880418a2e195ea5221f700fabc446de4af401c777ea66288eedcfe3ca7861a58 *./config.json
2
+ 7302a0fdc386e723b16ad5860787b50a7bff39d30549dae986e073b193f2beb4 *./configuration_openpangu_dense.py
3
+ 5cbfc09f10ae85f0e9bebc1281541dcc7107d86e34282839277782cbb146117d *./generation_config.json
4
+ 9bf645e8399be6d99000eae64bd172b5c457d6d2c44d2257b47eb97a3c41aeda *./model.safetensors.index.json
5
+ f15eaf322af8a0b0f16b26795eb68af836179413d3dbfa4dc44505db6c8b0d6f *./modeling_openpangu_dense.py
6
+ 7b8ec6cd94b1921560d37755c7c0c08280c1f9123195d14d352ad0607788f7f6 *./model-00001-of-00004.safetensors
7
+ fc05d80f52ce44d1433a942e867bf61ea49eb1eebb0700312f76d6b3a3dee917 *./model-00002-of-00004.safetensors
8
+ 1ed37f38214c755b51bea06a71e154c9ea27670eb3b8506c06addcfbea2066f2 *./model-00003-of-00004.safetensors
9
+ 0145e255ba965ed0e75164a037b9a0137c5e5c12ffc42463ff82568054fe0186 *./model-00004-of-00004.safetensors
10
+ c1f2d87f855b994039c52b1e83c8a7f3d71a2d1eb52946c4a2e862e99f19d8b3 *./modular_openpangu_dense.py
11
+ b34cf5e7c7660889303b6e2d0a346c440356385c9db551d06f6615cf9fc600d1 *./special_tokens_map.json
12
+ 6b16f1558c0cd4ae6ef1a2c605713be0a514f50e1ce2d2c878979ce988c148ec *./tokenizer.model
13
+ acb88eac57f8765fedf34e9c10bc16d55c46f0902b0fea74fbf041daca2667ae *./tokenizer_config.json
14
+ c98602d6d1f61792a8bd3393972bbbe7409a205c0bb6299394c74287c26bd723 *./tokenization_openpangu.py
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "PanguEmbeddedForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_openpangu_dense.PanguEmbeddedConfig",
7
+ "AutoModel": "modeling_openpangu_dense.PanguEmbeddedModel",
8
+ "AutoModelForCausalLM": "modeling_openpangu_dense.PanguEmbeddedForCausalLM"
9
+ },
10
+ "bias": true,
11
+ "attention_dropout": 0.0,
12
+ "bos_token_id": 1,
13
+ "pad_token_id": 0,
14
+ "eos_token_id": 45892,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 4096,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 12800,
19
+ "max_position_embeddings": 147456,
20
+ "model_type": "PanguEmbedded",
21
+ "num_attention_heads": 32,
22
+ "num_hidden_layers": 34,
23
+ "num_key_value_heads": 8,
24
+ "rms_norm_eps": 1e-05,
25
+ "rope_theta": 16000000.0,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "bfloat16",
28
+ "transformers_version": "4.53.2",
29
+ "use_cache": true,
30
+ "vocab_size": 153376
31
+ }
configuration_openpangu_dense.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+
4
+ from transformers.utils import logging
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+ class PanguEmbeddedConfig(PretrainedConfig):
12
+
13
+ model_type = "PanguEmbedded"
14
+ _auto_class = "AutoConfig"
15
+
16
+ def __init__(
17
+ self,
18
+ vocab_size=153376,
19
+ hidden_size=4096,
20
+ intermediate_size=12800,
21
+ num_hidden_layers=34,
22
+ num_attention_heads=32,
23
+ num_key_value_heads=8,
24
+ hidden_act="silu",
25
+ max_position_embeddings=147456,
26
+ initializer_range=0.02,
27
+ rms_norm_eps=1e-5,
28
+ use_cache=True,
29
+ pad_token_id=0,
30
+ bos_token_id=1,
31
+ eos_token_id=45892,
32
+ tie_word_embeddings=False,
33
+ rope_theta=16000000.0,
34
+ bias=True,
35
+ **kwargs,
36
+ ):
37
+ self.vocab_size = vocab_size
38
+ self.max_position_embeddings = max_position_embeddings
39
+ self.hidden_size = hidden_size
40
+ self.intermediate_size = intermediate_size
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.num_attention_heads = num_attention_heads
43
+ self.num_key_value_heads = num_key_value_heads
44
+ self.hidden_act = hidden_act
45
+ self.initializer_range = initializer_range
46
+ self.rms_norm_eps = rms_norm_eps
47
+ self.use_cache = use_cache
48
+ self.rope_theta = rope_theta
49
+ self.bias = bias
50
+ super().__init__(
51
+ pad_token_id=pad_token_id,
52
+ bos_token_id=bos_token_id,
53
+ eos_token_id=eos_token_id,
54
+ tie_word_embeddings=tie_word_embeddings,
55
+ **kwargs,
56
+ )
deepdiver_v2/cli/README.md ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLI Demo for DeepDiver Long Writer Multi-Agent System
2
+
3
+ This CLI demo showcases the multi-agent system that coordinates between PlannerAgent, InformationSeekerAgent, and WriterAgent to handle complex queries and generate comprehensive long-form content.
4
+
5
+ ## Features
6
+
7
+ - 🧠 **PlannerAgent**: Orchestrates the entire process and coordinates sub-agents
8
+ - 🔍 **InformationSeekerAgent**: Performs web research and gathers information
9
+ - ✍️ **WriterAgent**: Creates comprehensive long-form content
10
+ - 📊 **Real-time Visualization**: Shows tool calls, reasoning traces, and sub-agent responses
11
+ - ⚙️ **Configuration Management**: Loads settings from .env files
12
+
13
+ ## Setup
14
+
15
+ ### 1. Install Dependencies
16
+
17
+ ```bash
18
+ cd deepdiver_v2
19
+ pip install -r requirements.txt
20
+ ```
21
+
22
+ ### 2. Configure Environment
23
+
24
+ Create a `.env` file in the `config/` directory:
25
+
26
+ ```bash
27
+ # From the project root
28
+ cp env.template config/.env
29
+ ```
30
+
31
+ Then edit `config/.env` with your settings:
32
+
33
+ ```bash
34
+ # Custom LLM Service Configuration
35
+ MODEL_REQUEST_URL=http://your-llm-service-endpoint/v1/chat/completions
36
+ MODEL_REQUEST_TOKEN=your-service-token
37
+ MODEL_NAME=pangu_auto
38
+
39
+ # MCP Server Configuration
40
+ MCP_SERVER_URL=http://localhost:6274/mcp
41
+ MCP_AUTH_TOKEN=
42
+ MCP_USE_STDIO=true
43
+
44
+ # Agent Iteration Limits
45
+ PLANNER_MAX_ITERATION=20
46
+ INFORMATION_SEEKER_MAX_ITERATION=30
47
+ WRITER_MAX_ITERATION=20
48
+
49
+ # Mode
50
+ PLANNER_MODE=auto # auto, writing, qa
51
+
52
+ # Other settings...
53
+ ```
54
+
55
+ ### 3. Start Required Services
56
+
57
+ Make sure your MCP server is running:
58
+
59
+ ```bash
60
+ # Start MCP server (if needed)
61
+ python src/tools/mcp_server_standard.py
62
+ ```
63
+
64
+ ## Usage
65
+
66
+ ### Interactive Mode (Recommended)
67
+
68
+ ```bash
69
+ python cli/demo.py
70
+ ```
71
+
72
+ This will start an interactive session where you can enter queries and see the full execution flow.
73
+
74
+ ### Single Query Mode
75
+
76
+ ```bash
77
+ python cli/demo.py -q "Write a comprehensive analysis of artificial intelligence trends in 2024"
78
+ ```
79
+
80
+ ### Configuration Only
81
+
82
+ ```bash
83
+ python cli/demo.py --config-only
84
+ ```
85
+
86
+ ### Debug Mode (Verbose Logging)
87
+
88
+ ```bash
89
+ python cli/demo.py --debug -q "Debug a specific query"
90
+ ```
91
+
92
+ ### Quiet Mode (Clean Output)
93
+
94
+ ```bash
95
+ python cli/demo.py --quiet -q "Run with minimal output"
96
+ ```
97
+
98
+ ### Create Sample Configuration
99
+
100
+ ```bash
101
+ python cli/demo.py --create-env
102
+ ```
103
+
104
+ ## Example Queries
105
+
106
+ ### For Information Seeking Tasks:
107
+ - "What are the latest developments in quantum computing?"
108
+ - "Research the current state of renewable energy adoption globally"
109
+ - "Find information about recent AI breakthroughs in healthcare"
110
+
111
+ ### For Long-form Writing Tasks:
112
+ - "Write a comprehensive report on the impact of AI on education"
113
+ - "Create an in-depth analysis of climate change mitigation strategies"
114
+ - "Generate a detailed guide on sustainable business practices"
115
+
116
+ ## Demo Flow Visualization
117
+
118
+ The demo provides rich visual feedback showing:
119
+
120
+ 1. **🚀 Task Initiation**: Shows the user query and planner startup
121
+ 2. **🧠 Agent Reasoning**: Displays the planner's reasoning at each step
122
+ 3. **🔧 Tool Calls**: Shows what tools are being called with their arguments
123
+ 4. **📋 Tool Results**: Displays the results from each tool execution
124
+ 5. **🤝 Sub-Agent Execution**: Shows when sub-agents (InformationSeeker, Writer) are invoked
125
+ 6. **📊 Sub-Agent Results**: Displays results from sub-agent executions
126
+ 7. **🏁 Final Result**: Shows the complete execution summary
127
+ 8. **🔍 Execution Trace**: Detailed step-by-step trace of the entire process
128
+
129
+ ## Output Modes
130
+
131
+ The CLI demo supports different output modes for different use cases:
132
+
133
+ ### Default Mode
134
+ Shows the full rich interface with welcome screen, progress bars, and detailed visualization of all agent interactions.
135
+
136
+ ### Quiet Mode (`--quiet`)
137
+ Suppresses all non-essential output, showing only final results. Useful for:
138
+ - Integration with scripts or automation
139
+ - Focusing on results without process details
140
+ - Running in environments where rich output isn't needed
141
+
142
+ ### Debug Mode (`--debug`)
143
+ Enables verbose logging with timestamps, showing all internal system messages. Useful for:
144
+ - Troubleshooting configuration issues
145
+ - Understanding detailed agent behavior
146
+ - Development and debugging
147
+
148
+ ```bash
149
+ # Examples of different modes
150
+ python cli/demo.py --query "Test query" # Default rich mode
151
+ python cli/demo.py --quiet --query "Test query" # Minimal output
152
+ python cli/demo.py --debug --query "Test query" # Verbose debugging
153
+ ```
154
+
155
+ ## Troubleshooting
156
+
157
+ ### Configuration Issues
158
+
159
+ If you see configuration errors:
160
+
161
+ 1. Ensure `config/.env` exists and is properly formatted
162
+ 2. Check that all required environment variables are set
163
+ 3. Verify your LLM service endpoint is accessible
164
+ 4. Confirm MCP server is running and reachable
165
+ 5. Use `--debug` mode to see detailed error messages
166
+
167
+ ### Agent Initialization Issues
168
+
169
+ If agent initialization fails:
170
+
171
+ 1. Check MCP server connectivity
172
+ 2. Verify model configuration is correct
173
+ 3. Ensure required permissions for workspace directories
174
+ 4. Check log output for specific error messages
175
+
176
+ ### Tool Execution Issues
177
+
178
+ If tool calls fail:
179
+
180
+ 1. Verify MCP server is running and has the required tools
181
+ 2. Check network connectivity for web search/crawler tools
182
+ 3. Ensure workspace directories exist and are writable
183
+ 4. Review tool arguments for correctness
184
+
185
+ ## Advanced Usage
186
+
187
+ ### Custom Sub-Agent Configurations
188
+
189
+ You can customize sub-agent behavior by modifying the configurations in the demo script:
190
+
191
+ ```python
192
+ sub_agent_configs = {
193
+ "information_seeker": {
194
+ "model": "your-model",
195
+ "max_iterations": 30,
196
+ },
197
+ "writer": {
198
+ "model": "your-model",
199
+ "max_iterations": 20,
200
+ "temperature": 0.3,
201
+ "max_tokens": 16384
202
+ }
203
+ }
204
+ ```
205
+
206
+ ### Monitoring and Debugging
207
+
208
+ Enable debug mode in your `.env` file:
209
+
210
+ ```bash
211
+ DEBUG_MODE=true
212
+ ```
213
+
214
+ This will provide more detailed logging and error information.
215
+
216
+ ## Architecture Overview
217
+
218
+ The demo showcases a sophisticated multi-agent architecture:
219
+
220
+ ```
221
+ User Query
222
+
223
+ PlannerAgent (Coordinator)
224
+
225
+ ├── InformationSeekerAgent (Research)
226
+ │ ├── Web Search Tools
227
+ │ ├── URL Crawling Tools
228
+ │ ├── Document Analysis Tools
229
+ │ └── File Management Tools
230
+
231
+ └── WriterAgent (Content Generation)
232
+ ├── File Reading Tools
233
+ ├── Document QA Tools
234
+ ├── Content Synthesis
235
+ └── Long-form Writing
236
+ ```
237
+
238
+ Each agent follows the ReAct pattern (Reasoning + Acting) with iterative refinement until task completion.
deepdiver_v2/cli/demo.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ #!/usr/bin/env python3
3
+ """
4
+ CLI Demo for DeepDiver Long Writer Multi-Agent System
5
+
6
+ This demo showcases the multi-agent system that includes:
7
+ - PlannerAgent: Coordinates and orchestrates the entire process
8
+ - InformationSeekerAgent: Gathers and researches information
9
+ - WriterAgent: Creates long-form content
10
+
11
+ Features:
12
+ - Loads configuration from config/.env file
13
+ - Shows real-time tool calls and reasoning traces
14
+ - Displays sub-agent responses and interactions
15
+ - Visualizes the complete execution flow
16
+ - Query preprocessing for safety and task suitability check
17
+ """
18
+
19
+ import os
20
+ import sys
21
+ import json
22
+ import time
23
+ import logging
24
+ import argparse
25
+ import requests
26
+ from pathlib import Path
27
+ from typing import Dict, Any, List, Optional
28
+ from rich.console import Console
29
+ from rich.table import Table
30
+ from rich.panel import Panel
31
+ from rich.syntax import Syntax
32
+ from rich.markdown import Markdown
33
+
34
+ # Add project root to Python path
35
+ project_root = Path(__file__).parent.parent
36
+ sys.path.insert(0, str(project_root))
37
+
38
+
39
+ # Configure logging to keep the CLI clean
40
+ def setup_clean_logging(debug_mode: bool = False):
41
+ """Configure logging to show only relevant information for the demo"""
42
+ if debug_mode:
43
+ # Debug mode: show all logs with timestamps
44
+ logging.basicConfig(
45
+ level=logging.DEBUG,
46
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
47
+ datefmt='%H:%M:%S'
48
+ )
49
+ else:
50
+ # Clean demo mode: suppress verbose logs
51
+
52
+ # Suppress specific noisy loggers
53
+ noisy_loggers = [
54
+ 'httpx',
55
+ 'httpcore',
56
+ 'urllib3',
57
+ 'src.tools.mcp_client',
58
+ # 'src.agents.base_agent',
59
+ 'config.config' # Also suppress config messages in quiet mode
60
+ ]
61
+
62
+ for logger_name in noisy_loggers:
63
+ logging.getLogger(logger_name).setLevel(logging.ERROR)
64
+
65
+ # Set up default clean logging before any imports
66
+ setup_clean_logging(debug_mode=False)
67
+
68
+ # Import the multi-agent system components
69
+ from config.config import get_config, reload_config
70
+ from src.agents.planner_agent import create_planner_agent
71
+ from src.agents.base_agent import AgentResponse
72
+
73
+ console = Console()
74
+
75
+
76
+ class DemoVisualizer:
77
+ """Visualizes the execution of the multi-agent system"""
78
+
79
+ def __init__(self, quiet_mode: bool = False):
80
+ self.console = console
81
+ self.execution_log = []
82
+ self.quiet_mode = quiet_mode
83
+
84
+ def _should_display(self, force: bool = False) -> bool:
85
+ """Check if output should be displayed based on quiet mode"""
86
+ return not self.quiet_mode or force
87
+
88
+ def show_welcome(self):
89
+ """Display welcome message and system info"""
90
+ if not self._should_display():
91
+ return
92
+
93
+ welcome_text = """
94
+ # 🤖 DeepDiver Long Writer Multi-Agent System Demo
95
+
96
+ This demo showcases an advanced multi-agent system for research and long-form content generation.
97
+
98
+ ## System Components:
99
+ - **🧠 PlannerAgent**: Orchestrates the entire process and coordinates sub-agents
100
+ - **🔍 InformationSeekerAgent**: Performs web research and gathers information
101
+ - **✍️ WriterAgent**: Creates comprehensive long-form content
102
+
103
+ ## Features:
104
+ - Real-time tool execution visualization
105
+ - Sub-agent response tracking
106
+ - Complete reasoning trace display
107
+ - Configuration management
108
+ - Query safety and suitability pre-check
109
+ """
110
+ self.console.print(Panel(Markdown(welcome_text), title="[bold blue]Welcome", border_style="blue"))
111
+
112
+ def show_config(self, config):
113
+ """Display current configuration"""
114
+ if not self._should_display():
115
+ return
116
+
117
+ config_table = Table(title="📋 System Configuration", show_header=True, header_style="bold magenta")
118
+ config_table.add_column("Setting", style="cyan", no_wrap=True)
119
+ config_table.add_column("Value", style="green")
120
+
121
+ # Safe config display (hide sensitive values)
122
+ safe_config = config.to_dict()
123
+ for key, value in safe_config.items():
124
+ if value is not None and str(value) != "None":
125
+ display_value = str(value)
126
+ if len(display_value) > 60:
127
+ display_value = display_value[:57] + "..."
128
+ config_table.add_row(key, display_value)
129
+
130
+ self.console.print(config_table)
131
+
132
+ def show_planner_start(self, query: str):
133
+ """Show planner starting execution"""
134
+ self.console.print(Panel(
135
+ f"[bold yellow]User Query:[/bold yellow] {query}\n\n"
136
+ f"[bold green]🚀 Starting PlannerAgent execution...[/bold green]",
137
+ title="[bold blue]Task Initiation",
138
+ border_style="green"
139
+ ))
140
+
141
+ def show_reasoning_step(self, iteration: int, reasoning: str):
142
+ """Display reasoning step"""
143
+ self.console.print(Panel(
144
+ Markdown(f"**Iteration {iteration} - Reasoning:**\n\n{reasoning}"),
145
+ title=f"[bold yellow]🧠 Agent Reasoning (Step {iteration})",
146
+ border_style="yellow"
147
+ ))
148
+
149
+ def show_tool_call(self, iteration: int, tool_name: str, arguments: Dict[str, Any]):
150
+ """Display tool call"""
151
+ args_json = json.dumps(arguments, indent=2, ensure_ascii=False)
152
+
153
+ self.console.print(Panel(
154
+ f"[bold cyan]Tool:[/bold cyan] {tool_name}\n\n"
155
+ f"[bold cyan]Arguments:[/bold cyan]\n{Syntax(args_json, 'json', theme='monokai', line_numbers=True)}",
156
+ title=f"[bold cyan]🔧 Tool Call (Step {iteration})",
157
+ border_style="cyan"
158
+ ))
159
+
160
+ def show_tool_result(self, iteration: int, tool_name: str, result: Dict[str, Any]):
161
+ """Display tool result"""
162
+ success = result.get("success", True)
163
+ status_icon = "✅" if success else "❌"
164
+ status_color = "green" if success else "red"
165
+
166
+ # Format result for display
167
+ if success and "data" in result:
168
+ display_result = result["data"]
169
+ elif "error" in result:
170
+ display_result = {"error": result["error"]}
171
+ else:
172
+ display_result = result
173
+
174
+ result_text = json.dumps(display_result, indent=2, ensure_ascii=False)
175
+ if len(result_text) > 1000:
176
+ result_text = result_text[:997] + "..."
177
+
178
+ self.console.print(Panel(
179
+ f"[bold {status_color}]Status:[/bold {status_color}] {status_icon} {'Success' if success else 'Failed'}\n\n"
180
+ f"[bold {status_color}]Result:[/bold {status_color}]\n{Syntax(result_text, 'json', theme='monokai', line_numbers=True)}",
181
+ title=f"[bold {status_color}]📋 Tool Result: {tool_name} (Step {iteration})",
182
+ border_style=status_color
183
+ ))
184
+
185
+ def show_sub_agent_execution(self, agent_name: str, task_content: str):
186
+ """Show sub-agent starting execution"""
187
+ self.console.print(Panel(
188
+ f"[bold magenta]Agent:[/bold magenta] {agent_name}\n\n"
189
+ f"[bold magenta]Task:[/bold magenta] {task_content[:500]}{'...' if len(task_content) > 500 else ''}",
190
+ title="[bold magenta]🤝 Sub-Agent Execution",
191
+ border_style="magenta"
192
+ ))
193
+
194
+ def show_sub_agent_result(self, agent_name: str, result: Dict[str, Any]):
195
+ """Show sub-agent execution result"""
196
+ success = result.get("success", True)
197
+ status_icon = "✅" if success else "❌"
198
+ status_color = "green" if success else "red"
199
+
200
+ # Extract key information
201
+ iterations = result.get("iterations", 0)
202
+ execution_time = result.get("execution_time", 0)
203
+
204
+ summary = f"[bold {status_color}]Status:[/bold {status_color}] {status_icon} {'Success' if success else 'Failed'}\n"
205
+ summary += f"[bold blue]Iterations:[/bold blue] {iterations}\n"
206
+ summary += f"[bold blue]Execution Time:[/bold blue] {execution_time:.2f}s\n\n"
207
+
208
+ if success and "data" in result:
209
+ data = result["data"]
210
+ if isinstance(data, dict):
211
+ for key, value in data.items():
212
+ if isinstance(value, str) and len(value) > 200:
213
+ summary += f"[bold blue]{key}:[/bold blue] {value[:197]}...\n"
214
+ else:
215
+ summary += f"[bold blue]{key}:[/bold blue] {value}\n"
216
+ elif "error" in result:
217
+ summary += f"[bold red]Error:[/bold red] {result['error']}\n"
218
+
219
+ self.console.print(Panel(
220
+ summary,
221
+ title=f"[bold {status_color}]📊 Sub-Agent Result: {agent_name}",
222
+ border_style=status_color
223
+ ))
224
+
225
+ def show_final_result(self, response: AgentResponse):
226
+ """Display final execution result"""
227
+ # Always show final results, even in quiet mode
228
+ if not self._should_display(force=True):
229
+ return
230
+
231
+ success = response.success
232
+ status_icon = "✅" if success else "❌"
233
+ status_color = "green" if success else "red"
234
+
235
+ summary = f"[bold {status_color}]Final Status:[/bold {status_color}] {status_icon} {'Completed Successfully' if success else 'Failed'}\n"
236
+ summary += f"[bold blue]Total Iterations:[/bold blue] {response.iterations}\n"
237
+ summary += f"[bold blue]Total Execution Time:[/bold blue] {response.execution_time:.2f}s\n"
238
+ summary += f"[bold blue]Agent:[/bold blue] {response.agent_name}\n\n"
239
+
240
+ if success and response.result:
241
+ if isinstance(response.result, dict):
242
+ for key, value in response.result.items():
243
+ if isinstance(value, str) and len(value) > 3000:
244
+ summary += f"[bold blue]{key}:[/bold blue] {value[:2997]}...\n\n"
245
+ else:
246
+ summary += f"[bold blue]{key}:[/bold blue] {value}\n\n"
247
+ elif response.error:
248
+ summary += f"[bold red]Error:[/bold red] {response.error}\n"
249
+
250
+ self.console.print(Panel(
251
+ summary,
252
+ title=f"[bold {status_color}]🏁 Final Result",
253
+ border_style=status_color
254
+ ))
255
+
256
+ def show_reasoning_trace(self, trace: List[Dict[str, Any]]):
257
+ """Display detailed reasoning trace"""
258
+ if not trace:
259
+ return
260
+
261
+ trace_table = Table(title="🔍 Detailed Execution Trace", show_header=True, header_style="bold cyan")
262
+ trace_table.add_column("Step", style="cyan", width=8)
263
+ trace_table.add_column("Type", style="magenta", width=12)
264
+ trace_table.add_column("Details", style="white")
265
+
266
+ for i, step in enumerate(trace, 1):
267
+ step_type = step.get("type", "unknown")
268
+
269
+ if step_type == "reasoning":
270
+ content = step.get("content", "")[:100] + ("..." if len(step.get("content", "")) > 100 else "")
271
+ trace_table.add_row(str(i), "🧠 Reasoning", content)
272
+
273
+ elif step_type == "action":
274
+ tool = step.get("tool", "")
275
+ result_status = "✅" if step.get("result", {}).get("success", True) else "❌"
276
+ trace_table.add_row(str(i), "🔧 Tool Call", f"{result_status} {tool}")
277
+
278
+ elif step_type == "error":
279
+ error = step.get("error", "")[:100] + ("..." if len(step.get("error", "")) > 100 else "")
280
+ trace_table.add_row(str(i), "❌ Error", error)
281
+
282
+ self.console.print(trace_table)
283
+
284
+ def show_unsupported_response(self):
285
+ """Display the fixed response for unsupported queries"""
286
+ # Always show this response, even in quiet mode
287
+ self.console.print(Panel(
288
+ "Sorry, your question is not within the current scope of tasks for DeepDiver-V2. Please try asking a question related to long-form writing or complex knowledge Q&A instead.",
289
+ title="[bold yellow]❌ Unsupported Query",
290
+ border_style="yellow"
291
+ ))
292
+
293
+
294
+ class AgentExecutionMonitor:
295
+ """Monitors agent execution and provides real-time feedback"""
296
+
297
+ def __init__(self, visualizer: DemoVisualizer):
298
+ self.visualizer = visualizer
299
+ self.current_iteration = 0
300
+
301
+ def on_reasoning_step(self, iteration: int, reasoning: str):
302
+ """Called when agent performs reasoning"""
303
+ self.visualizer.show_reasoning_step(iteration, reasoning)
304
+
305
+ def on_tool_call(self, iteration: int, tool_name: str, arguments: Dict[str, Any]):
306
+ """Called when agent makes a tool call"""
307
+ self.visualizer.show_tool_call(iteration, tool_name, arguments)
308
+
309
+ # Check for sub-agent assignments
310
+ if "assign_" in tool_name and "task" in tool_name:
311
+ if "tasks" in arguments:
312
+ for task in arguments.get("tasks", []):
313
+ task_content = task.get("task_content", "")
314
+ self.visualizer.show_sub_agent_execution("InformationSeeker", task_content)
315
+ elif "task_content" in arguments:
316
+ task_content = arguments.get("task_content", "")
317
+ self.visualizer.show_sub_agent_execution("Writer", task_content)
318
+
319
+ def on_tool_result(self, iteration: int, tool_name: str, result: Dict[str, Any]):
320
+ """Called when tool execution completes"""
321
+ self.visualizer.show_tool_result(iteration, tool_name, result)
322
+
323
+ # Show sub-agent results if this was an assignment
324
+ if "assign_" in tool_name and "task" in tool_name:
325
+ if "data" in result and "tasks" in result["data"]:
326
+ for task_result in result["data"]["tasks"]:
327
+ agent_name = task_result.get("agent_name", "InformationSeeker")
328
+ self.visualizer.show_sub_agent_result(agent_name, task_result)
329
+ elif "data" in result:
330
+ agent_name = result["data"].get("agent_name", "Writer")
331
+ self.visualizer.show_sub_agent_result(agent_name, result["data"])
332
+
333
+
334
+ def classify_query(query: str, config) -> Dict[str, Any]:
335
+ """
336
+ Classify user query into one of three categories using LLM:
337
+ 1. SAFE_SENSITIVE: Contains unsafe content (insults, political risks, etc.)
338
+ 2. NON_KNOWLEDGE: Non-knowledge intensive (no need for research, e.g., greetings, simple calculations)
339
+ 3. NORMAL: Requires processing (long-form writing or complex knowledge Q&A)
340
+
341
+ Returns:
342
+ Dict with 'category' (str) and 'reasoning' (str)
343
+ """
344
+ logger = logging.getLogger(__name__)
345
+
346
+ # Get model configuration
347
+ model_config = config.get_custom_llm_config()
348
+ pangu_url = model_config.get('url') or os.getenv('MODEL_REQUEST_URL', '')
349
+ model_token = model_config.get('token') or os.getenv('MODEL_REQUEST_TOKEN', '')
350
+
351
+ # Validate model configuration
352
+ if not pangu_url:
353
+ logger.error("Model URL not configured for query classification")
354
+ # Fallback to NORMAL category if model config is missing
355
+ return {
356
+ "category": "NORMAL",
357
+ "reasoning": "模型配置不完整,跳过分类检查,默认按正常任务处理"
358
+ }
359
+
360
+ headers = {'Content-Type': 'application/json', 'csb-token': model_token}
361
+
362
+ # Classification prompt (detailed instructions for accurate categorization)
363
+ prompt_template = """
364
+ 你是一个Query分类器,需要将用户输入的查询分为以下三类,并给出明确的分类理由:
365
+
366
+ 1. 【SAFE_SENSITIVE - 安全敏感内容】:包含以下任何一种情况的查询
367
+ - 辱骂、侮辱性语言(如脏话、人身攻击)
368
+ - 涉及政治敏感内容(如国家领导人、敏感政治事件、舆情风险话题)
369
+ - 违法违规内容(如暴力、色情、恐怖主义相关)
370
+ - 歧视性言论(种族、性别、宗教等歧视)
371
+
372
+ 2. 【NON_KNOWLEDGE - 非知识密集型任务】:不需要进行信息搜索的简单查询
373
+ - 问候语(如"你好"、"早上好"、"嗨")
374
+ - 简单计算(如"1+1等于几"、"25乘以4是多少")
375
+ - 基础闲聊(如"你是谁")
376
+ - 指令性语句(如"退出"、"帮助"、"开始")
377
+ - 不需要信息收集的简单问题
378
+
379
+ 3. 【NORMAL - 正常任务】:不包含安全敏感内容,需要进行信息搜索或长文写作的任务
380
+ - 简单的信息收集任务 (如"华为成立时间是什么时候")
381
+ - 复杂知识问答(如"ACL2025举办地有什么美食推荐")
382
+ - 长文写作任务(如"写一篇关于气候变化影响的5000字报告")
383
+ - 需要数据支持的分析(如"2023年全球经济增长数据及分析")
384
+ - 专业领域研究(如"机器学习在医疗诊断中的应用案例")
385
+
386
+ 分类要求:
387
+ - 严格按照上述定义进行分类,不要遗漏任何关键特征
388
+ - 优先判断是否为SAFE_SENSITIVE,其次判断是否为NON_KNOWLEDGE,最后才是NORMAL
389
+ - 必须提供清晰的分类理由,说明为什么属于该类别
390
+ - 输出格式必须严格遵循:先输出分类理由的思考,然后换行输出分类结果(SAFE_SENSITIVE/NON_KNOWLEDGE/NORMAL)
391
+
392
+ 示例1(SAFE_SENSITIVE):
393
+ 该查询包含辱骂性语言"XXX",符合安全敏感内容的定义,属于需要拦截的内容
394
+ SAFE_SENSITIVE
395
+
396
+ 示例2(NON_KNOWLEDGE):
397
+ 该查询是简单的问候语"你好",不需要进行信息搜索,属于非知识密集型任务
398
+ NON_KNOWLEDGE
399
+
400
+ 示例3(NORMAL):
401
+ 该查询不包含安全敏感内容,要求撰写关于"区块链技术在金融领域的应用"的长文,需要进行信息收集、案例研究和深度分析,属于正常的长文写作任务
402
+ NORMAL
403
+
404
+ 用户输入query:$query"""
405
+
406
+ # Prepare conversation history
407
+ conversation_history = [
408
+ {"role": "user", "content": prompt_template.replace("$query", query) + " /no_think"}
409
+ ]
410
+
411
+ try:
412
+ # Call LLM with retry logic
413
+ retry_num = 1
414
+ max_retry_num = 3
415
+ while retry_num <= max_retry_num:
416
+ try:
417
+ response = requests.post(
418
+ url=pangu_url,
419
+ headers=headers,
420
+ json={
421
+ "model": config.model_name,
422
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<s>[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{'<s>[unused9]系统:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}",
423
+ "spaces_between_special_tokens": False,
424
+ "messages": conversation_history,
425
+ "temperature": 0.1, # Low temperature for deterministic classification
426
+ "max_tokens": 5000,
427
+ },
428
+ timeout=model_config.get("timeout", 60)
429
+ )
430
+
431
+ response_json = response.json()
432
+ logger.debug(f"Classification API response: {json.dumps(response_json, indent=2)}")
433
+
434
+ # Extract and parse result
435
+ assistant_message = response_json["choices"][0]["message"]["content"].strip()
436
+ lines = assistant_message.split('\n', 1)
437
+
438
+ if len(lines) < 2:
439
+ raise ValueError(f"Invalid response format: {assistant_message}")
440
+
441
+ reasoning = lines[0].strip()
442
+ category = lines[1].strip() if len(lines) > 1 else "NORMAL"
443
+
444
+ # Validate category
445
+ valid_categories = ["SAFE_SENSITIVE", "NON_KNOWLEDGE", "NORMAL"]
446
+ if category not in valid_categories:
447
+ logger.warning(f"Invalid category '{category}', using fallback NORMAL")
448
+ category = "NORMAL"
449
+ reasoning = f"模型返回无效分类 '{category}',默认按正常任务处理。原始理由:{reasoning}"
450
+
451
+ return {
452
+ "category": category,
453
+ "reasoning": reasoning
454
+ }
455
+
456
+ except Exception as e:
457
+ logger.error(f"Classification attempt {retry_num} failed: {str(e)}")
458
+ if retry_num == max_retry_num:
459
+ raise
460
+ time.sleep(2) # Wait before retry
461
+ retry_num += 1
462
+
463
+ except Exception as e:
464
+ logger.error(f"Query classification failed: {str(e)}")
465
+ # Fallback to NORMAL category if classification fails
466
+ return {
467
+ "category": "NORMAL",
468
+ "reasoning": f"分类服务暂时不可用(错误:{str(e)[:100]}...),默认按正常任务处理"
469
+ }
470
+
471
+
472
+ def load_environment_config(quiet: bool = False):
473
+ """Load configuration from .env file"""
474
+ try:
475
+ # Check for .env file in config directory
476
+ config_dir = Path(__file__).parent.parent / "config"
477
+ env_file = config_dir / ".env"
478
+
479
+ if not env_file.exists():
480
+ if not quiet:
481
+ console.print(f"[yellow]⚠️ No .env file found at {env_file}[/yellow]")
482
+ console.print(f"[yellow]💡 Please copy env.template to config/.env and configure your settings[/yellow]")
483
+ return None
484
+
485
+ # Reload configuration to pick up .env file
486
+ reload_config()
487
+ config = get_config()
488
+
489
+ if not quiet:
490
+ console.print("[green]✅ Configuration loaded successfully[/green]")
491
+ return config
492
+
493
+ except Exception as e:
494
+ if not quiet:
495
+ console.print(f"[red]❌ Failed to load configuration: {e}[/red]")
496
+ return None
497
+
498
+
499
+ def create_sample_env_file():
500
+ """Create a sample .env file for demo purposes"""
501
+ config_dir = Path(__file__).parent.parent / "config"
502
+ env_file = config_dir / ".env"
503
+
504
+ if env_file.exists():
505
+ return
506
+
507
+ # Copy from template
508
+ template_file = Path(__file__).parent.parent / "env.template"
509
+ if template_file.exists():
510
+ import shutil
511
+ shutil.copy2(template_file, env_file)
512
+ console.print(f"[green]✅ Created .env file from template at {env_file}[/green]")
513
+ console.print("[yellow]⚠️ Please edit the .env file with your actual configuration values[/yellow]")
514
+ else:
515
+ console.print(f"[red]❌ Could not find env.template to copy[/red]")
516
+
517
+
518
+ def run_demo_query(planner, query: str, visualizer: DemoVisualizer, config) -> Optional[AgentResponse]:
519
+ """Run a demo query through the planner with preprocessing"""
520
+
521
+ # Step 1: Show query information
522
+ visualizer.show_planner_start(query)
523
+
524
+ # Step 2: Query classification (preprocessing)
525
+ classification_result = classify_query(query, config)
526
+
527
+ # Step 3: Branch processing based on classification
528
+ unsupported_categories = ["SAFE_SENSITIVE", "NON_KNOWLEDGE"]
529
+ if classification_result["category"] in unsupported_categories:
530
+ # Show fixed response for unsupported queries
531
+ visualizer.show_unsupported_response()
532
+ return None
533
+
534
+ # Step 4: Process normal query (original flow)
535
+ try:
536
+ # Execute the query
537
+ with console.status("[bold green]Executing planner task...", spinner="dots"):
538
+ response = planner.execute_task(query)
539
+
540
+ # Show final results
541
+ visualizer.show_final_result(response)
542
+
543
+ # Show detailed trace if available
544
+ if hasattr(response, 'reasoning_trace') and response.reasoning_trace:
545
+ visualizer.show_reasoning_trace(response.reasoning_trace)
546
+
547
+ return response
548
+
549
+ except Exception as e:
550
+ console.print(f"[red]❌ Error during execution: {e}[/red]")
551
+ return None
552
+
553
+
554
+ def main():
555
+ """Main CLI demo function"""
556
+ parser = argparse.ArgumentParser(description="DeepDiver Multi-Agent System Demo")
557
+ parser.add_argument("--query", "-q", type=str, help="Query to execute (interactive mode if not provided)")
558
+ parser.add_argument("--config-only", "-c", action="store_true", help="Only show configuration and exit")
559
+ parser.add_argument("--create-env", "-e", action="store_true", help="Create sample .env file from template")
560
+ parser.add_argument("--debug", "-d", action="store_true", help="Enable debug mode with verbose logging")
561
+ parser.add_argument("--quiet", help="Suppress all non-essential output")
562
+
563
+ args = parser.parse_args()
564
+
565
+ # Setup logging based on arguments (re-configure if debug mode is requested)
566
+ if args.debug:
567
+ setup_clean_logging(debug_mode=True)
568
+
569
+ # Initialize visualizer
570
+ visualizer = DemoVisualizer(quiet_mode=args.quiet)
571
+ if not args.quiet:
572
+ visualizer.show_welcome()
573
+
574
+ # Create sample .env file if requested
575
+ if args.create_env:
576
+ create_sample_env_file()
577
+ return 0
578
+
579
+ # Load configuration
580
+ config = load_environment_config(quiet=args.quiet)
581
+ if not config:
582
+ if not args.quiet:
583
+ console.print("[red]❌ Cannot proceed without valid configuration[/red]")
584
+ console.print("[yellow]💡 Use --create-env to create a sample configuration file[/yellow]")
585
+ return 1
586
+
587
+ # Show configuration
588
+ visualizer.show_config(config)
589
+
590
+ if args.config_only:
591
+ return 0
592
+
593
+ # Initialize planner agent
594
+ try:
595
+ if not args.quiet:
596
+ console.print("[blue]🔄 Initializing PlannerAgent...[/blue]")
597
+
598
+ # Create planner with sub-agent configurations
599
+ sub_agent_configs = {
600
+ "information_seeker": {
601
+ "model": config.model_name,
602
+ "max_iterations": config.information_seeker_max_iterations or 30,
603
+ },
604
+ "writer": {
605
+ "model": config.model_name,
606
+ "max_iterations": config.writer_max_iterations or 30,
607
+ "temperature": config.model_temperature,
608
+ "max_tokens": config.model_max_tokens
609
+ }
610
+ }
611
+
612
+ planner = create_planner_agent(
613
+ model=config.model_name,
614
+ max_iterations=config.planner_max_iterations or 40,
615
+ sub_agent_configs=sub_agent_configs
616
+ )
617
+
618
+ if not args.quiet:
619
+ console.print("[green]✅ PlannerAgent initialized successfully[/green]")
620
+
621
+ except Exception as e:
622
+ if not args.quiet:
623
+ console.print(f"[red]❌ Failed to initialize PlannerAgent: {e}[/red]")
624
+ return 1
625
+
626
+ # Handle query execution
627
+ if args.query:
628
+ # Single query mode
629
+ run_demo_query(planner, args.query, visualizer, config)
630
+ else:
631
+ # Interactive mode
632
+ if not args.quiet:
633
+ console.print("\n[bold blue]🎯 Interactive Mode[/bold blue]")
634
+ console.print("Enter your queries below. Type 'quit' or 'exit' to leave.")
635
+
636
+ while True:
637
+ try:
638
+ prompt_text = "\n[bold cyan]Enter your query:[/bold cyan] " if not args.quiet else "Query: "
639
+ query = console.input(prompt_text).strip()
640
+
641
+ if query.lower() in ['quit', 'exit', 'q']:
642
+ if not args.quiet:
643
+ console.print("[green]👋 Goodbye![/green]")
644
+ break
645
+
646
+ if not query:
647
+ continue
648
+
649
+ if not args.quiet:
650
+ console.print("\n" + "="*80 + "\n")
651
+ run_demo_query(planner, query, visualizer, config)
652
+ if not args.quiet:
653
+ console.print("\n" + "="*80 + "\n")
654
+
655
+ except KeyboardInterrupt:
656
+ if not args.quiet:
657
+ console.print("\n[yellow]⚠️ Interrupted by user[/yellow]")
658
+ break
659
+ except EOFError:
660
+ if not args.quiet:
661
+ console.print("\n[green]👋 Goodbye![/green]")
662
+ break
663
+
664
+ return 0
665
+
666
+
667
+ if __name__ == "__main__":
668
+ sys.exit(main())
deepdiver_v2/cli/run_demo.sh ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # DeepDiver Multi-Agent System CLI Demo Runner
4
+ # This script makes it easier to run the CLI demo with different options
5
+
6
+ set -e
7
+
8
+ # Colors for output
9
+ RED='\033[0;31m'
10
+ GREEN='\033[0;32m'
11
+ YELLOW='\033[1;33m'
12
+ BLUE='\033[0;34m'
13
+ NC='\033[0m' # No Color
14
+
15
+ # Function to print colored output
16
+ print_status() {
17
+ echo -e "${GREEN}[INFO]${NC} $1"
18
+ }
19
+
20
+ print_warning() {
21
+ echo -e "${YELLOW}[WARNING]${NC} $1"
22
+ }
23
+
24
+ print_error() {
25
+ echo -e "${RED}[ERROR]${NC} $1"
26
+ }
27
+
28
+ # Get script directory
29
+ SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
30
+ PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
31
+
32
+ # Function to show help
33
+ show_help() {
34
+ echo "DeepDiver Multi-Agent System CLI Demo Runner"
35
+ echo ""
36
+ echo "Usage: $0 [OPTIONS] [QUERY]"
37
+ echo ""
38
+ echo "Options:"
39
+ echo " -h, --help Show this help message"
40
+ echo " -i, --interactive Start interactive mode (default)"
41
+ echo " -c, --config-only Show configuration and exit"
42
+ echo " -e, --create-env Create sample .env file from template"
43
+ echo " -q, --query \"QUERY\" Execute a specific query"
44
+ echo " -d, --debug Enable debug mode with verbose logging"
45
+ echo " --quiet Suppress all non-essential output"
46
+ echo " --setup Install dependencies and setup"
47
+ echo ""
48
+ echo "Examples:"
49
+ echo " $0 --interactive"
50
+ echo " $0 --query \"Research the latest trends in AI\""
51
+ echo " $0 --config-only"
52
+ echo " $0 --debug --query \"Debug a specific query\""
53
+ echo " $0 --quiet --query \"Run quietly\""
54
+ echo " $0 --setup"
55
+ echo ""
56
+ }
57
+
58
+ # Function to setup the demo
59
+ setup_demo() {
60
+ print_status "Setting up DeepDiver CLI Demo..."
61
+
62
+ # Check if we're in the right directory
63
+ if [ ! -f "$PROJECT_ROOT/cli/demo.py" ]; then
64
+ print_error "Cannot find demo.py. Please run this script from the CLI directory or project root."
65
+ exit 1
66
+ fi
67
+
68
+ # Install dependencies
69
+ print_status "Installing Python dependencies..."
70
+ cd "$PROJECT_ROOT"
71
+
72
+ if [ -f "cli/requirements.txt" ]; then
73
+ pip install -r cli/requirements.txt
74
+ print_status "Dependencies installed successfully"
75
+ else
76
+ print_warning "requirements.txt not found, skipping dependency installation"
77
+ fi
78
+
79
+ # Check for .env file
80
+ if [ ! -f "config/.env" ]; then
81
+ print_warning "No .env file found in config/ directory"
82
+ print_status "Creating sample .env file from template..."
83
+
84
+ if [ -f "env.template" ]; then
85
+ cp env.template config/.env
86
+ print_status "Sample .env file created at config/.env"
87
+ print_warning "Please edit config/.env with your actual configuration values"
88
+ else
89
+ print_error "No env.template found. Please create config/.env manually"
90
+ fi
91
+ else
92
+ print_status ".env file found at config/.env"
93
+ fi
94
+
95
+ # Make demo script executable
96
+ chmod +x "$PROJECT_ROOT/cli/demo.py"
97
+ print_status "Made demo.py executable"
98
+
99
+ print_status "Setup complete! You can now run the demo with:"
100
+ echo " $0 --interactive"
101
+ }
102
+
103
+ # Function to run the demo
104
+ run_demo() {
105
+ local args=("$@")
106
+
107
+ # Change to project root
108
+ cd "$PROJECT_ROOT"
109
+
110
+ print_status "Starting DeepDiver CLI Demo..."
111
+ python cli/demo.py "${args[@]}"
112
+ }
113
+
114
+ # Parse command line arguments
115
+ DEMO_ARGS=()
116
+
117
+ while [[ $# -gt 0 ]]; do
118
+ case $1 in
119
+ -h|--help)
120
+ show_help
121
+ exit 0
122
+ ;;
123
+ --setup)
124
+ setup_demo
125
+ exit 0
126
+ ;;
127
+ -c|--config-only)
128
+ DEMO_ARGS+=("--config-only")
129
+ shift
130
+ ;;
131
+ -e|--create-env)
132
+ DEMO_ARGS+=("--create-env")
133
+ shift
134
+ ;;
135
+ -q|--query)
136
+ if [ -z "${2:-}" ]; then
137
+ print_error "Query argument is required with --query option"
138
+ show_help
139
+ exit 1
140
+ fi
141
+ DEMO_ARGS+=("--query" "$2")
142
+ shift 2
143
+ ;;
144
+ -d|--debug)
145
+ DEMO_ARGS+=("--debug")
146
+ shift
147
+ ;;
148
+ --quiet)
149
+ DEMO_ARGS+=("--quiet")
150
+ shift
151
+ ;;
152
+ -i|--interactive)
153
+ # Interactive is default, no need to add args
154
+ shift
155
+ ;;
156
+ *)
157
+ # If it's not a flag, treat it as a query
158
+ if [[ "$1" != -* ]]; then
159
+ DEMO_ARGS+=("--query" "$1")
160
+ shift
161
+ else
162
+ print_error "Unknown option: $1"
163
+ show_help
164
+ exit 1
165
+ fi
166
+ ;;
167
+ esac
168
+ done
169
+
170
+ # Run the demo with collected arguments
171
+ run_demo "${DEMO_ARGS[@]}"
deepdiver_v2/config/config.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ import os
3
+ from typing import Optional, Dict, Any
4
+ from dataclasses import dataclass
5
+ import logging
6
+ from pathlib import Path
7
+ from dotenv import load_dotenv
8
+
9
+
10
+ load_dotenv()
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class APIConfig:
18
+ """Configuration class for API keys and settings"""
19
+
20
+ # Custom LLM Service Configuration
21
+ # Your own deployed LLM service accessed via requests
22
+ model_request_url: Optional[str] = None
23
+ model_request_token: Optional[str] = None
24
+ model_name: str = "pangu_auto" # Default model name
25
+
26
+ # Custom Planner Mode
27
+ planner_mode: str = "auto" # Default planner mode
28
+
29
+ # MCP Server Configuration
30
+ mcp_server_url: Optional[str] = None
31
+ mcp_auth_token: Optional[str] = None
32
+ mcp_use_stdio: bool = True # Default to stdio for backward compatibility
33
+
34
+ # Search Engine Configuration (Generic)
35
+ search_engine_base_url: Optional[str] = None
36
+ search_engine_api_keys: Optional[str] = None # Can be comma-separated for rotation
37
+
38
+ # URL Crawler Configuration (Generic)
39
+ url_crawler_base_url: Optional[str] = None
40
+ url_crawler_api_keys: Optional[str] = None # Can be comma-separated for rotation
41
+ url_crawler_max_tokens: int = 100000
42
+
43
+ # Model Interaction Configuration
44
+ model_temperature: float = 0.3
45
+ model_max_tokens: int = 8192
46
+ model_request_timeout: int = 180
47
+
48
+ # Tool Trajectory and Output Configuration
49
+ trajectory_storage_path: str = "./workspace"
50
+ report_output_path: str = "./report"
51
+ document_analysis_path: str = "./doc_analysis"
52
+
53
+ # Per-agent iteration controls (optional; resolved by agent factories)
54
+ planner_max_iterations: Optional[int] = None
55
+ information_seeker_max_iterations: Optional[int] = None
56
+ writer_max_iterations: Optional[int] = None
57
+
58
+ # General Settings
59
+ debug_mode: bool = False
60
+ max_retries: int = 3
61
+ timeout: int = 30
62
+
63
+ def __post_init__(self):
64
+ """Load configuration from environment variables"""
65
+ self.load_from_env()
66
+
67
+ def load_from_env(self):
68
+ """Load API keys and settings from environment variables"""
69
+ # Custom LLM Service
70
+ self.model_request_url = os.getenv('MODEL_REQUEST_URL')
71
+ self.model_request_token = os.getenv('MODEL_REQUEST_TOKEN')
72
+ self.model_name = os.getenv('MODEL_NAME', 'pangu-auto')
73
+
74
+ # Custom Planner Mode
75
+ self.planner_mode = os.getenv("PLANNER_MODE", self.planner_mode)
76
+
77
+ # MCP Server
78
+ self.mcp_server_url = os.getenv("MCP_SERVER_URL")
79
+ self.mcp_auth_token = os.getenv("MCP_AUTH_TOKEN")
80
+ self.mcp_use_stdio = os.getenv("MCP_USE_STDIO", "true").lower() == "true"
81
+
82
+ # Search Engine Configuration
83
+ self.search_engine_base_url = os.getenv("SEARCH_ENGINE_BASE_URL")
84
+ self.search_engine_api_keys = os.getenv("SEARCH_ENGINE_API_KEYS")
85
+
86
+ # URL Crawler Configuration
87
+ self.url_crawler_base_url = os.getenv("URL_CRAWLER_BASE_URL")
88
+ self.url_crawler_api_keys = os.getenv("URL_CRAWLER_API_KEYS")
89
+ self.url_crawler_max_tokens = int(os.getenv("URL_CRAWLER_MAX_TOKENS", self.url_crawler_max_tokens))
90
+
91
+ # Model Interaction Configuration
92
+ self.model_temperature = float(os.getenv("MODEL_TEMPERATURE", self.model_temperature))
93
+ self.model_max_tokens = int(os.getenv("MODEL_MAX_TOKENS", self.model_max_tokens))
94
+ self.model_request_timeout = int(os.getenv("MODEL_REQUEST_TIMEOUT", self.model_request_timeout))
95
+
96
+ # Tool Trajectory and Output Configuration
97
+ self.trajectory_storage_path = os.getenv("TRAJECTORY_STORAGE_PATH", self.trajectory_storage_path)
98
+ self.report_output_path = os.getenv("REPORT_OUTPUT_PATH", self.report_output_path)
99
+ self.document_analysis_path = os.getenv("DOCUMENT_ANALYSIS_PATH", self.document_analysis_path)
100
+
101
+ # Per-agent iteration controls
102
+ self.planner_max_iterations = (
103
+ int(os.getenv("PLANNER_MAX_ITERATION")) if os.getenv("PLANNER_MAX_ITERATION") else None
104
+ )
105
+ self.information_seeker_max_iterations = (
106
+ int(os.getenv("INFORMATION_SEEKER_MAX_ITERATION")) if os.getenv("INFORMATION_SEEKER_MAX_ITERATION") else None
107
+ )
108
+ self.writer_max_iterations = (
109
+ int(os.getenv("WRITER_MAX_ITERATION")) if os.getenv("WRITER_MAX_ITERATION") else None
110
+ )
111
+
112
+ # General Settings
113
+ self.debug_mode = os.getenv("DEBUG_MODE", "false").lower() == "true"
114
+ self.max_retries = int(os.getenv("MAX_RETRIES", self.max_retries))
115
+ self.timeout = int(os.getenv("TIMEOUT", self.timeout))
116
+
117
+ def get_custom_llm_config(self) -> Dict[str, Any]:
118
+ """Get configuration for custom LLM service"""
119
+ return {
120
+ "url": self.model_request_url,
121
+ "token": self.model_request_token,
122
+ "model": self.model_name,
123
+ "temperature": self.model_temperature,
124
+ "max_tokens": self.model_max_tokens,
125
+ "timeout": self.model_request_timeout,
126
+ "base_url": self.model_request_url # For backward compatibility with model_config.get('base_url')
127
+ }
128
+
129
+ def get_available_search_providers(self) -> list:
130
+ """Get list of available search providers based on API keys"""
131
+ providers = []
132
+ if self.search_engine_api_keys:
133
+ providers.append("custom")
134
+ return providers
135
+
136
+ def to_dict(self) -> Dict[str, Any]:
137
+ """Convert config to dictionary (excluding sensitive data)"""
138
+ config_dict = {}
139
+ for key, value in self.__dict__.items():
140
+ if "api_key" in key.lower() or "password" in key.lower():
141
+ config_dict[key] = "***" if value else None
142
+ else:
143
+ config_dict[key] = value
144
+ return config_dict
145
+
146
+ # Global configuration instance
147
+ config = APIConfig()
148
+
149
+
150
+ def get_config() -> APIConfig:
151
+ """Get the global configuration instance"""
152
+ return config
153
+
154
+
155
+ def reload_config():
156
+ """Reload configuration from environment variables"""
157
+ global config
158
+ config = APIConfig()
159
+ logger.info("Configuration reloaded")
160
+
161
+
162
+ def validate_api_key(api_key: Optional[str], service_name: str) -> bool:
163
+ """Validate that an API key is present and not empty"""
164
+ if not api_key or api_key.strip() == "":
165
+ logger.error(f"Missing or empty API key for {service_name}")
166
+ return False
167
+ return True
168
+
169
+
170
+ def get_url_crawler_config() -> Dict[str, Any]:
171
+ """Get generic URL crawler configuration"""
172
+ api_keys = config.url_crawler_api_keys
173
+ base_url = config.url_crawler_base_url
174
+
175
+ if not api_keys:
176
+ return {}
177
+
178
+ # Parse comma-separated API keys for rotation
179
+ api_key_list = [key.strip() for key in api_keys.split(",")] if isinstance(api_keys, str) else [api_keys]
180
+
181
+ return {
182
+ "api_keys": api_key_list,
183
+ "base_url": base_url,
184
+ "max_tokens": config.url_crawler_max_tokens,
185
+ "timeout": config.timeout
186
+ }
187
+
188
+
189
+ def get_search_engine_config() -> Dict[str, Any]:
190
+ """Get generic search engine configuration"""
191
+ api_keys = config.search_engine_api_keys
192
+ base_url = config.search_engine_base_url
193
+
194
+ if not api_keys:
195
+ return {}
196
+
197
+ # Parse comma-separated API keys for rotation
198
+ api_key_list = [key.strip() for key in api_keys.split(",")] if isinstance(api_keys, str) else [api_keys]
199
+
200
+ return {
201
+ "api_keys": api_key_list,
202
+ "base_url": base_url,
203
+ "timeout": config.timeout
204
+ }
205
+
206
+
207
+ def get_model_config() -> Dict[str, Any]:
208
+ """Get model interaction configuration for custom LLM service"""
209
+ return config.get_custom_llm_config()
210
+
211
+
212
+ def get_storage_config() -> Dict[str, Any]:
213
+ """Get storage and trajectory configuration"""
214
+ return {
215
+ "trajectory_storage_path": config.trajectory_storage_path,
216
+ "report_output_path": config.report_output_path,
217
+ "document_analysis_path": config.document_analysis_path
218
+ }
219
+
220
+
221
+ def get_mcp_config() -> Dict[str, Any]:
222
+ """Get MCP server specific configuration"""
223
+ return {
224
+ "server_url": config.mcp_server_url,
225
+ "auth_token": config.mcp_auth_token,
226
+ "use_stdio": config.mcp_use_stdio,
227
+ "timeout": config.timeout
228
+ }
229
+
230
+
231
+ # Example usage and testing
232
+ if __name__ == "__main__":
233
+ print("=== Multi Agent System Configuration ===")
234
+ print(f"Debug Mode: {config.debug_mode}")
235
+ print(f"Custom LLM Service URL: {config.model_request_url}")
236
+ print(f"Available Search Providers: {config.get_available_search_providers()}")
237
+ print("\nConfiguration Summary:")
238
+ for key, value in config.to_dict().items():
239
+ print(f" {key}: {value}")
deepdiver_v2/env.template ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ===================================
2
+ # DeepDiver Configuration Template
3
+ # ===================================
4
+ # Copy this file to .env and fill in your values
5
+
6
+ # Custom LLM Service Configuration
7
+ # Your own deployed LLM service endpoint
8
+
9
+ MODEL_REQUEST_URL=
10
+ MODEL_REQUEST_TOKEN=
11
+ MODEL_NAME=pangu_auto
12
+ PLANNER_MAX_ITERATION=40
13
+ INFORMATION_SEEKER_MAX_ITERATION=30
14
+ WRITER_MAX_ITERATION=40
15
+ PLANNER_MODE=auto # auto, writing, qa
16
+
17
+ # Model Interaction Settings
18
+ MODEL_TEMPERATURE=0.6
19
+ MODEL_MAX_TOKENS=8192
20
+ MODEL_REQUEST_TIMEOUT=180
21
+
22
+ # MCP Server Configuration
23
+ MCP_SERVER_URL=http://localhost:6274/mcp
24
+ MCP_AUTH_TOKEN=
25
+ MCP_USE_STDIO=false
26
+
27
+ # Search Engine Configuration
28
+ SEARCH_ENGINE_BASE_URL=
29
+ SEARCH_ENGINE_API_KEYS=
30
+
31
+ # URL Crawler Configuration
32
+ URL_CRAWLER_BASE_URL=
33
+ URL_CRAWLER_API_KEYS=
34
+ URL_CRAWLER_MAX_TOKENS=100000
35
+
36
+ # Tool Trajectory and Output Paths
37
+ TRAJECTORY_STORAGE_PATH=./workspace
38
+ REPORT_OUTPUT_PATH=./report
39
+ DOCUMENT_ANALYSIS_PATH=./doc_analysis
40
+
41
+ # General Settings
42
+ DEBUG_MODE=false
43
+ MAX_RETRIES=3
44
+ TIMEOUT=30
deepdiver_v2/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ beautifulsoup4==4.13.5
2
+ httpx[http2]==0.28.1
3
+ python-dotenv==1.1.1
4
+ python_dateutil==2.9.0.post0
5
+ Requests==2.32.5
6
+ rich==14.1.0
7
+ starlette==0.47.3
8
+ uvicorn==0.35.0
deepdiver_v2/src/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ """
3
+ DeepDiver Multi-Agent System
4
+
5
+ A comprehensive multi-agent system with MCP integration, local workspace management,
6
+ and advanced knowledge management capabilities.
7
+ """
8
+
9
+ __version__ = "2.0.0"
10
+ __author__ = "DeepDiver Team"
11
+ __description__ = "Multi-Agent System with MCP and Local Workspace Integration"
deepdiver_v2/src/agents/__init__.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ """
3
+ Multi-Agent System - Agent Module
4
+
5
+ This module provides the core agents for the multi-agent system:
6
+ - BaseAgent: Abstract base class with common functionality
7
+ - InformationSeekerAgent: Research and information gathering
8
+ - WriterAgent: Content creation and writing
9
+ - PlannerAgent: Top-level orchestrator
10
+
11
+ All agents follow the ReAct pattern and use standardized TaskInput format.
12
+ """
13
+
14
+ from .base_agent import (
15
+ BaseAgent,
16
+ AgentConfig,
17
+ AgentResponse,
18
+ TaskInput,
19
+ create_agent_config
20
+ )
21
+
22
+ from .subjective_information_seeker import (
23
+ InformationSeekerAgent,
24
+ create_subjective_information_seeker
25
+ )
26
+
27
+ from .objective_information_seeker import (
28
+ InformationSeekerAgent,
29
+ create_objective_information_seeker
30
+ )
31
+
32
+ from .writer_agent import (
33
+ WriterAgent,
34
+ create_writer_agent
35
+ )
36
+
37
+ from .planner_agent import (
38
+ PlannerAgent,
39
+ create_planner_agent
40
+ )
41
+
42
+ __all__ = [
43
+ # Base classes
44
+ "BaseAgent",
45
+ "AgentConfig",
46
+ "AgentResponse",
47
+ "TaskInput",
48
+ "create_agent_config",
49
+
50
+ # Specific agents
51
+ "InformationSeekerAgent",
52
+ "create_subjective_information_seeker",
53
+ "create_objective_information_seeker",
54
+ "WriterAgent",
55
+ "create_writer_agent",
56
+ "PlannerAgent",
57
+ "create_planner_agent"
58
+ ]
59
+
60
+ # Version info
61
+ __version__ = "0.1.0"
62
+ __author__ = "DeepDiver Multi-Agent System"
deepdiver_v2/src/agents/base_agent.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ import logging
3
+ import time
4
+ from abc import ABC, abstractmethod
5
+ from typing import Dict, Any, List, Optional
6
+ from dataclasses import dataclass, field
7
+
8
+ # Import MCP client availability flag without binding unused symbols
9
+ try:
10
+ from ..tools import mcp_client as _mcp_client_module # noqa: F401
11
+ MCP_CLIENT_AVAILABLE = True
12
+ except ImportError:
13
+ MCP_CLIENT_AVAILABLE = False
14
+
15
+
16
+ @dataclass
17
+ class AgentConfig:
18
+ """Configuration for agents - session management handled entirely by MCP server"""
19
+ agent_name: str = "base_agent"
20
+ planner_mode: str = "auto"
21
+ model: Optional[str] = None
22
+ max_iterations: int = 10
23
+ temperature: Optional[float] = None
24
+ max_tokens: Optional[int] = None
25
+ # Paths used by writer and other agents
26
+ trajectory_storage_path: Optional[str] = None
27
+ report_output_path: Optional[str] = None
28
+ document_analysis_path: Optional[str] = None
29
+
30
+
31
+ @dataclass
32
+ class AgentResponse:
33
+ """Standardized response format for all agents"""
34
+ success: bool
35
+ result: Optional[Dict[str, Any]] = None
36
+ error: Optional[str] = None
37
+ iterations: int = 0
38
+ reasoning_trace: List[Dict[str, Any]] = field(default_factory=list)
39
+ agent_name: str = ""
40
+ execution_time: float = 0.0
41
+
42
+
43
+ @dataclass
44
+ class TaskInput:
45
+ """Standardized task input format for all agents"""
46
+ task_content: str # The specific task content
47
+ task_steps_for_reference: Optional[str] = None # Reference steps for execution
48
+ deliverable_contents: Optional[str] = None # Format of final deliverable
49
+ current_task_status: Optional[str] = None # Description of current task status
50
+ task_executor: str = "info_seeker" # Name of task executor (info_seeker, writer)
51
+ workspace_id: Optional[str] = None # Workspace ID for stored files and memory
52
+ acceptance_checking_criteria: Optional[str] = None # Criteria for determining task completion and quality
53
+
54
+ def to_dict(self) -> Dict[str, Any]:
55
+ """Convert TaskInput to dictionary format"""
56
+ return {
57
+ "task_content": self.task_content,
58
+ "task_steps_for_reference": self.task_steps_for_reference,
59
+ "deliverable_contents": self.deliverable_contents,
60
+ "current_task_status": self.current_task_status,
61
+ "task_executor": self.task_executor,
62
+ "workspace_id": self.workspace_id,
63
+ "acceptance_checking_criteria": self.acceptance_checking_criteria
64
+ }
65
+
66
+ @classmethod
67
+ def from_dict(cls, data: Dict[str, Any]) -> 'TaskInput':
68
+ """Create TaskInput from dictionary"""
69
+ return cls(
70
+ task_content=data.get("task_content", ""),
71
+ task_steps_for_reference=data.get("task_steps_for_reference"),
72
+ deliverable_contents=data.get("deliverable_contents"),
73
+ current_task_status=data.get("current_task_status"),
74
+ task_executor=data.get("task_executor", "info_seeker"),
75
+ workspace_id=data.get("workspace_id"),
76
+ acceptance_checking_criteria=data.get("acceptance_checking_criteria")
77
+ )
78
+
79
+ def format_for_prompt(self) -> str:
80
+ """Format the task input for use in prompts"""
81
+ prompt = f"Task Content:\n{self.task_content}\n\n"
82
+
83
+ if self.task_steps_for_reference:
84
+ prompt += f"Task Steps for Reference:\n{self.task_steps_for_reference}\n\n"
85
+
86
+ if self.deliverable_contents:
87
+ prompt += f"Deliverable Contents:\n{self.deliverable_contents}\n\n"
88
+
89
+ if self.current_task_status:
90
+ prompt += f"Current Task Status:\n{self.current_task_status}\n\n"
91
+
92
+ if self.acceptance_checking_criteria:
93
+ prompt += f"Acceptance Checking Criteria:\n{self.acceptance_checking_criteria}\n\n"
94
+
95
+ prompt += f"Task Executor: {self.task_executor}\n"
96
+
97
+ if self.workspace_id:
98
+ prompt += f"Workspace ID: {self.workspace_id}\n"
99
+
100
+ return prompt
101
+
102
+
103
+ class SectionWriterTaskInput(TaskInput):
104
+ """
105
+ Specialized TaskInput for section writing tasks
106
+
107
+ Only stores the essential parameters. The section_writer agent
108
+ will handle prompt assembly internally.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ task_content: str,
114
+ user_query: str,
115
+ write_file_path: str,
116
+ overall_outline: str,
117
+ current_chapter_outline: str,
118
+ key_files: List[Dict[str, Any]],
119
+ written_chapters: str = "",
120
+ workspace_id: Optional[str] = None
121
+ ):
122
+ # Store the section writer specific parameters
123
+ self.write_file_path = write_file_path
124
+ self.user_query = user_query
125
+ self.current_chapter_outline = current_chapter_outline
126
+ self.key_files = key_files
127
+ self.written_chapters = written_chapters
128
+ self.overall_outline = overall_outline
129
+
130
+ # Initialize parent TaskInput with minimal required fields
131
+ super().__init__(
132
+ task_content=task_content,
133
+ task_executor="section_writer",
134
+ workspace_id=workspace_id,
135
+ )
136
+
137
+
138
+ class WriterAgentTaskInput(TaskInput):
139
+ """
140
+ Specialized TaskInput for section writing tasks
141
+
142
+ Only stores the 4 essential parameters. The section_writer agent
143
+ will handle prompt assembly internally.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ task_content: str,
149
+ user_query: str,
150
+ key_files: List[Dict[str, Any]],
151
+ workspace_id: Optional[str] = None
152
+ ):
153
+ # Store the section writer specific parameters
154
+ self.user_query = user_query
155
+ self.key_files = key_files
156
+
157
+ # Initialize parent TaskInput with minimal required fields
158
+ super().__init__(
159
+ task_content=task_content,
160
+ task_executor="writer_agent",
161
+ workspace_id=workspace_id,
162
+ )
163
+
164
+
165
+ class BaseAgent(ABC):
166
+ """
167
+ Base class for all agents with MCP server-managed sessions.
168
+
169
+ Session management is now entirely handled by the MCP server:
170
+ - Server assigns session IDs on connection
171
+ - Server creates workspace folders with UUID names
172
+ - All tool operations are performed in server-managed workspaces
173
+ """
174
+
175
+ def __init__(self, config: AgentConfig, shared_mcp_client=None):
176
+ self.execution_stats = None
177
+ self.reasoning_trace = None
178
+ self.config = config
179
+ self.logger = logging.getLogger(f"{__name__}.{config.agent_name}")
180
+
181
+ # Session info is populated by the MCP server
182
+ self.session_info = None
183
+
184
+ # Tool management
185
+ self.mcp_tools = None
186
+ self.available_tools = {}
187
+
188
+ self.reset_trace()
189
+
190
+ # Initialize MCP tools (server will handle session creation or use shared client)
191
+ self._initialize(shared_mcp_client)
192
+
193
+ def _initialize(self, shared_mcp_client=None):
194
+ """Initialize agent with MCP server connection or shared client"""
195
+ try:
196
+ self.logger.info(f"Initializing agent {self.config.agent_name}")
197
+
198
+ if shared_mcp_client:
199
+ # Use shared MCP client with agent-specific tool filtering
200
+ agent_type = self._get_agent_type()
201
+ self.mcp_tools = self._create_filtered_mcp_tools(shared_mcp_client, agent_type)
202
+ self.logger.info(f"Agent {self.config.agent_name} using shared MCP client with {agent_type} tools")
203
+ else:
204
+ # Create MCP tools with agent-specific filtering (no more unfiltered access)
205
+ self.mcp_tools = self._create_filtered_mcp_tools_standalone()
206
+
207
+ # Discover available tools
208
+ self.available_tools = self._discover_mcp_tools()
209
+
210
+ # Build tool schemas for function calling
211
+ self.tool_schemas = self._build_tool_schemas()
212
+
213
+ self.logger.info(f"Agent {self.config.agent_name} initialized successfully")
214
+ self.logger.info(f"Available tools: {list(self.available_tools.keys())}")
215
+
216
+ except Exception as e:
217
+ self.logger.error(f"Failed to initialize agent {self.config.agent_name}: {e}")
218
+ raise
219
+
220
+ def _discover_mcp_tools(self) -> Dict[str, Any]:
221
+ """Discover available tools from MCP server or fallback tools"""
222
+ available_tools = {}
223
+
224
+ # Try to get tools from MCP client first
225
+ if hasattr(self.mcp_tools, 'get_available_tools'):
226
+ try:
227
+ mcp_tools_dict = self.mcp_tools.get_available_tools()
228
+ for tool_name, tool_info in mcp_tools_dict.items():
229
+ # For proper MCP architecture, store tool info for direct client calls
230
+ # instead of creating wrapper lambda functions
231
+ available_tools[tool_name] = tool_info
232
+
233
+ if available_tools:
234
+ self.logger.info(f"Discovered {len(available_tools)} tools from MCP server")
235
+ return available_tools
236
+ except Exception as e:
237
+ self.logger.warning(f"Failed to discover MCP tools: {e}")
238
+
239
+ # Fallback: if MCP client not available, use direct method access
240
+ # This should rarely be needed with proper MCP setup
241
+ if hasattr(self.mcp_tools, '__dict__'):
242
+ for attr_name in dir(self.mcp_tools):
243
+ if not attr_name.startswith('_') and callable(getattr(self.mcp_tools, attr_name)):
244
+ available_tools[attr_name] = getattr(self.mcp_tools, attr_name)
245
+
246
+ return available_tools
247
+
248
+ def _get_agent_type(self) -> str:
249
+ """Get agent type for tool filtering"""
250
+ agent_name = self.config.agent_name.lower()
251
+ if "planner" in agent_name:
252
+ return "planner"
253
+ elif "information" in agent_name or "seeker" in agent_name:
254
+ return "information_seeker"
255
+ elif "writer" in agent_name:
256
+ return "writer"
257
+ else:
258
+ # Default to planner tools for unknown agent types
259
+ return "planner"
260
+
261
+ def _create_filtered_mcp_tools(self, shared_client, agent_type: str):
262
+ """Create filtered MCP tools adapter using shared client"""
263
+ try:
264
+ from src.tools.mcp_client import create_filtered_mcp_tools_adapter
265
+ return create_filtered_mcp_tools_adapter(shared_client, agent_type)
266
+ except ImportError:
267
+ # Fallback if FilteredMCPToolsAdapter not available
268
+ self.logger.warning("FilteredMCPToolsAdapter not available, using regular adapter")
269
+ from src.tools.mcp_client import MCPToolsAdapter
270
+ adapter = MCPToolsAdapter.__new__(MCPToolsAdapter)
271
+ adapter.client = shared_client
272
+ return adapter
273
+
274
+ def _create_filtered_mcp_tools_standalone(self):
275
+ """Create filtered MCP tools adapter with its own client connection"""
276
+ try:
277
+ # Get agent type for filtering
278
+ agent_type = self._get_agent_type()
279
+
280
+ # Create a new MCP client
281
+ client = self._create_new_mcp_client()
282
+
283
+ # Apply filtering based on agent type
284
+ from src.tools.mcp_client import create_filtered_mcp_tools_adapter
285
+ filtered_adapter = create_filtered_mcp_tools_adapter(client, agent_type)
286
+
287
+ self.logger.info(f"Agent {self.config.agent_name} created filtered MCP adapter with {agent_type} tools")
288
+ return filtered_adapter
289
+
290
+ except Exception as e:
291
+ self.logger.error(f"Failed to create filtered MCP tools: {e}")
292
+ raise RuntimeError(f"Failed to create filtered MCP client for {self.config.agent_name}: {e}")
293
+
294
+ def _create_new_mcp_client(self):
295
+ """Create a new MCP client connection"""
296
+ try:
297
+ # Get MCP configuration
298
+ from config.config import get_mcp_config
299
+ mcp_config = get_mcp_config()
300
+
301
+ # Create MCP client
302
+ from src.tools.mcp_client import MCPClient
303
+
304
+ if mcp_config.get("server_url") and not mcp_config.get("use_stdio", True):
305
+ # HTTP-based MCP server
306
+ client = MCPClient(server_url=mcp_config["server_url"])
307
+ self.logger.info(
308
+ f"Agent {self.config.agent_name} connected to HTTP MCP server: {mcp_config['server_url']}")
309
+ else:
310
+ # Default to the expected HTTP MCP server on port 6274
311
+ client = MCPClient(server_url="http://localhost:6274/mcp")
312
+ self.logger.info(
313
+ f"Agent {self.config.agent_name} connected to default HTTP MCP server: http://localhost:6274/mcp")
314
+
315
+ return client
316
+
317
+ except Exception as e:
318
+ self.logger.error(f"Failed to create MCP client: {e}")
319
+ raise RuntimeError(f"MCP client creation failed for {self.config.agent_name}: {e}")
320
+
321
+ # NOTE: _create_mcp_tools() method removed to prevent unfiltered tool access.
322
+ # All agents now use _create_filtered_mcp_tools_standalone() or _create_filtered_mcp_tools()
323
+ # to ensure proper tool isolation and security.
324
+
325
+ def get_session_info(self) -> Optional[Dict[str, Any]]:
326
+ """Get information about the current server-managed session"""
327
+ try:
328
+ # First try the adapter's get_session_info method if available
329
+ if hasattr(self.mcp_tools, 'get_session_info'):
330
+ session_info = self.mcp_tools.get_session_info()
331
+ if session_info:
332
+ # Add agent-specific information
333
+ session_info.update({
334
+ "server_managed": True,
335
+ "agent_name": self.config.agent_name
336
+ })
337
+ return session_info
338
+
339
+ # Fallback: Check if we have an MCP tools adapter with a client
340
+ if hasattr(self.mcp_tools, 'client'):
341
+ client = self.mcp_tools.client
342
+
343
+ # Check if client has session ID and connection status
344
+ if hasattr(client, '_session_id') and hasattr(client, 'is_connected'):
345
+ return {
346
+ "session_id": client._session_id,
347
+ "server_managed": True,
348
+ "agent_name": self.config.agent_name,
349
+ "connected": client.is_connected()
350
+ }
351
+
352
+ # Fallback: check if mcp_tools has session info directly
353
+ if hasattr(self.mcp_tools, '_session_id'):
354
+ return {
355
+ "session_id": self.mcp_tools._session_id,
356
+ "server_managed": True,
357
+ "agent_name": self.config.agent_name,
358
+ "connected": getattr(self.mcp_tools, 'is_connected', lambda: True)()
359
+ }
360
+
361
+ # If no session info available, return basic info
362
+ return {
363
+ "session_id": None,
364
+ "server_managed": True,
365
+ "agent_name": self.config.agent_name,
366
+ "connected": hasattr(self.mcp_tools, 'client') and getattr(self.mcp_tools.client, 'is_connected',
367
+ lambda: False)()
368
+ }
369
+
370
+ except Exception as e:
371
+ self.logger.warning(f"Failed to get session info: {e}")
372
+ return {
373
+ "session_id": None,
374
+ "server_managed": True,
375
+ "agent_name": self.config.agent_name,
376
+ "connected": False,
377
+ "error": str(e)
378
+ }
379
+
380
+ def _build_tool_schemas(self) -> List[Dict[str, Any]]:
381
+ """Build tool schemas for function calling"""
382
+ schemas = []
383
+
384
+ # Get agent-specific tool schemas
385
+ agent_schemas = self._build_agent_specific_tool_schemas()
386
+ schemas.extend(agent_schemas)
387
+
388
+ return schemas
389
+
390
+ def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
391
+ """
392
+ Build agent-specific tool schemas using proper MCP architecture.
393
+ Schemas come from MCP server via client, not direct imports.
394
+ """
395
+ schemas = []
396
+
397
+ # Proper MCP way: Get schemas from MCP client (which got them from server)
398
+ try:
399
+ if hasattr(self.mcp_tools, 'get_tool_schemas'):
400
+ # Use the MCP client to get schemas (proper MCP architecture)
401
+ schemas = self.mcp_tools.get_tool_schemas()
402
+ self.logger.info(f"Retrieved {len(schemas)} tool schemas from MCP server")
403
+ else:
404
+ # Fallback for adapters that don't have the new method yet
405
+ self.logger.warning("MCP adapter doesn't support get_tool_schemas, using fallback")
406
+ schemas = self._build_fallback_schemas()
407
+ except Exception as e:
408
+ self.logger.warning(f"Failed to get schemas from MCP client: {e}, using fallback")
409
+ schemas = self._build_fallback_schemas()
410
+
411
+ return schemas
412
+
413
+ def _build_fallback_schemas(self) -> List[Dict[str, Any]]:
414
+ """Fallback schema building if MCP client method fails"""
415
+ schemas = []
416
+
417
+ # Try to get tool info from MCP client
418
+ if hasattr(self.mcp_tools, 'get_available_tools'):
419
+ try:
420
+ available_tools = self.mcp_tools.get_available_tools()
421
+ for tool_name, tool_info in available_tools.items():
422
+ schema = {
423
+ "type": "function",
424
+ "function": {
425
+ "name": tool_name,
426
+ "description": getattr(tool_info, 'description', f"Tool: {tool_name}"),
427
+ "parameters": getattr(tool_info, 'input_schema', {"type": "object", "properties": {}, "required": []})
428
+ }
429
+ }
430
+ schemas.append(schema)
431
+ self.logger.info(f"Built {len(schemas)} schemas using fallback method")
432
+ except Exception as e:
433
+ self.logger.warning(f"Fallback schema building failed: {e}")
434
+
435
+ return schemas
436
+
437
+ def execute_tool_call(self, tool_call) -> Dict[str, Any]:
438
+ """Execute a tool call and return results using proper MCP architecture"""
439
+ tool_name = tool_call["name"]
440
+
441
+ try:
442
+ # Parse arguments
443
+ arguments = tool_call["arguments"]
444
+
445
+ # Check if tool is available
446
+ if tool_name not in self.available_tools:
447
+ return {
448
+ "success": False,
449
+ "error": f"Tool '{tool_name}' not available for this agent"
450
+ }
451
+
452
+ # Route tool execution based on tool type
453
+ # Built-in tools (like assign_task_to_*) are callable methods, not MCP server tools
454
+ if callable(self.available_tools[tool_name]):
455
+ # Built-in tool: execute locally
456
+ tool_function = self.available_tools[tool_name]
457
+ result = tool_function(**arguments)
458
+
459
+ # Convert result to standard format
460
+ if hasattr(result, 'to_dict'):
461
+ return result.to_dict()
462
+ elif isinstance(result, dict):
463
+ return result
464
+ else:
465
+ return {
466
+ "success": True,
467
+ "data": result,
468
+ "error": None,
469
+ "metadata": {}
470
+ }
471
+
472
+ elif hasattr(self.mcp_tools, 'client') and hasattr(self.mcp_tools.client, 'call_tool'):
473
+ # MCP server tool: execute via client
474
+ result = self.mcp_tools.client.call_tool(tool_name, arguments)
475
+
476
+ # Convert MCPClientResult to standard format
477
+ if hasattr(result, 'success'):
478
+ return {
479
+ "success": result.success,
480
+ "data": result.data,
481
+ "error": result.error,
482
+ "metadata": getattr(result, 'metadata', {})
483
+ }
484
+ else:
485
+ return result
486
+ else:
487
+ return {
488
+ "success": False,
489
+ "error": f"Tool '{tool_name}' is not executable (neither built-in nor MCP)"
490
+ }
491
+
492
+ except Exception as e:
493
+ self.logger.error(f"Error executing tool {tool_name}: {e}")
494
+ return {
495
+ "success": False,
496
+ "error": f"Tool execution failed: {str(e)}"
497
+ }
498
+
499
+ def log_reasoning(self, iteration: int, reasoning: str):
500
+ """Log reasoning step in the trace"""
501
+ self.reasoning_trace.append({
502
+ "type": "reasoning",
503
+ "iteration": iteration,
504
+ "content": reasoning,
505
+ "timestamp": time.time()
506
+ })
507
+ self.execution_stats["reasoning_steps"] += 1
508
+ self.execution_stats["total_steps"] += 1
509
+ self.logger.info(f"Reasoning (Iter {iteration}): {reasoning[:100]}...")
510
+
511
+ def log_action(self, iteration: int, tool: str, arguments: Dict[str, Any], result: Dict[str, Any]):
512
+ """Log action step in the trace"""
513
+ self.reasoning_trace.append({
514
+ "type": "action",
515
+ "iteration": iteration,
516
+ "tool": tool,
517
+ "arguments": arguments,
518
+ "result": result,
519
+ "timestamp": time.time()
520
+ })
521
+ self.execution_stats["action_steps"] += 1
522
+ self.execution_stats["total_steps"] += 1
523
+
524
+ # Log success/failure
525
+ success = result.get("success", True)
526
+ status = "Success" if success else "Failed"
527
+ self.logger.info(f"Action (Iter {iteration}): {tool} -> {status} -> {str(arguments)[:400]}...")
528
+
529
+ def log_error(self, iteration: int, error: str):
530
+ """Log error in the trace"""
531
+ self.reasoning_trace.append({
532
+ "type": "error",
533
+ "iteration": iteration,
534
+ "error": error,
535
+ "timestamp": time.time()
536
+ })
537
+ self.execution_stats["error_steps"] += 1
538
+ self.execution_stats["total_steps"] += 1
539
+ self.logger.error(f"Error (Iter {iteration}): {error}")
540
+
541
+ def reset_trace(self):
542
+ """Reset the reasoning trace for a new task"""
543
+ self.reasoning_trace = []
544
+ self.execution_stats = {
545
+ "total_steps": 0,
546
+ "reasoning_steps": 0,
547
+ "action_steps": 0,
548
+ "error_steps": 0,
549
+ "tool_usage": {},
550
+ "success_rate": 1.0
551
+ }
552
+
553
+ def get_execution_stats(self) -> Dict[str, Any]:
554
+ """Get execution statistics"""
555
+ # Calculate success rate
556
+ if self.execution_stats["action_steps"] > 0:
557
+ failed_actions = sum(1 for step in self.reasoning_trace
558
+ if step.get("type") == "action"
559
+ and not step.get("result", {}).get("success", True))
560
+ self.execution_stats["success_rate"] = (
561
+ (self.execution_stats["action_steps"] - failed_actions) /
562
+ self.execution_stats["action_steps"]
563
+ )
564
+
565
+ return self.execution_stats.copy()
566
+
567
+ def create_response(self, success: bool, result: Dict[str, Any] = None,
568
+ error: str = None, iterations: int = 0,
569
+ execution_time: float = 0.0) -> AgentResponse:
570
+ """Create a standardized agent response"""
571
+ return AgentResponse(
572
+ success=success,
573
+ result=result,
574
+ error=error,
575
+ iterations=iterations,
576
+ reasoning_trace=self.reasoning_trace.copy(),
577
+ agent_name=self.config.agent_name,
578
+ execution_time=execution_time
579
+ )
580
+
581
+ def validate_config(self) -> bool:
582
+ """Validate agent configuration"""
583
+ try:
584
+ # Check required fields
585
+ if not self.config.agent_name:
586
+ return False
587
+ if not self.config.model:
588
+ return False
589
+ if self.config.max_iterations <= 0:
590
+ return False
591
+ if not (0.0 <= self.config.temperature <= 2.0):
592
+ return False
593
+ if self.config.max_tokens <= 0:
594
+ return False
595
+
596
+ return True
597
+ except Exception:
598
+ return False
599
+
600
+ @abstractmethod
601
+ def execute_task(self, task_input: TaskInput) -> AgentResponse:
602
+ """
603
+ Execute a task using the standardized TaskInput format
604
+
605
+ Args:
606
+ task_input: TaskInput object with standardized task information
607
+
608
+ Returns:
609
+ AgentResponse with results and process trace
610
+ """
611
+ pass
612
+
613
+ @abstractmethod
614
+ def _build_system_prompt(self) -> str:
615
+ """Build the system prompt for this agent"""
616
+ pass
617
+
618
+
619
+ # Simple factory function for creating agent configurations
620
+
621
+ def create_agent_config(
622
+ agent_name: str,
623
+ model: Optional[str] = None,
624
+ max_iterations: Optional[int] = None,
625
+ temperature: Optional[float] = None,
626
+ max_tokens: Optional[int] = None
627
+ ) -> AgentConfig:
628
+ """
629
+ Create an AgentConfig instance for server-managed sessions.
630
+
631
+ Args:
632
+ agent_name: Name of the agent
633
+ model: LLM model to use
634
+ max_iterations: Maximum number of iterations
635
+ temperature: LLM temperature setting
636
+ max_tokens: Maximum tokens for LLM response
637
+
638
+ Returns:
639
+ Configured AgentConfig instance
640
+ """
641
+ # Load env-backed defaults
642
+ try:
643
+ from config.config import get_config
644
+ api_cfg = get_config()
645
+ except Exception as e:
646
+ raise ValueError(f"Failed to load global configuration: {e}")
647
+
648
+ planner_mode = getattr(api_cfg, "planner_mode", "auto")
649
+
650
+ resolved_model = model if model is not None else getattr(api_cfg, "model_name", None)
651
+ if not resolved_model:
652
+ raise ValueError("Model is not specified and MODEL_NAME is not set in environment")
653
+
654
+ resolved_temperature = temperature if temperature is not None else getattr(api_cfg, "model_temperature", None)
655
+ if resolved_temperature is None:
656
+ raise ValueError("Temperature is not specified and MODEL_TEMPERATURE is not set in environment")
657
+
658
+ resolved_max_tokens = max_tokens if max_tokens is not None else getattr(api_cfg, "model_max_tokens", None)
659
+ if resolved_max_tokens is None:
660
+ raise ValueError("Max tokens is not specified and MODEL_MAX_TOKENS is not set in environment")
661
+
662
+ # Optional paths used by writer and others
663
+ trajectory_storage_path = getattr(api_cfg, "trajectory_storage_path", None)
664
+ report_output_path = getattr(api_cfg, "report_output_path", None)
665
+ document_analysis_path = getattr(api_cfg, "document_analysis_path", None)
666
+
667
+ # Resolve max_iterations per agent type
668
+ if max_iterations is None:
669
+ agent_lower = (agent_name or "").lower()
670
+ resolved_max_iterations = None
671
+ if "planner" in agent_lower:
672
+ resolved_max_iterations = getattr(api_cfg, "planner_max_iterations", None)
673
+ elif "writer" in agent_lower:
674
+ resolved_max_iterations = getattr(api_cfg, "writer_max_iterations", None)
675
+ elif "information" in agent_lower or "seeker" in agent_lower:
676
+ resolved_max_iterations = getattr(api_cfg, "information_seeker_max_iterations", None)
677
+ # if not found in env, raise
678
+ if resolved_max_iterations is None:
679
+ raise ValueError("Max iterations not specified and no env override (PLANNER_MAX_ITERATION/WRITER_MAX_ITERATION/INFORMATION_SEEKER_MAX_ITERATION)")
680
+ max_iterations = resolved_max_iterations
681
+
682
+ return AgentConfig(
683
+ agent_name=agent_name,
684
+ planner_mode=planner_mode,
685
+ model=resolved_model,
686
+ max_iterations=int(max_iterations),
687
+ temperature=resolved_temperature,
688
+ max_tokens=resolved_max_tokens,
689
+ trajectory_storage_path=trajectory_storage_path,
690
+ report_output_path=report_output_path,
691
+ document_analysis_path=document_analysis_path
692
+ )
deepdiver_v2/src/agents/objective_information_seeker.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ import json
3
+ from typing import Dict, Any, List
4
+ import time
5
+ import requests
6
+ import os
7
+ from .base_agent import BaseAgent, AgentConfig, AgentResponse, TaskInput
8
+
9
+
10
+
11
+ class InformationSeekerAgent(BaseAgent):
12
+ """
13
+ Information Seeker Agent that follows ReAct pattern (Reasoning + Acting)
14
+
15
+ This agent takes decomposed sub-questions or tasks from parent agents,
16
+ thinks interleaved (reasoning -> action -> reasoning -> action),
17
+ uses MCP tools to gather information, and returns structured results.
18
+ """
19
+
20
+ def __init__(self, config: AgentConfig = None, shared_mcp_client=None):
21
+ # Set default agent name if not specified
22
+ if config is None:
23
+ config = AgentConfig(agent_name="InformationSeekerAgent")
24
+ elif config.agent_name == "base_agent":
25
+ config.agent_name = "InformationSeekerAgent"
26
+
27
+ super().__init__(config, shared_mcp_client)
28
+
29
+ def _build_system_prompt(self) -> str:
30
+ """Build the system prompt for the ReAct agent"""
31
+ tool_schemas_str = json.dumps(self.tool_schemas, ensure_ascii=False)
32
+ system_prompt_template = """You are an Information Seeker Agent that follows the ReAct pattern (Reasoning + Acting).
33
+
34
+ Your role is to:
35
+ 1. Take decomposed sub-questions or tasks from parent agents
36
+ 2. Think step-by-step through reasoning
37
+ 3. Use available tools to gather information when needed
38
+ 4. Continue reasoning based on tool results
39
+ 5. Repeat this process until you have sufficient information
40
+ 6. Call info_seeker_objective_task_done to provide a structured summary
41
+
42
+ ### Optimized Workflow:
43
+ Follow this optimized workflow for information gathering:
44
+
45
+ 1. INITIAL RESEARCH:
46
+ - Use `batch_web_search` to find relevant URLs for your queries. When calling the search statement, consider the language of the user's question. For example, for a Chinese question, generate a part of the search statement in Chinese.
47
+ - Analyze the search results (titles, snippets, URLs) to identify promising sources
48
+
49
+ 2. CONTENT EXTRACTION:
50
+ - For important URLs, use `url_crawler` to:
51
+ a) Extract full content from the webpage
52
+ b) Save the content to a file in the workspace
53
+ - Store results with meaningful file paths (e.g., \"research/ai_trends_2024.txt\")
54
+
55
+ 3. CONTENT ANALYSIS:
56
+ - Use `document_qa` to ask specific questions about the saved files:
57
+ a) Formulate focused questions to extract key insights
58
+ b) Use answers to deepen your understanding
59
+ - You can ask multiple questions about the same file
60
+
61
+ 4. FILE MANAGEMENT:
62
+ - Use `file_write` to save important findings or summaries
63
+ - For reviewing saved content:
64
+ a) Prefer `document_qa` to ask specific questions about the content
65
+ b) Use `file_read` ONLY for small files (<1000 tokens) when you need the entire content
66
+ c) Avoid reading large files directly as it may exceed context limits
67
+
68
+ 5. TASK COMPLETION:
69
+ - When ready to report, call `info_seeker_objective_task_done` with:
70
+ a) Comprehensive markdown summary of your process and findings
71
+ b) List of key files created with descriptions
72
+
73
+ ### Usage of Systematic Tool:
74
+ - `think` is a systematic tool. After receiving the response from the complex tool or before invoking any other tools, you must **first invoke the `think` tool**: to deeply reflect on the results of previous tool invocations (if any), and to thoroughly consider and plan the user's task. The `think` tool does not acquire new information; it only saves your thoughts into memory.
75
+ - `reflect` is a systematic tool. When encountering a failure in tool execution, it is necessary to invoke the reflect tool to conduct a review and revise the task plan. It does not acquire new information; it only saves your thoughts into memory.
76
+
77
+ Always provide clear reasoning for your actions and synthesize information effectively.
78
+
79
+ Below, within the <tools></tools> tags, are the descriptions of each tool and the required fields for invocation:
80
+ <tools>
81
+ $tool_schemas
82
+ </tools>
83
+ For each function call, return a JSON object placed within the [unused11][unused12] tags, which includes the function name and the corresponding function arguments:
84
+ [unused11][{\"name\": <function name>, \"arguments\": <args json object>}][unused12]
85
+ """
86
+ return system_prompt_template.replace("$tool_schemas", tool_schemas_str)
87
+
88
+ @staticmethod
89
+ def _build_initial_message_from_task_input(task_input: TaskInput) -> str:
90
+ """Build the initial user message from TaskInput"""
91
+ message = task_input.format_for_prompt()
92
+
93
+ message += "\nPlease analyze this task and start your ReAct process:\n"
94
+ message += "1. Reason about what information you need to gather\n"
95
+ message += "2. Use appropriate tools to get that information\n"
96
+ message += "3. Continue reasoning and acting until you have sufficient information\n"
97
+ message += "4. Call task_done when ready to provide your complete findings\n\n"
98
+ message += "Begin with your initial reasoning about the task."
99
+
100
+ return message
101
+
102
+ def execute_task(self, task_input: TaskInput) -> AgentResponse:
103
+ """
104
+ Execute a task using ReAct pattern (Reasoning + Acting)
105
+
106
+ Args:
107
+ task_input: TaskInput object with standardized task information
108
+
109
+ Returns:
110
+ AgentResponse with results and process trace
111
+ """
112
+ start_time = time.time()
113
+
114
+ try:
115
+ self.logger.info(f"Starting information seeker task: {task_input.task_content}")
116
+
117
+ # Reset trace for new task
118
+ self.reset_trace()
119
+
120
+ # Initialize conversation history
121
+ conversation_history = []
122
+
123
+ # Build initial system prompt for ReAct
124
+ system_prompt = self._build_system_prompt()
125
+
126
+ # Build initial user message from TaskInput
127
+ user_message = self._build_initial_message_from_task_input(task_input)
128
+
129
+
130
+ # Add to conversation
131
+ conversation_history.append({"role": "system", "content": system_prompt})
132
+ conversation_history.append({"role": "user", "content": user_message+" /no_think"})
133
+
134
+
135
+ iteration = 0
136
+ task_completed = False
137
+ # Get model endpoint configuration from env-backed config
138
+ from config.config import get_config
139
+ config = get_config()
140
+ model_config = config.get_custom_llm_config()
141
+
142
+ pangu_url = model_config.get('url') or os.getenv('MODEL_REQUEST_URL', '')
143
+ model_token = model_config.get('token') or os.getenv('MODEL_REQUEST_TOKEN', '')
144
+ headers = {'Content-Type': 'application/json', 'csb-token': model_token}
145
+
146
+ # ReAct Loop: Reasoning -> Acting -> Reasoning -> Acting...
147
+ while iteration < self.config.max_iterations and not task_completed:
148
+ iteration += 1
149
+ self.logger.info(f"Planning iteration {iteration}")
150
+
151
+ try:
152
+ # Get LLM response (reasoning + potential tool calls)
153
+ retry_num = 1
154
+ max_retry_num = 10
155
+ while retry_num < max_retry_num:
156
+ try:
157
+ response = requests.post(
158
+ url=pangu_url,
159
+ headers=headers,
160
+ json={
161
+ "model": self.config.model,
162
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<s>[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{'<s>[unused9]系统:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}",
163
+ "messages": conversation_history,
164
+ "temperature": self.config.temperature,
165
+ "spaces_between_special_tokens": False,
166
+ "max_tokens": self.config.max_tokens,
167
+ },
168
+ timeout=model_config.get("timeout", 180)
169
+ )
170
+ response = response.json()
171
+ self.logger.debug(f"API response received")
172
+ break
173
+ except Exception as e:
174
+ time.sleep(3)
175
+ retry_num += 1
176
+ if retry_num == max_retry_num:
177
+ raise ValueError(str(e))
178
+ continue
179
+
180
+ assistant_message = response["choices"][0]["message"]
181
+
182
+ # Log the reasoning
183
+ try:
184
+ if assistant_message["content"]:
185
+ reasoning_content = assistant_message["content"].split("[unused16]")[-1].split("[unused17]")[0]
186
+ if len(reasoning_content) > 0:
187
+ self.log_reasoning(iteration, reasoning_content)
188
+ except Exception as e:
189
+ self.logger.warning(f"Tool call parsing error: {e}")
190
+ # Parse error, rerun
191
+ followup_prompt = f"There is a problem with the format of model generation: {e}. Please try again."
192
+ conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
193
+ continue
194
+
195
+ def extract_tool_calls(content):
196
+ import re
197
+ if not content:
198
+ return []
199
+ tool_call_str = re.findall(r"\[unused11\]([\s\S]*?)\[unused12\]", content)
200
+ if len(tool_call_str) > 0:
201
+ try:
202
+ tool_calls = json.loads(tool_call_str[0].strip())
203
+ except:
204
+ return []
205
+ else:
206
+ return []
207
+ return tool_calls
208
+
209
+ # Add assistant message to conversation
210
+ conversation_history.append({
211
+ "role": "assistant",
212
+ "content": assistant_message["content"]
213
+ })
214
+
215
+ tool_calls = extract_tool_calls(assistant_message["content"])
216
+
217
+ # Execute tool calls if any (Acting phase)
218
+
219
+ for tool_call in tool_calls:
220
+ arguments = tool_call["arguments"]
221
+
222
+ # Check if planning is complete
223
+ if tool_call["name"] in ["info_seeker_objective_task_done"]:
224
+ task_completed = True
225
+ self.log_action(iteration, tool_call["name"], arguments, arguments)
226
+ break
227
+ if tool_call["name"] in ["think", "reflect"]:
228
+ tool_result = {"tool_results": "You can proceed to invoke other tools if needed."}
229
+ else:
230
+ tool_result = self.execute_tool_call(tool_call)
231
+
232
+ # Log the action using base class method
233
+ self.log_action(iteration, tool_call["name"], arguments, tool_result)
234
+
235
+ # Add tool result to conversation
236
+ conversation_history.append({
237
+ "role": "tool",
238
+ "content": json.dumps(tool_result, ensure_ascii=False, indent=2) + " /no_think"
239
+ })
240
+
241
+ # If no tool calls, encourage continued planning
242
+ if len(tool_calls) == 0:
243
+ # Add follow-up prompt to encourage action or completion
244
+ followup_prompt = (
245
+ "Continue your planning process. Use available tools to assign tasks to agents, "
246
+ "search for information, or coordinate work. When you have a complete answer, "
247
+ "call info_seeker_objective_task_done. /no_think"
248
+ )
249
+ conversation_history.append({"role": "user", "content": followup_prompt})
250
+ if iteration == self.config.max_iterations-3:
251
+ followup_prompt = "Due to length and number of rounds restrictions, you must now call the `info_seeker_objective_task_done` tool to report the completion of your task. /no_think"
252
+ conversation_history.append({"role": "user", "content": followup_prompt})
253
+
254
+
255
+ except Exception as e:
256
+ error_msg = f"Error in planning iteration {iteration}: {e}"
257
+ self.log_error(iteration, error_msg)
258
+ break
259
+
260
+ execution_time = time.time() - start_time
261
+ # Extract final result
262
+ if task_completed:
263
+ # Find the info_seeker_objective_task_done result in the trace
264
+ task_done_result = None
265
+ for step in reversed(self.reasoning_trace):
266
+ if step.get("type") == "action" and step.get("tool") == "info_seeker_objective_task_done":
267
+ task_done_result = step.get("result")
268
+ break
269
+
270
+ return self.create_response(
271
+ success=True,
272
+ result=task_done_result,
273
+ iterations=iteration,
274
+ execution_time=execution_time
275
+ )
276
+ else:
277
+ return self.create_response(
278
+ success=False,
279
+ error=f"Task not completed within {self.config.max_iterations} iterations",
280
+ iterations=iteration,
281
+ execution_time=execution_time
282
+ )
283
+
284
+ except Exception as e:
285
+ execution_time = time.time() - start_time
286
+ self.logger.error(f"Error in execute_task: {e}")
287
+ return self.create_response(
288
+ success=False,
289
+ error=str(e),
290
+ iterations=iteration if 'iteration' in locals() else 0,
291
+ execution_time=execution_time
292
+ )
293
+
294
+ def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
295
+ """
296
+ Build tool schemas for InformationSeekerAgent using proper MCP architecture.
297
+ Schemas come from MCP server via client, not direct imports.
298
+ """
299
+
300
+ # Get MCP tool schemas from server via client (proper MCP architecture)
301
+ schemas = super()._build_agent_specific_tool_schemas()
302
+
303
+ # Add schemas for built-in task assignment tools
304
+ builtin_assignment_schemas = [
305
+ {
306
+ "type": "function",
307
+ "function": {
308
+ "name": "think",
309
+ "description": "Use the tool to think about something. It will not obtain new information or make any changes to the repository, but just log the thought. Use it when complex reasoning or brainstorming is needed.",
310
+ "parameters": {
311
+ "type": "object",
312
+ "properties": {
313
+ "thought": {
314
+ "type": "string",
315
+ "description": "Your thoughts."
316
+ }
317
+ },
318
+ "required": ["thought"]
319
+ }
320
+ }
321
+ },
322
+ {
323
+ "type": "function",
324
+ "function": {
325
+ "name": "reflect",
326
+ "description": "When multiple attempts yield no progress, use this tool to reflect on previous reasoning and planning, considering possible overlooked clues and exploring more possibilities. It will not obtain new information or make any changes to the repository.",
327
+ "parameters": {
328
+ "type": "object",
329
+ "properties": {
330
+ "reflect": {
331
+ "type": "string",
332
+ "description": "The specific content of your reflection"
333
+ }
334
+ },
335
+ "required": ["reflect"]
336
+ }
337
+ }
338
+ },
339
+ {
340
+ "type": "function",
341
+ "function": {
342
+ "name": "info_seeker_objective_task_done",
343
+ "description": "Structured reporting of task completion details including summary, decisions, outputs, and status",
344
+ "inputSchema": {
345
+ "type": "object",
346
+ "properties": {
347
+ "task_summary": {
348
+ "type": "string",
349
+ "description": "Comprehensive markdown covering what the agent was asked to do, steps taken, tools used, key findings, files created, challenges, and final deliverables.",
350
+ "format": "markdown"
351
+ },
352
+ "task_name": {
353
+ "type": "string",
354
+ "description": "The name of the task currently assigned to the agent, usually with underscores (e.g., 'web_research_ai_trends')"
355
+ },
356
+ "key_files": {
357
+ "type": "array",
358
+ "items": {
359
+ "type": "object",
360
+ "properties": {
361
+ "file_path": {
362
+ "type": "string",
363
+ "description": "Relative path to created/modified file"
364
+ },
365
+ "desc": {
366
+ "type": "string",
367
+ "description": "File contents and creation purpose"
368
+ },
369
+ "is_final_output_file": {
370
+ "type": "boolean",
371
+ "description": "Whether file is primary deliverable"
372
+ }
373
+ },
374
+ "required": ["file_path", "desc", "is_final_output_file"]
375
+ },
376
+ "description": "List of key files generated or modified during the task, with their details."
377
+ },
378
+ "completion_status": {
379
+ "type": "string",
380
+ "enum": ["completed", "partial", "failed"],
381
+ "description": "Final task status"
382
+ }
383
+ },
384
+ "required": ["task_summary", "task_name", "key_files", "completion_status"]
385
+ }
386
+ }
387
+ },
388
+ ]
389
+
390
+ schemas.extend(builtin_assignment_schemas)
391
+
392
+ return schemas
393
+
394
+
395
+ # Factory function for creating the agent
396
+ def create_objective_information_seeker(
397
+ model: Any = None,
398
+ max_iterations: Any = None,
399
+ shared_mcp_client=None,
400
+ **kwargs
401
+ ) -> InformationSeekerAgent:
402
+ """
403
+ Create an InformationSeekerAgent instance with server-managed sessions.
404
+
405
+ Args:
406
+ model: The LLM model to use
407
+ max_iterations: Maximum number of iterations
408
+ shared_mcp_client: Optional shared MCP client from parent agent (prevents extra sessions)
409
+ **kwargs: Additional configuration options
410
+
411
+ Returns:
412
+ Configured InformationSeekerAgent instance with appropriate tools
413
+ """
414
+ # Import the enhanced config function
415
+ from ..agents.base_agent import create_agent_config
416
+
417
+ # Create agent configuration (session managed by MCP server)
418
+ config = create_agent_config(
419
+ agent_name="InformationSeekerAgent",
420
+ model=model,
421
+ max_iterations=max_iterations,
422
+ **kwargs
423
+ )
424
+
425
+ # Create agent instance with shared MCP client (filtered tools for information seeking)
426
+ agent = InformationSeekerAgent(config=config, shared_mcp_client=shared_mcp_client)
427
+
428
+ return agent
deepdiver_v2/src/agents/planner_agent.py ADDED
@@ -0,0 +1,1203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ """
3
+ Planner Agent for Multi-Agent Task Coordination
4
+
5
+ This agent serves as a coordinator for complex tasks that require multiple agents
6
+ working together. It implements the ReAct pattern for reasoning and action.
7
+ """
8
+ import time
9
+ import json
10
+ import requests
11
+ import os
12
+ from typing import Dict, Any, List
13
+ from concurrent.futures import ThreadPoolExecutor
14
+
15
+ # Base imports
16
+ from .base_agent import BaseAgent, AgentConfig, AgentResponse, WriterAgentTaskInput
17
+ # Import agent creators for built-in task assignment
18
+ from .writer_agent import create_writer_agent
19
+
20
+
21
+ class PlannerAgent(BaseAgent):
22
+ """
23
+ PlannerAgent coordinates multiple agents to handle complex user queries.
24
+
25
+ The agent uses the ReAct pattern (Reasoning + Acting) to analyze user requests,
26
+ break them down into manageable tasks, and coordinate the appropriate agents
27
+ to complete the work.
28
+ """
29
+
30
+ def __init__(self, config: AgentConfig = None, shared_mcp_client=None):
31
+ # Set default agent name if not specified
32
+ if config and not config.agent_name:
33
+ config.agent_name = "PlannerAgent"
34
+ elif not config:
35
+ config = AgentConfig(agent_name="PlannerAgent")
36
+
37
+ super().__init__(config, shared_mcp_client)
38
+
39
+ # Planner-specific state
40
+ self.execution_plan = []
41
+ self.task_queue = []
42
+
43
+ # Add built-in task assignment methods to available tools
44
+ self._add_builtin_assignment_tools()
45
+
46
+ # Regenerate tool schemas with built-in assignment tools
47
+ self.tool_schemas = self._build_tool_schemas()
48
+
49
+ self.sub_agent_configs = {}
50
+
51
+ def _add_builtin_assignment_tools(self):
52
+ """Add built-in task assignment methods as available tools"""
53
+ # Add assignment methods that share the MCP client connection
54
+ self.available_tools.update({
55
+ "assign_subjective_task_to_writer": self.assign_subjective_task_to_writer, # assign_subjective_task_to_writer
56
+ "assign_multi_objective_tasks_to_info_seeker": self.assign_multi_objective_tasks_to_info_seeker,
57
+ "assign_multi_subjective_tasks_to_info_seeker": self.assign_multi_subjective_tasks_to_info_seeker
58
+ })
59
+
60
+ def assign_multi_objective_tasks_to_info_seeker(
61
+ self,
62
+ tasks: List[Dict[str, str]],
63
+ max_workers: int = 5
64
+ ) -> Dict[str, Any]:
65
+ """
66
+ Creates multiple TaskInput objects and routes them to info_seeker agents for concurrent execution.
67
+ This tool enables the PlannerAgent to assign multiple research tasks through the MCP tool interface.
68
+
69
+ Args:
70
+ tasks: List of task dictionaries with the following keys:
71
+ - task_content (required): The specific task content
72
+ - task_steps_for_reference: Optional reference steps for execution
73
+ - deliverable_contents: Format of expected deliverable
74
+ - acceptance_checking_criteria: Criteria for task completion and quality
75
+ - workspace_id: Workspace ID for stored files and memory
76
+ - current_task_status: Description of current task status
77
+
78
+ max_workers: Maximum concurrent threads (default=4)
79
+
80
+ Returns:
81
+ MCPToolResult with execution results for all tasks
82
+ """
83
+ try:
84
+ # Validate task count (1-4 tasks)
85
+ if not (1 <= len(tasks) <= 5):
86
+ return {
87
+ "success": False,
88
+ "error": f"Invalid task count ({len(tasks)}). Must assign 1~5 tasks. Please re-plan the task execution schedule or re-decompose the task."
89
+ }
90
+
91
+ # Import here to avoid circular imports
92
+ try:
93
+ from agents import TaskInput, create_objective_information_seeker
94
+ except ImportError:
95
+ from ..agents import TaskInput, create_objective_information_seeker
96
+
97
+ results = []
98
+ import threading
99
+ lock = threading.Lock()
100
+
101
+ def process_task(task: Dict[str, str]):
102
+ """Process a single task with thread-safe result collection"""
103
+ try:
104
+
105
+
106
+ # Create TaskInput object
107
+ task_input = TaskInput(
108
+ task_content=task["task_content"],
109
+ task_steps_for_reference=task.get("task_steps_for_reference"),
110
+ deliverable_contents=task.get("deliverable_contents"),
111
+ current_task_status=task.get("current_task_status"),
112
+ workspace_id=None, # Session/workspace is managed by the server; no need to set explicitly
113
+ acceptance_checking_criteria=task.get("acceptance_checking_criteria")
114
+ )
115
+
116
+ # Create and execute with info seeker agent - use shared MCP client for session consistency
117
+ info_seeker_config = getattr(self, 'sub_agent_configs', {}).get('information_seeker', {})
118
+ info_seeker = create_objective_information_seeker(
119
+ model=info_seeker_config.get('model', self.config.model),
120
+ max_iterations=info_seeker_config.get('max_iterations', 30),
121
+ shared_mcp_client=self.mcp_tools.client if hasattr(self.mcp_tools, 'client') else self.mcp_tools
122
+ )
123
+
124
+ self.logger.info(f"Assigning task to InformationSeekerAgent: {task['task_content'][:8000]}...")
125
+
126
+
127
+ # Execute the task
128
+ response = info_seeker.execute_task(task_input)
129
+
130
+ if response.success:
131
+ response_data = {
132
+ "task_content": task.get("task_content", "Unknown task"),
133
+ "success": True,
134
+ "data": response.result,
135
+ "agent_name": response.agent_name,
136
+ "iterations": response.iterations,
137
+ "execution_time": response.execution_time,
138
+ # "reasoning_trace": response.reasoning_trace
139
+ }
140
+ else:
141
+ response_data = {
142
+ "task_content": task.get("task_content", "Unknown task"),
143
+ "success": False,
144
+ "error": response.error,
145
+ "agent_name": response.agent_name
146
+ }
147
+
148
+ # Thread-safe result collection
149
+ with lock:
150
+ results.append(response_data)
151
+
152
+ return response_data
153
+
154
+ except Exception as e:
155
+ error_msg = f"Task processing failed: {str(e)}"
156
+ self.logger.error(error_msg)
157
+ with lock:
158
+ results.append({
159
+ "task_content": task.get("task_content", "Unknown task"),
160
+ "success": False,
161
+ "error": error_msg
162
+ })
163
+ return None
164
+
165
+ # Execute tasks in parallel with thread pool
166
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
167
+ futures = [executor.submit(process_task, task) for task in tasks]
168
+ # Wait for all tasks to complete
169
+ for future in futures:
170
+ future.result() # Raise exceptions if any
171
+
172
+ # Check overall success
173
+ all_success = all(task_result.get("success", False) for task_result in results)
174
+
175
+ return {
176
+ "success": all_success,
177
+ "data": {"tasks": results},
178
+ "error": None if all_success else "Some tasks failed",
179
+ "metadata": {
180
+ "tool_name": "assign_multi_objective_tasks_to_info_seeker",
181
+ "task_count": len(tasks),
182
+ "success_count": sum(1 for r in results if r.get("success")),
183
+ "failure_count": sum(1 for r in results if not r.get("success"))
184
+ }
185
+ }
186
+
187
+ except Exception as e:
188
+ self.logger.error(f"Multi-task assignment failed: {e}")
189
+ return {
190
+ "success": False,
191
+ "error": f"Multi-task assignment failed: {str(e)}"
192
+ }
193
+
194
+
195
+ def assign_multi_subjective_tasks_to_info_seeker(
196
+ self,
197
+ tasks: List[Dict[str, str]],
198
+ max_workers: int = 5
199
+ ) -> Dict[str, Any]:
200
+ """
201
+ Creates multiple TaskInput objects and routes them to info_seeker agents for concurrent execution.
202
+ This tool enables the PlannerAgent to assign multiple research tasks through the MCP tool interface.
203
+
204
+ Args:
205
+ tasks: List of task dictionaries with the following keys:
206
+ - task_content (required): The specific task content
207
+ - task_steps_for_reference: Optional reference steps for execution
208
+ - deliverable_contents: Format of expected deliverable
209
+ - acceptance_checking_criteria: Criteria for task completion and quality
210
+ - workspace_id: Workspace ID for stored files and memory
211
+ - current_task_status: Description of current task status
212
+
213
+ max_workers: Maximum concurrent threads (default=4)
214
+
215
+ Returns:
216
+ MCPToolResult with execution results for all tasks
217
+ """
218
+ try:
219
+ # Validate task count (1-4 tasks)
220
+ if not (1 <= len(tasks) <= 6):
221
+ return {
222
+ "success": False,
223
+ "error": f"Invalid task count ({len(tasks)}). Must assign 1-6 tasks."
224
+ }
225
+
226
+ # Import here to avoid circular imports
227
+ try:
228
+ from agents import TaskInput, create_subjective_information_seeker
229
+ except ImportError:
230
+ from ..agents import TaskInput, create_subjective_information_seeker
231
+
232
+ results = []
233
+ import threading
234
+ lock = threading.Lock()
235
+
236
+ def process_task(task: Dict[str, str]):
237
+ """Process a single task with thread-safe result collection"""
238
+ try:
239
+ # Create TaskInput object
240
+ task_input = TaskInput(
241
+ task_content=task["task_content"],
242
+ task_steps_for_reference=task.get("task_steps_for_reference"),
243
+ deliverable_contents=task.get("deliverable_contents"),
244
+ current_task_status=task.get("current_task_status"),
245
+ workspace_id=self.get_session_info()["session_id"], # Session/workspace is managed by the server; no need to set explicitly
246
+ acceptance_checking_criteria=task.get("acceptance_checking_criteria")
247
+ )
248
+
249
+ # Create and execute with info seeker agent - use shared MCP client for session consistency
250
+ info_seeker_config = getattr(self, 'sub_agent_configs', {}).get('information_seeker', {})
251
+ info_seeker = create_subjective_information_seeker(
252
+ model=info_seeker_config.get('model', self.config.model),
253
+ max_iterations=info_seeker_config.get('max_iterations', 30),
254
+ shared_mcp_client=self.mcp_tools.client if hasattr(self.mcp_tools, 'client') else self.mcp_tools
255
+ )
256
+
257
+ self.logger.info(f"Assigning task to InformationSeekerAgent: {task['task_content'][:8000]}...")
258
+
259
+ # Execute the task
260
+ response = info_seeker.execute_task(task_input)
261
+
262
+ if response.success:
263
+ response_data = {
264
+ "task_content": task.get("task_content", "Unknown task"),
265
+ "success": True,
266
+ "data": response.result,
267
+ "agent_name": response.agent_name,
268
+ "iterations": response.iterations,
269
+ "execution_time": response.execution_time,
270
+ # "reasoning_trace": response.reasoning_trace
271
+ }
272
+ else:
273
+ response_data = {
274
+ "task_content": task.get("task_content", "Unknown task"),
275
+ "success": False,
276
+ "error": response.error,
277
+ "agent_name": response.agent_name
278
+ }
279
+
280
+ # Thread-safe result collection
281
+ with lock:
282
+ results.append(response_data)
283
+
284
+ return response_data
285
+
286
+ except Exception as e:
287
+ error_msg = f"Task processing failed: {str(e)}"
288
+ self.logger.error(error_msg)
289
+ with lock:
290
+ results.append({
291
+ "task_content": task.get("task_content", "Unknown task"),
292
+ "success": False,
293
+ "error": error_msg
294
+ })
295
+ return None
296
+
297
+ # Execute tasks in parallel with thread pool
298
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
299
+ futures = [executor.submit(process_task, task) for task in tasks]
300
+ # Wait for all tasks to complete
301
+ for future in futures:
302
+ future.result() # Raise exceptions if any
303
+
304
+ # Check overall success
305
+ all_success = all(task_result.get("success", False) for task_result in results)
306
+
307
+ return {
308
+ "success": all_success,
309
+ "data": {"tasks": results},
310
+ "error": None if all_success else "Some tasks failed",
311
+ "metadata": {
312
+ "tool_name": "assign_multi_subjective_tasks_to_info_seeker",
313
+ "task_count": len(tasks),
314
+ "success_count": sum(1 for r in results if r.get("success")),
315
+ "failure_count": sum(1 for r in results if not r.get("success"))
316
+ }
317
+ }
318
+
319
+ except Exception as e:
320
+ self.logger.error(f"Multi-task assignment failed: {e}")
321
+ return {
322
+ "success": False,
323
+ "error": f"Multi-task assignment failed: {str(e)}"
324
+ }
325
+
326
+ def assign_subjective_task_to_writer(
327
+ self,
328
+ task_content: str,
329
+ user_query: str,
330
+ key_files: List[Dict[str, str]]
331
+ ) -> Dict[str, Any]:
332
+ """
333
+ Assign a writing or content creation task to the WriterAgent
334
+
335
+ Args:
336
+ task_content: Detailed description of the writing task to be performed
337
+ user_query: List storing previous information seeker subtask summaries intact to preserve information from each completed research task
338
+ key_files: Curated list of relevant files with file_path and desc for each file
339
+
340
+ Returns:
341
+ Dictionary with task assignment results
342
+ """
343
+ try:
344
+
345
+ self.logger.info("Assigning task to WriterAgent")
346
+
347
+ # Create task input
348
+ task_input = WriterAgentTaskInput(
349
+ task_content=task_content,
350
+ user_query=user_query,
351
+ key_files=key_files,
352
+ workspace_id=self.get_session_info()["session_id"],
353
+ )
354
+
355
+ # Create writer agent with shared MCP client and sub-agent configuration
356
+ writer_config = getattr(self, 'sub_agent_configs', {}).get('writer', {})
357
+ writer = create_writer_agent(
358
+ shared_mcp_client=self.mcp_tools.client,
359
+ model=writer_config.get('model', self.config.model),
360
+ max_iterations=writer_config.get('max_iterations', 20),
361
+ temperature=writer_config.get('temperature', 0.3),
362
+ max_tokens=writer_config.get('max_tokens', 16384)
363
+ )
364
+
365
+ self.logger.info(f"Assigning task to WriterAgent: {task_content[:800]}...")
366
+
367
+ # Execute the task with shared connection
368
+ response = writer.execute_task(task_input)
369
+
370
+ if response.success:
371
+ return {
372
+ "success": True,
373
+ "data": response.result,
374
+ "agent_name": response.agent_name,
375
+ "iterations": response.iterations,
376
+ "execution_time": response.execution_time,
377
+ # "reasoning_trace": response.reasoning_trace
378
+ }
379
+ else:
380
+ return {
381
+ "success": False,
382
+ "error": response.error,
383
+ "agent_name": response.agent_name
384
+ }
385
+
386
+ except Exception as e:
387
+ self.logger.error(f"Failed to assign task to WriterAgent: {e}")
388
+ return {
389
+ "success": False,
390
+ "error": f"Task assignment failed: {str(e)}"
391
+ }
392
+
393
+ def _build_system_prompt(self) -> str:
394
+ """Build the system prompt for the planner agent"""
395
+ tool_schemas_str = json.dumps(self.tool_schemas, ensure_ascii=False)
396
+
397
+ auto_system_prompt_template = """# PlannerAgent: Multi-Agent Task Coordinator
398
+ **Role:** Analyze complex queries, first distinguish query type (long-form writing type/objective question type), then create structured plans, and coordinate specialized agents to deliver comprehensive solutions—call corresponding tools based on query type, and only invoke writer for long-form writing type queries.
399
+
400
+ #### Available Sub-Agents:
401
+ - **`information_seeker`**: Research, data gathering, web search (supports single/parallel multi-task; long-form writing type uses assign_multi_subjective_tasks_to_info_seeker, other types use assign_multi_objective_tasks_to_info_seeker)
402
+ - **`writer`**: Only invoke this sub-agent when long-form writing is required.
403
+
404
+ ---
405
+
406
+ ## Optimized Workflow
407
+ ### 1. Query Type Judgment & Analysis & Planning Phase
408
+ **Goal:** Use the `think` tool to analyze the problem and determine whether it is a simple task (refers to tasks that do not require calling the information search agent or tool) or a complex task (requires calling info seeker). If it is a complex task, it is necessary to further analyze whether it is a objective question(do not require calling the writer agent)or a long-form writing question (requires long-form expression and need to call the writer agent later).
409
+ - **Simple Tasks:** For simple tasks that do not require info seeker invocation, you can directly call the `planner_objective_task_done` tool and write the answer in `final_answer` field without creating a todo.md file.
410
+ - **Complex Tasks:**
411
+ - For objective tasks, must use `assign_multi_objective_tasks_to_info_seeker`
412
+ - For long-form writing tasks, must use `assign_multi_subjective_tasks_to_info_seeker`, and call the writer agent to integrate the collected information to generate a very long text
413
+ - **Task Decomposition Rules:**
414
+ - Construct a task tree with a tree-like structure, where the root node represents the user's input query. Each subtask is marked with its depth in the task tree, and the entire task tree is executed from shallow to deep. Tasks at the same depth in the task tree must be independent and can be executed in parallel (via `assign_multi_xxx_tasks_to_info_seeker`) without mutual dependencies.
415
+ - At the first level of the task tree, it is essential to thoroughly design subtasks that can be executed in parallel to explore various potential background information, thereby providing more specific clues for the next step of planning.
416
+ - Competitive Redundancy Mechanism:
417
+ - For key subtasks that have a significant impact on subsequent reasoning and planning, a redundancy mechanism should be established. This involves duplicating the task at the same depth level in the task tree, enabling the parallel execution of nearly identical tasks to enhance the completion rate and robustness of the task execution.
418
+ - **Task Parallel Sending Requirements:**
419
+ - When using `assign_multi_xxx_tasks_to_info_seeker`, all parallel-sent subtasks must be independent of each other; the description of each subtask must not contain any mutual references or dependency requirements for other subtasks.
420
+ - There is no sequential execution relationship among all parallel-sent subtasks.
421
+
422
+ - **Mandatory Documentation:** Create and write `todo.md` (e.g., `todo_v1.md`) with fields:
423
+ ```markdown
424
+ # Task Planning Document
425
+ ## task_name: [Clear identifier]
426
+ ## task_desc: [Detailed requirements - focus on WHAT not HOW]
427
+ ## deliverable_contents: [Exact output format specs]
428
+ ## success_criteria: [Measurable 100% completion metrics]
429
+ ## context: [Background, constraints, prior results]
430
+ ## task_steps_for_reference: [Tree-structured preliminary execution plan, tag tasks with the depth in task tree `[DEPTH:xx]`]
431
+ ```
432
+
433
+ ### 2. Execution & Iteration Phase
434
+ #### A. Unified Iteration Triggers (Shared by Both Types)
435
+ - Based on upper-layer task results, refine the next layer of planning and document it in a new version of `todo.md` (e.g., `todo_v2.md`).
436
+ - If upper-layer tasks fail/encounter challenges: Invoke the `reflect` tool for introspection (no new information acquired, only saves thoughts), adjust the plan, and re-invoke the corresponding `information_seeker` method (objective: `assign_multi_objective_tasks_to_info_seeker`; long-form writing: `assign_multi_subjective_tasks_to_info_seeker`).
437
+ - If current tasks require prior round information: Clearly specify the context of each task and referenced files (e.g., `./data/agent_output_v1.json`) when calling `information_seeker`.
438
+ - Decompose and refine clues from upper-layer results, then execute verification in parallel.
439
+
440
+ #### B. Query-Type-Specific Operations
441
+ - **Objective tasks**: No additional operations (strictly no writer invocation). Continue iterating until information meets `success_criteria`.
442
+ - **Long-form writing tasks**: Add **information sufficiency check before writer invocation**:
443
+ 1. Evaluate collected information from two dimensions: quantity (e.g., "Enough case studies for 3 chapters") and comprehensiveness (e.g., "Covers both positive and negative impacts of AI on education").
444
+ 2. If information is insufficient: Adjust subtask directions (e.g., "Supplement AI education failure cases") and re-invoke `assign_multi_subjective_tasks_to_info_seeker` for targeted collection.
445
+ 3. If information is sufficient: Invoke the writer via `assign_subjective_task_to_writer` (provide all collected materials and `todo.md` as context).
446
+ 4. If the writer returns an incomplete result: Do not assist in completing it; only feed back the current completion status to the user.
447
+
448
+ ### 3. Completion & Synthesis Phase
449
+ #### A. Unified Validation & Integration (Shared by Both Types)
450
+ - **Validation**: Cross-check multi-source `information_seeker` outputs for consistency (e.g., "NBS and World Bank GDP data differ by ≤1%").
451
+ - **Integration**: Combine parallel outputs into a unified deliverable (e.g., "Merge two GDP data sources into a single table" or "Integrate writer’s report with supplementary case studies").
452
+ - **Delivery**: Output language must match the user’s query language (e.g., Chinese query → Chinese deliverable).
453
+
454
+ #### B. Query-Type-Specific Task Completion (Critical)
455
+ - **Objective tasks**: Call the `planner_objective_task_done` tool **only when** all planned tasks are completed and the final deliverable (e.g., verified data, clear answers) is ready for user delivery.
456
+ - **Long-form writing tasks**: Call the `planner_subjective_task_done` tool **only when** the writer has finished executing and the final long-form content meets the `success_criteria` in `todo.md`.
457
+
458
+ ---
459
+
460
+ ## Critical Protocols
461
+ 1. **Dependency Management:**
462
+ - Prohibit parallel dispatch for sequential dependent tasks unless using competitive redundancy mechanism
463
+ - Convert sequential chains to parallel where possible (e.g., Hypothesis_A vs Hypothesis_B testing)
464
+ 2. **File Traceability:**
465
+ - All output references use relative paths (`./data/agent_output_1.json`)
466
+ - Version `todo.md` after each iteration (e.g., `todo_v2.md`)
467
+ 3. **Local File Reading Recommendations:**
468
+ - For files crawled natively, it is not recommended to directly use the `file_read` tool to read the entire content (maybe too long). Instead, the `document_qa` tool should be used to extract and verify the required information.
469
+ - For task deliverables and summary documents from sub-agents, the `file_read` tool can be used to read them.
470
+ 4. The final deliverable presented to the user should be consistent with the language used in the user's question.
471
+ 5. **Writer invocation**: Strictly prohibit calling the writer for objective tasks; for long-form writing tasks, **never directly answer based on collected information**—must invoke the writer to generate the final long-form content.
472
+
473
+ Below, within the <tools></tools> tags, are the descriptions of each tool and the required fields for invocation:
474
+ <tools>
475
+ $tool_schemas
476
+ </tools>
477
+ For each function call, return a JSON object placed within the [unused11][unused12] tags, which includes the function name and the corresponding function arguments:
478
+ [unused11][{\"name\": <function name>, \"arguments\": <args json object>}][unused12]"""
479
+
480
+ writing_system_prompt_template = """### PlannerAgent: Multi-Agent Task Coordinator
481
+ **Role:** Analyze complex queries, create structured plans, and coordinate specialized agents to deliver comprehensive solutions.
482
+
483
+ #### Available Sub-Agents:
484
+ - **`information_seeker`**: Research, data gathering, web search (supports single/parallel multi-task)
485
+ - **`writer`**: Creates content (e.g., reports, analysis, etc.), and synthesizes from existing materials
486
+
487
+ ---
488
+
489
+ ### Optimized Workflow
490
+ #### 1. Analysis & Planning Phase
491
+ **Goal:** Analyze the problem and determine whether it is a simple task or a complex task. If it is a complex task, it is necessary to further analyze whether it is a subject-driven question or an objective-driven question, so as to decompose the problem into multiple clear and executable subtasks according to the specific problem type. The main characteristic of objective-driven questions is that their answers are clear and verifiable entities, otherwise they are subject-driven questions.
492
+ - **Simple Tasks:** For simple tasks that do not require sub-agent invocation, you can directly answer without creating a todo.md file
493
+ - **Complex Tasks:**
494
+ - For Objective-driven tasks, Adopt *diverge-converge* strategy:
495
+ 1. Use `assign_multi_subjective_tasks_to_info_seeker` call for divergent background research
496
+ 2. Converge findings to define specific sub-problems
497
+ - For Subject-driven tasks, Adopt *multi-perspective* strategy:
498
+ 1. Use assign_multi_subjective_tasks_to_info_seeker call for divergent multi-source exploration (each task targets independent dimensions)
499
+ 2. Converge findings to define focused sub-problems addressing distinct knowledge gaps
500
+ 3. When the information seeker collects information, start to call the writer agent to integrate the collected information to generate a very long text
501
+ - **Task Decomposition Rules:**
502
+ - Construct a task tree with a tree-like structure, where the root node represents the user's input query. Each subtask is marked with its depth in the task tree, and the entire task tree is executed from shallow to deep. Tasks at the same depth in the task tree must be independent and can be executed in parallel (via `assign_multi_subjective_tasks_to_info_seeker`) without mutual dependencies.
503
+ - At the first level of the task tree, it is essential to thoroughly design subtasks that can be executed in parallel to explore various potential background information, thereby providing more specific clues for the next step of planning.
504
+ - Competitive Redundancy Mechanism:
505
+ - For key subtasks that have a significant impact on subsequent reasoning and planning, a redundancy mechanism should be established. This involves duplicating the task at the same depth level in the task tree, enabling the parallel execution of nearly identical tasks to enhance the completion rate and robustness of the task execution.
506
+ - **Task Parallel Sending Requirements:**
507
+ - When using `assign_multi_subjective_tasks_to_info_seeker`, all parallel-sent subtasks must be independent of each other; the description of each subtask must not contain any mutual references or dependency requirements for other subtasks.
508
+ - There is no sequential execution relationship among all parallel-sent subtasks.
509
+
510
+ - **Mandatory Documentation:** Create and write `todo.md` (e.g., `todo_v1.md`) with fields:
511
+ ```markdown
512
+ # Task Planning Document
513
+ ## task_name: [Clear identifier]
514
+ ## task_desc: [Detailed requirements - focus on WHAT not HOW]
515
+ ## deliverable_contents: [Exact output format specs]
516
+ ## success_criteria: [Measurable 100% completion metrics]
517
+ ## context: [Background, constraints, prior results]
518
+ ## task_steps_for_reference: [Tree-structured preliminary execution plan, tag tasks with the depth in task tree `[DEPTH:xx]`]
519
+ ```
520
+
521
+ #### 2. Execution & Iteration Phase
522
+ - **Iteration Triggers:**
523
+ - Based on the execution results of the upper layer of the task tree, specify and refine the next layer and subsequent task planning, and document them in a new `todo.md` file (e.g., `todo_v2.md`).
524
+ - If there are tasks in the previous layer that have failed or encountered challenges, it is necessary to invoke `reflect` for introspection, consider more possibilities, and make new task planning and invoke `assign_multi_subjective_tasks_to_info_seeker` again.
525
+ - If the tasks sent in the current round require reference to task information from previous rounds, it is essential to clearly specify the context of each task and the files that may need to be used or referenced when calling `assign_multi_subjective_tasks_to_info_seeker`.
526
+ - For the multiple clues of the execution results from the previous layer, they should be decomposed and refined, and executed in parallel for verification.
527
+ - **Information check required before calling writer:**
528
+ - Before invoking writer, analyze collected information for sufficiency: evaluate both quantity and comprehensiveness to ensure adequate material for long article generation
529
+ - If information is insufficient, adjust subtask direction and initiate additional targeted information collection
530
+ - **When information is sufficient, invoke writer agent** via `assign_subjective_task_to_writer`
531
+
532
+ #### 3. Completion & Synthesis Phase
533
+ - **Validation:** Cross-check multi-source outputs for consistency, and Check whether the information source is sufficient
534
+ - **Integration:** Combine parallel outputs into unified deliverable
535
+ - **Delivery:** Output language must match user's query language
536
+ - When the writer agent is finished executing, planner_subjective_task_done tool needs to be called to end the current task
537
+
538
+ ---
539
+
540
+ ### Critical Protocols
541
+ 1. **Dependency Management:**
542
+ - Prohibit parallel dispatch for sequential dependent tasks unless using competitive redundancy mechanism
543
+ - Convert sequential chains to parallel where possible (e.g., Hypothesis_A vs Hypothesis_B testing)
544
+ 2. **File Traceability:**
545
+ - All output references use relative paths (`./data/agent_output_1.json`)
546
+ - Version `todo.md` after each iteration (e.g., `todo_v2.md`)
547
+ 3. **Iteration Discipline:**
548
+ - Minimum 2 parallel agents for critical hypothesis-validation tasks
549
+ - Terminate only when ALL success criteria are met at 100%
550
+ 5. **Usage of Think Tool:**
551
+ - `think` is a systematic tool. After receiving the response from the complex tool or before invoking any other tools, you must **first invoke the `think` tool**: to deeply reflect on the results of previous tool invocations (if any), and to thoroughly consider and plan the user's task. The `think` tool does not acquire new information; it only saves your thoughts into memory.
552
+ 6. **Usage of Reflect Tool:**
553
+ `reflect` is a systematic tool. When encountering a failure in tool execution, it is necessary to invoke the reflect tool to conduct a review and revise the task plan. It does not acquire new information; it only saves your thoughts into memory.
554
+ 7. Always prioritize complete solutions over partial delivery. Use parallel redundancy for critical path tasks, and convert agent disagreements into new parallel investigation branches.
555
+ 8. **CRITICAL:** When you determine that the information_seeker has gathered sufficient information, you must invoke the writer agent to draft the final article in response to the user's query. You are not allowed to reply directly based on the collected information!
556
+ 9.Also note that when the writing agent returns a result that shows it is not completed, you do not need to help it complete it further. You only need to feedback the current completion status to the user.
557
+
558
+ Below, within the <tools></tools> tags, are the descriptions of each tool and the required fields for invocation:
559
+ <tools>
560
+ $tool_schemas
561
+ </tools>
562
+ For each function call, return a JSON object placed within the [unused11][unused12] tags, which includes the function name and the corresponding function arguments:
563
+ [unused11][{\"name\": <function name>, \"arguments\": <args json object>}][unused12]"""
564
+
565
+ qa_system_prompt_template = """### PlannerAgent: Multi-Agent Task Coordinator
566
+ **Role:** Analyze complex queries, create structured plans, and coordinate specialized agents to deliver comprehensive solutions.
567
+
568
+ #### Available Sub-Agents:
569
+ - **`information_seeker`**: Research, data gathering, web search (supports single/parallel multi-task)
570
+
571
+ ---
572
+
573
+ ### Optimized Workflow
574
+ #### 1. Analysis & Planning Phase
575
+ **Goal:** Decompose problems into executable units with clear dependencies
576
+ - **Simple Tasks:** For simple tasks that do not require sub-agent invocation, you can directly answer and call `planner_objective_task_done` without creating a todo.md file
577
+ - **Complex Tasks:**
578
+ - **Task Decomposition Rules:**
579
+ - Construct a task tree with a tree-like structure, where the root node represents the user\'s input query. Each subtask is marked with its depth in the task tree, and the entire task tree is executed from shallow to deep. Tasks at the same depth in the task tree must be independent and can be executed in parallel (via `assign_multi_objective_tasks_to_info_seeker`) without mutual dependencies.
580
+ - At the first level of the task tree, it is essential to thoroughly design subtasks that can be executed in parallel to explore various potential background information, thereby providing more specific clues for the next step of planning.
581
+ - Competitive Redundancy Mechanism:
582
+ - For key subtasks that have a significant impact on subsequent reasoning and planning, a redundancy mechanism should be established. This involves duplicating the task at the same depth level in the task tree, enabling the parallel execution of nearly identical tasks to enhance the completion rate and robustness of the task execution.
583
+ - **Task Parallel Sending Requirements:**
584
+ - When using `assign_multi_objective_tasks_to_info_seeker`, all parallel-sent subtasks must be independent of each other; the description of each subtask must not contain any mutual references or dependency requirements for other subtasks.
585
+ - There is no sequential execution relationship among all parallel-sent subtasks.
586
+
587
+ - **Mandatory Documentation:** Create and write `todo.md` (e.g., `todo_v1.md`) with fields:
588
+ ```markdown
589
+ # Task Planning Document
590
+ ## task_name: [Clear identifier]
591
+ ## task_desc: [Detailed requirements - focus on WHAT not HOW]
592
+ ## deliverable_contents: [Exact output format specs]
593
+ ## success_criteria: [Measurable 100% completion metrics]
594
+ ## context: [Background, constraints, prior results]
595
+ ## task_steps_for_reference: [Tree-structured preliminary execution plan, tag tasks with the depth in task tree `[DEPTH:xx]`]
596
+ ```
597
+
598
+ #### 2. Execution & Iteration Phase
599
+ - **Iteration Triggers:**
600
+ - Based on the execution results of the upper layer of the task tree, specify and refine the next layer and subsequent task planning, and document them in a new `todo.md` file (e.g., `todo_v2.md`).
601
+ - If there are tasks in the previous layer that have failed or encountered challenges, it is necessary to invoke `reflect` for introspection, consider more possibilities, and make new task planning and invoke `assign_multi_objective_tasks_to_info_seeker` again.
602
+ - If the tasks sent in the current round require reference to task information from previous rounds, it is essential to clearly specify the context of each task and the files that may need to be used or referenced when calling `assign_multi_objective_tasks_to_info_seeker`.
603
+ - For the multiple clues of the execution results from the previous layer, they should be decomposed and refined, and executed in parallel for verification.
604
+
605
+ #### 3. Completion & Synthesis Phase
606
+ - **Validation:** Cross-check multi-source outputs for consistency
607
+ - **Integration:** Combine parallel outputs into unified deliverable
608
+ - **Delivery:** Output language must match user\'s query language
609
+ - **Task Completed:** The `planner_objective_task_done` can only be called when all planned tasks have been completed and the final results are ready to be delivered to the user.
610
+
611
+ ---
612
+
613
+ ### Critical Protocols
614
+ 1. **Dependency Management:**
615
+ - Prohibit parallel dispatch for sequential dependent tasks unless using competitive redundancy mechanism
616
+ - Convert sequential chains to parallel where possible (e.g., Hypothesis_A vs Hypothesis_B testing)
617
+ 2. **File Traceability:**
618
+ - All output references use relative paths (`./data/agent_output_1.json`)
619
+ - Version `todo.md` after each iteration (e.g., `todo_v2.md`)
620
+ 3. **Local File Reading Recommendations:**
621
+ - For files crawled natively, it is not recommended to directly use the `file_read` tool to read the entire content (maybe too long). Instead, the `document_qa` tool should be used to extract and verify the required information.
622
+ - For task deliverables and summary documents from sub-agents, the `file_read` tool can be used to read them.
623
+ 4. The final deliverable presented to the user should be consistent with the language used in the user\'s question.
624
+
625
+ Below, within the <tools></tools> tags, are the descriptions of each tool and the required fields for invocation:
626
+ <tools>
627
+ $tool_schemas
628
+ </tools>
629
+ For each function call, return a JSON object placed within the [unused11][unused12] tags, which includes the function name and the corresponding function arguments:
630
+ [unused11][{\"name\": <function name>, \"arguments\": <args json object>}][unused12]"""
631
+
632
+ planner_mode_system_prompt_map = {
633
+ "auto": auto_system_prompt_template,
634
+ "writing": writing_system_prompt_template,
635
+ "qa": qa_system_prompt_template
636
+ }
637
+
638
+ system_prompt = planner_mode_system_prompt_map[self.config.planner_mode].replace("$tool_schemas", tool_schemas_str)
639
+
640
+ return system_prompt
641
+
642
+ def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
643
+ """
644
+ Build tool schemas for PlannerAgent using proper MCP architecture.
645
+ Schemas come from MCP server via client, not direct imports.
646
+ """
647
+
648
+ # Get MCP tool schemas from server via client (proper MCP architecture)
649
+ schemas = super()._build_agent_specific_tool_schemas()
650
+
651
+ # Add schemas for built-in task assignment tools
652
+ planner_mode_builtin_tools_map = {
653
+ "auto": ["think", "reflect", "assign_multi_subjective_tasks_to_info_seeker", "assign_multi_objective_tasks_to_info_seeker", "assign_subjective_task_to_writer", "writer_subjective_task_done", "planner_subjective_task_done", "planner_objective_task_done"],
654
+ "writing": ["think", "reflect", "assign_multi_subjective_tasks_to_info_seeker", "assign_subjective_task_to_writer", "writer_subjective_task_done", "planner_subjective_task_done"],
655
+ "qa": ["think", "reflect", "assign_multi_objective_tasks_to_info_seeker", "planner_objective_task_done"],
656
+ }
657
+ builtin_assignment_schemas = [
658
+ {
659
+ "type": "function",
660
+ "function": {
661
+ "name": "think",
662
+ "description": "Use the tool to think about something. It will not obtain new information or make any changes to the repository, but just log the thought. Use it when complex reasoning or brainstorming is needed.",
663
+ "parameters": {
664
+ "type": "object",
665
+ "properties": {
666
+ "thought": {
667
+ "type": "string",
668
+ "description": "Your thoughts."
669
+ }
670
+ },
671
+ "required": ["thought"]
672
+ }
673
+ }
674
+ },
675
+ {
676
+ "type": "function",
677
+ "function": {
678
+ "name": "reflect",
679
+ "description": "When multiple attempts yield no progress, use this tool to reflect on previous reasoning and planning, considering possible overlooked clues and exploring more possibilities. It will not obtain new information or make any changes to the repository.",
680
+ "parameters": {
681
+ "type": "object",
682
+ "properties": {
683
+ "reflect": {
684
+ "type": "string",
685
+ "description": "The specific content of your reflection"
686
+ }
687
+ },
688
+ "required": ["reflect"]
689
+ }
690
+ }
691
+ },
692
+ {
693
+ "type": "function",
694
+ "function": {
695
+ "name": "assign_multi_subjective_tasks_to_info_seeker",
696
+ "description": "Assign 1~6 research or information gathering tasks to different InformationSeekerAgents for parallel execution, each task descriptions must be semantically complete and clearly provide contextual information and potentially important reference documents.",
697
+ "parameters": {
698
+ "type": "object",
699
+ "properties": {
700
+ "tasks": {
701
+ "type": "array",
702
+ "description": "List of tasks to be assigned to multiple InformationSeekerAgents",
703
+ "items": {
704
+ "type": "object",
705
+ "properties": {
706
+ "task_content": {
707
+ "type": "string",
708
+ "description": "Detailed description of the task to be performed"
709
+ },
710
+ "task_steps_for_reference": {
711
+ "type": "string",
712
+ "description": "Optional reference steps for task execution"
713
+ },
714
+ "deliverable_contents": {
715
+ "type": "string",
716
+ "description": "Expected format and content of deliverables"
717
+ },
718
+ "current_task_status": {
719
+ "type": "string",
720
+ "description": "Current status and context of the task, important documents that may be used and referenced"
721
+ },
722
+ "acceptance_checking_criteria": {
723
+ "type": "string",
724
+ "description": "Criteria for determining task completion and quality"
725
+ },
726
+ },
727
+ "required": ["task_content"]
728
+ }
729
+ }
730
+ },
731
+ "required": ["tasks"]
732
+ }
733
+ }
734
+ },
735
+ {
736
+ "type": "function",
737
+ "function": {
738
+ "name": "assign_multi_objective_tasks_to_info_seeker",
739
+ "description": "Assign 1~5 research or information gathering tasks to different InformationSeekerAgents for parallel execution, each task descriptions must be semantically complete and clearly provide contextual information and potentially important reference documents.",
740
+ "parameters": {
741
+ "type": "object",
742
+ "properties": {
743
+ "tasks": {
744
+ "type": "array",
745
+ "description": "List of tasks to be assigned to multiple InformationSeekerAgents",
746
+ "items": {
747
+ "type": "object",
748
+ "properties": {
749
+ "task_content": {
750
+ "type": "string",
751
+ "description": "Detailed description of the task to be performed, the task description must be semantically complete"
752
+ },
753
+ "task_steps_for_reference": {
754
+ "type": "string",
755
+ "description": "Optional reference steps for task execution"
756
+ },
757
+ "deliverable_contents": {
758
+ "type": "string",
759
+ "description": "Expected format and content of deliverables"
760
+ },
761
+ "current_task_status": {
762
+ "type": "string",
763
+ "description": "Current status and context of the task, important documents that may be used and referenced"
764
+ },
765
+ "acceptance_checking_criteria": {
766
+ "type": "string",
767
+ "description": "Criteria for determining task completion and quality, and the requirements in the event of task completion failure"
768
+ },
769
+ },
770
+ "required": ["task_content"]
771
+ }
772
+ }
773
+ },
774
+ "required": ["tasks"]
775
+ }
776
+ }
777
+ },
778
+ {
779
+ "type": "function",
780
+ "function": {
781
+ "name": "assign_subjective_task_to_writer",
782
+ "description": "Assign a writing or content creation task to the WriterAgent",
783
+ "parameters": {
784
+ "type": "object",
785
+ "properties": {
786
+ "user_query": {
787
+ "type": "string",
788
+ "description": "Pass in the original user question."
789
+ },
790
+ "task_content": {
791
+ "type": "string",
792
+ "description": "Integrate and synthesize provided materials to generate comprehensive long-form content exceeding 10,000 words, especially careful not to give specific details, such as an outline plan, you are only providing the writer with a general description of the task."
793
+ },
794
+ "key_files": {
795
+ "type": "array",
796
+ "items": {
797
+ "type": "object",
798
+ "properties": {
799
+ "file_path": {
800
+ "type": "string",
801
+ "description": "Relative path to the file containing research content"
802
+ }
803
+ },
804
+ "required": ["file_path"]
805
+ },
806
+ "description": "Collect all key_files returned by the information seeker for long-form content creation."
807
+ }
808
+ },
809
+ "required": ["user_query", "task_content", "key_files"]
810
+ }
811
+ }
812
+ },
813
+ {
814
+ "type": "function",
815
+ "function": {
816
+ "name": "writer_subjective_task_done",
817
+ "description": "Writer Agent task completion reporting for complete long-form content. Called after all chapters/sections are written to provide a summary of the complete long article, final completion status and analysis, and the storage path of the final consolidated article.",
818
+ "parameters": {
819
+ "type": "object",
820
+ "properties": {
821
+ "final_article_path": {
822
+ "type": "string",
823
+ "description": "The file path where the final article is saved."
824
+ },
825
+ "article_summary": {
826
+ "type": "string",
827
+ "description": "Comprehensive summary of the complete long-form article, including main themes, key points covered, and overall narrative structure.",
828
+ "format": "markdown"
829
+ },
830
+ "completion_status": {
831
+ "type": "string",
832
+ "enum": ["completed", "partial", "failed"],
833
+ "description": "Final status of the complete long-form writing task"
834
+ },
835
+ "completion_analysis": {
836
+ "type": "string",
837
+ "description": "Analysis of the overall writing project completion including: assessment of article coherence and quality, evaluation of content organization and flow, identification of any challenges in the writing process, and overall evaluation of the long-form content creation success."
838
+ }
839
+ },
840
+ "required": ["final_article_path", "article_summary", "completion_status", "completion_analysis"]
841
+ }
842
+ }
843
+ },
844
+ {
845
+ "type": "function",
846
+ "function": {
847
+ "name": "planner_subjective_task_done",
848
+ "description": "When the writer agent is executed, the task done tool is called to end the planner's task.",
849
+ "parameters": {
850
+ "type": "object",
851
+ "properties": {
852
+ "final_article_path": {
853
+ "type": "string",
854
+ "description": "The file path where the final article is saved."
855
+ },
856
+ "task_summary": {
857
+ "type": "string",
858
+ "description": "This field is mainly used to describe the main content of the article, briefly summarize it, and finally indicate the path where the final article is saved.",
859
+ "format": "markdown"
860
+ },
861
+ "task_name": {
862
+ "type": "string",
863
+ "description": "The name of the task currently assigned to the agent, usually with underscores (e.g., 'web_research_ai_trends')"
864
+ },
865
+ "completion_status": {
866
+ "type": "string",
867
+ "enum": ["completed", "partial", "failed"],
868
+ "description": "Final task status"
869
+ }
870
+ },
871
+ "required": ["final_article_path", "task_summary", "task_name", "completion_status"]
872
+ }
873
+ }
874
+ },
875
+ {
876
+ "type": "function",
877
+ "function": {
878
+ "name": "planner_objective_task_done",
879
+ "description": "Structured reporting of task completion details including summary, decisions, and final answer",
880
+ "parameters": {
881
+ "type": "object",
882
+ "properties": {
883
+ "task_summary": {
884
+ "type": "string",
885
+ "description": "Comprehensive markdown covering what the agent was asked to do, steps taken, tools used, key findings, files created, challenges",
886
+ "format": "markdown"
887
+ },
888
+ "task_name": {
889
+ "type": "string",
890
+ "description": "The name of the task currently assigned to the agent, usually with underscores (e.g., 'web_research_ai_trends')"
891
+ },
892
+ "key_files": {
893
+ "type": "array",
894
+ "items": {
895
+ "type": "object",
896
+ "properties": {
897
+ "file_path": {
898
+ "type": "string",
899
+ "description": "Relative path to created/modified file"
900
+ },
901
+ "desc": {
902
+ "type": "string",
903
+ "description": "File contents and creation purpose"
904
+ },
905
+ "is_final_output_file": {
906
+ "type": "boolean",
907
+ "description": "Whether file is primary deliverable"
908
+ }
909
+ },
910
+ "required": ["file_path", "desc", "is_final_output_file"]
911
+ },
912
+ "description": "List of key files generated or modified during the task, with their details."
913
+ },
914
+ "completion_status": {
915
+ "type": "string",
916
+ "enum": ["completed", "partial", "failed"],
917
+ "description": "Final task status"
918
+ },
919
+ "final_answer": {
920
+ "type": "string",
921
+ "description": "The final response displayed to the user",
922
+ }
923
+ },
924
+ "required": ["task_summary", "task_name", "key_files", "completion_status", "final_answer"]
925
+ }
926
+ }
927
+ },
928
+ ]
929
+
930
+ used_builtin_schemas = [schema for schema in builtin_assignment_schemas if schema["function"]["name"] in planner_mode_builtin_tools_map[self.config.planner_mode]]
931
+ schemas.extend(used_builtin_schemas)
932
+
933
+ return schemas
934
+
935
+ def _execute_react_loop(self, initial_message: str, max_iterations: int = 20) -> Dict[str, Any]:
936
+ """
937
+ Execute the ReAct loop for planning tasks
938
+
939
+ Args:
940
+ initial_message: Initial message to start the planning process
941
+ max_iterations: Maximum number of iterations to perform
942
+
943
+ Returns:
944
+ Dictionary with execution results and trace
945
+ """
946
+ start_time = time.time()
947
+ try:
948
+ # Reset trace for new task
949
+ self.reset_trace()
950
+ # Initialize conversation history
951
+ conversation_history = []
952
+
953
+ # Build system prompt for planning
954
+ system_prompt = self._build_system_prompt()
955
+ # Add to conversation
956
+ conversation_history.append({"role": "system", "content": system_prompt})
957
+ conversation_history.append({"role": "user", "content": initial_message + " /no_think"})
958
+
959
+ iteration = 0
960
+ task_completed = False
961
+
962
+ # Get model endpoint configuration from env-backed config
963
+ from config.config import get_config
964
+ config = get_config()
965
+ model_config = config.get_custom_llm_config()
966
+
967
+ pangu_url = model_config.get('url') or os.getenv('MODEL_REQUEST_URL', '')
968
+ model_token = model_config.get('token') or os.getenv('MODEL_REQUEST_TOKEN', '')
969
+ headers = {'Content-Type': 'application/json', 'csb-token': model_token}
970
+ # ReAct Loop: Reasoning -> Acting -> Reasoning -> Acting...
971
+ while iteration < self.config.max_iterations and not task_completed:
972
+ iteration += 1
973
+ self.logger.info(f"Planning iteration {iteration}")
974
+
975
+ try:
976
+ # Get LLM response (reasoning + potential tool calls)
977
+ retry_num = 1
978
+ max_retry_num = 10
979
+ while retry_num < max_retry_num:
980
+ try:
981
+ response = requests.post(
982
+ url=pangu_url,
983
+ headers=headers,
984
+ json={
985
+ "model": self.config.model,
986
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<s>[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{'<s>[unused9]系统:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}",
987
+ "spaces_between_special_tokens": False,
988
+ "messages": conversation_history,
989
+ "temperature": self.config.temperature,
990
+ "max_tokens": self.config.max_tokens,
991
+ },
992
+ timeout=model_config.get("timeout", 180)
993
+ )
994
+ response = response.json()
995
+ self.logger.debug(f"API response received")
996
+ break
997
+ except Exception as e:
998
+ time.sleep(3)
999
+ retry_num += 1
1000
+ if retry_num == max_retry_num:
1001
+ raise ValueError(str(e))
1002
+ continue
1003
+ assistant_message = response["choices"][0]["message"]
1004
+
1005
+ # Log the reasoning
1006
+ try:
1007
+ if assistant_message["content"]:
1008
+ reasoning_content = assistant_message["content"].split("[unused16]")[-1].split("[unused17]")[0]
1009
+ if len(reasoning_content) > 0:
1010
+ self.log_reasoning(iteration, reasoning_content)
1011
+ except Exception as e:
1012
+ self.logger.warning(f"Tool call parsing error: {e}")
1013
+ # Parse error, rerun
1014
+ followup_prompt = f"There is a problem with the format of model generation: {e}. Please try again."
1015
+ conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
1016
+ continue
1017
+
1018
+ def extract_tool_calls(content):
1019
+ import re
1020
+ if not content:
1021
+ return []
1022
+ tool_call_str = re.findall(r"\[unused11\]([\s\S]*?)\[unused12\]", content)
1023
+ if len(tool_call_str) > 0:
1024
+ try:
1025
+ tool_calls = json.loads(tool_call_str[0].strip())
1026
+ except:
1027
+ return []
1028
+ else:
1029
+ return []
1030
+ return tool_calls
1031
+
1032
+ # Add assistant message to conversation
1033
+ conversation_history.append({
1034
+ "role": "assistant",
1035
+ "content": assistant_message["content"]
1036
+ })
1037
+
1038
+ tool_calls = extract_tool_calls(assistant_message["content"])
1039
+
1040
+ # Execute tool calls if any (Acting phase)
1041
+
1042
+ for tool_call in tool_calls:
1043
+ arguments = tool_call["arguments"]
1044
+ self.logger.debug(f"Arguments is string: {isinstance(arguments, str)}")
1045
+
1046
+ # Check if planning is complete
1047
+ if tool_call["name"] in ["planner_subjective_task_done", "planner_objective_task_done", "writer_subjective_task_done"]:
1048
+ task_completed = True
1049
+ self.log_action(iteration, tool_call["name"], arguments, arguments)
1050
+ break
1051
+ if tool_call["name"] in ["think", "reflect"]:
1052
+ tool_result = {"tool_results": "You can proceed to invoke other tools if needed. "}
1053
+ else:
1054
+ tool_result = self.execute_tool_call(tool_call)
1055
+
1056
+ # Log the action using base class method
1057
+ self.log_action(iteration, tool_call["name"], arguments, tool_result)
1058
+
1059
+ # Add tool result to conversation
1060
+ conversation_history.append({
1061
+ "role": "tool",
1062
+ "content": json.dumps(tool_result, ensure_ascii=False, indent=2) + " /no_think"
1063
+ })
1064
+
1065
+ # If no tool calls, encourage continued planning
1066
+ if len(tool_calls) == 0:
1067
+ # Add follow-up prompt to encourage action or completion
1068
+ followup_prompt = (
1069
+ "Continue your planning process. Use available tools to assign tasks to agents, "
1070
+ "search for information, or coordinate work. When you have a complete answer, "
1071
+ "call planner_subjective_task_done or planner_objective_task_done. /no_think"
1072
+ )
1073
+ conversation_history.append({"role": "user", "content": followup_prompt})
1074
+
1075
+ except Exception as e:
1076
+ error_msg = f"Error in planning iteration {iteration}: {e}"
1077
+ self.log_error(iteration, error_msg)
1078
+ break
1079
+
1080
+ execution_time = time.time() - start_time
1081
+
1082
+ # Extract final result
1083
+ if task_completed:
1084
+ # Find the completion result in the trace
1085
+ completion_result = None
1086
+ for step in reversed(self.reasoning_trace):
1087
+ if step.get("type") == "action" and step.get("tool") in ["planner_subjective_task_done",
1088
+ "planner_objective_task_done"]:
1089
+ completion_result = step.get("result")
1090
+ break
1091
+
1092
+ return {
1093
+ "success": True,
1094
+ "data": completion_result,
1095
+ "reasoning_trace": self.reasoning_trace,
1096
+ "iterations": iteration,
1097
+ "execution_time": execution_time
1098
+ }
1099
+ else:
1100
+ return {
1101
+ "success": False,
1102
+ "error": f"Planning task not completed within {max_iterations} iterations",
1103
+ "reasoning_trace": self.reasoning_trace,
1104
+ "iterations": iteration,
1105
+ "execution_time": execution_time
1106
+ }
1107
+ except Exception as e:
1108
+ execution_time = time.time() - start_time if 'start_time' in locals() else 0
1109
+ self.logger.error(f"Error in execute_react_loop: {e}")
1110
+ return {
1111
+ "success": False,
1112
+ "error": str(e),
1113
+ "reasoning_trace": self.reasoning_trace,
1114
+ "iterations": iteration if 'iteration' in locals() else 0,
1115
+ "execution_time": execution_time
1116
+ }
1117
+
1118
+
1119
+ def execute_task(self, user_query: str) -> AgentResponse:
1120
+ """
1121
+ Execute a planning task for the given user query
1122
+
1123
+ Args:
1124
+ user_query: The user's query or request
1125
+
1126
+ Returns:
1127
+ AgentResponse with planning results and process trace
1128
+ """
1129
+ start_time = time.time()
1130
+
1131
+ try:
1132
+ self.logger.info(f"Starting planner task: {user_query}")
1133
+
1134
+ # Execute the planning task using ReAct pattern
1135
+ result = self._execute_react_loop(
1136
+ initial_message=user_query,
1137
+ max_iterations=self.config.max_iterations # Reasonable limit for planning tasks
1138
+ )
1139
+
1140
+ execution_time = time.time() - start_time
1141
+
1142
+ return AgentResponse(
1143
+ success=result.get("success", False),
1144
+ result=result.get("data"),
1145
+ error=result.get("error"),
1146
+ reasoning_trace=result.get("reasoning_trace", []),
1147
+ iterations=result.get("iterations", 0),
1148
+ execution_time=execution_time,
1149
+ agent_name=self.config.agent_name
1150
+ )
1151
+
1152
+ except Exception as e:
1153
+ execution_time = time.time() - start_time
1154
+ self.logger.error(f"Planner execution failed: {e}")
1155
+
1156
+ return AgentResponse(
1157
+ success=False,
1158
+ error=f"Planner execution failed: {str(e)}",
1159
+ reasoning_trace=[],
1160
+ iterations=0,
1161
+ execution_time=execution_time,
1162
+ agent_name=self.config.agent_name
1163
+ )
1164
+
1165
+
1166
+ def create_planner_agent(
1167
+ model: Any = None,
1168
+ sub_agent_configs: Dict[str, Dict[str, Any]] = None,
1169
+ shared_mcp_client=None,
1170
+ **kwargs
1171
+ ) -> PlannerAgent:
1172
+ """
1173
+ Create a PlannerAgent instance with server-managed sessions.
1174
+
1175
+ Args:
1176
+ model: The LLM model to use
1177
+ sub_agent_configs: Configuration for sub-agents (information_seeker, writer)
1178
+ shared_mcp_client: Optional shared MCP client to prevent duplicate connections
1179
+ **kwargs: Additional configuration options
1180
+
1181
+ Returns:
1182
+ Configured PlannerAgent instance
1183
+ """
1184
+ # Import the enhanced config function
1185
+ from .base_agent import create_agent_config
1186
+
1187
+ # Create agent configuration (session managed by MCP server)
1188
+ config = create_agent_config(
1189
+ agent_name="PlannerAgent",
1190
+ model=model,
1191
+ **kwargs
1192
+ )
1193
+
1194
+ # Create planner agent with optional shared MCP client
1195
+ planner = PlannerAgent(config=config, shared_mcp_client=shared_mcp_client)
1196
+
1197
+ # Store sub-agent configurations for use when creating sub-agents
1198
+ planner.sub_agent_configs = sub_agent_configs or {
1199
+ "information_seeker": {},
1200
+ "writer": {}
1201
+ }
1202
+
1203
+ return planner
deepdiver_v2/src/agents/subjective_information_seeker.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ import json
3
+ from typing import Dict, Any, List
4
+ import time
5
+ import requests
6
+ import os
7
+ from .base_agent import BaseAgent, AgentConfig, AgentResponse, TaskInput
8
+
9
+
10
+
11
+ class InformationSeekerAgent(BaseAgent):
12
+ """
13
+ Information Seeker Agent that follows ReAct pattern (Reasoning + Acting)
14
+
15
+ This agent takes decomposed sub-questions or tasks from parent agents,
16
+ thinks interleaved (reasoning -> action -> reasoning -> action),
17
+ uses MCP tools to gather information, and returns structured results.
18
+ """
19
+
20
+ def __init__(self, config: AgentConfig = None, shared_mcp_client=None):
21
+ # Set default agent name if not specified
22
+ if config is None:
23
+ config = AgentConfig(agent_name="InformationSeekerAgent")
24
+ elif config.agent_name == "base_agent":
25
+ config.agent_name = "InformationSeekerAgent"
26
+
27
+ super().__init__(config, shared_mcp_client)
28
+
29
+ def _build_system_prompt(self) -> str:
30
+ """Build the system prompt for the ReAct agent"""
31
+ tool_schemas_str = json.dumps(self.tool_schemas, ensure_ascii=False)
32
+ system_prompt_template = """You are an Information Seeker Agent that follows the ReAct pattern (Reasoning + Acting).
33
+
34
+ Your role is to:
35
+ 1. Take decomposed sub-questions or tasks from parent agents
36
+ 2. Think step-by-step through reasoning
37
+ 3. Use available tools to gather information when needed
38
+ 4. Continue reasoning based on tool results
39
+ 5. Repeat this process until you have sufficient information
40
+ 6. Call info_seeker_subjective_task_done to provide a structured summary and key files
41
+
42
+ TOOL USAGE STRATEGY:
43
+ Follow this optimized workflow for information gathering:
44
+
45
+ 1. INITIAL RESEARCH:
46
+ - Generate focused search queries (≤10): Limit to no more than 10 initial search queries to avoid increased failure rates from excessive decomposition.
47
+ - Use `batch_web_search` to find relevant URLs for your queries. When calling the search statement, consider the language of the user's question. For example, for a Chinese question, generate a part of the search statement in Chinese.
48
+ - Analyze the search results (titles, snippets, URLs) to identify promising sources
49
+
50
+ 2. CONTENT EXTRACTION:
51
+ - For important URLs, use `url_crawler` to:
52
+ a) Extract full content from the webpage
53
+ b) Save the content to a file in the workspace **under the relative path `./url_crawler_save_files/`**
54
+ - Store results with meaningful file paths (e.g., `url_crawler_save_files/research/ai_trends_2024.txt`)
55
+
56
+ 3. CONTENT ANALYSIS:
57
+ - Use `document_extract` for multi-dimensional analysis of saved files:
58
+ a) Provides structured analysis across five key dimensions: doc time source authority, core content and task relevance
59
+
60
+ 4. FILE MANAGEMENT:
61
+ - For reviewing saved content:
62
+ a) Prefer `document_extract` to get comprehensive multi-dimensional analysis of saved files
63
+ b) Use `file_read` ONLY for small files (<1000 tokens) when you need the entire content
64
+ c) Avoid reading large files directly as it may exceed context limits
65
+
66
+ ### Usage of Systematic Tool:
67
+ - `think` is a systematic tool. After receiving the response from the complex tool or before invoking any other tools, you must **first invoke the `think` tool**: to deeply reflect on the results of previous tool invocations (if any), and to thoroughly consider and plan the user's task. The `think` tool does not acquire new information; it only saves your thoughts into memory.
68
+
69
+ Always provide clear reasoning for your actions and synthesize information effectively.
70
+
71
+ Below, within the <tools></tools> tags, are the descriptions of each tool and the required fields for invocation:
72
+ <tools>
73
+ $tool_schemas
74
+ </tools>
75
+ For each function call, return a JSON object placed within the [unused11][unused12] tags, which includes the function name and the corresponding function arguments:
76
+ [unused11][{\"name\": <function name>, \"arguments\": <args json object>}][unused12]
77
+ """
78
+ return system_prompt_template.replace("$tool_schemas", tool_schemas_str)
79
+
80
+ @staticmethod
81
+ def _build_initial_message_from_task_input(task_input: TaskInput) -> str:
82
+ """Build the initial user message from TaskInput"""
83
+ message = task_input.format_for_prompt()
84
+
85
+ message += "\nPlease analyze this task and start your ReAct process:\n"
86
+ message += "1. Reason about what information you need to gather\n"
87
+ message += "2. Use appropriate tools to get that information\n"
88
+ message += "3. Continue reasoning and acting until you have sufficient information\n"
89
+ message += "4. Call info_seeker_subjective_task_done when ready to provide your complete findings\n\n"
90
+ message += "Begin with your initial reasoning about the task."
91
+
92
+ return message
93
+
94
+ def execute_task(self, task_input: TaskInput) -> AgentResponse:
95
+ """
96
+ Execute a task using ReAct pattern (Reasoning + Acting)
97
+
98
+ Args:
99
+ task_input: TaskInput object with standardized task information
100
+
101
+ Returns:
102
+ AgentResponse with results and process trace
103
+ """
104
+ start_time = time.time()
105
+
106
+ try:
107
+ self.logger.info(f"Starting information seeker task: {task_input.task_content}")
108
+
109
+ # Reset trace for new task
110
+ self.reset_trace()
111
+
112
+ # Initialize conversation history
113
+ conversation_history = []
114
+
115
+ # Build initial system prompt for ReAct
116
+ system_prompt = self._build_system_prompt()
117
+
118
+ # Build initial user message from TaskInput
119
+ user_message = self._build_initial_message_from_task_input(task_input)
120
+
121
+ # Add to conversation
122
+ conversation_history.append({"role": "system", "content": system_prompt})
123
+ conversation_history.append({"role": "user", "content": user_message + " /no_think"})
124
+
125
+ iteration = 0
126
+ task_completed = False
127
+ # Get model configuration from config
128
+ from config.config import get_config
129
+ config = get_config()
130
+ model_config = config.get_custom_llm_config()
131
+
132
+ pangu_url = model_config.get('url') or os.getenv('MODEL_REQUEST_URL', '')
133
+ model_token = model_config.get('token') or os.getenv('MODEL_REQUEST_TOKEN', '')
134
+ headers = {'Content-Type': 'application/json', 'csb-token': model_token}
135
+
136
+ # ReAct Loop: Reasoning -> Acting -> Reasoning -> Acting...
137
+ self.config.max_iterations = 30
138
+ while iteration < self.config.max_iterations and not task_completed:
139
+ iteration += 1
140
+ self.logger.info(f"Planning iteration {iteration}")
141
+
142
+ try:
143
+ # Get LLM response (reasoning + potential tool calls)
144
+ retry_num = 1
145
+ max_retry_num = 10
146
+ while retry_num < max_retry_num:
147
+ try:
148
+ response = requests.post(
149
+ url=pangu_url,
150
+ headers=headers,
151
+ json={
152
+ "model": model_config.get('model', 'pangu_auto'),
153
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<s>[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{'<s>[unused9]系统:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}",
154
+ "messages": conversation_history,
155
+ "spaces_between_special_tokens": False,
156
+ "temperature": self.config.temperature,
157
+ },
158
+ timeout=model_config.get("timeout", 180)
159
+ )
160
+ response = response.json()
161
+
162
+ self.logger.debug(f"API response received")
163
+ break
164
+ except Exception as e:
165
+ time.sleep(3)
166
+ retry_num += 1
167
+ if retry_num == max_retry_num:
168
+ raise ValueError(str(e))
169
+ continue
170
+
171
+ assistant_message = response["choices"][0]["message"]
172
+ # Log the reasoning
173
+ try:
174
+ if assistant_message["content"]:
175
+ reasoning_content = assistant_message["content"].split("[unused16]")[-1].split("[unused17]")[0]
176
+ if len(reasoning_content) > 0:
177
+ self.log_reasoning(iteration, reasoning_content)
178
+ except Exception as e:
179
+ self.logger.warning(f"Tool call parsing error: {e}")
180
+ # Parse error, rerun
181
+ followup_prompt = f"There is a problem with the format of model generation: {e}. Please try again."
182
+ conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
183
+ continue
184
+
185
+
186
+ def extract_tool_calls(content):
187
+ import re
188
+ if not content:
189
+ return []
190
+ tool_call_str = re.findall(r"\[unused11\]([\s\S]*?)\[unused12\]", content)
191
+ if len(tool_call_str) > 0:
192
+ try:
193
+ tool_calls = json.loads(tool_call_str[0].strip())
194
+ except Exception as ee:
195
+ return ["fail_tools_load", ee]
196
+ else:
197
+ return []
198
+ return tool_calls
199
+
200
+ # Add assistant message to conversation
201
+ conversation_history.append({
202
+ "role": "assistant",
203
+ "content": assistant_message["content"]
204
+ })
205
+
206
+ tool_calls = extract_tool_calls(assistant_message["content"])
207
+
208
+ if tool_calls[0] == "fail_tools_load":
209
+ # Parse error, rerun
210
+ followup_prompt = f"There was a parsing error in the format of the tool call" \
211
+ f" you generated:{tool_calls[1]} Please regenerate it."
212
+ conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
213
+ continue
214
+
215
+
216
+ # Execute tool calls if any (Acting phase)
217
+
218
+ for tool_call in tool_calls:
219
+ arguments = tool_call["arguments"]
220
+
221
+ # Check if planning is complete
222
+ if tool_call["name"] in ["info_seeker_subjective_task_done"]:
223
+ task_completed = True
224
+ self.log_action(iteration, tool_call["name"], arguments, arguments)
225
+ break
226
+ if tool_call["name"] in ["think", "reflect"]:
227
+ tool_result = {"tool_results": "You can proceed to invoke other tools if needed."}
228
+ else:
229
+ tool_result = self.execute_tool_call(tool_call)
230
+
231
+ # Log the action using base class method
232
+ self.log_action(iteration, tool_call["name"], arguments, tool_result)
233
+
234
+ # Add tool result to conversation
235
+ conversation_history.append({
236
+ "role": "tool",
237
+ "content": json.dumps(tool_result, ensure_ascii=False, indent=2) + " /no_think"
238
+ })
239
+
240
+ # If no tool calls, encourage continued planning
241
+ if len(tool_calls) == 0:
242
+ # Add follow-up prompt to encourage action or completion
243
+ followup_prompt = (
244
+ "Continue your analysis. If you need more information, use available tools. "
245
+ "If you have enough information to answer the question, call info_seeker_subjective_task_done with your complete context."
246
+ )
247
+ conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
248
+ if iteration == self.config.max_iterations-3:
249
+ followup_prompt = "Due to length and number of rounds restrictions, you must now call the `info_seeker_subjective_task_done` tool to report the completion of your task."
250
+ conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
251
+
252
+
253
+ except Exception as e:
254
+ error_msg = f"Error in planning iteration {iteration}: {e}"
255
+ self.log_error(iteration, error_msg)
256
+ break
257
+
258
+ execution_time = time.time() - start_time
259
+ # Extract final result
260
+ if task_completed:
261
+ # Find the task_done result in the trace
262
+ task_done_result = None
263
+ for step in reversed(self.reasoning_trace):
264
+ if step.get("type") == "action" and step.get("tool") == "info_seeker_subjective_task_done":
265
+ task_done_result = step.get("result")
266
+ break
267
+
268
+ return self.create_response(
269
+ success=True,
270
+ result=task_done_result,
271
+ iterations=iteration,
272
+ execution_time=execution_time
273
+ )
274
+ else:
275
+ return self.create_response(
276
+ success=False,
277
+ error=f"Task not completed within {self.config.max_iterations} iterations",
278
+ iterations=iteration,
279
+ execution_time=execution_time
280
+ )
281
+
282
+ except Exception as e:
283
+ execution_time = time.time() - start_time
284
+ self.logger.error(f"Error in execute_task: {e}")
285
+ return self.create_response(
286
+ success=False,
287
+ error=str(e),
288
+ iterations=iteration if 'iteration' in locals() else 0,
289
+ execution_time=execution_time
290
+ )
291
+
292
+ def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
293
+ """
294
+ Build tool schemas for InformationSeekerAgent using proper MCP architecture.
295
+ Schemas come from MCP server via client, not direct imports.
296
+ """
297
+ # Get MCP tool schemas from server via client (proper MCP architecture)
298
+ schemas = super()._build_agent_specific_tool_schemas()
299
+
300
+ # Add schemas for built-in task assignment tools
301
+ builtin_assignment_schemas = [
302
+ {
303
+ "type": "function",
304
+ "function": {
305
+ "name": "think",
306
+ "description": "Use the tool to think about something. It will not obtain new information or make any changes to the repository, but just log the thought. Use it when complex reasoning or brainstorming is needed.",
307
+ "parameters": {
308
+ "type": "object",
309
+ "properties": {
310
+ "thought": {
311
+ "type": "string",
312
+ "description": "Your thoughts."
313
+ }
314
+ },
315
+ "required": ["thought"]
316
+ }
317
+ }
318
+ },
319
+ {
320
+ "type": "function",
321
+ "function": {
322
+ "name": "reflect",
323
+ "description": "When multiple attempts yield no progress, use this tool to reflect on previous reasoning and planning, considering possible overlooked clues and exploring more possibilities. It will not obtain new information or make any changes to the repository.",
324
+ "parameters": {
325
+ "type": "object",
326
+ "properties": {
327
+ "reflect": {
328
+ "type": "string",
329
+ "description": "The specific content of your reflection"
330
+ }
331
+ },
332
+ "required": ["reflect"]
333
+ }
334
+ }
335
+ },
336
+ {
337
+ "type": "function",
338
+ "function": {
339
+ "name": "info_seeker_subjective_task_done",
340
+ "description": "Information Seeker Agent task completion reporting with information collection summary and related files.",
341
+ "parameters": {
342
+ "type": "object",
343
+ "properties": {
344
+ "task_summary": {
345
+ "type": "string",
346
+ "description": "Simple summary of what information has been collected for the current task and what new discoveries have been made.",
347
+ "format": "markdown"
348
+ },
349
+ "key_files": {
350
+ "type": "array",
351
+ "items": {
352
+ "type": "object",
353
+ "properties": {
354
+ "file_path": {
355
+ "type": "string",
356
+ "description": "Relative path to the file with collected content"
357
+ },
358
+ },
359
+ "required": ["file_path"]
360
+ },
361
+ "description": "Collect files highly relevant to this task. "
362
+ },
363
+ "completion_status": {
364
+ "type": "string",
365
+ "enum": ["completed", "partial", "failed"],
366
+ "description": "Final status of the information gathering task"
367
+ },
368
+ "completion_analysis": {
369
+ "type": "string",
370
+ "description": "Brief analysis of task completion quality, information thoroughness, and any limitations or gaps."
371
+ }
372
+ },
373
+ "required": ["task_summary", "key_files", "completion_status", "completion_analysis"]
374
+ }
375
+ }
376
+ },
377
+ ]
378
+
379
+ schemas.extend(builtin_assignment_schemas)
380
+
381
+ return schemas
382
+
383
+
384
+ # Factory function for creating the agent
385
+ def create_subjective_information_seeker(
386
+ model: str = "pangu_auto",
387
+ max_iterations: int = 10,
388
+ shared_mcp_client=None,
389
+ **kwargs
390
+ ) -> InformationSeekerAgent:
391
+ """
392
+ Create an InformationSeekerAgent instance with server-managed sessions.
393
+
394
+ Args:
395
+ model: The LLM model to use
396
+ max_iterations: Maximum number of iterations
397
+ shared_mcp_client: Optional shared MCP client from parent agent (prevents extra sessions)
398
+ **kwargs: Additional configuration options
399
+
400
+ Returns:
401
+ Configured InformationSeekerAgent instance with appropriate tools
402
+ """
403
+ # Import the enhanced config function
404
+ from .base_agent import create_agent_config
405
+
406
+ # Create agent configuration (session managed by MCP server)
407
+ config = create_agent_config(
408
+ agent_name="InformationSeekerAgent",
409
+ model=model,
410
+ max_iterations=max_iterations,
411
+ **kwargs
412
+ )
413
+
414
+ # Create agent instance with shared MCP client (filtered tools for information seeking)
415
+ agent = InformationSeekerAgent(config=config, shared_mcp_client=shared_mcp_client)
416
+
417
+ return agent
deepdiver_v2/src/agents/writer_agent.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ import json
3
+ from typing import Dict, Any, List
4
+ import time
5
+ import requests
6
+ import os
7
+ from .base_agent import BaseAgent, AgentConfig, AgentResponse, WriterAgentTaskInput
8
+
9
+
10
+
11
+ class WriterAgent(BaseAgent):
12
+ """
13
+ Writer Agent that follows ReAct pattern for content synthesis and generation
14
+
15
+ This agent takes writing tasks from parent agents, searches through existing
16
+ files and knowledge base, and creates long-form content through iterative
17
+ reasoning and refinement. It does NOT access internet resources, only
18
+ local files and memories.
19
+ """
20
+
21
+ def __init__(self, config: AgentConfig = None, shared_mcp_client=None):
22
+ # Set default agent name if not specified
23
+ if config is None:
24
+ config = AgentConfig(agent_name="WriterAgent")
25
+ elif config.agent_name == "base_agent":
26
+ config.agent_name = "WriterAgent"
27
+
28
+ super().__init__(config, shared_mcp_client)
29
+
30
+ # Rebuild tool schemas with writer-specific tools only
31
+ self.tool_schemas = self._build_tool_schemas()
32
+
33
+ def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
34
+ """
35
+ Build tool schemas for WriterAgent using proper MCP architecture.
36
+ Schemas come from MCP server via client, not direct imports.
37
+ """
38
+ # Get MCP tool schemas from server via client (proper MCP architecture)
39
+ schemas = super()._build_agent_specific_tool_schemas()
40
+
41
+ # Add schemas for built-in task assignment tools
42
+ builtin_assignment_schemas = [
43
+ {
44
+ "type": "function",
45
+ "function": {
46
+ "name": "think",
47
+ "description": "Use the tool to think about something. It will not obtain new information or make any changes to the repository, but just log the thought. Use it when complex reasoning or brainstorming is needed.",
48
+ "parameters": {
49
+ "type": "object",
50
+ "properties": {
51
+ "thought": {
52
+ "type": "string",
53
+ "description": "Your thoughts."
54
+ }
55
+ },
56
+ "required": ["thought"]
57
+ }
58
+ }
59
+ },
60
+ {
61
+ "type": "function",
62
+ "function": {
63
+ "name": "reflect",
64
+ "description": "When multiple attempts yield no progress, use this tool to reflect on previous reasoning and planning, considering possible overlooked clues and exploring more possibilities. It will not obtain new information or make any changes to the repository.",
65
+ "parameters": {
66
+ "type": "object",
67
+ "properties": {
68
+ "reflect": {
69
+ "type": "string",
70
+ "description": "The specific content of your reflection"
71
+ }
72
+ },
73
+ "required": ["reflect"]
74
+ }
75
+ }
76
+ },
77
+ {
78
+ "type": "function",
79
+ "function": {
80
+ "name": "writer_subjective_task_done",
81
+ "description": "Writer Agent task completion reporting for complete long-form content. Called after all chapters/sections are written to provide a summary of the complete long article, final completion status and analysis, and the storage path of the final consolidated article.",
82
+ "parameters": {
83
+ "type": "object",
84
+ "properties": {
85
+ "final_article_path": {
86
+ "type": "string",
87
+ "description": "The file path where the final article is saved."
88
+ },
89
+ "article_summary": {
90
+ "type": "string",
91
+ "description": "Comprehensive summary of the complete long-form article, including main themes, key points covered, and overall narrative structure.",
92
+ "format": "markdown"
93
+ },
94
+ "completion_status": {
95
+ "type": "string",
96
+ "enum": ["completed", "partial", "failed"],
97
+ "description": "Final status of the complete long-form writing task"
98
+ },
99
+ "completion_analysis": {
100
+ "type": "string",
101
+ "description": "Analysis of the overall writing project completion including: assessment of article coherence and quality, evaluation of content organization and flow, identification of any challenges in the writing process, and overall evaluation of the long-form content creation success."
102
+ }
103
+ },
104
+ "required": ["final_article_path", "article_summary", "completion_status",
105
+ "completion_analysis"]
106
+ }
107
+ }
108
+ },
109
+ ]
110
+
111
+ schemas.extend(builtin_assignment_schemas)
112
+
113
+ return schemas
114
+
115
+ def _build_system_prompt(self) -> str:
116
+ """Build the system prompt for the writer agent"""
117
+ tool_schemas_str = json.dumps(self.tool_schemas, ensure_ascii=False)
118
+ system_prompt_template = """You are a professional writing master. You will receive key files and user problems. Your task is to generate an outline highly consistent with the user problem, classify files into sections, and iteratively call section_writer tool to create comprehensive content. Then you strictly follow the steps given below:
119
+
120
+ MANDATORY WORKFLOW:
121
+
122
+ 1. OUTLINE GENERATION
123
+ Based on the core content of the provided key files collection(file_core_content), generate a high-quality outline suitable for long-form writing. Strictly adhere to the following requirements during generation:
124
+ - Before generating the outline, carefully review the provided **file_core_content**, prioritizing sections with:
125
+ 1.**Higher authority** (credible sources)
126
+ 2.**Greater information richness** (substantive, detailed content)
127
+ 3.**Stronger relevance** (direct alignment with user query)
128
+ 4.**Timeliness** (if user’s query is time-sensitive, prioritize recent/updated content)
129
+ Select these segments as the basis for outline generation. Note that we only focus on relevance to the question, so when generating the outline, do not add unrelated sections just for the sake of length. Additionally, the sections should flow logically and not be too disjointed, as this would harm the readability of the final output.
130
+ - The overall structure must be **logically clear**, with **no repetition or redundancy** between chapters.
131
+ - **Note1:** The generated outline must not only have chapter-level headings (Level 1) highly relevant to the user’s question, but the subheadings (Level 2) must also be highly relevant to the user’s question. It is not permitted to generate chapter titles with weak relevance, whether Level 1 or Level 2.
132
+ - **Note2:** The number of chapters must not exceed 7, dynamic evaluation can be performed based on the collected content. For example, if there is a lot of content, more chapters can be generated, and vice versa. But each chapter should only include Level 1 and Level 2 headings. Also, be careful not to generate too many Level 2 headings, limit them to 4. However, if the first chapter is an abstract or introduction, do not generate subheadings (level-2 headings)—only include the main heading (level-1). Additionally, tailor the outline style based on the type of document. For example, in a research report, the first chapter should preferably be titled \"Abstract\" or \"Introduction.\"
133
+
134
+ 2. FILE CLASSIFICATION
135
+ - Use the search_result_classifier tool to reasonably split the outline generated above and accurately assign key files to each chapter of the outline.
136
+ - Ensure optimal distribution of reference materials across chapters based on content relevance.
137
+
138
+ 3. ITERATIVE SECTION WRITING
139
+ - Call section_writer tool sequentially for each chapter
140
+ - CRITICAL: Must wait for previous chapter completion before starting the next chapter
141
+ - Pass only the specific chapter outline , target file path and corresponding classified files to each section writer
142
+ - Generate save path for each chapter using \"./report/part_X.md\" format (e.g., \"./report/part_1.md\" for first chapter)
143
+ - Check section writer results after completion; retry up to 2 times per chapter if quality is insufficient based on returned fields (do not read saved files)
144
+ - When you call the section_writer tool, pay special attention to the fact that the parameter value of written_chapters_summary is a summary of the content returned by all previously completed chapters. Be careful not to make any changes to the summary content, including compressing the content.
145
+
146
+ 4. TASK COMPLETION
147
+ - After all chapters are written, you must first call the concat_section_files tool to merge the saved chapter files into one file, then call writer_subjective_task_done to finalize and return.
148
+
149
+ CRITICAL REQUIREMENTS:
150
+ - The creation of the outline is crucial! Therefore, you must strictly adhere to the above requirements for generating the outline.
151
+ - No parallel writing - strictly sequential chapter execution
152
+ - Wait for each section writer completion before proceeding to next chapter
153
+ - Classify files appropriately to support each chapter's content needs
154
+ - Note again that to merge all the written chapter files, you must use the concat_section_files tool!!! You are not allowed to call any other tools for merging!!!
155
+
156
+ FORBIDDEN CONTENT PATTERNS:
157
+ - NEVER generate meta-structural chapters that describe how the article is organized
158
+ - AVOID introductory sections that outline \"Chapter 1 will cover..., Chapter 2 will discuss...\"
159
+ - DO NOT create chapters that explain the report structure or methodology
160
+ - Each chapter must contain SUBSTANTIVE CONTENT, not descriptions of what other chapters contain
161
+ - When generating an outline, if it is not a professional term, the language should remain consistent with the user's question.\"
162
+
163
+ Usage of TOOLS:
164
+ - search_result_classifier: Classify key files into outline sections
165
+ - section_writer: Write individual chapters sequentially
166
+ - writer_subjective_task_done: Complete the writing task
167
+ - concat_section_files: Concatenate the content of the saved section files into a single file
168
+ - think tool: \"Think\" is a systematic tool requiring its use during key steps. Before executing actions like generating an outline, you must first call this tool to deeply consider the given content and key requirements, ensuring the output meets specifications. Similarly, during iterative chapter generation, after receiving feedback and before writing the next chapter, call \"think\" to reflect on the current chapter. This provides guidance to avoid content repetition and ensure smooth transitions between chapters.
169
+
170
+ Execute workflow systematically to produce high-quality, coherent long-form content with substantive chapters.
171
+
172
+ Below, within the <tools></tools> tags, are the descriptions of each tool and the required fields for invocation:
173
+ <tools>
174
+ $tool_schemas
175
+ </tools>
176
+ For each function call, return a JSON object placed within the [unused11][unused12] tags, which includes the function name and the corresponding function arguments:
177
+ [unused11][{\"name\": <function name>, \"arguments\": <args json object>}][unused12]
178
+ """
179
+ return system_prompt_template.replace("$tool_schemas", tool_schemas_str)
180
+
181
+ def _build_initial_message_from_task_input(self, task_input: WriterAgentTaskInput) -> str:
182
+ """Build the initial user message from TaskInput"""
183
+ message = ""
184
+
185
+ # Add key files information with reliability dimensions
186
+ def load_json_from_server(file_path):
187
+ """Load JSONL file from MCP server using unlimited internal tool"""
188
+ res = []
189
+ try:
190
+ # Use json read tool directly through raw MCP client
191
+ raw_result = self.mcp_tools.client.call_tool("load_json", {"file_path": file_path})
192
+
193
+ if not raw_result.success:
194
+ self.logger.error(f"Failed to read file from server: {raw_result.error}")
195
+ return res
196
+
197
+ res = json.loads(raw_result.data["content"][0]["text"])["data"]
198
+
199
+ except Exception as e:
200
+ self.logger.error(f"Error loading file {file_path} from MCP server: {e}")
201
+ import traceback
202
+ self.logger.debug(f"Full traceback: {traceback.format_exc()}")
203
+
204
+ return res
205
+
206
+ key_files_dict = {}
207
+
208
+ server_analysis_path = f"doc_analysis/file_analysis.jsonl"
209
+ self.logger.debug(f"Loading analysis from MCP server: {server_analysis_path}")
210
+ file_analysis_list = load_json_from_server(server_analysis_path)
211
+
212
+ for file_info in file_analysis_list:
213
+ if file_info.get('file_path'):
214
+ key_files_dict[file_info.get('file_path')] = file_info
215
+
216
+ file_core_content = ""
217
+ if hasattr(task_input, 'key_files') and task_input.key_files:
218
+ message += "Key Files:\n"
219
+ for i, file_ in enumerate(task_input.key_files, 1):
220
+ file_path = file_.get('file_path')
221
+ if file_path in key_files_dict:
222
+ file_info = key_files_dict[file_path]
223
+ doc_time = file_info.get('doc_time', 'Not specified')
224
+ source_authority = file_info.get('source_authority', 'Not assessed')
225
+ task_relevance = file_info.get('task_relevance', 'Not assessed')
226
+ information_richness = file_info.get('information_richness', 'Not assessed')
227
+ message += f"{i}. File: {file_path}\n"
228
+
229
+ file_core_content += f"[{str(i)}]doc_time:{doc_time}|||source_authority:{source_authority}|||task_relevance:{task_relevance}|||information_richness:{information_richness}|||summary_content:{file_info.get('core_content', '')}\n"
230
+ message += "\n"
231
+ message += f"file_core_content: {file_core_content}\n"
232
+ else:
233
+ message += "Key Files: None provided\n"
234
+
235
+ message += "\n"
236
+ # Add user query
237
+ if hasattr(task_input, 'user_query') and task_input.user_query:
238
+ message += f"User Query: {task_input.user_query}\n"
239
+ else:
240
+ message += "User Query: Not provided\n"
241
+
242
+ return message
243
+
244
+ def execute_task(self, task_input: WriterAgentTaskInput) -> AgentResponse:
245
+ """
246
+ Execute a writing task using ReAct pattern
247
+
248
+ Args:
249
+ task_input: TaskInput object with standardized task information
250
+
251
+ Returns:
252
+ AgentResponse with writing results and process trace
253
+ """
254
+ start_time = time.time()
255
+
256
+ try:
257
+ self.logger.info(f"Starting writing task: {task_input.task_content}")
258
+
259
+ # Reset trace for new task
260
+ self.reset_trace()
261
+
262
+ # Initialize conversation history
263
+ conversation_history = []
264
+
265
+ # Build system prompt for writing
266
+ system_prompt = self._build_system_prompt()
267
+
268
+ # Build initial user message from TaskInput
269
+ user_message = self._build_initial_message_from_task_input(task_input)
270
+
271
+ # Add to conversation
272
+ conversation_history.append({"role": "system", "content": system_prompt})
273
+ conversation_history.append({"role": "user", "content": user_message + " /no_think"})
274
+
275
+ iteration = 0
276
+ task_completed = False
277
+
278
+ self.logger.debug("Checking conversation history before model call")
279
+ self.logger.debug(f"Conversation history: {conversation_history}")
280
+ # ReAct Loop for Writing: Research → Plan → Write → Refine → Complete
281
+ # Get model configuration from config
282
+ from config.config import get_config
283
+ config = get_config()
284
+ model_config = config.get_custom_llm_config()
285
+
286
+ pangu_url = model_config.get('url') or os.getenv('MODEL_REQUEST_URL', '')
287
+ model_token = model_config.get('token') or os.getenv('MODEL_REQUEST_TOKEN', '')
288
+ headers = {'Content-Type': 'application/json', 'csb-token': model_token}
289
+
290
+ while iteration < self.config.max_iterations and not task_completed:
291
+ iteration += 1
292
+ self.logger.info(f"Writing iteration {iteration}")
293
+
294
+ try:
295
+ # Get LLM response (reasoning + potential tool calls) with retry
296
+
297
+ max_retries = 10
298
+ response = None
299
+
300
+ for attempt in range(max_retries):
301
+ try:
302
+
303
+ response = requests.post(
304
+ url=pangu_url,
305
+ headers=headers,
306
+ json={
307
+ "model": self.config.model,
308
+ "chat_template":"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<s>[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{'<s>[unused9]系统:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}",
309
+ "messages": conversation_history,
310
+ "temperature": self.config.temperature,
311
+ "max_tokens": self.config.max_tokens,
312
+ "spaces_between_special_tokens": False,
313
+ },
314
+ timeout=model_config.get("timeout", 180)
315
+ )
316
+ response = response.json()
317
+
318
+ self.logger.debug(f"API response received")
319
+ break # Success, exit retry loop
320
+
321
+ except Exception as e:
322
+ self.logger.warning(f"LLM API call attempt {attempt + 1} failed: {e}")
323
+ if attempt == max_retries - 1:
324
+ raise e # Last attempt, re-raise the exception
325
+ time.sleep(6) # Simple 1 second delay between retries
326
+
327
+ if response is None:
328
+ raise Exception("Failed to get response after all retries")
329
+
330
+ assistant_message = response["choices"][0]["message"]
331
+
332
+ try:
333
+ if assistant_message["content"]:
334
+ reasoning_content = assistant_message["content"].split("[unused16]")[-1].split("[unused17]")[0]
335
+ if len(reasoning_content) > 0:
336
+ self.log_reasoning(iteration, reasoning_content)
337
+ except Exception as e:
338
+ self.logger.warning(f"Tool call parsing error: {e}")
339
+ # Parse error, rerun
340
+ followup_prompt = f"There is a problem with the format of model generation: {e}. Please try again."
341
+ conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
342
+ continue
343
+
344
+ def extract_tool_calls(content):
345
+ import re
346
+ tool_call_str = re.findall(r"\[unused11\]([\s\S]*?)\[unused12\]", content)
347
+ if len(tool_call_str) > 0:
348
+ try:
349
+ tool_calls = json.loads(tool_call_str[0])
350
+ except:
351
+ return []
352
+ else:
353
+ return []
354
+ return tool_calls
355
+
356
+ # Add assistant message to conversation
357
+ conversation_history.append({
358
+ "role": "assistant",
359
+ "content": assistant_message["content"]
360
+ })
361
+
362
+ tool_calls = extract_tool_calls(assistant_message["content"])
363
+
364
+ # Execute tool calls if any (Acting phase)
365
+ for tool_call in tool_calls:
366
+ # Str
367
+ arguments = tool_call["arguments"]
368
+ self.logger.debug(f"Arguments is string: {isinstance(arguments, str)}")
369
+
370
+ # Check if planning is complete
371
+ if tool_call["name"] in ["writer_subjective_task_done"]:
372
+ task_completed = True
373
+ self.log_action(iteration, tool_call["name"], arguments, arguments)
374
+ break
375
+ if tool_call["name"] in ["think"]:
376
+ tool_result = {
377
+ "tool_results": "You can proceed to invoke other tools if needed. But the next step cannot call the reflect tool"}
378
+ else:
379
+ tool_result = self.execute_tool_call(tool_call)
380
+
381
+ # Log the action using base class method
382
+ self.log_action(iteration, tool_call["name"], arguments, tool_result)
383
+
384
+ # Add tool result to conversation
385
+ conversation_history.append({
386
+ "role": "tool",
387
+ "content": json.dumps(tool_result, ensure_ascii=False, indent=2) + " /no_think"
388
+ })
389
+
390
+ # If no tool calls, encourage continued writing
391
+ if len(tool_calls) == 0:
392
+ # Add follow-up prompt to encourage action or completion
393
+ followup_prompt = (
394
+ "Continue your writing process. If you need to research more, use available tools. "
395
+ "If you need to write or edit content, use file operations. "
396
+ "If your writing is complete and meets requirements, call writer_subjective_task_done. /no_think"
397
+ )
398
+ conversation_history.append({"role": "user", "content": followup_prompt})
399
+
400
+ except Exception as e:
401
+ error_msg = f"Error in writing iteration {iteration}: {e}"
402
+ self.log_error(iteration, error_msg)
403
+ break
404
+
405
+ execution_time = time.time() - start_time
406
+ # Extract final result
407
+ if task_completed:
408
+ # Find the completion result in the trace
409
+ completion_result = None
410
+ for step in reversed(self.reasoning_trace):
411
+ if step.get("type") == "action" and step.get("tool") in ["writer_subjective_task_done"]:
412
+ completion_result = step.get("result")
413
+ break
414
+ return self.create_response(
415
+ success=True,
416
+ result=completion_result,
417
+ iterations=iteration,
418
+ execution_time=execution_time
419
+ )
420
+ else:
421
+
422
+ return self.create_response(
423
+ success=False,
424
+ error=f"Writing task not completed within {self.config.max_iterations} iterations",
425
+ iterations=iteration,
426
+ execution_time=execution_time
427
+ )
428
+
429
+ except Exception as e:
430
+ execution_time = time.time() - start_time if 'start_time' in locals() else 0
431
+ self.logger.error(f"Error in execute_react_loop: {e}")
432
+
433
+ return self.create_response(
434
+ success=False,
435
+ error=str(e),
436
+ iterations=iteration if 'iteration' in locals() else 0,
437
+ execution_time=execution_time
438
+ )
439
+
440
+
441
+ # Factory function for creating the writer agent
442
+ def create_writer_agent(
443
+ model: Any = None,
444
+ max_iterations: int = 15, # More iterations for writing tasks
445
+ temperature: Any = None, # Resolved from env if not provided
446
+ max_tokens: Any = None,
447
+ shared_mcp_client=None
448
+ ) -> WriterAgent:
449
+ """
450
+ Create a WriterAgent instance with server-managed sessions.
451
+
452
+ Args:
453
+ model: The LLM model to use
454
+ max_iterations: Maximum number of iterations for writing tasks
455
+ temperature: Temperature setting for creativity
456
+ max_tokens: Maximum tokens for the AI response
457
+ shared_mcp_client: Optional shared MCP client from parent agent (prevents extra sessions)
458
+
459
+ Returns:
460
+ Configured WriterAgent instance with writing-focused tools
461
+ """
462
+ # Import the enhanced config function
463
+ from .base_agent import create_agent_config
464
+
465
+ # Create agent configuration (session managed by MCP server)
466
+ config = create_agent_config(
467
+ agent_name="WriterAgent",
468
+ model=model,
469
+ max_iterations=max_iterations,
470
+ temperature=temperature,
471
+ max_tokens=max_tokens,
472
+ )
473
+
474
+ # Create agent instance with shared MCP client (filtered tools for writing)
475
+ agent = WriterAgent(config=config, shared_mcp_client=shared_mcp_client)
476
+
477
+ return agent
deepdiver_v2/src/tools/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ """
3
+ Model Context Protocol (MCP) Integration
4
+
5
+ This package contains MCP server implementations, tools, and integrations
6
+ for the DeepDiver multi-agent system.
7
+ """
8
+
9
+ from .mcp_tools import MCPTools
10
+
11
+ # Server imports
12
+ try:
13
+ from .mcp_server_standard import create_app as create_standard_app
14
+ from .mcp_server_simple import app as simple_app
15
+ MCP_STANDARD_AVAILABLE = True
16
+ except ImportError:
17
+ MCP_STANDARD_AVAILABLE = False
18
+ create_standard_app = None
19
+ simple_app = None
20
+
21
+ # For backward compatibility
22
+ try:
23
+ standard_app = simple_app # Keep simple app for basic compatibility
24
+ MCP_AVAILABLE = MCP_STANDARD_AVAILABLE
25
+ except Exception as e:
26
+ MCP_AVAILABLE = False
27
+ standard_app = None
28
+
29
+ __all__ = [
30
+ 'MCPTools',
31
+ 'create_standard_app',
32
+ 'simple_app',
33
+ 'standard_app', # Backward compatibility
34
+ 'MCP_AVAILABLE',
35
+ 'MCP_STANDARD_AVAILABLE'
36
+ ]
deepdiver_v2/src/tools/mcp_client.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ #!/usr/bin/env python3
3
+ """
4
+ MCP Client for Agent-to-Server Communication
5
+ Provides a proper MCP client that uses the official MCP package
6
+ to connect to and communicate with MCP servers through the Model Context Protocol.
7
+ """
8
+
9
+ import json
10
+ import logging
11
+ import time
12
+ from typing import Dict, Any, List, Optional
13
+ from dataclasses import dataclass, field
14
+ from pathlib import Path
15
+ import sys
16
+ sys.path.append(str(Path(__file__).parent.parent.parent))
17
+ from ..utils.status_codes import JsonRpcErr
18
+ from http import HTTPStatus
19
+
20
+ try:
21
+ import httpx
22
+ MCP_AVAILABLE = True
23
+ except ImportError:
24
+ MCP_AVAILABLE = False
25
+ logging.warning("HTTP client dependencies not available. Falling back to direct tools.")
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ @dataclass
31
+ class MCPClientResult:
32
+ """Standard result format for MCP client operations"""
33
+ success: bool
34
+ data: Any = None
35
+ error: str = None
36
+ metadata: Dict[str, Any] = field(default_factory=dict)
37
+
38
+ def to_dict(self) -> Dict[str, Any]:
39
+ return {
40
+ "success": self.success,
41
+ "data": self.data,
42
+ "error": self.error,
43
+ "metadata": self.metadata
44
+ }
45
+
46
+
47
+ @dataclass
48
+ class MCPTool:
49
+ """Simple representation of an MCP tool"""
50
+ name: str
51
+ description: str = ""
52
+ input_schema: Dict[str, Any] = field(default_factory=dict)
53
+
54
+
55
+ @dataclass
56
+ class RetryConfig:
57
+ """Configuration for retry behavior on rate limiting"""
58
+ max_retries: int = 20 # Maximum number of retry attempts
59
+ base_delay: float = 2.0 # Base delay between retries (seconds)
60
+ max_delay: float = 60.0 # Maximum delay between retries (seconds)
61
+ exponential_backoff: bool = True # Use exponential backoff
62
+ respect_retry_after: bool = True # Respect server's Retry-After header
63
+ retry_on_rate_limit: bool = True # Enable automatic retry on rate limits
64
+
65
+
66
+ class MCPClient:
67
+ """
68
+ Simple HTTP-based MCP Client for dynamic tool discovery and execution.
69
+
70
+ This client makes direct HTTP JSON-RPC calls to the MCP server,
71
+ avoiding the complexity of streaming connections.
72
+
73
+ Session management is handled entirely by the server:
74
+ - Server assigns session IDs on connection
75
+ - Server manages workspace creation and isolation
76
+ - All tool operations use server-managed workspaces
77
+ """
78
+
79
+ def __init__(self, server_url: str = "http://localhost:6274/mcp", retry_config: Optional[RetryConfig] = None):
80
+ self.server_url = server_url.rstrip('/')
81
+ self.retry_config = retry_config or RetryConfig()
82
+ self._tools: Dict[str, MCPTool] = {}
83
+ self._connected = False
84
+ self._request_id = 0
85
+ self._session_id = None
86
+
87
+ if not MCP_AVAILABLE:
88
+ logger.warning("HTTP client not available. Some functionality may be limited.")
89
+ return
90
+
91
+ # Initialize connection and discover tools
92
+ self._initialize_connection()
93
+
94
+ def _get_next_id(self) -> int:
95
+ """Get next request ID"""
96
+ self._request_id += 1
97
+ return self._request_id
98
+
99
+ @staticmethod
100
+ def _parse_sse_response(sse_text: str) -> Dict[str, Any]:
101
+ """Parse Server-Sent Events response and extract JSON data"""
102
+ try:
103
+ # SSE format: "event: message\ndata: {json}\n\n"
104
+ lines = sse_text.strip().split('\n')
105
+
106
+ for line in lines:
107
+ if line.startswith('data: '):
108
+ json_data = line[6:] # Remove "data: " prefix
109
+ return json.loads(json_data)
110
+
111
+ # If no data line found, try parsing entire response as JSON
112
+ return json.loads(sse_text)
113
+
114
+ except json.JSONDecodeError as e:
115
+ logger.error(f"Failed to parse SSE response: {e}")
116
+ logger.error(f"SSE text: {sse_text[:200]}...")
117
+ return {"error": {"code": JsonRpcErr.PARSE_ERROR, "message": f"Parse error: {e}"}}
118
+
119
+ def _make_request(self, method: str, params: Dict[str, Any] = None) -> MCPClientResult:
120
+ """Make a JSON-RPC request to the MCP server with automatic retry on rate limits"""
121
+ if not MCP_AVAILABLE:
122
+ return MCPClientResult(success=False, error="HTTP client not available")
123
+
124
+ # Prepare JSON-RPC request
125
+ request_data = {
126
+ "jsonrpc": "2.0",
127
+ "id": self._get_next_id(),
128
+ "method": method,
129
+ "params": params or {}
130
+ }
131
+
132
+ # Make HTTP request with proper MCP headers
133
+ headers = {
134
+ "Content-Type": "application/json",
135
+ "Accept": "application/json, text/event-stream"
136
+ }
137
+
138
+ # Add session ID if available
139
+ if self._session_id:
140
+ headers["X-Session-ID"] = self._session_id
141
+
142
+ last_error = None
143
+ retry_count = 0
144
+
145
+ while retry_count <= self.retry_config.max_retries:
146
+ try:
147
+ # Disable proxy for localhost/127.0.0.1 connections to avoid proxy interference
148
+ import os
149
+ from urllib.parse import urlparse
150
+ parsed_url = urlparse(self.server_url)
151
+ is_localhost = parsed_url.hostname in ['localhost', '127.0.0.1', '::1']
152
+
153
+ # Add localhost to NO_PROXY for localhost connections
154
+ original_no_proxy = None
155
+ if is_localhost:
156
+ original_no_proxy = os.environ.get('NO_PROXY', os.environ.get('no_proxy', ''))
157
+ # Add localhost and 127.0.0.1 to NO_PROXY
158
+ no_proxy_hosts = ['localhost', '127.0.0.1', '::1']
159
+ if original_no_proxy:
160
+ existing_hosts = [h.strip() for h in original_no_proxy.split(',')]
161
+ no_proxy_hosts.extend(existing_hosts)
162
+ os.environ['NO_PROXY'] = ','.join(no_proxy_hosts)
163
+ os.environ['no_proxy'] = ','.join(no_proxy_hosts)
164
+
165
+ try:
166
+ # Create client with connection pooling for high-concurrency
167
+ limits = httpx.Limits(
168
+ max_keepalive_connections=3000, # Keep more connections alive
169
+ max_connections=3000, # Allow more concurrent connections
170
+ keepalive_expiry=1000.0 # Keep connections alive longer
171
+ )
172
+ timeout = httpx.Timeout(
173
+ connect=100.0,
174
+ read=None,
175
+ write=60.0,
176
+ pool=30.0
177
+ )
178
+ with httpx.Client(
179
+ timeout=timeout, # Higher timeout for high-concurrency scenarios
180
+ limits=limits, # Connection pooling for better performance
181
+ trust_env=False,
182
+ http2=True # Enable HTTP/2 for better multiplexing
183
+ ) as client:
184
+ response = client.post(
185
+ self.server_url,
186
+ json=request_data,
187
+ headers=headers
188
+ )
189
+
190
+ # Check for rate limiting (HTTP 429)
191
+ if response.status_code == 429:
192
+ if not self.retry_config.retry_on_rate_limit:
193
+ return MCPClientResult(
194
+ success=False,
195
+ error=f"Rate limit exceeded (HTTP 429) - retries disabled",
196
+ metadata={"status_code": 429, "retry_count": retry_count}
197
+ )
198
+
199
+ if retry_count >= self.retry_config.max_retries:
200
+ return MCPClientResult(
201
+ success=False,
202
+ error=f"Rate limit exceeded (HTTP 429) - max retries ({self.retry_config.max_retries}) reached",
203
+ metadata={"status_code": 429, "retry_count": retry_count}
204
+ )
205
+
206
+ # Calculate retry delay
207
+ delay = self._calculate_retry_delay(response, retry_count)
208
+
209
+ logger.warning(f"Rate limit exceeded for {method} (attempt {retry_count + 1}/{self.retry_config.max_retries + 1}). Retrying in {delay:.1f}s...")
210
+
211
+ # Wait before retry
212
+ time.sleep(delay)
213
+ retry_count += 1
214
+ continue
215
+
216
+ # Handle other HTTP errors
217
+ if response.status_code != HTTPStatus.OK:
218
+ return MCPClientResult(
219
+ success=False,
220
+ error=f"HTTP {response.status_code}: {response.text}",
221
+ metadata={"status_code": response.status_code, "retry_count": retry_count}
222
+ )
223
+
224
+ # Parse successful response (could be JSON or SSE format)
225
+ if response.headers.get("content-type", "").startswith("text/event-stream"):
226
+ # Parse SSE format
227
+ response_data = self._parse_sse_response(response.text)
228
+ else:
229
+ # Parse regular JSON
230
+ response_data = response.json()
231
+
232
+ if "error" in response_data:
233
+ return MCPClientResult(
234
+ success=False,
235
+ error=f"MCP Error: {response_data['error']}",
236
+ metadata={"retry_count": retry_count}
237
+ )
238
+
239
+ # Capture session ID from response data (for all methods, not just initialize)
240
+ if "session_id" in response_data:
241
+ self._session_id = response_data["session_id"]
242
+ logger.info(f"Captured session ID from response: {self._session_id}")
243
+
244
+ # Success! Log retry info if this wasn't the first attempt
245
+ if retry_count > 0:
246
+ logger.info(f"Request {method} succeeded after {retry_count} retries")
247
+
248
+ return MCPClientResult(
249
+ success=True,
250
+ data=response_data.get("result"),
251
+ metadata={
252
+ "method": method,
253
+ "server_url": self.server_url,
254
+ "session_id": self._session_id,
255
+ "retry_count": retry_count
256
+ }
257
+ )
258
+ finally:
259
+ # Restore original NO_PROXY environment variable
260
+ if is_localhost:
261
+ if original_no_proxy is not None:
262
+ if original_no_proxy:
263
+ os.environ['NO_PROXY'] = original_no_proxy
264
+ os.environ['no_proxy'] = original_no_proxy
265
+ else:
266
+ # Remove NO_PROXY if it wasn't set originally
267
+ os.environ.pop('NO_PROXY', None)
268
+ os.environ.pop('no_proxy', None)
269
+
270
+ except Exception as e:
271
+ last_error = str(e)
272
+ logger.error(f"MCP request failed for {method} (attempt {retry_count + 1}): {e}")
273
+
274
+ # Only retry on certain exceptions (network issues, timeouts)
275
+ if not self._should_retry_exception(e) or retry_count >= self.retry_config.max_retries:
276
+ break
277
+
278
+ # Calculate retry delay for exceptions
279
+ delay = self._calculate_exception_retry_delay(retry_count)
280
+ logger.warning(f"Request {method} failed, retrying in {delay:.1f}s... (attempt {retry_count + 1}/{self.retry_config.max_retries + 1})")
281
+
282
+ time.sleep(delay)
283
+ retry_count += 1
284
+
285
+ # All retries exhausted
286
+ return MCPClientResult(
287
+ success=False,
288
+ error=f"Request failed after {retry_count} retries. Last error: {last_error}",
289
+ metadata={"retry_count": retry_count}
290
+ )
291
+
292
+ def _calculate_retry_delay(self, response, retry_count: int) -> float:
293
+ """Calculate delay before retry based on server response and retry count"""
294
+ delay = self.retry_config.base_delay
295
+
296
+ # Respect server's Retry-After header if available
297
+ if self.retry_config.respect_retry_after and "Retry-After" in response.headers:
298
+ try:
299
+ retry_after = float(response.headers["Retry-After"])
300
+ delay = min(retry_after, self.retry_config.max_delay)
301
+ logger.debug("Using server Retry-After: {%s}s", delay)
302
+ except (ValueError, TypeError):
303
+ logger.warning(f"Invalid Retry-After header: {response.headers.get('Retry-After')}")
304
+
305
+ # Apply exponential backoff if enabled
306
+ elif self.retry_config.exponential_backoff:
307
+ delay = min(
308
+ self.retry_config.base_delay * (2 ** retry_count),
309
+ self.retry_config.max_delay
310
+ )
311
+
312
+ return delay
313
+
314
+ def _calculate_exception_retry_delay(self, retry_count: int) -> float:
315
+ """Calculate delay for exception-based retries"""
316
+ if self.retry_config.exponential_backoff:
317
+ return min(
318
+ self.retry_config.base_delay * (2 ** retry_count),
319
+ self.retry_config.max_delay
320
+ )
321
+ return self.retry_config.base_delay
322
+
323
+ @staticmethod
324
+ def _should_retry_exception(exception: Exception) -> bool:
325
+ """Determine if an exception warrants a retry"""
326
+ # Retry on network-related exceptions
327
+ if isinstance(exception, (httpx.RequestError, httpx.TimeoutException, httpx.ConnectError)):
328
+ return True
329
+
330
+ # Don't retry on other exceptions (parsing errors, etc.)
331
+ return False
332
+
333
+ def _initialize_connection(self):
334
+ """Initialize MCP client connection and fetch available tools"""
335
+ if not MCP_AVAILABLE:
336
+ return
337
+
338
+ try:
339
+ # Initialize session
340
+ init_result = self._make_request("initialize", {
341
+ "protocolVersion": "2025-06-18",
342
+ "capabilities": {},
343
+ "clientInfo": {
344
+ "name": "DeepDiver-MCP-Client",
345
+ "version": "1.0.0"
346
+ }
347
+ })
348
+ print(init_result)
349
+ if not init_result.success:
350
+ logger.error(f"MCP initialization failed: {init_result.error}")
351
+ return
352
+
353
+ logger.info("MCP client initialized successfully")
354
+
355
+ # Fetch available tools
356
+ tools_result = self._make_request("tools/list")
357
+
358
+ if tools_result.success and tools_result.data:
359
+ tools_data = tools_result.data.get("tools", [])
360
+ self._tools = {}
361
+
362
+ for tool_data in tools_data:
363
+ tool = MCPTool(
364
+ name=tool_data.get("name", ""),
365
+ description=tool_data.get("description", ""),
366
+ input_schema=tool_data.get("inputSchema", {})
367
+ )
368
+ self._tools[tool.name] = tool
369
+
370
+ logger.info(f"Discovered {len(self._tools)} tools from MCP server: {list(self._tools.keys())}")
371
+
372
+ self._connected = True
373
+
374
+ except Exception as e:
375
+ logger.error(f"Failed to initialize MCP client: {e}")
376
+ self._connected = False
377
+
378
+ def _ensure_connection(self):
379
+ """Ensure MCP client is connected"""
380
+ if not MCP_AVAILABLE:
381
+ raise RuntimeError("HTTP client not available")
382
+
383
+ if not self._connected:
384
+ self._initialize_connection()
385
+
386
+ if not self._connected:
387
+ raise RuntimeError("MCP client not connected to server")
388
+
389
+ def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> MCPClientResult:
390
+ """
391
+ Generic method to call any tool available on the MCP server.
392
+
393
+ Args:
394
+ tool_name: Name of the tool to call
395
+ arguments: Dictionary of arguments to pass to the tool
396
+
397
+ Returns:
398
+ MCPClientResult with the tool execution result
399
+ """
400
+ try:
401
+ self._ensure_connection()
402
+
403
+ if tool_name not in self._tools:
404
+ return MCPClientResult(
405
+ success=False,
406
+ error=f"Tool '{tool_name}' not available on server. Available tools: {list(self._tools.keys())}"
407
+ )
408
+
409
+ # Call the tool via JSON-RPC
410
+ result = self._make_request("tools/call", {
411
+ "name": tool_name,
412
+ "arguments": arguments
413
+ })
414
+
415
+ return result
416
+
417
+ except Exception as e:
418
+ logger.error(f"Error calling tool '{tool_name}': {e}")
419
+ return MCPClientResult(
420
+ success=False,
421
+ error=str(e)
422
+ )
423
+
424
+ def get_available_tools(self) -> Dict[str, MCPTool]:
425
+ """Get dictionary of available tools from the server"""
426
+ return self._tools.copy()
427
+
428
+ def list_tools(self) -> List[str]:
429
+ """Get list of available tool names"""
430
+ return list(self._tools.keys())
431
+
432
+ def get_tool_info(self, tool_name: str) -> Optional[MCPTool]:
433
+ """Get detailed information about a specific tool"""
434
+ return self._tools.get(tool_name)
435
+
436
+ def is_connected(self) -> bool:
437
+ """Check if client is connected to MCP server"""
438
+ return self._connected and MCP_AVAILABLE
439
+
440
+ def refresh_tools(self):
441
+ """Refresh the list of available tools from the server"""
442
+ try:
443
+ # Fetch available tools
444
+ tools_result = self._make_request("tools/list")
445
+
446
+ if tools_result.success and tools_result.data:
447
+ tools_data = tools_result.data.get("tools", [])
448
+ self._tools = {}
449
+ print(self._tools)
450
+
451
+ for tool_data in tools_data:
452
+ tool = MCPTool(
453
+ name=tool_data.get("name", ""),
454
+ description=tool_data.get("description", ""),
455
+ input_schema=tool_data.get("inputSchema", {})
456
+ )
457
+ self._tools[tool.name] = tool
458
+
459
+ logger.info(f"Refreshed {len(self._tools)} tools from MCP server")
460
+ else:
461
+ logger.error(f"Failed to refresh tools: {tools_result.error}")
462
+
463
+ except Exception as e:
464
+ logger.error(f"Error refreshing tools: {e}")
465
+
466
+ def close(self):
467
+ """Close MCP client connection"""
468
+ # Since we create connections per request, just mark as disconnected
469
+ self._connected = False
470
+
471
+
472
+ class MCPToolsAdapter:
473
+ """
474
+ Adapter class that provides the MCPTools interface while using the generic MCP client.
475
+
476
+ This adapter provides backward compatibility with existing agents by mapping
477
+ MCPTools method calls to generic MCP client tool calls.
478
+ """
479
+
480
+ def __init__(self, server_url: str = "http://localhost:6274/mcp", retry_config: Optional[RetryConfig] = None):
481
+ self.client = MCPClient(server_url, retry_config)
482
+
483
+ def _call_tool(self, tool_name: str, **kwargs) -> MCPClientResult:
484
+ """Internal method to call tools through the MCP client"""
485
+ return self.client.call_tool(tool_name, kwargs)
486
+
487
+ def __getattr__(self, name: str):
488
+ """
489
+ Dynamic method creation for any tool available on the server.
490
+ This allows calling tools like adapter.batch_web_search(...) or adapter.file_read(...)
491
+ """
492
+ if name.startswith('_'):
493
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
494
+
495
+ # Create a dynamic method that calls the tool
496
+ def tool_method(**kwargs):
497
+ result = self._call_tool(name, **kwargs)
498
+ # For backward compatibility, return the data portion
499
+ return result.data if result.success else {"error": result.error}
500
+
501
+ return tool_method
502
+
503
+
504
+
505
+ def is_connected(self) -> bool:
506
+ """Check if the MCP client is connected to the server."""
507
+ return self.client.is_connected()
508
+
509
+ def get_available_tools(self) -> Dict[str, MCPTool]:
510
+ """Get available tools from the MCP server."""
511
+ return self.client.get_available_tools()
512
+
513
+ def list_tools(self) -> List[str]:
514
+ """Get list of available tool names."""
515
+ return self.client.list_tools()
516
+
517
+ def get_tool_schemas(self) -> List[Dict[str, Any]]:
518
+ """
519
+ Get tool schemas for all available tools.
520
+ This is the proper MCP way - schemas come from server, not direct imports.
521
+ """
522
+ schemas = []
523
+ available_tools = self.get_available_tools()
524
+
525
+ for tool_name, tool_info in available_tools.items():
526
+ schema = {
527
+ "type": "function",
528
+ "function": {
529
+ "name": tool_name,
530
+ "description": tool_info.description,
531
+ "parameters": tool_info.input_schema
532
+ }
533
+ }
534
+ schemas.append(schema)
535
+
536
+ return schemas
537
+
538
+ def refresh_tools(self):
539
+ """Refresh the list of available tools from the server."""
540
+ self.client.refresh_tools()
541
+
542
+ def get_session_info(self) -> Optional[Dict[str, Any]]:
543
+ """Get session information from the underlying MCP client."""
544
+ try:
545
+ if hasattr(self.client, '_session_id'):
546
+ return {
547
+ "session_id": self.client._session_id,
548
+ "connected": self.client.is_connected(),
549
+ "server_url": getattr(self.client, 'server_url', 'unknown')
550
+ }
551
+ return None
552
+ except Exception:
553
+ return None
554
+
555
+ def close(self):
556
+ """Close the MCP client connection."""
557
+ self.client.close()
558
+
559
+
560
+ class FilteredMCPToolsAdapter:
561
+ """
562
+ Filtered adapter that shares MCP client connection but restricts tool access per agent type.
563
+
564
+ This allows agents to:
565
+ - Share the same session/workspace (via shared client)
566
+ - Have different tool sets appropriate for their role
567
+ - Maintain proper separation of concerns
568
+ """
569
+
570
+ def __init__(self, shared_client: MCPClient, allowed_tools: List[str]):
571
+ """
572
+ Initialize with shared client and allowed tools list
573
+
574
+ Args:
575
+ shared_client: Shared MCPClient instance (same session)
576
+ allowed_tools: List of tools this agent can access
577
+ """
578
+ self.client = shared_client
579
+ self.allowed_tools = set(allowed_tools)
580
+
581
+ # Validate that allowed tools exist on server
582
+ available_tools = set(self.client.list_tools())
583
+ invalid_tools = self.allowed_tools - available_tools
584
+ if invalid_tools:
585
+ logger.warning(f"Requested tools not available on server: {invalid_tools}")
586
+ self.allowed_tools = self.allowed_tools & available_tools
587
+
588
+ def _call_tool(self, tool_name: str, **kwargs) -> MCPClientResult:
589
+ """Call tool if allowed, otherwise return error"""
590
+ if tool_name not in self.allowed_tools:
591
+ return MCPClientResult(
592
+ success=False,
593
+ error=f"Tool '{tool_name}' not allowed for this agent. Allowed tools: {list(self.allowed_tools)}"
594
+ )
595
+
596
+ # Remove any workspace_path if accidentally passed - server handles workspace
597
+ kwargs.pop('workspace_path', None)
598
+ return self.client.call_tool(tool_name, kwargs)
599
+
600
+ def __getattr__(self, name: str):
601
+ """
602
+ Dynamic method resolution with tool filtering.
603
+
604
+ Only allows access to tools in the allowed_tools list.
605
+ """
606
+ if name in self.allowed_tools:
607
+ def tool_method(**kwargs):
608
+ return self._call_tool(name, **kwargs)
609
+ return tool_method
610
+
611
+ if name in self.client.list_tools():
612
+ # Tool exists but not allowed for this agent
613
+ raise AttributeError(f"Tool '{name}' not allowed for this agent. Allowed tools: {list(self.allowed_tools)}")
614
+ else:
615
+ # Tool doesn't exist on server
616
+ raise AttributeError(f"Tool '{name}' not available on server. Available tools: {self.client.list_tools()}")
617
+
618
+
619
+
620
+ # ================ CLIENT MANAGEMENT ================
621
+
622
+
623
+ def is_connected(self) -> bool:
624
+ """Check if client is connected to MCP server"""
625
+ return self.client.is_connected()
626
+
627
+ def get_available_tools(self) -> Dict[str, MCPTool]:
628
+ """Get filtered list of available tools for this agent"""
629
+ all_tools = self.client.get_available_tools()
630
+ return {name: tool for name, tool in all_tools.items() if name in self.allowed_tools}
631
+
632
+ def list_tools(self) -> List[str]:
633
+ """Get list of allowed tool names for this agent"""
634
+ return list(self.allowed_tools)
635
+
636
+ def get_tool_schemas(self) -> List[Dict[str, Any]]:
637
+ """
638
+ Get tool schemas for tools allowed for this agent.
639
+ This is the proper MCP way - schemas come from server, not direct imports.
640
+ """
641
+ schemas = []
642
+ available_tools = self.get_available_tools()
643
+
644
+ for tool_name, tool_info in available_tools.items():
645
+ schema = {
646
+ "type": "function",
647
+ "function": {
648
+ "name": tool_name,
649
+ "description": tool_info.description,
650
+ "parameters": tool_info.input_schema
651
+ }
652
+ }
653
+ schemas.append(schema)
654
+
655
+ return schemas
656
+
657
+ def refresh_tools(self):
658
+ """Refresh the underlying client's tools"""
659
+ self.client.refresh_tools()
660
+
661
+ # Re-validate allowed tools after refresh
662
+ available_tools = set(self.client.list_tools())
663
+ invalid_tools = self.allowed_tools - available_tools
664
+ if invalid_tools:
665
+ logger.warning(f"Some allowed tools no longer available after refresh: {invalid_tools}")
666
+ self.allowed_tools = self.allowed_tools & available_tools
667
+
668
+ def close(self):
669
+ """Close MCP client connection"""
670
+ self.client.close()
671
+
672
+
673
+ # ================ AGENT TOOL SETS ================
674
+ # Define what tools each agent type should have access to
675
+
676
+ PLANNER_AGENT_TOOLS = [
677
+ "download_files",
678
+ "document_qa",
679
+
680
+ "file_read",
681
+ "file_write",
682
+ "str_replace_based_edit_tool",
683
+
684
+ "list_workspace",
685
+ "file_find_by_name",
686
+ ]
687
+
688
+
689
+ INFORMATION_SEEKER_TOOLS = [
690
+ "batch_web_search",
691
+ "url_crawler",
692
+ "document_extract",
693
+ "document_qa",
694
+ "download_files",
695
+ "file_read",
696
+ "file_write",
697
+ "str_replace_based_edit_tool",
698
+ "list_workspace",
699
+ "file_find_by_name",
700
+ ]
701
+
702
+ WRITER_AGENT_TOOLS = [
703
+ "file_read",
704
+ "list_workspace",
705
+ "file_find_by_name",
706
+
707
+ "search_result_classifier",
708
+ "section_writer",
709
+ "concat_section_files",
710
+ ]
711
+
712
+
713
+ def create_filtered_mcp_tools_adapter(
714
+ shared_client: MCPClient,
715
+ agent_type: str
716
+ ) -> FilteredMCPToolsAdapter:
717
+ """
718
+ Create a filtered MCP tools adapter for specific agent type
719
+
720
+ Args:
721
+ shared_client: Shared MCPClient instance
722
+ agent_type: Type of agent ("planner", "information_seeker", "writer")
723
+
724
+ Returns:
725
+ FilteredMCPToolsAdapter with appropriate tools for agent type
726
+ """
727
+ tool_sets = {
728
+ "planner": PLANNER_AGENT_TOOLS,
729
+ "information_seeker": INFORMATION_SEEKER_TOOLS,
730
+ "writer": WRITER_AGENT_TOOLS
731
+ }
732
+
733
+ allowed_tools = tool_sets.get(agent_type, PLANNER_AGENT_TOOLS)
734
+
735
+ return FilteredMCPToolsAdapter(
736
+ shared_client=shared_client,
737
+ allowed_tools=allowed_tools
738
+ )
739
+
740
+
741
+ def create_agent_mcp_tools(
742
+ agent_type: str,
743
+ server_url: str = "http://localhost:6274/mcp",
744
+ retry_config: Optional[RetryConfig] = None
745
+ ) -> FilteredMCPToolsAdapter:
746
+ """
747
+ Convenience factory to create a filtered MCP tools adapter with retry support.
748
+ This is the RECOMMENDED way to create MCP tools for agents.
749
+
750
+ Args:
751
+ agent_type: Type of agent ("planner", "information_seeker", "writer")
752
+ server_url: URL of the MCP server (default: http://localhost:6274/mcp)
753
+ retry_config: Optional retry configuration for handling rate limits
754
+
755
+ Returns:
756
+ FilteredMCPToolsAdapter with appropriate tools and retry support for the agent type
757
+ """
758
+ # Create client with retry support
759
+ client = create_mcp_client(server_url=server_url, retry_config=retry_config)
760
+
761
+ # Create filtered adapter for the agent type
762
+ return create_filtered_mcp_tools_adapter(client, agent_type)
763
+
764
+
765
+ def create_mcp_client(
766
+ server_url: str = "http://localhost:6274/mcp",
767
+ retry_config: Optional[RetryConfig] = None
768
+ ) -> MCPClient:
769
+ """
770
+ Factory function to create a generic MCP Client with optional retry configuration
771
+
772
+ Args:
773
+ server_url: URL of the MCP server (default: http://localhost:6274/mcp)
774
+ retry_config: Optional retry configuration for handling rate limits
775
+
776
+ Returns:
777
+ MCPClient instance for direct tool calling with automatic retry on rate limits
778
+ """
779
+ return MCPClient(server_url=server_url, retry_config=retry_config)
780
+
781
+
782
+ def create_mcp_tools_adapter(
783
+ server_url: str = "http://localhost:6274/mcp",
784
+ retry_config: Optional[RetryConfig] = None
785
+ ) -> MCPToolsAdapter:
786
+ """
787
+ Factory function to create an MCP Tools Adapter for backward compatibility with retry support.
788
+
789
+ Args:
790
+ server_url: URL of the MCP server (default: http://localhost:6274/mcp)
791
+ retry_config: Optional retry configuration for handling rate limits
792
+
793
+ Returns:
794
+ MCPToolsAdapter instance that behaves like MCPTools but uses MCP client with automatic retries
795
+ """
796
+ return MCPToolsAdapter(server_url=server_url, retry_config=retry_config)
797
+
798
+
799
+ # Export for compatibility
800
+ __all__ = [
801
+ 'MCPClientResult',
802
+ 'MCPClient',
803
+ 'MCPTool',
804
+ 'RetryConfig',
805
+ 'MCPToolsAdapter',
806
+ 'FilteredMCPToolsAdapter',
807
+ 'create_mcp_client',
808
+ 'create_mcp_tools_adapter',
809
+ 'create_filtered_mcp_tools_adapter',
810
+ 'create_agent_mcp_tools', # RECOMMENDED for agents
811
+ 'PLANNER_AGENT_TOOLS',
812
+ 'INFORMATION_SEEKER_TOOLS',
813
+ 'WRITER_AGENT_TOOLS'
814
+ ]
deepdiver_v2/src/tools/mcp_server_standard.py ADDED
@@ -0,0 +1,1751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ #!/usr/bin/env python3
3
+ """
4
+ Demo-Ready MCP Server - New Standard Implementation
5
+ Combines robust session management with comprehensive tool definitions.
6
+ Features: workspace isolation, tool call tracking, rate limiting, security, and full tool suite.
7
+ """
8
+
9
+ import argparse
10
+ import asyncio
11
+ import json
12
+ import logging
13
+ import time
14
+ import uuid
15
+ import yaml
16
+ from collections import defaultdict, deque
17
+ from dataclasses import dataclass, field
18
+ from datetime import datetime, timedelta
19
+ from pathlib import Path
20
+ from threading import Thread, Event
21
+ from typing import Any, Dict, List, Optional
22
+
23
+ # Third-party imports
24
+ from starlette.applications import Starlette
25
+ from starlette.middleware.base import BaseHTTPMiddleware
26
+ from starlette.requests import Request
27
+ from starlette.responses import JSONResponse, StreamingResponse
28
+ import uvicorn
29
+
30
+ # Add project root to Python path for imports
31
+ import sys
32
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
33
+ from src.utils.status_codes import JsonRpcErr
34
+ from http import HTTPStatus
35
+
36
+ # Handle both relative and absolute imports
37
+ try:
38
+ from .mcp_tools import MCPTools, get_tool_schemas
39
+ from .mcp_tools_async import AsyncMCPTools
40
+ except ImportError:
41
+ # Fallback for direct script execution
42
+ from src.tools.mcp_tools import MCPTools, get_tool_schemas
43
+ try:
44
+ from src.tools.mcp_tools_async import AsyncMCPTools
45
+ except ImportError:
46
+ AsyncMCPTools = None
47
+
48
+ # Workspace knowledge manager disabled
49
+ WORKSPACE_KNOWLEDGE_AVAILABLE = False
50
+
51
+ # Configure structured logging
52
+ logging.basicConfig(
53
+ level=logging.INFO,
54
+ format='%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s',
55
+ handlers=[
56
+ logging.StreamHandler(sys.stdout),
57
+ logging.FileHandler('mcp_server.log')
58
+ ]
59
+ )
60
+ logger = logging.getLogger(__name__)
61
+
62
+ # ================ CONFIGURATION ================
63
+
64
+
65
+ @dataclass
66
+ class ServerConfig:
67
+ """Server configuration with only actually implemented options"""
68
+ # Server Core Settings
69
+ host: str = "127.0.0.1"
70
+ port: int = 6274
71
+ debug_mode: bool = False
72
+
73
+ # Session Management
74
+ session_ttl_seconds: int = 3600 # 1 hour default
75
+ max_sessions: int = 1000
76
+ cleanup_interval_seconds: int = 300 # 5 minutes
77
+ enable_session_keepalive: bool = True
78
+ keepalive_touch_interval: int = 300
79
+
80
+ # Request Handling
81
+ request_timeout_seconds: int = 120
82
+ max_request_size_mb: int = 10
83
+
84
+ # Client Rate Limiting (per IP)
85
+ rate_limit_requests_per_minute: int = 300
86
+
87
+ # Workspace Management
88
+ base_workspace_dir: str = "workspaces"
89
+
90
+ # Tool Call Tracking & Logging
91
+ enable_tool_tracking: bool = True
92
+ max_tracked_calls_per_session: int = 1000
93
+ track_detailed_errors: bool = True
94
+
95
+
96
+
97
+ # Per-tool Rate Limiting Configuration
98
+ tool_rate_limits: Dict[str, Dict[str, int]] = field(default_factory=dict)
99
+
100
+ @classmethod
101
+ def from_yaml(cls, config_path: str) -> 'ServerConfig':
102
+ """Load configuration from YAML file"""
103
+ try:
104
+ with open(config_path, 'r') as f:
105
+ config_data = yaml.safe_load(f)
106
+
107
+ # Extract configuration sections with defaults
108
+ server_config = config_data.get('server', {})
109
+ tracking_config = config_data.get('tracking', {})
110
+ tool_rate_limits = config_data.get('tool_rate_limits', {})
111
+
112
+ return cls(
113
+ # Server Core Settings
114
+ host=server_config.get('host', "127.0.0.1"),
115
+ port=server_config.get('port', 6274),
116
+ debug_mode=server_config.get('debug_mode', False),
117
+
118
+ # Session Management
119
+ session_ttl_seconds=server_config.get('session_ttl_seconds', 3600),
120
+ max_sessions=server_config.get('max_sessions', 1000),
121
+ cleanup_interval_seconds=server_config.get('cleanup_interval_seconds', 300),
122
+ enable_session_keepalive=server_config.get('enable_session_keepalive', True),
123
+ keepalive_touch_interval=server_config.get('keepalive_touch_interval', 300),
124
+
125
+ # Request Handling
126
+ request_timeout_seconds=server_config.get('request_timeout_seconds', 120),
127
+ max_request_size_mb=server_config.get('max_request_size_mb', 10),
128
+
129
+ # Client Rate Limiting
130
+ rate_limit_requests_per_minute=server_config.get('rate_limit_requests_per_minute', 300),
131
+
132
+ # Workspace Management
133
+ base_workspace_dir=server_config.get('base_workspace_dir', "workspaces"),
134
+
135
+ # Tool Call Tracking & Logging
136
+ enable_tool_tracking=tracking_config.get('enable_tool_tracking', True),
137
+ max_tracked_calls_per_session=tracking_config.get('max_tracked_calls_per_session', 1000),
138
+ track_detailed_errors=tracking_config.get('track_detailed_errors', True),
139
+
140
+ # Per-tool Rate Limiting
141
+ tool_rate_limits=tool_rate_limits
142
+ )
143
+
144
+ except Exception as e:
145
+ logger.error(f"Failed to load configuration from {config_path}: {e}")
146
+ logger.info("Using default configuration")
147
+ return cls()
148
+
149
+ # Global configuration instance - will be set during startup
150
+ config: Optional[ServerConfig] = None
151
+
152
+ # ================ GLOBAL PER-TOOL RATE LIMITING ================
153
+
154
+
155
+ @dataclass
156
+ class ToolRateLimit:
157
+ """Rate limit configuration for a specific tool"""
158
+ requests_per_minute: float
159
+ requests_per_hour: float
160
+ burst_limit: int
161
+
162
+
163
+ class GlobalToolRateLimiter:
164
+ """
165
+ Global rate limiter that controls QPS to external APIs per tool.
166
+ This is shared across all sessions and clients to manage upstream service load.
167
+ """
168
+
169
+ def __init__(self, tool_rate_limits: Dict[str, Dict[str, int]]):
170
+ self.tool_limits: Dict[str, ToolRateLimit] = {}
171
+ self.tool_requests: Dict[str, deque] = defaultdict(deque)
172
+ self.lock = asyncio.Lock()
173
+
174
+ # Initialize rate limits for each tool
175
+ for tool_name, limits_config in tool_rate_limits.items():
176
+ self.tool_limits[tool_name] = ToolRateLimit(
177
+ requests_per_minute=limits_config.get('requests_per_minute', float('inf')),
178
+ requests_per_hour=limits_config.get('requests_per_hour', float('inf')),
179
+ burst_limit=limits_config.get('burst_limit', 10)
180
+ )
181
+ self.tool_requests[tool_name] = deque()
182
+
183
+ logger.info(f"Initialized global tool rate limiter for {len(self.tool_limits)} tools")
184
+
185
+ async def is_allowed(self, tool_name: str) -> tuple[bool, Optional[str]]:
186
+ """
187
+ Check if a request to the specified tool is allowed based on global rate limits.
188
+
189
+ Returns:
190
+ tuple[bool, Optional[str]]: (allowed, reason_if_denied)
191
+ """
192
+ if tool_name not in self.tool_limits:
193
+ # Tool not configured for rate limiting - allow
194
+ return True, None
195
+
196
+ async with self.lock:
197
+ now = time.time()
198
+ limits = self.tool_limits[tool_name]
199
+ requests = self.tool_requests[tool_name]
200
+
201
+ # Clean old requests outside the time windows
202
+ self._cleanup_old_requests(requests, now)
203
+
204
+ # Check various time window limits
205
+ recent_requests = list(requests)
206
+
207
+ # Check burst limit (rapid requests in last second) - only if specified
208
+ if limits.burst_limit != float('inf'):
209
+ burst_count = sum(1 for req_time in recent_requests if now - req_time < 1.0)
210
+ if burst_count >= limits.burst_limit:
211
+ return False, f"Tool '{tool_name}' burst limit exceeded ({limits.burst_limit} requests/burst)"
212
+
213
+ # Check per-minute limit - only if specified
214
+ if limits.requests_per_minute != float('inf'):
215
+ minute_count = sum(1 for req_time in recent_requests if now - req_time < 60.0)
216
+ if minute_count >= limits.requests_per_minute:
217
+ return False, f"Tool '{tool_name}' per-minute limit exceeded ({limits.requests_per_minute} requests/minute)"
218
+
219
+ # Check per-hour limit - only if specified
220
+ if limits.requests_per_hour != float('inf'):
221
+ hour_count = sum(1 for req_time in recent_requests if now - req_time < 3600.0)
222
+ if hour_count >= limits.requests_per_hour:
223
+ return False, f"Tool '{tool_name}' per-hour limit exceeded ({limits.requests_per_hour} requests/hour)"
224
+
225
+ return True, None
226
+
227
+ async def record_request(self, tool_name: str):
228
+ """Record a successful request for rate limiting tracking"""
229
+ if tool_name not in self.tool_limits:
230
+ return
231
+
232
+ async with self.lock:
233
+ now = time.time()
234
+ self.tool_requests[tool_name].append(now)
235
+
236
+ # Keep deque size manageable (only keep last hour of requests)
237
+ self._cleanup_old_requests(self.tool_requests[tool_name], now)
238
+
239
+ @staticmethod
240
+ def _cleanup_old_requests(requests: deque, now: float):
241
+ """Remove requests older than 1 hour to keep memory usage bounded"""
242
+ while requests and now - requests[0] > 3600.0: # 1 hour
243
+ requests.popleft()
244
+
245
+ async def get_tool_stats(self, tool_name: str) -> Dict[str, Any]:
246
+ """Get current usage statistics for a tool"""
247
+ if tool_name not in self.tool_limits:
248
+ return {"error": f"Tool '{tool_name}' not configured for rate limiting"}
249
+
250
+ async with self.lock:
251
+ now = time.time()
252
+ requests = self.tool_requests[tool_name]
253
+ limits = self.tool_limits[tool_name]
254
+
255
+ # Clean old requests first
256
+ self._cleanup_old_requests(requests, now)
257
+
258
+ recent_requests = list(requests)
259
+
260
+ return {
261
+ "tool_name": tool_name,
262
+ "current_usage": {
263
+ "last_second": sum(1 for req_time in recent_requests if now - req_time < 1.0),
264
+ "last_minute": sum(1 for req_time in recent_requests if now - req_time < 60.0),
265
+ "last_hour": sum(1 for req_time in recent_requests if now - req_time < 3600.0)
266
+ },
267
+ "limits": {
268
+ "requests_per_minute": limits.requests_per_minute if limits.requests_per_minute != float('inf') else None,
269
+ "requests_per_hour": limits.requests_per_hour if limits.requests_per_hour != float('inf') else None,
270
+ "burst_limit": limits.burst_limit if limits.burst_limit != float('inf') else None
271
+ },
272
+ "utilization": {
273
+ "minute_utilization": sum(1 for req_time in recent_requests if now - req_time < 60.0) / limits.requests_per_minute if limits.requests_per_minute != float('inf') else 0,
274
+ "hour_utilization": sum(1 for req_time in recent_requests if now - req_time < 3600.0) / limits.requests_per_hour if limits.requests_per_hour != float('inf') else 0
275
+ }
276
+ }
277
+
278
+ def get_all_stats(self) -> Dict[str, Any]:
279
+ """Get usage statistics for all tools"""
280
+ return {
281
+ tool_name: self.get_tool_stats(tool_name)
282
+ for tool_name in self.tool_limits.keys()
283
+ }
284
+
285
+ # Global tool rate limiter instance - will be initialized during startup
286
+ global_tool_rate_limiter: Optional[GlobalToolRateLimiter] = None
287
+
288
+ # ================ TOOL DEFINITIONS ================
289
+
290
+ # Tool execution function mapping - maps tool names to their implementation functions
291
+
292
+
293
+ def get_tool_function(tool_name: str):
294
+ """Get the actual function for a tool"""
295
+ tool_map = {
296
+ "batch_web_search": lambda tools, **kwargs: tools.batch_web_search(**kwargs),
297
+ "url_crawler": lambda tools, **kwargs: tools.url_crawler(**kwargs),
298
+ "download_files": lambda tools, **kwargs: tools.download_files(**kwargs),
299
+ "list_workspace": lambda tools, **kwargs: tools.list_workspace(**kwargs),
300
+ "str_replace_based_edit_tool": lambda tools, **kwargs: tools.str_replace_based_edit_tool(**kwargs),
301
+ "file_stats": lambda tools, **kwargs: tools.file_stats(**kwargs),
302
+ "file_read": lambda tools, **kwargs: tools.file_read(**kwargs),
303
+ "file_read_lines": lambda tools, **kwargs: tools.file_read_lines(**kwargs),
304
+ "content_preview": lambda tools, **kwargs: tools.content_preview(**kwargs),
305
+ "file_write": lambda tools, **kwargs: tools.file_write(**kwargs),
306
+ "file_grep_search": lambda tools, **kwargs: tools.file_grep_search(**kwargs),
307
+ "file_grep_with_context": lambda tools, **kwargs: tools.file_grep_with_context(**kwargs),
308
+ "file_find_by_name": lambda tools, **kwargs: tools.file_find_by_name(**kwargs),
309
+ "bash": lambda tools, **kwargs: tools.bash(**kwargs),
310
+ "task_done": lambda tools, **kwargs: tools.task_done(**kwargs),
311
+ "think": lambda tools, **kwargs: tools.think(**kwargs),
312
+ "reflect": lambda tools, **kwargs: tools.reflect(**kwargs),
313
+ "document_qa": lambda tools, **kwargs: tools.document_qa(**kwargs),
314
+ "extract_markdown_toc": lambda tools, **kwargs: tools.extract_markdown_toc(**kwargs),
315
+ "extract_markdown_section": lambda tools, **kwargs: tools.extract_markdown_section(**kwargs),
316
+
317
+ "document_extract": lambda tools, **kwargs: tools.document_extract(**kwargs),
318
+ "search_result_classifier": lambda tools, **kwargs: tools.search_result_classifier(**kwargs),
319
+ "info_seeker_subjective_task_done": None,
320
+ "writer_subjective_task_done": None,
321
+ "section_writer": lambda tools, **kwargs: tools.section_writer(**kwargs),
322
+ "concat_section_files": lambda tools, **kwargs: tools.concat_section_files(**kwargs),
323
+
324
+ # Internal tools - available to server but NOT exposed to agents via tool schemas
325
+ "internal_file_read_unlimited": lambda tools, **kwargs: tools.internal_file_read_unlimited(**kwargs),
326
+ }
327
+ return tool_map.get(tool_name)
328
+
329
+
330
+ # ================ TOOL CALL TRACKING ================
331
+
332
+
333
+ @dataclass
334
+ class ToolCallLog:
335
+ """Individual tool call log entry"""
336
+ call_id: str
337
+ timestamp: datetime
338
+ tool_name: str
339
+ input_args: Dict[str, Any]
340
+ output_result: Dict[str, Any]
341
+ success: bool
342
+ duration_ms: float
343
+ error_details: Optional[str] = None
344
+ session_id: str = ""
345
+ agent_info: Optional[Dict[str, Any]] = None
346
+
347
+ def to_dict(self) -> Dict[str, Any]:
348
+ """Convert to dictionary for JSON serialization"""
349
+ return {
350
+ "call_id": self.call_id,
351
+ "timestamp": self.timestamp.isoformat(),
352
+ "tool_name": self.tool_name,
353
+ "input_args": self.input_args,
354
+ "output_result": self.output_result,
355
+ "success": self.success,
356
+ "duration_ms": self.duration_ms,
357
+ "error_details": self.error_details,
358
+ "session_id": self.session_id,
359
+ "agent_info": self.agent_info
360
+ }
361
+
362
+
363
+ class ToolCallTracker:
364
+ """Tracks and saves tool calls to workspace-specific files"""
365
+
366
+ def __init__(self, workspace_path: Path, session_id: str):
367
+ self.workspace_path = workspace_path
368
+ self.session_id = session_id
369
+ self.logs_dir = workspace_path / "tool_call_logs"
370
+ self.logs_dir.mkdir(exist_ok=True)
371
+
372
+ # Create daily log file
373
+ today = datetime.now().strftime("%Y-%m-%d")
374
+ self.current_log_file = self.logs_dir / f"tool_calls_{today}.jsonl"
375
+ self.summary_file = self.logs_dir / "session_summary.json"
376
+
377
+ # Track call counts
378
+ self.call_count = 0
379
+ self.tool_usage_stats = defaultdict(int)
380
+
381
+ # Initialize session summary
382
+ self._initialize_session_summary()
383
+
384
+ def _initialize_session_summary(self):
385
+ """Initialize or update session summary file"""
386
+ summary = {
387
+ "session_id": self.session_id,
388
+ "session_start": datetime.now().isoformat(),
389
+ "last_updated": datetime.now().isoformat(),
390
+ "total_tool_calls": 0,
391
+ "tool_usage_stats": {},
392
+ "agent_activity": {},
393
+ "workspace_path": str(self.workspace_path)
394
+ }
395
+
396
+ # Load existing summary if it exists
397
+ if self.summary_file.exists():
398
+ try:
399
+ with open(self.summary_file, 'r') as f:
400
+ existing_summary = json.load(f)
401
+ summary.update(existing_summary)
402
+ # Don't overwrite session_start if it already exists
403
+ if "session_start" in existing_summary:
404
+ summary["session_start"] = existing_summary["session_start"]
405
+ except Exception as e:
406
+ logger.warning(f"Could not load existing session summary: {e}")
407
+
408
+ self._save_summary(summary)
409
+
410
+ def _save_summary(self, summary: Dict[str, Any]):
411
+ """Save session summary to file"""
412
+ try:
413
+ with open(self.summary_file, 'w') as f:
414
+ json.dump(summary, f, indent=2, ensure_ascii=False)
415
+ except Exception as e:
416
+ logger.error(f"Failed to save session summary: {e}")
417
+
418
+ def log_tool_call(self,
419
+ tool_name: str,
420
+ input_args: Dict[str, Any],
421
+ output_result: Dict[str, Any],
422
+ success: bool,
423
+ duration_ms: float,
424
+ error_details: Optional[str] = None,
425
+ agent_info: Optional[Dict[str, Any]] = None) -> str:
426
+ """Log a tool call and return the call ID"""
427
+
428
+ if not config.enable_tool_tracking:
429
+ return ""
430
+
431
+ # Respect max call limit per session
432
+ if self.call_count >= config.max_tracked_calls_per_session:
433
+ logger.warning(f"Max tracked calls reached for session {self.session_id}")
434
+ return ""
435
+
436
+ call_id = str(uuid.uuid4())
437
+ timestamp = datetime.now()
438
+
439
+ # Create log entry
440
+ log_entry = ToolCallLog(
441
+ call_id=call_id,
442
+ timestamp=timestamp,
443
+ tool_name=tool_name,
444
+ input_args=self._sanitize_args(input_args),
445
+ output_result=self._sanitize_result(output_result),
446
+ success=success,
447
+ duration_ms=duration_ms,
448
+ error_details=error_details if config.track_detailed_errors else None,
449
+ session_id=self.session_id,
450
+ agent_info=agent_info
451
+ )
452
+
453
+ # Save to JSONL file (one JSON object per line)
454
+ try:
455
+ with open(self.current_log_file, 'a', encoding="utf-8") as f:
456
+ f.write(json.dumps(log_entry.to_dict(), ensure_ascii=False) + '\n')
457
+ except Exception as e:
458
+ logger.error(f"Failed to save tool call log: {e}")
459
+
460
+ # Update session summary
461
+ self._update_session_summary(log_entry)
462
+
463
+ self.call_count += 1
464
+ self.tool_usage_stats[tool_name] += 1
465
+
466
+ return call_id
467
+
468
+ @staticmethod
469
+ def _sanitize_args(args: Dict[str, Any]) -> Dict[str, Any]:
470
+ """Sanitize arguments for logging (remove sensitive data)"""
471
+ sanitized = {}
472
+ for key, value in args.items():
473
+ if isinstance(value, str) and len(value) > 1000:
474
+ sanitized[key] = value[:1000] + "... [truncated]"
475
+ elif key.lower() in ['password', 'token', 'secret', 'key']:
476
+ sanitized[key] = "[REDACTED]"
477
+ else:
478
+ sanitized[key] = value
479
+ return sanitized
480
+
481
+ def _sanitize_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
482
+ """Sanitize result for logging (remove large content)"""
483
+ if not isinstance(result, dict):
484
+ return result
485
+
486
+ sanitized = {}
487
+ for key, value in result.items():
488
+ if isinstance(value, str) and len(value) > 2000:
489
+ sanitized[key] = value[:2000] + "... [truncated]"
490
+ elif isinstance(value, dict):
491
+ sanitized[key] = self._sanitize_result(value)
492
+ else:
493
+ sanitized[key] = value
494
+ return sanitized
495
+
496
+ def _update_session_summary(self, log_entry: ToolCallLog):
497
+ """Update session summary with new tool call"""
498
+ try:
499
+ summary = {
500
+ "session_id": self.session_id,
501
+ "last_updated": datetime.now().isoformat(),
502
+ "total_tool_calls": self.call_count + 1,
503
+ "tool_usage_stats": dict(self.tool_usage_stats),
504
+ "workspace_path": str(self.workspace_path)
505
+ }
506
+
507
+ # Load existing summary
508
+ if self.summary_file.exists():
509
+ with open(self.summary_file, 'r') as f:
510
+ existing_summary = json.load(f)
511
+ summary.update(existing_summary)
512
+
513
+ # Update with new data
514
+ summary["last_updated"] = datetime.now().isoformat()
515
+ summary["total_tool_calls"] = self.call_count + 1
516
+ summary["tool_usage_stats"] = dict(self.tool_usage_stats)
517
+ summary["tool_usage_stats"][log_entry.tool_name] = self.tool_usage_stats[log_entry.tool_name] + 1
518
+
519
+ # Track agent activity
520
+ if log_entry.agent_info:
521
+ agent_type = log_entry.agent_info.get('type', 'unknown')
522
+ if 'agent_activity' not in summary:
523
+ summary['agent_activity'] = {}
524
+ if agent_type not in summary['agent_activity']:
525
+ summary['agent_activity'][agent_type] = {
526
+ 'tool_calls': 0,
527
+ 'last_active': log_entry.timestamp.isoformat()
528
+ }
529
+ summary['agent_activity'][agent_type]['tool_calls'] += 1
530
+ summary['agent_activity'][agent_type]['last_active'] = log_entry.timestamp.isoformat()
531
+
532
+ self._save_summary(summary)
533
+
534
+ except Exception as e:
535
+ logger.error(f"Failed to update session summary: {e}")
536
+
537
+ # ================ SESSION KEEP-ALIVE FOR LONG OPERATIONS ================
538
+
539
+
540
+ class KeepAliveSessionWrapper:
541
+ """Wrapper that keeps a session alive during long-running operations"""
542
+
543
+ def __init__(self, session: 'Session', touch_interval: int = 300): # Touch every 5 minutes
544
+ self.session = session
545
+ self.touch_interval = touch_interval
546
+ self.keep_alive_thread = None
547
+ self.stop_event = Event()
548
+ self.active = False
549
+
550
+ def start_keep_alive(self):
551
+ """Start the keep-alive mechanism"""
552
+ if self.active:
553
+ return
554
+
555
+ self.active = True
556
+ self.stop_event.clear()
557
+
558
+ def keep_alive_worker():
559
+ while not self.stop_event.wait(self.touch_interval):
560
+ try:
561
+ self.session.touch()
562
+ logger.debug("Keep-alive: Touched session {%s}", self.session.id)
563
+ except Exception as e:
564
+ logger.error(f"Keep-alive error for session {self.session.id}: {e}")
565
+ break
566
+
567
+ self.keep_alive_thread = Thread(target=keep_alive_worker, daemon=True)
568
+ self.keep_alive_thread.start()
569
+ logger.info(f"Started keep-alive for session {self.session.id}")
570
+
571
+ def stop_keep_alive(self):
572
+ """Stop the keep-alive mechanism"""
573
+ if not self.active:
574
+ return
575
+
576
+ self.active = False
577
+ self.stop_event.set()
578
+
579
+ if self.keep_alive_thread and self.keep_alive_thread.is_alive():
580
+ self.keep_alive_thread.join(timeout=1.0)
581
+
582
+ # Final touch
583
+ try:
584
+ self.session.touch()
585
+ except Exception as e:
586
+ logger.error(f"Final keep-alive touch error for session {self.session.id}: {e}")
587
+
588
+ logger.info(f"Stopped keep-alive for session {self.session.id}")
589
+
590
+ def __enter__(self):
591
+ self.start_keep_alive()
592
+ return self
593
+
594
+ def __exit__(self, exc_type, exc_val, exc_tb):
595
+ self.stop_keep_alive()
596
+
597
+ # ================ SESSION MANAGEMENT ================
598
+
599
+
600
+ @dataclass
601
+ class Session:
602
+ """Thread-safe session data structure with workspace management"""
603
+ id: str
604
+ created_at: datetime
605
+ last_accessed: datetime
606
+ initialized: bool = False
607
+ request_count: int = 0
608
+ metadata: Dict[str, Any] = field(default_factory=dict)
609
+ workspace_path: Optional[Path] = None
610
+ mcp_tools: Optional[MCPTools] = None
611
+ tool_tracker: Optional[ToolCallTracker] = None
612
+
613
+
614
+ def is_expired(self, ttl_seconds: int) -> bool:
615
+ """Check if session has expired"""
616
+ return datetime.now() - self.last_accessed > timedelta(seconds=ttl_seconds)
617
+
618
+ def touch(self):
619
+ """Update last accessed time"""
620
+ self.last_accessed = datetime.now()
621
+ self.request_count += 1
622
+
623
+ def get_mcp_tools(self, prefer_async: bool = True) -> MCPTools:
624
+ """Get or create MCP tools instance for this session"""
625
+ if self.mcp_tools is None:
626
+ # Use async tools if available and preferred
627
+ if prefer_async and AsyncMCPTools is not None:
628
+ self.mcp_tools = AsyncMCPTools(workspace_path=str(self.workspace_path) if self.workspace_path else None)
629
+ else:
630
+ self.mcp_tools = MCPTools(workspace_path=str(self.workspace_path) if self.workspace_path else None)
631
+ return self.mcp_tools
632
+
633
+ def get_tool_tracker(self) -> Optional[ToolCallTracker]:
634
+ """Get or create tool call tracker for this session"""
635
+ if config.enable_tool_tracking and self.workspace_path:
636
+ if self.tool_tracker is None:
637
+ self.tool_tracker = ToolCallTracker(self.workspace_path, self.id)
638
+ return self.tool_tracker
639
+ return None
640
+
641
+
642
+
643
+ class AsyncRLock:
644
+ """异步可重入锁,模拟 threading.RLock 的异步版本"""
645
+ def __init__(self):
646
+ self._lock = asyncio.Lock()
647
+ self._owner: Optional[asyncio.Task] = None # 记录持有锁的协程任务
648
+ self._count = 0 # 重入次数
649
+
650
+ async def acquire(self):
651
+ current_task = asyncio.current_task()
652
+ # 如果当前协程已持有锁,直接增加重入次数
653
+ if self._owner == current_task:
654
+ self._count += 1
655
+ return
656
+ # 否则等待获取锁
657
+ await self._lock.acquire()
658
+ self._owner = current_task
659
+ self._count = 1
660
+
661
+ async def release(self):
662
+ if self._owner != asyncio.current_task():
663
+ raise RuntimeError("不能释放非当前协程持有的锁")
664
+ self._count -= 1
665
+ if self._count == 0: # 重入次数归零时,真正释放锁
666
+ self._owner = None
667
+ self._lock.release()
668
+
669
+ # 支持 async with 语法
670
+ async def __aenter__(self):
671
+ await self.acquire()
672
+ return self
673
+
674
+ async def __aexit__(self, exc_type, exc, tb):
675
+ await self.release()
676
+
677
+
678
+ class ThreadSafeSessionManager:
679
+ """Thread-safe session manager with workspace management"""
680
+
681
+ def __init__(self, ttl_seconds: int = 3600, max_sessions: int = 1000, base_workspace_dir: str = "workspaces"):
682
+ self.ttl_seconds = ttl_seconds
683
+ self.max_sessions = max_sessions
684
+ self.base_workspace_dir = Path(base_workspace_dir)
685
+ self.base_workspace_dir.mkdir(exist_ok=True)
686
+
687
+ # Thread-safe session storage
688
+ self.sessions: Dict[str, Session] = {}
689
+ self.lock = AsyncRLock()
690
+
691
+ # Start cleanup thread
692
+ self._start_cleanup_thread()
693
+
694
+ async def create_session(self) -> str:
695
+ """Create a new session and return session ID"""
696
+ session_id = str(uuid.uuid4())
697
+
698
+ async with self.lock:
699
+ # Check session limits
700
+ if len(self.sessions) >= self.max_sessions:
701
+ await self._cleanup_oldest_sessions()
702
+
703
+ # Create workspace directory
704
+ workspace_path = self.base_workspace_dir / session_id
705
+ workspace_path.mkdir(exist_ok=True, parents=True)
706
+
707
+ # Create session
708
+ session = Session(
709
+ id=session_id,
710
+ created_at=datetime.now(),
711
+ last_accessed=datetime.now(),
712
+ workspace_path=workspace_path
713
+ )
714
+
715
+ self.sessions[session_id] = session
716
+
717
+ logger.info(f"Created session {session_id} with workspace {workspace_path}")
718
+ return session_id
719
+
720
+ async def get_session(self, session_id: str) -> Optional[Session]:
721
+ """Get session by ID if it exists and is not expired"""
722
+ async with self.lock:
723
+ session = self.sessions.get(session_id)
724
+ if session and not session.is_expired(self.ttl_seconds):
725
+ session.touch()
726
+ return session
727
+ elif session:
728
+ # Remove expired session
729
+ del self.sessions[session_id]
730
+ logger.info(f"Removed expired session {session_id}")
731
+ return None
732
+
733
+ async def get_or_create_session(self, session_id: Optional[str] = None) -> Session:
734
+ """Get existing session or create new one"""
735
+ if session_id:
736
+ session = await self.get_session(session_id)
737
+ if session:
738
+ return session
739
+
740
+ # Create new session
741
+ new_session_id = await self.create_session()
742
+ return self.sessions[new_session_id]
743
+
744
+ async def _cleanup_expired_sessions(self):
745
+ """Remove expired sessions"""
746
+ async with self.lock:
747
+ expired_sessions = []
748
+ for session_id, session in self.sessions.items():
749
+ if session.is_expired(self.ttl_seconds):
750
+ expired_sessions.append(session_id)
751
+
752
+ for session_id in expired_sessions:
753
+ del self.sessions[session_id]
754
+ logger.info(f"Cleaned up expired session {session_id}")
755
+
756
+ async def _cleanup_oldest_sessions(self):
757
+ """Remove oldest sessions when limit is reached"""
758
+ async with self.lock:
759
+ if len(self.sessions) < self.max_sessions:
760
+ return
761
+
762
+ # Sort by last accessed time and remove oldest
763
+ sorted_sessions = sorted(
764
+ self.sessions.items(),
765
+ key=lambda x: x[1].last_accessed
766
+ )
767
+
768
+ sessions_to_remove = len(self.sessions) - self.max_sessions + 10 # Remove extra
769
+ for i in range(sessions_to_remove):
770
+ if i < len(sorted_sessions):
771
+ session_id = sorted_sessions[i][0]
772
+ del self.sessions[session_id]
773
+ logger.info(f"Removed old session {session_id} due to session limit")
774
+
775
+ def _start_cleanup_thread(self):
776
+ """Start background cleanup thread"""
777
+ def cleanup_worker():
778
+ while True:
779
+ try:
780
+ time.sleep(config.cleanup_interval_seconds)
781
+ # Run async method in sync context
782
+ loop = asyncio.new_event_loop()
783
+ loop.run_until_complete(self._cleanup_expired_sessions())
784
+ loop.close()
785
+ except Exception as e:
786
+ logger.error(f"Error in cleanup thread: {e}")
787
+
788
+ import threading
789
+ cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
790
+ cleanup_thread.start()
791
+ logger.info("Started session cleanup thread")
792
+
793
+
794
+
795
+ async def get_stats(self) -> Dict[str, Any]:
796
+ """Get session manager statistics"""
797
+ async with self.lock:
798
+ return {
799
+ "total_sessions": len(self.sessions),
800
+ "max_sessions": self.max_sessions,
801
+ "ttl_seconds": self.ttl_seconds,
802
+ "session_ids": list(self.sessions.keys())
803
+ }
804
+
805
+ # ================ MIDDLEWARE AND SECURITY ================
806
+
807
+
808
+ class RateLimiter:
809
+ """Simple rate limiter with time-window tracking"""
810
+
811
+ def __init__(self, requests_per_minute: int = 60):
812
+ self.requests_per_minute = requests_per_minute
813
+ self.requests: Dict[str, List[float]] = defaultdict(list)
814
+ self.lock = asyncio.Lock()
815
+
816
+ async def is_allowed(self, client_id: str) -> bool:
817
+ """Check if request is allowed for client"""
818
+ async with self.lock:
819
+ now = time.time()
820
+ minute_ago = now - 60
821
+
822
+ # Clean old requests
823
+ self.requests[client_id] = [
824
+ req_time for req_time in self.requests[client_id]
825
+ if req_time > minute_ago
826
+ ]
827
+
828
+ # Check rate limit
829
+ if len(self.requests[client_id]) >= self.requests_per_minute:
830
+ return False
831
+
832
+ # Add current request
833
+ self.requests[client_id].append(now)
834
+ return True
835
+
836
+
837
+ class RequestValidator:
838
+ """Validates incoming MCP requests"""
839
+
840
+ @staticmethod
841
+ def validate_mcp_request(data: Dict[str, Any]) -> tuple[bool, Optional[str]]:
842
+ """Validate basic MCP request structure"""
843
+ if not isinstance(data, dict):
844
+ return False, "Request must be a JSON object"
845
+
846
+ if "method" not in data:
847
+ return False, "Missing 'method' field"
848
+
849
+ if "id" not in data:
850
+ return False, "Missing 'id' field"
851
+
852
+ return True, None
853
+
854
+ @staticmethod
855
+ def validate_tool_call(params: Dict[str, Any]) -> tuple[bool, Optional[str]]:
856
+ """Validate tool call parameters"""
857
+ if not isinstance(params, dict):
858
+ return False, "Tool parameters must be a JSON object"
859
+
860
+ if "name" not in params:
861
+ return False, "Missing tool 'name'"
862
+
863
+ if "arguments" not in params:
864
+ return False, "Missing tool 'arguments'"
865
+
866
+ tool_name = params["name"]
867
+
868
+ # Get detailed schemas
869
+ detailed_schemas = get_tool_schemas()
870
+
871
+ if tool_name not in detailed_schemas:
872
+ return False, f"Unknown tool: {tool_name}. Available tools: {sorted(list(detailed_schemas.keys()))}"
873
+
874
+ return True, None
875
+
876
+
877
+ class SecurityMiddleware(BaseHTTPMiddleware):
878
+ """Security middleware for basic protection"""
879
+
880
+ async def dispatch(self, request: Request, call_next):
881
+ # Check content length
882
+ content_length = request.headers.get("content-length")
883
+ if content_length and int(content_length) > config.max_request_size_mb * 1024 * 1024:
884
+ return JSONResponse(
885
+ status_code=HTTPStatus.REQUEST_ENTITY_TOO_LARGE,
886
+ content={"error": "Request too large"}
887
+ )
888
+
889
+ # Add security headers
890
+ response = await call_next(request)
891
+ response.headers["X-Content-Type-Options"] = "nosniff"
892
+ response.headers["X-Frame-Options"] = "DENY"
893
+ response.headers["X-XSS-Protection"] = "1; mode=block"
894
+
895
+ return response
896
+
897
+
898
+ class RateLimitMiddleware(BaseHTTPMiddleware):
899
+ """Rate limiting middleware"""
900
+
901
+ def __init__(self, app, input_rate_limiter: RateLimiter):
902
+ super().__init__(app)
903
+ self.rate_limiter = input_rate_limiter
904
+
905
+ async def dispatch(self, request: Request, call_next):
906
+ # Get client identifier (IP address)
907
+ client_ip = request.client.host if request.client else "unknown"
908
+
909
+ if not await self.rate_limiter.is_allowed(client_ip):
910
+ return JSONResponse(
911
+ status_code=HTTPStatus.TOO_MANY_REQUESTS,
912
+ content={"error": "Rate limit exceeded"}
913
+ )
914
+
915
+ return await call_next(request)
916
+
917
+ # Global session manager
918
+ session_manager = None
919
+ rate_limiter = None
920
+
921
+
922
+ @dataclass
923
+ class RateLimitViolation:
924
+ """Represents a rate limit violation with standardized error information"""
925
+ tool_name: str
926
+ limit_type: str # "burst", "second", "minute", "hour"
927
+ current_usage: int
928
+ limit_value: float
929
+ retry_after_seconds: float
930
+
931
+ def to_user_friendly_message(self) -> str:
932
+ """Generate user-friendly error message"""
933
+ if self.limit_type == "burst":
934
+ return f"Service temporarily unavailable: Too many rapid requests to {self.tool_name}. Please wait {self.retry_after_seconds:.0f} seconds before trying again."
935
+ elif self.limit_type == "second":
936
+ return f"Service temporarily unavailable: {self.tool_name} request rate exceeded ({self.limit_value}/second). Please wait {self.retry_after_seconds:.0f} seconds before trying again."
937
+ elif self.limit_type == "minute":
938
+ return f"Service temporarily unavailable: {self.tool_name} quota exceeded ({self.limit_value}/minute). Please try again in {self.retry_after_seconds:.0f} seconds."
939
+ elif self.limit_type == "hour":
940
+ return f"Service temporarily unavailable: {self.tool_name} hourly quota exceeded ({self.limit_value}/hour). Please try again in {self.retry_after_seconds:.0f} minutes."
941
+ else:
942
+ return f"Service temporarily unavailable: {self.tool_name} rate limit exceeded. Please try again later."
943
+
944
+ def to_technical_message(self) -> str:
945
+ """Generate technical error message for debugging"""
946
+ return f"Tool '{self.tool_name}' {self.limit_type} limit exceeded ({self.current_usage}/{self.limit_value} {self.limit_type})"
947
+
948
+
949
+ def _parse_rate_limit_denial(tool_name: str, denial_reason: str) -> RateLimitViolation:
950
+ """Parse rate limit denial reason into structured violation information"""
951
+ import re
952
+
953
+ # Default values
954
+ limit_type = "unknown"
955
+ current_usage = 0
956
+ limit_value = 0.0
957
+ retry_after_seconds = 60.0 # Default retry after 1 minute
958
+
959
+ # Parse different types of rate limit violations
960
+ if "burst limit exceeded" in denial_reason:
961
+ limit_type = "burst"
962
+ retry_after_seconds = 1.0 # Burst limits reset quickly
963
+ match = re.search(r'\((\d+) requests/burst\)', denial_reason)
964
+ if match:
965
+ limit_value = float(match.group(1))
966
+ current_usage = int(limit_value) # Approximation
967
+
968
+ elif "per-second limit exceeded" in denial_reason:
969
+ limit_type = "second"
970
+ retry_after_seconds = 1.0 # Wait 1 second
971
+ match = re.search(r'\(([0-9.]+) requests/second\)', denial_reason)
972
+ if match:
973
+ limit_value = float(match.group(1))
974
+ current_usage = int(limit_value) # Approximation
975
+
976
+ elif "per-minute limit exceeded" in denial_reason:
977
+ limit_type = "minute"
978
+ retry_after_seconds = 10.0 # Wait 10 seconds for minute limits
979
+ match = re.search(r'\(([0-9.]+) requests/minute\)', denial_reason)
980
+ if match:
981
+ limit_value = float(match.group(1))
982
+ current_usage = int(limit_value) # Approximation
983
+
984
+ elif "per-hour limit exceeded" in denial_reason:
985
+ limit_type = "hour"
986
+ retry_after_seconds = 300.0 # Wait 5 minutes for hour limits
987
+ match = re.search(r'\(([0-9.]+) requests/hour\)', denial_reason)
988
+ if match:
989
+ limit_value = float(match.group(1))
990
+ current_usage = int(limit_value) # Approximation
991
+
992
+ return RateLimitViolation(
993
+ tool_name=tool_name,
994
+ limit_type=limit_type,
995
+ current_usage=current_usage,
996
+ limit_value=limit_value,
997
+ retry_after_seconds=retry_after_seconds
998
+ )
999
+
1000
+
1001
+ async def _call_session_tool_async(session: Session, tool_name: str, tool_args: Dict[str, Any],
1002
+ client_ip: str = "unknown") -> Dict[str, Any]:
1003
+ """Execute a tool within a session context with full tracking, workspace management, and global rate limiting"""
1004
+
1005
+ start_time = time.time()
1006
+ success = False
1007
+ error_details = None
1008
+ result_data = None
1009
+
1010
+ # Touch session at start of tool execution to prevent expiry during long operations
1011
+ session.touch()
1012
+
1013
+ try:
1014
+ # CHECK GLOBAL TOOL RATE LIMITS FIRST
1015
+ if global_tool_rate_limiter:
1016
+ allowed, deny_reason = await global_tool_rate_limiter.is_allowed(tool_name)
1017
+ if not allowed:
1018
+ # Parse the denial reason to create structured rate limit violation
1019
+ rate_limit_violation = _parse_rate_limit_denial(tool_name, deny_reason)
1020
+
1021
+ # Create user-friendly error message
1022
+ user_message = rate_limit_violation.to_user_friendly_message()
1023
+ technical_message = rate_limit_violation.to_technical_message()
1024
+
1025
+ logger.warning(f"Session {session.id}: {technical_message}")
1026
+
1027
+ result_data = {
1028
+ "success": False,
1029
+ "error": user_message,
1030
+ "error_code": "RATE_LIMIT_EXCEEDED",
1031
+ "error_type": "rate_limit",
1032
+ "tool_name": tool_name,
1033
+ "limit_type": rate_limit_violation.limit_type,
1034
+ "retry_after_seconds": rate_limit_violation.retry_after_seconds,
1035
+ "data": None,
1036
+ "rate_limited": True, # Keep for backward compatibility
1037
+ "technical_details": technical_message # For debugging
1038
+ }
1039
+
1040
+ # Still log this for tracking purposes
1041
+ duration_ms = (time.time() - start_time) * 1000
1042
+ tracker = session.get_tool_tracker()
1043
+ if tracker:
1044
+ try:
1045
+ agent_info = {
1046
+ "client_ip": client_ip,
1047
+ "type": "unknown",
1048
+ "session_request_count": session.request_count
1049
+ }
1050
+
1051
+ tracker.log_tool_call(
1052
+ tool_name=tool_name,
1053
+ input_args=tool_args,
1054
+ output_result=result_data,
1055
+ success=False,
1056
+ duration_ms=duration_ms,
1057
+ error_details=user_message,
1058
+ agent_info=agent_info
1059
+ )
1060
+ except Exception as e:
1061
+ logger.error(f"Failed to log rate-limited tool call: {e}")
1062
+
1063
+ return result_data
1064
+
1065
+ # Get MCP tools instance for this session (handles workspace isolation)
1066
+ mcp_tools = session.get_mcp_tools(prefer_async=True)
1067
+
1068
+ # Get tool method directly from the mcp_tools instance
1069
+ if not hasattr(mcp_tools, tool_name):
1070
+ raise ValueError(f"Tool '{tool_name}' not implemented")
1071
+
1072
+ tool_method = getattr(mcp_tools, tool_name)
1073
+
1074
+ # Add session context to tool arguments for workspace-aware tools
1075
+ if hasattr(mcp_tools, 'set_session_context'):
1076
+ mcp_tools.set_session_context(session.id, str(session.workspace_path))
1077
+
1078
+ # Execute tool with keep-alive for potentially long operations
1079
+ logger.info(f"Session {session.id}: Executing tool '{tool_name}' with args: {list(tool_args.keys())}")
1080
+
1081
+ # Use keep-alive wrapper for tools that might take a long time
1082
+ long_running_tools = {'batch_web_search', 'url_crawler', 'document_qa', 'document_extract', 'bash'}
1083
+
1084
+ # Check if the tool method is async
1085
+ import inspect
1086
+ is_async_tool = inspect.iscoroutinefunction(tool_method)
1087
+
1088
+ # Execute tool based on whether it's async or sync
1089
+ if is_async_tool:
1090
+ # Tool is async - execute directly
1091
+ logger.debug("Executing async tool '{%s}'", tool_name)
1092
+
1093
+ if config.enable_session_keepalive and tool_name in long_running_tools:
1094
+ # For long-running async tools, use keep-alive
1095
+ with KeepAliveSessionWrapper(session, touch_interval=config.keepalive_touch_interval):
1096
+ result = await tool_method(**tool_args)
1097
+ else:
1098
+ # For regular async tools, execute directly
1099
+ result = await tool_method(**tool_args)
1100
+ else:
1101
+ # Tool is sync - execute in thread pool
1102
+ logger.debug("Executing sync tool '{%s}' in thread pool", tool_name)
1103
+
1104
+ # Define the synchronous tool execution function
1105
+ def execute_tool_sync():
1106
+ """Synchronous tool execution to be run in thread pool"""
1107
+ return tool_method(**tool_args)
1108
+
1109
+ # Execute tool asynchronously in thread pool for true non-blocking execution
1110
+ import asyncio
1111
+ import concurrent.futures
1112
+
1113
+ # Create a thread pool executor for CPU-bound/blocking operations
1114
+ loop = asyncio.get_event_loop()
1115
+
1116
+ if config.enable_session_keepalive and tool_name in long_running_tools:
1117
+ # For long-running tools, use keep-alive with async execution
1118
+ with KeepAliveSessionWrapper(session, touch_interval=config.keepalive_touch_interval):
1119
+ # Run in thread pool to avoid blocking the event loop
1120
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
1121
+ result = await loop.run_in_executor(executor, execute_tool_sync)
1122
+ else:
1123
+ # For regular tools, use async execution without keep-alive
1124
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
1125
+ result = await loop.run_in_executor(executor, execute_tool_sync)
1126
+
1127
+ # Touch session after tool execution to update activity
1128
+ session.touch()
1129
+
1130
+ # Handle different result formats
1131
+ if hasattr(result, 'to_dict'):
1132
+ result_data = result.to_dict()
1133
+ elif isinstance(result, dict):
1134
+ result_data = result
1135
+ else:
1136
+ result_data = {"result": result}
1137
+
1138
+ success = result_data.get('success', True)
1139
+
1140
+ if success:
1141
+ logger.info(f"Session {session.id}: Tool '{tool_name}' completed successfully")
1142
+
1143
+ # RECORD SUCCESSFUL REQUEST FOR RATE LIMITING
1144
+ if global_tool_rate_limiter:
1145
+ await global_tool_rate_limiter.record_request(tool_name)
1146
+
1147
+
1148
+
1149
+ else:
1150
+ error_details = result_data.get('error', 'Unknown error')
1151
+ logger.warning(f"Session {session.id}: Tool '{tool_name}' failed: {error_details}")
1152
+
1153
+ except Exception as e:
1154
+ success = False
1155
+ error_details = str(e)
1156
+ result_data = {
1157
+ "success": False,
1158
+ "error": error_details,
1159
+ "data": None
1160
+ }
1161
+ logger.error(f"Session {session.id}: Tool '{tool_name}' exception: {e}")
1162
+
1163
+ # Calculate execution time
1164
+ duration_ms = (time.time() - start_time) * 1000
1165
+
1166
+ # Log tool call if tracking is enabled
1167
+ tracker = session.get_tool_tracker()
1168
+ if tracker:
1169
+ try:
1170
+ agent_info = {
1171
+ "client_ip": client_ip,
1172
+ "type": "unknown", # Could be enhanced to detect agent type
1173
+ "session_request_count": session.request_count
1174
+ }
1175
+
1176
+ tracker.log_tool_call(
1177
+ tool_name=tool_name,
1178
+ input_args=tool_args,
1179
+ output_result=result_data,
1180
+ success=success,
1181
+ duration_ms=duration_ms,
1182
+ error_details=error_details,
1183
+ agent_info=agent_info
1184
+ )
1185
+ except Exception as e:
1186
+ logger.error(f"Failed to log tool call: {e}")
1187
+
1188
+ return result_data
1189
+
1190
+
1191
+
1192
+ def create_sse_response(response_data: dict, session_id: str = None) -> StreamingResponse:
1193
+ """Create Server-Sent Events response with proper formatting"""
1194
+ def generate_sse():
1195
+ try:
1196
+ # Add session info to response if available
1197
+ if session_id:
1198
+ response_data["session_id"] = session_id
1199
+
1200
+ json_data = json.dumps(response_data, ensure_ascii=False)
1201
+ yield f"event: message\n"
1202
+ yield f"data: {json_data}\n"
1203
+ yield f"\n"
1204
+ except Exception as e:
1205
+ error_data = {
1206
+ "jsonrpc": "2.0",
1207
+ "error": {"code": JsonRpcErr.INTERNAL_ERROR, "message": f"Internal error: {str(e)}"},
1208
+ "id": response_data.get("id")
1209
+ }
1210
+ json_data = json.dumps(error_data, ensure_ascii=False)
1211
+ yield f"event: error\n"
1212
+ yield f"data: {json_data}\n"
1213
+ yield f"\n"
1214
+
1215
+ return StreamingResponse(
1216
+ generate_sse(),
1217
+ media_type="text/event-stream",
1218
+ headers={
1219
+ "Cache-Control": "no-cache",
1220
+ "Connection": "keep-alive",
1221
+ "Access-Control-Allow-Origin": "*",
1222
+ }
1223
+ )
1224
+
1225
+
1226
+ def create_error_response(request_id: Any, code: int, message: str, session_id: str = None) -> StreamingResponse:
1227
+ """Create error response in SSE format"""
1228
+ error_data = {
1229
+ "jsonrpc": "2.0",
1230
+ "error": {"code": code, "message": message},
1231
+ "id": request_id
1232
+ }
1233
+ return create_sse_response(error_data, session_id)
1234
+
1235
+
1236
+ def create_rate_limit_response(
1237
+ request_id: Any,
1238
+ tool_name: str,
1239
+ error_message: str,
1240
+ retry_after_seconds: float,
1241
+ limit_type: str,
1242
+ technical_details: str = "",
1243
+ session_id: str = None
1244
+ ) -> JSONResponse:
1245
+ """
1246
+ Create HTTP 429 Rate Limit Exceeded response with proper headers and error format.
1247
+
1248
+ Returns proper HTTP status code instead of SSE for rate limiting errors.
1249
+ """
1250
+
1251
+ # Calculate retry-after header value
1252
+ retry_after_header = int(max(1.0, retry_after_seconds))
1253
+
1254
+ # Create standardized error response
1255
+ error_data = {
1256
+ "error": {
1257
+ "type": "rate_limit_exceeded",
1258
+ "code": "RATE_LIMIT_EXCEEDED",
1259
+ "message": error_message,
1260
+ "details": {
1261
+ "tool_name": tool_name,
1262
+ "limit_type": limit_type,
1263
+ "retry_after_seconds": retry_after_seconds,
1264
+ "technical_details": technical_details
1265
+ }
1266
+ },
1267
+ "request_id": request_id,
1268
+ "timestamp": datetime.now().isoformat(),
1269
+ "session_id": session_id
1270
+ }
1271
+
1272
+ # Set appropriate headers
1273
+ headers = {
1274
+ "Retry-After": str(retry_after_header), # HTTP standard header
1275
+ "X-RateLimit-Limit-Type": limit_type,
1276
+ "X-RateLimit-Tool": tool_name,
1277
+ "X-RateLimit-Retry-After": str(retry_after_seconds),
1278
+ "Content-Type": "application/json"
1279
+ }
1280
+
1281
+ return JSONResponse(
1282
+ status_code=HTTPStatus.TOO_MANY_REQUESTS, # Too Many Requests
1283
+ content=error_data,
1284
+ headers=headers
1285
+ )
1286
+
1287
+
1288
+ async def handle_mcp_request(request: Request) -> StreamingResponse:
1289
+ """Main MCP request handler with session management and tool execution"""
1290
+
1291
+ try:
1292
+ # Check content length before reading body
1293
+ content_length = request.headers.get("content-length")
1294
+ if content_length:
1295
+ content_size_mb = int(content_length) / (1024 * 1024)
1296
+ if content_size_mb > config.max_request_size_mb:
1297
+ logger.warning(f"Request too large: {content_size_mb:.2f}MB > {config.max_request_size_mb}MB")
1298
+ return create_error_response(None, JsonRpcErr.PARSE_ERROR, f"Request too large: {content_size_mb:.2f}MB")
1299
+
1300
+ # Parse request with timeout protection
1301
+ try:
1302
+ body = await asyncio.wait_for(request.body(), timeout=config.request_timeout_seconds)
1303
+ except asyncio.TimeoutError:
1304
+ logger.error("Timeout while reading request body")
1305
+ return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Request body read timeout")
1306
+
1307
+ if not body:
1308
+ return create_error_response(None, JsonRpcErr.PARSE_ERROR, "Empty request body")
1309
+
1310
+ try:
1311
+ data = json.loads(body.decode('utf-8'))
1312
+ except json.JSONDecodeError as e:
1313
+ return create_error_response(None, JsonRpcErr.PARSE_ERROR, f"Invalid JSON: {str(e)}")
1314
+
1315
+ # Validate MCP request structure
1316
+ is_valid, error_msg = RequestValidator.validate_mcp_request(data)
1317
+ if not is_valid:
1318
+ return create_error_response(data.get("id"), JsonRpcErr.INVALID_REQUEST, error_msg)
1319
+
1320
+ request_id = data["id"]
1321
+ method = data["method"]
1322
+ params = data.get("params", {})
1323
+
1324
+ # Get or create session
1325
+ session_id = request.headers.get("X-Session-ID")
1326
+ client_ip = request.client.host if request.client else "unknown"
1327
+
1328
+ session = await session_manager.get_or_create_session(session_id)
1329
+ logger.info(f"Processing {method} request for session {session.id} from {client_ip}")
1330
+
1331
+ # Handle different MCP methods
1332
+ if method == "initialize":
1333
+ # MCP initialization
1334
+ response_data = {
1335
+ "jsonrpc": "2.0",
1336
+ "result": {
1337
+ "protocolVersion": "2025-06-18",
1338
+ "capabilities": {
1339
+ "tools": {"supportsProgress": True},
1340
+ "resources": {},
1341
+ "prompts": {}
1342
+ },
1343
+ "serverInfo": {
1344
+ "name": "DeepDiver-Demo-MCP",
1345
+ "version": "1.0.0"
1346
+ }
1347
+ },
1348
+ "id": request_id
1349
+ }
1350
+
1351
+ elif method == "tools/list":
1352
+ # List available tools using detailed schemas from get_tool_schemas()
1353
+ tools_list = []
1354
+ detailed_schemas = get_tool_schemas()
1355
+
1356
+ # Build tools list from schemas
1357
+ for _, detailed_schema in detailed_schemas.items():
1358
+ tools_list.append({
1359
+ "name": detailed_schema["name"],
1360
+ "description": detailed_schema["description"],
1361
+ "inputSchema": detailed_schema["inputSchema"]
1362
+ })
1363
+
1364
+ logger.info(f"Serving {len(tools_list)} tools with detailed schemas to client")
1365
+
1366
+ response_data = {
1367
+ "jsonrpc": "2.0",
1368
+ "result": {"tools": tools_list},
1369
+ "id": request_id
1370
+ }
1371
+
1372
+ elif method == "tools/call":
1373
+ # Execute tool call
1374
+ is_valid, error_msg = RequestValidator.validate_tool_call(params)
1375
+ if not is_valid:
1376
+ return create_error_response(request_id, JsonRpcErr.INVALID_PARAMS, error_msg, session.id)
1377
+
1378
+ tool_name = params["name"]
1379
+ tool_arguments = params["arguments"]
1380
+
1381
+ # Execute tool in session context asynchronously
1382
+ result = await _call_session_tool_async(session, tool_name, tool_arguments, client_ip)
1383
+
1384
+ # CHECK FOR RATE LIMITING AND RETURN PROPER HTTP STATUS
1385
+ if result.get("rate_limited", False):
1386
+ return create_rate_limit_response(
1387
+ request_id=request_id,
1388
+ tool_name=tool_name,
1389
+ error_message=result.get("error", "Rate limit exceeded"),
1390
+ retry_after_seconds=result.get("retry_after_seconds", 60),
1391
+ limit_type=result.get("limit_type", "unknown"),
1392
+ technical_details=result.get("technical_details", ""),
1393
+ session_id=session.id
1394
+ )
1395
+
1396
+ # Format normal response
1397
+ response_data = {
1398
+ "jsonrpc": "2.0",
1399
+ "result": {
1400
+ "content": [
1401
+ {
1402
+ "type": "text",
1403
+ "text": json.dumps(result, indent=2, ensure_ascii=False)
1404
+ }
1405
+ ]
1406
+ },
1407
+ "id": request_id
1408
+ }
1409
+
1410
+ else:
1411
+ return create_error_response(request_id, JsonRpcErr.METHOD_NOT_FOUND, f"Method not found: {method}", session.id)
1412
+
1413
+ return create_sse_response(response_data, session.id)
1414
+
1415
+ except asyncio.TimeoutError:
1416
+ logger.warning("Request timeout - client may have disconnected")
1417
+ return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Request timeout")
1418
+ except Exception as e:
1419
+ # Handle client disconnects gracefully
1420
+ if "ClientDisconnect" in str(e) or "ConnectionClosedError" in str(e):
1421
+ logger.warning(f"Client disconnected during request processing: {e}")
1422
+ return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Client disconnected")
1423
+
1424
+ logger.error(f"Unexpected error in MCP request handler: {e}")
1425
+ import traceback
1426
+ logger.error(traceback.format_exc())
1427
+ return create_error_response(None, JsonRpcErr.INTERNAL_ERROR, f"Internal server error: {str(e)}")
1428
+
1429
+
1430
+ async def handle_health_check(request: Request) -> JSONResponse:
1431
+ """Health check endpoint"""
1432
+ try:
1433
+ stats = await session_manager.get_stats() if session_manager else {}
1434
+
1435
+ # Get rate limiting summary
1436
+ rate_limit_summary = {}
1437
+ if global_tool_rate_limiter:
1438
+ all_stats = global_tool_rate_limiter.get_all_stats()
1439
+ rate_limit_summary = {
1440
+ "enabled": True,
1441
+ "tools_with_limits": len(all_stats),
1442
+ "total_configured_tools": list(all_stats.keys())
1443
+ }
1444
+ else:
1445
+ rate_limit_summary = {"enabled": False}
1446
+
1447
+ health_data = {
1448
+ "status": "healthy",
1449
+ "timestamp": datetime.now().isoformat(),
1450
+ "version": "1.0.0",
1451
+ "session_stats": stats,
1452
+ "features": {
1453
+ "workspace_isolation": True,
1454
+ "tool_call_tracking": config.enable_tool_tracking if config else False,
1455
+ "client_rate_limiting": True,
1456
+ "global_tool_rate_limiting": rate_limit_summary["enabled"],
1457
+ "security_middleware": True,
1458
+ "standardized_rate_limit_responses": True
1459
+ },
1460
+ "rate_limiting": rate_limit_summary,
1461
+ "error_formats": {
1462
+ "rate_limit_exceeded": {
1463
+ "http_status": HTTPStatus.TOO_MANY_REQUESTS,
1464
+ "headers": ["Retry-After", "X-RateLimit-*"],
1465
+ "error_code": "RATE_LIMIT_EXCEEDED",
1466
+ "response_format": "application/json"
1467
+ }
1468
+ }
1469
+ }
1470
+
1471
+ return JSONResponse(content=health_data)
1472
+
1473
+ except Exception as e:
1474
+ return JSONResponse(
1475
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
1476
+ content={"status": "unhealthy", "error": str(e)}
1477
+ )
1478
+
1479
+
1480
+ async def handle_tracking_info(request: Request) -> JSONResponse:
1481
+ """Get tool call tracking information for a session"""
1482
+ try:
1483
+ session_id = request.query_params.get("session_id")
1484
+ if not session_id:
1485
+ return JSONResponse(
1486
+ status_code=HTTPStatus.BAD_REQUEST,
1487
+ content={"error": "session_id parameter required"}
1488
+ )
1489
+
1490
+ session = await session_manager.get_session(session_id)
1491
+ if not session:
1492
+ return JSONResponse(
1493
+ status_code=HTTPStatus.NOT_FOUND,
1494
+ content={"error": f"Session {session_id} not found"}
1495
+ )
1496
+
1497
+ tracker = session.get_tool_tracker()
1498
+ if not tracker:
1499
+ return JSONResponse(
1500
+ content={
1501
+ "session_id": session_id,
1502
+ "tracking_enabled": False,
1503
+ "message": "Tool call tracking not enabled or no workspace"
1504
+ }
1505
+ )
1506
+
1507
+ # Read session summary
1508
+ summary_data = {}
1509
+ if tracker.summary_file.exists():
1510
+ try:
1511
+ with open(tracker.summary_file, 'r') as f:
1512
+ summary_data = json.load(f)
1513
+ except Exception as e:
1514
+ logger.error(f"Failed to read session summary: {e}")
1515
+
1516
+ return JSONResponse(content={
1517
+ "session_id": session_id,
1518
+ "tracking_enabled": True,
1519
+ "summary": summary_data,
1520
+ "logs_directory": str(tracker.logs_dir),
1521
+ "current_log_file": str(tracker.current_log_file)
1522
+ })
1523
+
1524
+ except Exception as e:
1525
+ return JSONResponse(
1526
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
1527
+ content={"error": str(e)}
1528
+ )
1529
+
1530
+
1531
+
1532
+ async def handle_rate_limit_stats(request: Request) -> JSONResponse:
1533
+ """Get global tool rate limiting statistics"""
1534
+ try:
1535
+ if not global_tool_rate_limiter:
1536
+ return JSONResponse(
1537
+ status_code=HTTPStatus.NOT_FOUND,
1538
+ content={"error": "Global tool rate limiter not initialized"}
1539
+ )
1540
+
1541
+ # Check if specific tool requested
1542
+ tool_name = request.query_params.get("tool")
1543
+
1544
+ if tool_name:
1545
+ # Get stats for specific tool
1546
+ stats = await global_tool_rate_limiter.get_tool_stats(tool_name)
1547
+ return JSONResponse(content=stats)
1548
+ else:
1549
+ # Get stats for all tools
1550
+ all_stats = global_tool_rate_limiter.get_all_stats()
1551
+ return JSONResponse(content={
1552
+ "timestamp": datetime.now().isoformat(),
1553
+ "global_tool_rate_limiting": True,
1554
+ "tools": all_stats,
1555
+ "summary": {
1556
+ "total_tools_with_limits": len(all_stats),
1557
+ "tools_configured": list(all_stats.keys())
1558
+ }
1559
+ })
1560
+
1561
+ except Exception as e:
1562
+ logger.error(f"Failed to get rate limit stats: {e}")
1563
+ return JSONResponse(
1564
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
1565
+ content={"error": str(e)}
1566
+ )
1567
+
1568
+
1569
+ def create_app() -> Starlette:
1570
+ """Create and configure the Starlette application"""
1571
+ global session_manager, rate_limiter, global_tool_rate_limiter
1572
+
1573
+ if not config:
1574
+ raise RuntimeError("Server configuration not initialized")
1575
+
1576
+ # Initialize global components
1577
+ session_manager = ThreadSafeSessionManager(
1578
+ ttl_seconds=config.session_ttl_seconds,
1579
+ max_sessions=config.max_sessions,
1580
+ base_workspace_dir=config.base_workspace_dir
1581
+ )
1582
+ rate_limiter = RateLimiter(config.rate_limit_requests_per_minute)
1583
+
1584
+ # Initialize global tool rate limiter
1585
+ if config.tool_rate_limits:
1586
+ global_tool_rate_limiter = GlobalToolRateLimiter(config.tool_rate_limits)
1587
+ logger.info(f"Initialized global tool rate limiter with {len(config.tool_rate_limits)} tool limits")
1588
+ else:
1589
+ logger.info("No tool rate limits configured - tools will run without global rate limiting")
1590
+
1591
+ # Create app
1592
+ app = Starlette(debug=config.debug_mode)
1593
+
1594
+ app.add_middleware(SecurityMiddleware)
1595
+ app.add_middleware(RateLimitMiddleware, input_rate_limiter=rate_limiter)
1596
+
1597
+ # Add routes
1598
+ app.add_route("/mcp", handle_mcp_request, methods=["POST"])
1599
+ app.add_route("/health", handle_health_check, methods=["GET"])
1600
+ app.add_route("/tracking", handle_tracking_info, methods=["GET"])
1601
+ app.add_route("/rate-limits", handle_rate_limit_stats, methods=["GET"])
1602
+
1603
+ return app
1604
+
1605
+
1606
+ def parse_arguments():
1607
+ """Parse command line arguments"""
1608
+ parser = argparse.ArgumentParser(
1609
+ description="Demo-Ready MCP Server with Per-Tool Rate Limiting",
1610
+ formatter_class=argparse.RawDescriptionHelpFormatter,
1611
+ epilog="""
1612
+ Examples:
1613
+ python src/tools/mcp_server_standard.py --config src/tools/server_config.yaml
1614
+ python src/tools/mcp_server_standard.py --host 127.0.0.1 --port 8080
1615
+ python src/tools/mcp_server_standard.py --config custom_config.yaml --debug
1616
+ """
1617
+ )
1618
+
1619
+ parser.add_argument(
1620
+ '--config', '-c',
1621
+ type=str,
1622
+ help='Path to YAML configuration file'
1623
+ )
1624
+
1625
+ parser.add_argument(
1626
+ '--host',
1627
+ type=str,
1628
+ help='Server host (overrides config file)'
1629
+ )
1630
+
1631
+ parser.add_argument(
1632
+ '--port', '-p',
1633
+ type=int,
1634
+ help='Server port (overrides config file)'
1635
+ )
1636
+
1637
+ parser.add_argument(
1638
+ '--debug',
1639
+ action='store_true',
1640
+ help='Enable debug mode (overrides config file)'
1641
+ )
1642
+
1643
+ parser.add_argument(
1644
+ '--workspace-dir',
1645
+ type=str,
1646
+ help='Base workspace directory (overrides config file)'
1647
+ )
1648
+
1649
+ return parser.parse_args()
1650
+
1651
+
1652
+ def print_startup_info():
1653
+ """Print server startup information"""
1654
+ logger.info("🚀 DeepDiver Demo MCP Server")
1655
+ logger.info("=" * 50)
1656
+ logger.info(f"📊 Features:")
1657
+ logger.info(f" • Session Management: ✅ (TTL: {config.session_ttl_seconds}s)")
1658
+ logger.info(f" • Workspace Isolation: ✅ (Base: {config.base_workspace_dir})")
1659
+ logger.info(f" • Tool Call Tracking: {'✅' if config.enable_tool_tracking else '❌'}")
1660
+ logger.info(f" • Client Rate Limiting: ✅ ({config.rate_limit_requests_per_minute}/min)")
1661
+ logger.info(f" • Global Tool Rate Limiting: {'✅' if config.tool_rate_limits else '❌'}")
1662
+ logger.info(f" • Security Middleware: ✅")
1663
+
1664
+ # Tool rate limiting information
1665
+ if config.tool_rate_limits:
1666
+ logger.info(f"🚦 Tool Rate Limits: {len(config.tool_rate_limits)} tools configured")
1667
+ for tool_name, limits in list(config.tool_rate_limits.items())[:3]:
1668
+ burst = limits.get('burst_limit', '∞')
1669
+ rpm = limits.get('requests_per_minute', '∞')
1670
+ logger.info(f" • {tool_name}: {rpm}/min, burst: {burst}")
1671
+ if len(config.tool_rate_limits) > 3:
1672
+ logger.info(f" • ... and {len(config.tool_rate_limits) - 3} more tools")
1673
+
1674
+ # Tool information from schemas
1675
+ tool_schemas = get_tool_schemas()
1676
+ available_tools = list(tool_schemas.keys())
1677
+
1678
+ logger.info(f"🔧 Tools Available: {len(available_tools)}")
1679
+ logger.info(f" • All tools defined in schemas: {len(available_tools)} tools")
1680
+ logger.info(f" • Sample tools: {', '.join(sorted(available_tools)[:5])}...")
1681
+ logger.info("=" * 50)
1682
+
1683
+
1684
+ def main():
1685
+ """Main function to run the production MCP server"""
1686
+ global config
1687
+
1688
+ # Parse command line arguments
1689
+ args = parse_arguments()
1690
+
1691
+ config = ServerConfig.from_yaml("./src/tools/server_config.yaml")
1692
+
1693
+ # Apply CLI overrides
1694
+ if args.host:
1695
+ config.host = args.host
1696
+ logger.info(f"🔧 Override: Host = {config.host}")
1697
+
1698
+ if args.port:
1699
+ config.port = args.port
1700
+ logger.info(f"🔧 Override: Port = {config.port}")
1701
+
1702
+ if args.debug:
1703
+ config.debug_mode = True
1704
+ logger.info(f"🔧 Override: Debug mode enabled")
1705
+
1706
+ if args.workspace_dir:
1707
+ config.base_workspace_dir = args.workspace_dir
1708
+ logger.info(f"🔧 Override: Workspace directory = {config.base_workspace_dir}")
1709
+
1710
+ print_startup_info()
1711
+
1712
+ try:
1713
+ import os
1714
+
1715
+ # Calculate optimal worker count for high-concurrency FIRST
1716
+ # Use CPU core count indirectly via uvicorn's defaults; no local variable needed
1717
+
1718
+ # Override for high-concurrency scenarios
1719
+ if os.getenv('FORCE_HIGH_CONCURRENCY', '').lower() == 'true':
1720
+ pass # Configuration handled elsewhere if needed
1721
+
1722
+ app = create_app()
1723
+
1724
+ logger.info(f"🌐 Starting server at http://{config.host}:{config.port}")
1725
+ logger.info(f"📡 MCP endpoint: http://{config.host}:{config.port}/mcp")
1726
+ logger.info(f"🏥 Health check: http://{config.host}:{config.port}/health")
1727
+ logger.info(f"📊 Tracking info: http://{config.host}:{config.port}/tracking?session_id=<id>")
1728
+ logger.info(f"🚦 Rate limit stats: http://{config.host}:{config.port}/rate-limits")
1729
+
1730
+ uvicorn.run(
1731
+ app, # Use app instance directly for single worker with async optimizations
1732
+ host=config.host,
1733
+ port=config.port,
1734
+ log_level="info",
1735
+ timeout_keep_alive=config.request_timeout_seconds,
1736
+ workers=1, # Single worker with async optimizations
1737
+ backlog=1024, # Larger backlog for high-concurrency
1738
+ access_log=False, # Disable access logs for better performance
1739
+ limit_concurrency=None, # No artificial concurrency limit
1740
+ )
1741
+
1742
+ except KeyboardInterrupt:
1743
+ print("\n⏹️ Server stopped by user")
1744
+ except Exception as e:
1745
+ print(f"❌ Server startup failed: {e}")
1746
+ import traceback
1747
+ traceback.print_exc()
1748
+ raise
1749
+
1750
+ if __name__ == "__main__":
1751
+ main()
deepdiver_v2/src/tools/mcp_tools.py ADDED
The diff for this file is too large to render. See raw diff
 
deepdiver_v2/src/tools/server_config.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =================================================================
2
+ # DeepDiver MCP Server Configuration
3
+ # =================================================================
4
+ # This file contains ONLY the configuration options that are actually
5
+ # implemented and used by the server. No unused options!
6
+
7
+ # =================================================================
8
+ # SERVER CORE SETTINGS
9
+ # =================================================================
10
+ server:
11
+ # Network Configuration
12
+ host: "127.0.0.1" # Server bind address
13
+ port: 6274 # Server port
14
+ debug_mode: false # Enable debug logging and error details
15
+
16
+ # Session Management
17
+ session_ttl_seconds: 21600 # Session timeout (6 hours)
18
+ max_sessions: 1000 # Maximum concurrent sessions
19
+ cleanup_interval_seconds: 600 # How often to clean expired sessions (5 min)
20
+ enable_session_keepalive: true # Keep sessions alive during long operations
21
+ keepalive_touch_interval: 300 # Touch session every N seconds during long ops
22
+
23
+ # Request Handling
24
+ request_timeout_seconds: 1800 # Request timeout
25
+ max_request_size_mb: 1000 # Maximum request size
26
+
27
+ # Client Rate Limiting (per IP address)
28
+ rate_limit_requests_per_minute: 300000 # Requests per minute per client IP
29
+
30
+ # Workspace Management
31
+ base_workspace_dir: "workspaces" # Base directory for session workspaces
32
+
33
+ # =================================================================
34
+ # TOOL CALL TRACKING & LOGGING
35
+ # =================================================================
36
+ tracking:
37
+ enable_tool_tracking: true # Enable detailed tool call logging
38
+ max_tracked_calls_per_session: 10000 # Limit tool calls logged per session
39
+ track_detailed_errors: true # Include full error details in logs
40
+
41
+
42
+ # =================================================================
43
+ # GLOBAL PER-TOOL RATE LIMITING
44
+ # =================================================================
45
+ # These limits control requests to external APIs to avoid hitting provider limits.
46
+ # They are shared across ALL sessions and clients.
47
+ #
48
+ # Each tool can have these limits:
49
+ # - requests_per_second: QPS limit
50
+ # - requests_per_minute: Per-minute limit
51
+ # - requests_per_hour: Hourly limit
52
+ # - burst_limit: Short-term burst allowance
53
+ #
54
+ # Omit a limit to disable it (infinite). All limits are optional.
55
+
56
+ tool_rate_limits:
57
+ # API-based tools with external service limits
58
+ batch_web_search:
59
+ requests_per_minute: 9000
60
+ burst_limit: 35
61
+
62
+ url_crawler:
63
+ requests_per_minute: 9000
64
+ burst_limit: 60
65
+
66
+ document_qa:
67
+ requests_per_minute: 15000
68
+ burst_limit: 150
69
+
70
+ document_extract:
71
+ requests_per_minute: 15000
72
+ burst_limit: 150
73
+
deepdiver_v2/src/utils/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ #!/usr/bin/env python3
3
+
4
+ from .status_codes import JsonRpcErr
5
+
6
+ __all__ = [
7
+ 'JsonRpcErr',
8
+ ]
deepdiver_v2/src/utils/status_codes.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ #!/usr/bin/env python3
3
+ from enum import IntEnum
4
+
5
+ class JsonRpcErr(IntEnum):
6
+ PARSE_ERROR = -32700
7
+ INVALID_REQUEST = -32600
8
+ METHOD_NOT_FOUND = -32601
9
+ INVALID_PARAMS = -32602
10
+ INTERNAL_ERROR = -32603
11
+ REQUEST_TIMEOUT = -32000
12
+
deepdiver_v2/src/workspace/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ """
3
+ Workspace management module for the DeepDiver Multi-Agent System.
4
+
5
+ This module provides local workspace management capabilities that don't require
6
+ external dependencies like E2B. Each chat session gets its own isolated workspace
7
+ directory for file operations and data persistence.
8
+ """
9
+
10
+ from .local_workspace_manager import (
11
+ LocalWorkspaceManager,
12
+ WorkspaceInfo,
13
+ WorkspaceStatus,
14
+ get_workspace_manager,
15
+ initialize_workspace_manager,
16
+ shutdown_workspace_manager
17
+ )
18
+
19
+ __all__ = [
20
+ 'LocalWorkspaceManager',
21
+ 'WorkspaceInfo',
22
+ 'WorkspaceStatus',
23
+ 'get_workspace_manager',
24
+ 'initialize_workspace_manager',
25
+ 'shutdown_workspace_manager'
26
+ ]
deepdiver_v2/src/workspace/local_workspace_manager.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
2
+ """
3
+ Local Workspace Manager for Multi-Agent System
4
+
5
+ This module provides session-based workspace management using local directories.
6
+ Each chat session gets its own isolated workspace directory that persists
7
+ throughout the conversation and can be cleaned up when the session ends.
8
+
9
+ Features:
10
+ - Session-based workspace lifecycle management
11
+ - Local directory isolation per session
12
+ - File operations within session workspaces
13
+ - Integration with existing MCP tools
14
+ - Comprehensive error handling and logging
15
+ """
16
+
17
+ import shutil
18
+ import logging
19
+ from typing import Dict, Optional, Any, List, Union
20
+ from datetime import datetime, timedelta
21
+ from pathlib import Path
22
+ from dataclasses import dataclass, field
23
+ from enum import Enum
24
+ import json
25
+
26
+ # Configure logging
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class WorkspaceStatus(Enum):
31
+ """Workspace lifecycle status"""
32
+ CREATING = "creating"
33
+ ACTIVE = "active"
34
+ DESTROYING = "destroying"
35
+ DESTROYED = "destroyed"
36
+ ERROR = "error"
37
+
38
+
39
+ @dataclass
40
+ class WorkspaceInfo:
41
+ """Information about a workspace instance"""
42
+ workspace_id: str
43
+ session_id: str
44
+ workspace_path: Path
45
+ created_at: datetime
46
+ last_activity: datetime
47
+ status: WorkspaceStatus
48
+ workspace_files: List[str] = field(default_factory=list)
49
+ metadata: Dict[str, Any] = field(default_factory=dict)
50
+ error_message: Optional[str] = None
51
+
52
+ def to_dict(self) -> Dict[str, Any]:
53
+ """Convert to dictionary for serialization"""
54
+ return {
55
+ "workspace_id": self.workspace_id,
56
+ "session_id": self.session_id,
57
+ "workspace_path": str(self.workspace_path),
58
+ "created_at": self.created_at.isoformat(),
59
+ "last_activity": self.last_activity.isoformat(),
60
+ "status": self.status.value,
61
+ "workspace_files": self.workspace_files,
62
+ "metadata": self.metadata,
63
+ "error_message": self.error_message
64
+ }
65
+
66
+
67
+ class LocalWorkspaceManager:
68
+ """
69
+ Manages local workspaces for multi-agent chat sessions.
70
+
71
+ Each chat session gets its own isolated workspace directory that persists
72
+ throughout the conversation. Workspaces are automatically managed
73
+ with cleanup capabilities.
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ base_workspace_dir: str = "workspaces",
79
+ default_timeout: int = 86400, # 24 hours default
80
+ cleanup_on_exit: bool = False # Don't auto-cleanup by default
81
+ ):
82
+ """
83
+ Initialize the workspace manager.
84
+
85
+ Args:
86
+ base_workspace_dir: Base directory for all workspaces
87
+ default_timeout: Default workspace timeout in seconds
88
+ cleanup_on_exit: Whether to cleanup workspaces on manager shutdown
89
+ """
90
+ self.base_workspace_dir = Path(base_workspace_dir)
91
+ self.base_workspace_dir.mkdir(exist_ok=True)
92
+ self.default_timeout = default_timeout
93
+ self.cleanup_on_exit = cleanup_on_exit
94
+
95
+ # Active workspaces by session ID
96
+ self.workspaces: Dict[str, WorkspaceInfo] = {}
97
+
98
+ # Load existing workspaces from metadata
99
+ self._load_existing_workspaces()
100
+
101
+ logger.info(f"LocalWorkspaceManager initialized with base_dir={base_workspace_dir}")
102
+
103
+ def _load_existing_workspaces(self):
104
+ """Load existing workspaces from metadata files"""
105
+ try:
106
+ for workspace_dir in self.base_workspace_dir.iterdir():
107
+ if workspace_dir.is_dir():
108
+ metadata_file = workspace_dir / ".workspace_metadata.json"
109
+ if metadata_file.exists():
110
+ try:
111
+ with open(metadata_file, 'r') as f:
112
+ data = json.load(f)
113
+
114
+ workspace_info = WorkspaceInfo(
115
+ workspace_id=data["workspace_id"],
116
+ session_id=data["session_id"],
117
+ workspace_path=Path(data["workspace_path"]),
118
+ created_at=datetime.fromisoformat(data["created_at"]),
119
+ last_activity=datetime.fromisoformat(data["last_activity"]),
120
+ status=WorkspaceStatus(data["status"]),
121
+ workspace_files=data.get("workspace_files", []),
122
+ metadata=data.get("metadata", {}),
123
+ error_message=data.get("error_message")
124
+ )
125
+
126
+ self.workspaces[workspace_info.session_id] = workspace_info
127
+ logger.info(f"Loaded existing workspace for session {workspace_info.session_id}")
128
+
129
+ except Exception as e:
130
+ logger.warning(f"Failed to load workspace metadata from {metadata_file}: {e}")
131
+
132
+ except Exception as e:
133
+ logger.warning(f"Failed to load existing workspaces: {e}")
134
+
135
+ @staticmethod
136
+ def _save_workspace_metadata(workspace_info: WorkspaceInfo):
137
+ """Save workspace metadata to disk"""
138
+ try:
139
+ metadata_file = workspace_info.workspace_path / ".workspace_metadata.json"
140
+ with open(metadata_file, 'w') as f:
141
+ json.dump(workspace_info.to_dict(), f, indent=2)
142
+ except Exception as e:
143
+ logger.error(f"Failed to save workspace metadata: {e}")
144
+
145
+ def create_workspace(
146
+ self,
147
+ session_id: str,
148
+ workspace_id: Optional[str] = None,
149
+ metadata: Optional[Dict[str, Any]] = None
150
+ ) -> WorkspaceInfo:
151
+ """
152
+ Create a new workspace for a chat session.
153
+
154
+ Args:
155
+ session_id: Unique session identifier
156
+ workspace_id: Optional custom workspace ID (defaults to session_id)
157
+ metadata: Additional metadata to store with the workspace
158
+
159
+ Returns:
160
+ WorkspaceInfo: Information about the created workspace
161
+
162
+ Raises:
163
+ ValueError: If session already has an active workspace
164
+ Exception: If workspace creation fails
165
+ """
166
+ if session_id in self.workspaces:
167
+ raise ValueError(f"Session {session_id} already has an active workspace")
168
+
169
+ workspace_id = workspace_id or session_id
170
+ workspace_path = self.base_workspace_dir / workspace_id
171
+
172
+ logger.info(f"Creating workspace for session {session_id} at {workspace_path}")
173
+
174
+ # Create workspace info with creating status
175
+ workspace_info = WorkspaceInfo(
176
+ workspace_id=workspace_id,
177
+ session_id=session_id,
178
+ workspace_path=workspace_path,
179
+ created_at=datetime.now(),
180
+ last_activity=datetime.now(),
181
+ status=WorkspaceStatus.CREATING,
182
+ metadata=metadata or {}
183
+ )
184
+
185
+ try:
186
+ # Create workspace directory
187
+ workspace_path.mkdir(parents=True, exist_ok=True)
188
+
189
+ # Create subdirectories
190
+ (workspace_path / "downloads").mkdir(exist_ok=True)
191
+ (workspace_path / "outputs").mkdir(exist_ok=True)
192
+ (workspace_path / "temp").mkdir(exist_ok=True)
193
+
194
+ # Update status
195
+ workspace_info.status = WorkspaceStatus.ACTIVE
196
+ self.workspaces[session_id] = workspace_info
197
+
198
+ # Save metadata
199
+ self._save_workspace_metadata(workspace_info)
200
+
201
+ # Update workspace files list
202
+ self._update_workspace_files(session_id)
203
+
204
+ logger.info(f"Workspace created successfully: {workspace_path} for session {session_id}")
205
+ return workspace_info
206
+
207
+ except Exception as e:
208
+ workspace_info.status = WorkspaceStatus.ERROR
209
+ workspace_info.error_message = str(e)
210
+ logger.error(f"Failed to create workspace for session {session_id}: {e}")
211
+ raise
212
+
213
+ def get_workspace(self, session_id: str) -> Optional[WorkspaceInfo]:
214
+ """Get workspace info for a session"""
215
+ workspace_info = self.workspaces.get(session_id)
216
+ if workspace_info:
217
+ # Update last activity
218
+ workspace_info.last_activity = datetime.now()
219
+ self._save_workspace_metadata(workspace_info)
220
+ return workspace_info
221
+
222
+ def get_workspace_path(self, session_id: str) -> Optional[Path]:
223
+ """Get workspace path for a session"""
224
+ workspace_info = self.get_workspace(session_id)
225
+ return workspace_info.workspace_path if workspace_info else None
226
+
227
+ def list_sessions(self) -> List[str]:
228
+ """List all active session IDs"""
229
+ return list(self.workspaces.keys())
230
+
231
+ def destroy_workspace(self, session_id: str, force: bool = False) -> bool:
232
+ """
233
+ Destroy a workspace for a session.
234
+
235
+ Args:
236
+ session_id: Session identifier
237
+ force: Force removal even if files exist
238
+
239
+ Returns:
240
+ bool: True if destroyed successfully
241
+ """
242
+ if session_id not in self.workspaces:
243
+ logger.warning(f"No workspace found for session {session_id}")
244
+ return False
245
+
246
+ workspace_info = self.workspaces[session_id]
247
+
248
+ try:
249
+ logger.info(f"Destroying workspace for session {session_id}")
250
+ workspace_info.status = WorkspaceStatus.DESTROYING
251
+
252
+ # Remove workspace directory
253
+ if workspace_info.workspace_path.exists():
254
+ if force or not any(workspace_info.workspace_path.iterdir()):
255
+ shutil.rmtree(workspace_info.workspace_path)
256
+ logger.info(f"Workspace directory removed: {workspace_info.workspace_path}")
257
+ else:
258
+ logger.warning(f"Workspace contains files, use force=True to remove: {workspace_info.workspace_path}")
259
+ return False
260
+
261
+ # Update status and remove from active workspaces
262
+ workspace_info.status = WorkspaceStatus.DESTROYED
263
+ del self.workspaces[session_id]
264
+
265
+ logger.info(f"Workspace destroyed for session {session_id}")
266
+ return True
267
+
268
+ except Exception as e:
269
+ workspace_info.status = WorkspaceStatus.ERROR
270
+ workspace_info.error_message = str(e)
271
+ logger.error(f"Failed to destroy workspace for session {session_id}: {e}")
272
+ return False
273
+
274
+ def write_file(self, session_id: str, file_path: str, content: Union[str, bytes]) -> bool:
275
+ """Write content to a file in the workspace"""
276
+ workspace_info = self.get_workspace(session_id)
277
+ if not workspace_info:
278
+ logger.error(f"No workspace found for session {session_id}")
279
+ return False
280
+
281
+ try:
282
+ full_path = workspace_info.workspace_path / file_path
283
+ full_path.parent.mkdir(parents=True, exist_ok=True)
284
+
285
+ if isinstance(content, str):
286
+ with open(full_path, 'w', encoding='utf-8') as f:
287
+ f.write(content)
288
+ else:
289
+ with open(full_path, 'wb') as f:
290
+ f.write(content)
291
+
292
+ # Update workspace files list
293
+ self._update_workspace_files(session_id)
294
+
295
+ logger.info(f"File written to workspace: {file_path}")
296
+ return True
297
+
298
+ except Exception as e:
299
+ logger.error(f"Failed to write file {file_path} in workspace {session_id}: {e}")
300
+ return False
301
+
302
+ def read_file(self, session_id: str, file_path: str) -> Optional[Union[str, bytes]]:
303
+ """Read content from a file in the workspace"""
304
+ workspace_info = self.get_workspace(session_id)
305
+ if not workspace_info:
306
+ logger.error(f"No workspace found for session {session_id}")
307
+ return None
308
+
309
+ try:
310
+ full_path = workspace_info.workspace_path / file_path
311
+
312
+ if not full_path.exists():
313
+ logger.error(f"File not found: {file_path}")
314
+ return None
315
+
316
+ # Try to read as text first
317
+ try:
318
+ with open(full_path, 'r', encoding='utf-8') as f:
319
+ return f.read()
320
+ except UnicodeDecodeError:
321
+ # If text reading fails, read as bytes
322
+ with open(full_path, 'rb') as f:
323
+ return f.read()
324
+
325
+ except Exception as e:
326
+ logger.error(f"Failed to read file {file_path} from workspace {session_id}: {e}")
327
+ return None
328
+
329
+ def list_files(self, session_id: str, directory: str = "") -> List[str]:
330
+ """List files in the workspace directory"""
331
+ workspace_info = self.get_workspace(session_id)
332
+ if not workspace_info:
333
+ logger.error(f"No workspace found for session {session_id}")
334
+ return []
335
+
336
+ try:
337
+ target_path = workspace_info.workspace_path / directory if directory else workspace_info.workspace_path
338
+
339
+ if not target_path.exists():
340
+ return []
341
+
342
+ files = []
343
+ for item in target_path.rglob('*'):
344
+ if item.is_file() and not item.name.startswith('.'):
345
+ rel_path = item.relative_to(workspace_info.workspace_path)
346
+ files.append(str(rel_path))
347
+
348
+ return sorted(files)
349
+
350
+ except Exception as e:
351
+ logger.error(f"Failed to list files in workspace {session_id}: {e}")
352
+ return []
353
+
354
+ def _update_workspace_files(self, session_id: str):
355
+ """Update the list of workspace files for a session."""
356
+ try:
357
+ workspace_info = self.workspaces.get(session_id)
358
+ if workspace_info:
359
+ files = self.list_files(session_id)
360
+ workspace_info.workspace_files = files
361
+ self._save_workspace_metadata(workspace_info)
362
+ except Exception as e:
363
+ logger.debug("Failed to update workspace files for session {%s}: {%s}", session_id, e)
364
+
365
+ def cleanup_expired_workspaces(self, max_age_hours: int = 24):
366
+ """Clean up workspaces older than max_age_hours"""
367
+ cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
368
+ expired_sessions = []
369
+
370
+ for session_id, workspace_info in self.workspaces.items():
371
+ if workspace_info.last_activity < cutoff_time:
372
+ expired_sessions.append(session_id)
373
+
374
+ for session_id in expired_sessions:
375
+ logger.info(f"Cleaning up expired workspace for session {session_id}")
376
+ self.destroy_workspace(session_id, force=True)
377
+
378
+ def shutdown(self):
379
+ """Shutdown the workspace manager"""
380
+ logger.info("Shutting down LocalWorkspaceManager...")
381
+
382
+ if self.cleanup_on_exit:
383
+ # Clean up all workspaces
384
+ session_ids = list(self.workspaces.keys())
385
+ for session_id in session_ids:
386
+ self.destroy_workspace(session_id, force=True)
387
+ else:
388
+ # Just save metadata for all workspaces
389
+ for workspace_info in self.workspaces.values():
390
+ self._save_workspace_metadata(workspace_info)
391
+
392
+ logger.info("LocalWorkspaceManager shutdown complete")
393
+
394
+
395
+ # Global instance
396
+ _workspace_manager: Optional[LocalWorkspaceManager] = None
397
+
398
+
399
+ def get_workspace_manager(base_workspace_dir: str = "workspaces") -> LocalWorkspaceManager:
400
+ """Get or create the global workspace manager instance"""
401
+ global _workspace_manager
402
+ if _workspace_manager is None:
403
+ _workspace_manager = LocalWorkspaceManager(base_workspace_dir)
404
+ return _workspace_manager
405
+
406
+
407
+ def initialize_workspace_manager(base_workspace_dir: str = "workspaces", **kwargs) -> LocalWorkspaceManager:
408
+ """Initialize the workspace manager with custom settings"""
409
+ global _workspace_manager
410
+ _workspace_manager = LocalWorkspaceManager(base_workspace_dir, **kwargs)
411
+ logger.info(f"Workspace manager initialized with base directory: {base_workspace_dir}")
412
+ return _workspace_manager
413
+
414
+
415
+ def shutdown_workspace_manager():
416
+ """Shutdown the global workspace manager"""
417
+ global _workspace_manager
418
+ if _workspace_manager:
419
+ _workspace_manager.shutdown()
420
+ _workspace_manager = None
docs/openpangu-deepdiver-v2-tech-report.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9ff32d4bd7190ea26049ef1f5d009b9861c652a980623a5f5cd043e7dcec2a4
3
+ size 39847395
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "do_sample": true,
4
+ "bos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "eos_token_id": 45892,
7
+ "temperature": 1.0,
8
+ "top_k": 0,
9
+ "top_p": 0.8,
10
+ "transformers_version": "4.53.2"
11
+ }
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b8ec6cd94b1921560d37755c7c0c08280c1f9123195d14d352ad0607788f7f6
3
+ size 4926842416
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc05d80f52ce44d1433a942e867bf61ea49eb1eebb0700312f76d6b3a3dee917
3
+ size 4991686576
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ed37f38214c755b51bea06a71e154c9ea27670eb3b8506c06addcfbea2066f2
3
+ size 4886853760
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0145e255ba965ed0e75164a037b9a0137c5e5c12ffc42463ff82568054fe0186
3
+ size 1256456320
model.safetensors.index.json ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 16061784576
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00004-of-00004.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
13
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
14
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
15
+ "model.layers.0.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
16
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
17
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
18
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
19
+ "model.layers.0.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
20
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
21
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
22
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
23
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
24
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
25
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
26
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
27
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
28
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
29
+ "model.layers.1.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
30
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
31
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
32
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
33
+ "model.layers.1.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
34
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
35
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
36
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
37
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
38
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
39
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
40
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
41
+ "model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
42
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
43
+ "model.layers.10.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
44
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
45
+ "model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
46
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
47
+ "model.layers.10.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
48
+ "model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
49
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
50
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
51
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
52
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
53
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
54
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
55
+ "model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
56
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
57
+ "model.layers.11.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
58
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
59
+ "model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
60
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
61
+ "model.layers.11.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
62
+ "model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
63
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
64
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
65
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
66
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
67
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
68
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
69
+ "model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
70
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
71
+ "model.layers.12.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
72
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
73
+ "model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
74
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
75
+ "model.layers.12.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
76
+ "model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
77
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
78
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
79
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
80
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
81
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
82
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
83
+ "model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
84
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
85
+ "model.layers.13.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
86
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
87
+ "model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
88
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
89
+ "model.layers.13.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
90
+ "model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
91
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
92
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
93
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
94
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
95
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
96
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
97
+ "model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
98
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
99
+ "model.layers.14.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
100
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
101
+ "model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
102
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
103
+ "model.layers.14.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
104
+ "model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
105
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
106
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
107
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
108
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
109
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
110
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
111
+ "model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
112
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
113
+ "model.layers.15.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
114
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
115
+ "model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
116
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
117
+ "model.layers.15.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
118
+ "model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
119
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
120
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
121
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
122
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
123
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
124
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
125
+ "model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
126
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
127
+ "model.layers.16.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
128
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
129
+ "model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
130
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
131
+ "model.layers.16.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
132
+ "model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
133
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
134
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
135
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
136
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
137
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
138
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
139
+ "model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
140
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
141
+ "model.layers.17.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
142
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
143
+ "model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
144
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
145
+ "model.layers.17.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
146
+ "model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
147
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
148
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
149
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
150
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
151
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
152
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
153
+ "model.layers.18.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
154
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
155
+ "model.layers.18.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
156
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
157
+ "model.layers.18.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
158
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
159
+ "model.layers.18.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
160
+ "model.layers.18.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
161
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
162
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
163
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
164
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
165
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
166
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
167
+ "model.layers.19.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
168
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
169
+ "model.layers.19.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
170
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
171
+ "model.layers.19.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
172
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
173
+ "model.layers.19.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
174
+ "model.layers.19.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
175
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
176
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
177
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
178
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
179
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
180
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
181
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
182
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
183
+ "model.layers.2.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
184
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
185
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
186
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
187
+ "model.layers.2.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
188
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
189
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
190
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00004.safetensors",
191
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
192
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
193
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
194
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
195
+ "model.layers.20.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
196
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
197
+ "model.layers.20.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
198
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
199
+ "model.layers.20.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
200
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
201
+ "model.layers.20.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
202
+ "model.layers.20.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
203
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
204
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
205
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
206
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
207
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
208
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
209
+ "model.layers.21.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
210
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
211
+ "model.layers.21.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
212
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
213
+ "model.layers.21.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
214
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
215
+ "model.layers.21.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
216
+ "model.layers.21.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
217
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
218
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
219
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
220
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
221
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
222
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
223
+ "model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
224
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
225
+ "model.layers.22.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
226
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
227
+ "model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
228
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
229
+ "model.layers.22.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
230
+ "model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
231
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
232
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
233
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
234
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
235
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
236
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
237
+ "model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
238
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
239
+ "model.layers.23.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
240
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
241
+ "model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
242
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
243
+ "model.layers.23.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
244
+ "model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
245
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
246
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
247
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
248
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
249
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
250
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
251
+ "model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
252
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
253
+ "model.layers.24.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
254
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
255
+ "model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
256
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
257
+ "model.layers.24.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
258
+ "model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
259
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
260
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
261
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
262
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
263
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
264
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
265
+ "model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
266
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
267
+ "model.layers.25.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
268
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
269
+ "model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
270
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
271
+ "model.layers.25.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
272
+ "model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
273
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
274
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
275
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
276
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
277
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
278
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
279
+ "model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
280
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
281
+ "model.layers.26.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
282
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
283
+ "model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
284
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
285
+ "model.layers.26.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
286
+ "model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
287
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
288
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
289
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
290
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
291
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
292
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
293
+ "model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
294
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
295
+ "model.layers.27.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
296
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
297
+ "model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
298
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
299
+ "model.layers.27.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
300
+ "model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
301
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
302
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
303
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
304
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
305
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
306
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
307
+ "model.layers.28.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
308
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
309
+ "model.layers.28.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
310
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
311
+ "model.layers.28.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
312
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
313
+ "model.layers.28.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
314
+ "model.layers.28.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
315
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
316
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
317
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
318
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
319
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
320
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
321
+ "model.layers.29.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
322
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
323
+ "model.layers.29.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
324
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
325
+ "model.layers.29.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
326
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
327
+ "model.layers.29.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
328
+ "model.layers.29.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
329
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
330
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
331
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
332
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
333
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
334
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
335
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
336
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
337
+ "model.layers.3.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
338
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
339
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
340
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
341
+ "model.layers.3.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
342
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
343
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
344
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
345
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
346
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
347
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
348
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
349
+ "model.layers.30.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
350
+ "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
351
+ "model.layers.30.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
352
+ "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
353
+ "model.layers.30.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
354
+ "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
355
+ "model.layers.30.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
356
+ "model.layers.30.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
357
+ "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
358
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
359
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
360
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
361
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
362
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
363
+ "model.layers.31.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
364
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
365
+ "model.layers.31.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
366
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
367
+ "model.layers.31.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
368
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
369
+ "model.layers.31.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
370
+ "model.layers.31.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
371
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
372
+ "model.layers.32.input_layernorm.weight": "model-00003-of-00004.safetensors",
373
+ "model.layers.32.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
374
+ "model.layers.32.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
375
+ "model.layers.32.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
376
+ "model.layers.32.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
377
+ "model.layers.32.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
378
+ "model.layers.32.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
379
+ "model.layers.32.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
380
+ "model.layers.32.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
381
+ "model.layers.32.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
382
+ "model.layers.32.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
383
+ "model.layers.32.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
384
+ "model.layers.32.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
385
+ "model.layers.32.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
386
+ "model.layers.33.input_layernorm.weight": "model-00003-of-00004.safetensors",
387
+ "model.layers.33.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
388
+ "model.layers.33.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
389
+ "model.layers.33.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
390
+ "model.layers.33.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
391
+ "model.layers.33.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
392
+ "model.layers.33.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
393
+ "model.layers.33.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
394
+ "model.layers.33.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
395
+ "model.layers.33.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
396
+ "model.layers.33.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
397
+ "model.layers.33.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
398
+ "model.layers.33.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
399
+ "model.layers.33.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
400
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
401
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
402
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
403
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
404
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
405
+ "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
406
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
407
+ "model.layers.4.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
408
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
409
+ "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
410
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
411
+ "model.layers.4.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
412
+ "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
413
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
414
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
415
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
416
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
417
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
418
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
419
+ "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
420
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
421
+ "model.layers.5.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
422
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
423
+ "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
424
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
425
+ "model.layers.5.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
426
+ "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
427
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
428
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
429
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
430
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
431
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
432
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
433
+ "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
434
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
435
+ "model.layers.6.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
436
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
437
+ "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
438
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
439
+ "model.layers.6.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
440
+ "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
441
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
442
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
443
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
444
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
445
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
446
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
447
+ "model.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
448
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
449
+ "model.layers.7.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
450
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
451
+ "model.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
452
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
453
+ "model.layers.7.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
454
+ "model.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
455
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
456
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
457
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
458
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
459
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
460
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
461
+ "model.layers.8.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
462
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
463
+ "model.layers.8.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
464
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
465
+ "model.layers.8.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
466
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
467
+ "model.layers.8.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
468
+ "model.layers.8.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
469
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
470
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
471
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
472
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
473
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
474
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
475
+ "model.layers.9.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
476
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
477
+ "model.layers.9.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
478
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
479
+ "model.layers.9.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
480
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
481
+ "model.layers.9.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
482
+ "model.layers.9.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
483
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
484
+ "model.norm.weight": "model-00003-of-00004.safetensors"
485
+ }
486
+ }
modeling_openpangu_dense.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from modular_openpangu_dense.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_openpangu_dense.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+
8
+ # coding=utf-8
9
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
10
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
11
+ #
12
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
13
+ # and OPT implementations in this library. It has been modified from its
14
+ # original forms to accommodate minor architectural differences compared
15
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
16
+ #
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at
20
+ #
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+ #
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+
29
+ from typing import Callable, Optional, Union
30
+
31
+ import torch
32
+ from torch import nn
33
+
34
+ import torch_npu
35
+ from torch_npu.contrib import transfer_to_npu
36
+ if "910" in torch.npu.get_device_name():
37
+ NPU_ATTN_INFR = True
38
+ print("[INFO] torch_npu detected. Using NPU fused infer attention.")
39
+ else:
40
+ NPU_ATTN_INFR = False
41
+
42
+ from transformers.activations import ACT2FN
43
+ from transformers.cache_utils import Cache, DynamicCache
44
+ from transformers.generation import GenerationMixin
45
+ from transformers.masking_utils import create_causal_mask
46
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
47
+ from transformers.modeling_layers import GradientCheckpointingLayer
48
+ from transformers.modeling_outputs import (
49
+ BaseModelOutputWithPast,
50
+ CausalLMOutputWithPast,
51
+ SequenceClassifierOutputWithPast,
52
+ )
53
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
54
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
55
+ from transformers.processing_utils import Unpack
56
+ from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
57
+ from .configuration_openpangu_dense import PanguEmbeddedConfig
58
+
59
+
60
+ logger = logging.get_logger(__name__)
61
+
62
+
63
+ class PanguEmbeddedRMSNorm(nn.Module):
64
+ def __init__(self, hidden_size, eps=1e-6):
65
+ """
66
+ PanguEmbeddedRMSNorm is equivalent to T5LayerNorm
67
+ """
68
+ super().__init__()
69
+ self.weight = nn.Parameter(torch.ones(hidden_size))
70
+ self.variance_epsilon = eps
71
+
72
+ def forward(self, hidden_states):
73
+ input_dtype = hidden_states.dtype
74
+ hidden_states = hidden_states.to(torch.float32)
75
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
76
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
77
+ return self.weight * hidden_states.to(input_dtype)
78
+
79
+ def extra_repr(self):
80
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
81
+
82
+
83
+ class PanguEmbeddedRotaryEmbedding(nn.Module):
84
+ def __init__(self, config: PanguEmbeddedConfig, device=None):
85
+ super().__init__()
86
+ # BC: "rope_type" was originally "type"
87
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
88
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
89
+ else:
90
+ self.rope_type = "default"
91
+ self.max_seq_len_cached = config.max_position_embeddings
92
+ self.original_max_seq_len = config.max_position_embeddings
93
+
94
+ self.config = config
95
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
96
+
97
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
98
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
99
+ self.original_inv_freq = self.inv_freq
100
+
101
+ @torch.no_grad()
102
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
103
+ def forward(self, x, position_ids):
104
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
105
+ position_ids_expanded = position_ids[:, None, :].float()
106
+
107
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
108
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
109
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
110
+ emb = torch.cat((freqs, freqs), dim=-1)
111
+ cos = emb.cos() * self.attention_scaling
112
+ sin = emb.sin() * self.attention_scaling
113
+
114
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
115
+
116
+
117
+ def rotate_half(x):
118
+ """Rotates half the hidden dims of the input."""
119
+ x1 = x[..., : x.shape[-1] // 2]
120
+ x2 = x[..., x.shape[-1] // 2 :]
121
+ return torch.cat((-x2, x1), dim=-1)
122
+
123
+
124
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
125
+ """Applies Rotary Position Embedding to the query and key tensors.
126
+
127
+ Args:
128
+ q (`torch.Tensor`): The query tensor.
129
+ k (`torch.Tensor`): The key tensor.
130
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
131
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
132
+ position_ids (`torch.Tensor`, *optional*):
133
+ Deprecated and unused.
134
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
135
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
136
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
137
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
138
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
139
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
140
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
141
+ Returns:
142
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
143
+ """
144
+ cos = cos.unsqueeze(unsqueeze_dim)
145
+ sin = sin.unsqueeze(unsqueeze_dim)
146
+ q_embed = (q * cos) + (rotate_half(q) * sin)
147
+ k_embed = (k * cos) + (rotate_half(k) * sin)
148
+ return q_embed, k_embed
149
+
150
+
151
+ class PanguEmbeddedMLP(nn.Module):
152
+ def __init__(self, config):
153
+ super().__init__()
154
+ self.config = config
155
+ self.hidden_size = config.hidden_size
156
+ self.intermediate_size = config.intermediate_size
157
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
158
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
159
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
160
+ self.act_fn = ACT2FN[config.hidden_act]
161
+
162
+ def forward(self, x):
163
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
164
+ return down_proj
165
+
166
+
167
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
168
+ """
169
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
170
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
171
+ """
172
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
173
+ if n_rep == 1:
174
+ return hidden_states
175
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
176
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
177
+
178
+
179
+ def eager_attention_forward(
180
+ module: nn.Module,
181
+ query: torch.Tensor,
182
+ key: torch.Tensor,
183
+ value: torch.Tensor,
184
+ attention_mask: Optional[torch.Tensor],
185
+ scaling: float,
186
+ dropout: float = 0.0,
187
+ **kwargs,
188
+ ):
189
+ key_states = repeat_kv(key, module.num_key_value_groups)
190
+ value_states = repeat_kv(value, module.num_key_value_groups)
191
+
192
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
193
+ if attention_mask is not None:
194
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
195
+ attn_weights = attn_weights + causal_mask
196
+
197
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
198
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
199
+ attn_output = torch.matmul(attn_weights, value_states)
200
+ attn_output = attn_output.transpose(1, 2).contiguous()
201
+
202
+ return attn_output, attn_weights
203
+
204
+
205
+ class PanguEmbeddedAttention(nn.Module):
206
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
207
+
208
+ def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
209
+ super().__init__()
210
+ self.config = config
211
+ self.layer_idx = layer_idx
212
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
213
+ self.num_heads = config.num_attention_heads
214
+ self.num_key_value_heads = config.num_key_value_heads
215
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
216
+ self.scaling = self.head_dim**-0.5
217
+ self.attention_dropout = config.attention_dropout
218
+ self.is_causal = True
219
+
220
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.bias)
221
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
222
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
223
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)
224
+
225
+ def forward(
226
+ self,
227
+ hidden_states: torch.Tensor,
228
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
229
+ attention_mask: Optional[torch.Tensor],
230
+ past_key_value: Optional[Cache] = None,
231
+ cache_position: Optional[torch.LongTensor] = None,
232
+ **kwargs: Unpack[FlashAttentionKwargs],
233
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
234
+ input_shape = hidden_states.shape[:-1]
235
+ hidden_shape = (*input_shape, -1, self.head_dim)
236
+
237
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
238
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
239
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
240
+
241
+ cos, sin = position_embeddings
242
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
243
+
244
+ if past_key_value is not None:
245
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
246
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
247
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
248
+
249
+ attention_interface: Callable = eager_attention_forward
250
+ if self.config._attn_implementation != "eager":
251
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
252
+
253
+ if not self.training and NPU_ATTN_INFR:
254
+ q_len = input_shape[1]
255
+ if attention_mask is not None:
256
+ attention_mask = ~attention_mask.bool()
257
+ elif q_len > 1:
258
+ attention_mask = torch.triu(torch.ones([q_len, q_len]), diagonal=1).bool().unsqueeze(0).unsqueeze(0).to(query_states.device)
259
+
260
+ attn_output, _ = torch_npu.npu_fused_infer_attention_score(
261
+ query_states, key_states, value_states,
262
+ num_heads=self.num_heads, num_key_value_heads=self.num_key_value_heads,
263
+ input_layout="BNSD", atten_mask=attention_mask, scale=self.scaling)
264
+ attn_output = attn_output.transpose(1, 2)
265
+ attn_weights = None
266
+ else:
267
+ attn_output, attn_weights = attention_interface(
268
+ self,
269
+ query_states,
270
+ key_states,
271
+ value_states,
272
+ attention_mask,
273
+ dropout=0.0 if not self.training else self.attention_dropout,
274
+ scaling=self.scaling,
275
+ **kwargs,
276
+ )
277
+
278
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
279
+ attn_output = self.o_proj(attn_output)
280
+ return attn_output, attn_weights
281
+
282
+
283
+ class PanguEmbeddedDecoderLayer(GradientCheckpointingLayer):
284
+ def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
285
+ super().__init__()
286
+ self.hidden_size = config.hidden_size
287
+ self.self_attn = PanguEmbeddedAttention(config=config, layer_idx=layer_idx)
288
+ self.mlp = PanguEmbeddedMLP(config)
289
+ self.input_layernorm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
290
+ self.post_attention_layernorm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
291
+
292
+ def forward(
293
+ self,
294
+ hidden_states: torch.Tensor,
295
+ attention_mask: Optional[torch.Tensor] = None,
296
+ position_ids: Optional[torch.LongTensor] = None,
297
+ past_key_value: Optional[Cache] = None,
298
+ output_attentions: Optional[bool] = False,
299
+ use_cache: Optional[bool] = False,
300
+ cache_position: Optional[torch.LongTensor] = None,
301
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
302
+ **kwargs: Unpack[FlashAttentionKwargs],
303
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
304
+ residual = hidden_states
305
+ hidden_states = self.input_layernorm(hidden_states)
306
+
307
+ # Self Attention
308
+ hidden_states, self_attn_weights = self.self_attn(
309
+ hidden_states=hidden_states,
310
+ attention_mask=attention_mask,
311
+ position_ids=position_ids,
312
+ past_key_value=past_key_value,
313
+ output_attentions=output_attentions,
314
+ use_cache=use_cache,
315
+ cache_position=cache_position,
316
+ position_embeddings=position_embeddings,
317
+ **kwargs,
318
+ )
319
+ hidden_states = residual + hidden_states
320
+
321
+ # Fully Connected
322
+ residual = hidden_states
323
+ hidden_states = self.post_attention_layernorm(hidden_states)
324
+ hidden_states = self.mlp(hidden_states)
325
+ hidden_states = residual + hidden_states
326
+
327
+ outputs = (hidden_states,)
328
+ if output_attentions:
329
+ outputs += (self_attn_weights,)
330
+
331
+ return outputs
332
+
333
+
334
+ @auto_docstring
335
+ class PanguEmbeddedPreTrainedModel(PreTrainedModel):
336
+ config_class = PanguEmbeddedConfig
337
+ base_model_prefix = "model"
338
+ supports_gradient_checkpointing = True
339
+ _no_split_modules = ["PanguEmbeddedDecoderLayer"]
340
+ _skip_keys_device_placement = ["past_key_values"]
341
+ _supports_flash_attn_3 = True
342
+ _supports_flash_attn_2 = True
343
+ _supports_sdpa = True
344
+ _supports_flex_attn = True
345
+ _supports_cache_class = True
346
+ _supports_quantized_cache = True
347
+ _supports_static_cache = True
348
+ _supports_attention_backend = True
349
+
350
+ def _init_weights(self, module):
351
+ std = self.config.initializer_range
352
+ if isinstance(module, nn.Linear):
353
+ module.weight.data.normal_(mean=0.0, std=std)
354
+ if module.bias is not None:
355
+ module.bias.data.zero_()
356
+ elif isinstance(module, nn.Embedding):
357
+ module.weight.data.normal_(mean=0.0, std=std)
358
+ if module.padding_idx is not None:
359
+ module.weight.data[module.padding_idx].zero_()
360
+ elif isinstance(module, PanguEmbeddedRMSNorm):
361
+ module.weight.data.fill_(1.0)
362
+
363
+
364
+ @auto_docstring
365
+ class PanguEmbeddedModel(PanguEmbeddedPreTrainedModel):
366
+ def __init__(self, config: PanguEmbeddedConfig):
367
+ super().__init__(config)
368
+ self.padding_idx = config.pad_token_id
369
+ self.vocab_size = config.vocab_size
370
+
371
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
372
+ self.layers = nn.ModuleList(
373
+ [PanguEmbeddedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
374
+ )
375
+ self.norm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
376
+ self.rotary_emb = PanguEmbeddedRotaryEmbedding(config=config)
377
+ self.gradient_checkpointing = False
378
+
379
+ # Initialize weights and apply final processing
380
+ self.post_init()
381
+
382
+ def get_input_embeddings(self):
383
+ return self.embed_tokens
384
+
385
+ def set_input_embeddings(self, value):
386
+ self.embed_tokens = value
387
+
388
+ @can_return_tuple
389
+ @auto_docstring
390
+ def forward(
391
+ self,
392
+ input_ids: Optional[torch.LongTensor] = None,
393
+ attention_mask: Optional[torch.Tensor] = None,
394
+ position_ids: Optional[torch.LongTensor] = None,
395
+ past_key_values: Optional[Cache] = None,
396
+ inputs_embeds: Optional[torch.FloatTensor] = None,
397
+ use_cache: Optional[bool] = None,
398
+ output_attentions: Optional[bool] = None,
399
+ output_hidden_states: Optional[bool] = None,
400
+ cache_position: Optional[torch.LongTensor] = None,
401
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
402
+ ) -> BaseModelOutputWithPast:
403
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
404
+ output_hidden_states = (
405
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
406
+ )
407
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
408
+
409
+ if (input_ids is None) ^ (inputs_embeds is not None):
410
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
411
+
412
+ if self.gradient_checkpointing and self.training and use_cache:
413
+ logger.warning_once(
414
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
415
+ )
416
+ use_cache = False
417
+
418
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
419
+ if not isinstance(past_key_values, (type(None), Cache)):
420
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
421
+
422
+ if inputs_embeds is None:
423
+ inputs_embeds = self.embed_tokens(input_ids)
424
+
425
+ if use_cache and past_key_values is None:
426
+ past_key_values = DynamicCache()
427
+
428
+ if cache_position is None:
429
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
430
+ cache_position = torch.arange(
431
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
432
+ )
433
+
434
+ if position_ids is None:
435
+ position_ids = cache_position.unsqueeze(0)
436
+
437
+ causal_mask = create_causal_mask(
438
+ config=self.config,
439
+ input_embeds=inputs_embeds,
440
+ attention_mask=attention_mask,
441
+ cache_position=cache_position,
442
+ past_key_values=past_key_values,
443
+ position_ids=position_ids,
444
+ )
445
+
446
+ hidden_states = inputs_embeds
447
+
448
+ # create position embeddings to be shared across the decoder layers
449
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
450
+
451
+ # decoder layers
452
+ all_hidden_states = () if output_hidden_states else None
453
+ all_self_attns = () if output_attentions else None
454
+
455
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
456
+ if output_hidden_states:
457
+ all_hidden_states += (hidden_states,)
458
+
459
+ layer_outputs = decoder_layer(
460
+ hidden_states,
461
+ attention_mask=causal_mask,
462
+ position_ids=position_ids,
463
+ past_key_value=past_key_values,
464
+ output_attentions=output_attentions,
465
+ use_cache=use_cache,
466
+ cache_position=cache_position,
467
+ position_embeddings=position_embeddings,
468
+ **flash_attn_kwargs,
469
+ )
470
+
471
+ hidden_states = layer_outputs[0]
472
+
473
+ if output_attentions:
474
+ all_self_attns += (layer_outputs[1],)
475
+
476
+ hidden_states = self.norm(hidden_states)
477
+
478
+ # add hidden states from the last decoder layer
479
+ if output_hidden_states:
480
+ all_hidden_states += (hidden_states,)
481
+
482
+ return BaseModelOutputWithPast(
483
+ last_hidden_state=hidden_states,
484
+ past_key_values=past_key_values if use_cache else None,
485
+ hidden_states=all_hidden_states,
486
+ attentions=all_self_attns,
487
+ )
488
+
489
+
490
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
491
+
492
+
493
+ @auto_docstring
494
+ class PanguEmbeddedForCausalLM(PanguEmbeddedPreTrainedModel, GenerationMixin):
495
+ _tied_weights_keys = ["lm_head.weight"]
496
+ _tp_plan = {"lm_head": "colwise_rep"}
497
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
498
+
499
+ def __init__(self, config):
500
+ super().__init__(config)
501
+ self.model = PanguEmbeddedModel(config)
502
+ self.vocab_size = config.vocab_size
503
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
504
+
505
+ # Initialize weights and apply final processing
506
+ self.post_init()
507
+
508
+ def get_input_embeddings(self):
509
+ return self.model.embed_tokens
510
+
511
+ def set_input_embeddings(self, value):
512
+ self.model.embed_tokens = value
513
+
514
+ def get_output_embeddings(self):
515
+ return self.lm_head
516
+
517
+ def set_output_embeddings(self, new_embeddings):
518
+ self.lm_head = new_embeddings
519
+
520
+ def set_decoder(self, decoder):
521
+ self.model = decoder
522
+
523
+ def get_decoder(self):
524
+ return self.model
525
+
526
+ @can_return_tuple
527
+ @auto_docstring
528
+ def forward(
529
+ self,
530
+ input_ids: Optional[torch.LongTensor] = None,
531
+ attention_mask: Optional[torch.Tensor] = None,
532
+ position_ids: Optional[torch.LongTensor] = None,
533
+ past_key_values: Optional[Cache] = None,
534
+ inputs_embeds: Optional[torch.FloatTensor] = None,
535
+ labels: Optional[torch.LongTensor] = None,
536
+ use_cache: Optional[bool] = None,
537
+ output_attentions: Optional[bool] = None,
538
+ output_hidden_states: Optional[bool] = None,
539
+ cache_position: Optional[torch.LongTensor] = None,
540
+ logits_to_keep: Union[int, torch.Tensor] = 0,
541
+ **kwargs: Unpack[KwargsForCausalLM],
542
+ ) -> CausalLMOutputWithPast:
543
+
544
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
545
+ output_hidden_states = (
546
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
547
+ )
548
+
549
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
550
+ outputs: BaseModelOutputWithPast = self.model(
551
+ input_ids=input_ids,
552
+ attention_mask=attention_mask,
553
+ position_ids=position_ids,
554
+ past_key_values=past_key_values,
555
+ inputs_embeds=inputs_embeds,
556
+ use_cache=use_cache,
557
+ output_attentions=output_attentions,
558
+ output_hidden_states=output_hidden_states,
559
+ cache_position=cache_position,
560
+ **kwargs,
561
+ )
562
+
563
+ hidden_states = outputs.last_hidden_state
564
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
565
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
566
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
567
+
568
+ loss = None
569
+ if labels is not None:
570
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
571
+
572
+ return CausalLMOutputWithPast(
573
+ loss=loss,
574
+ logits=logits,
575
+ past_key_values=outputs.past_key_values,
576
+ hidden_states=outputs.hidden_states,
577
+ attentions=outputs.attentions,
578
+ )
579
+
580
+
581
+ __all__ = [
582
+ "PanguEmbeddedForCausalLM",
583
+ "PanguEmbeddedModel",
584
+ "PanguEmbeddedPreTrainedModel",
585
+ ]
modular_openpangu_dense.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from typing import Callable, Optional, Tuple
23
+
24
+ import torch
25
+ from torch import nn
26
+
27
+ import torch_npu
28
+ from torch_npu.contrib import transfer_to_npu
29
+ if "910" in torch.npu.get_device_name():
30
+ NPU_ATTN_INFR = True
31
+ print("[INFO] torch_npu detected. Using NPU fused infer attention.")
32
+ else:
33
+ NPU_ATTN_INFR = False
34
+
35
+ from transformers.cache_utils import Cache
36
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
37
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
38
+ from transformers.processing_utils import Unpack
39
+ from transformers.utils import logging
40
+ from transformers.models.llama.modeling_llama import (
41
+ LlamaAttention,
42
+ LlamaDecoderLayer,
43
+ LlamaForCausalLM,
44
+ LlamaForSequenceClassification,
45
+ LlamaMLP,
46
+ LlamaModel,
47
+ apply_rotary_pos_emb,
48
+ eager_attention_forward,
49
+ )
50
+ from .configuration_openpangu_dense import PanguEmbeddedConfig
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+
56
+ class PanguEmbeddedMLP(LlamaMLP):
57
+ def __init__(self, config):
58
+ super().__init__(config)
59
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
60
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
61
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
62
+
63
+
64
+ class PanguEmbeddedAttention(LlamaAttention):
65
+ def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
66
+ super().__init__()
67
+ self.config = config
68
+ self.layer_idx = layer_idx
69
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
70
+ self.num_heads = config.num_attention_heads
71
+ self.num_key_value_heads = config.num_key_value_heads
72
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
73
+ self.scaling = self.head_dim**-0.5
74
+ self.attention_dropout = config.attention_dropout
75
+ self.is_causal = True
76
+
77
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.bias)
78
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
79
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
80
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)
81
+
82
+ def forward(
83
+ self,
84
+ hidden_states: torch.Tensor,
85
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
86
+ attention_mask: Optional[torch.Tensor],
87
+ past_key_value: Optional[Cache] = None,
88
+ cache_position: Optional[torch.LongTensor] = None,
89
+ **kwargs: Unpack[FlashAttentionKwargs],
90
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
91
+ input_shape = hidden_states.shape[:-1]
92
+ hidden_shape = (*input_shape, -1, self.head_dim)
93
+
94
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
95
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
96
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
97
+
98
+ cos, sin = position_embeddings
99
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
100
+
101
+ if past_key_value is not None:
102
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
103
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
104
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
105
+
106
+ attention_interface: Callable = eager_attention_forward
107
+ if self.config._attn_implementation != "eager":
108
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
109
+
110
+ if not self.training and NPU_ATTN_INFR:
111
+ q_len = input_shape[1]
112
+ if attention_mask is not None:
113
+ attention_mask = ~attention_mask.bool()
114
+ elif q_len > 1:
115
+ attention_mask = torch.triu(torch.ones([q_len, q_len]), diagonal=1).bool().unsqueeze(0).unsqueeze(0).to(query_states.device)
116
+
117
+ attn_output, _ = torch_npu.npu_fused_infer_attention_score(
118
+ query_states, key_states, value_states,
119
+ num_heads=self.num_heads, num_key_value_heads=self.num_key_value_heads,
120
+ input_layout="BNSD", atten_mask=attention_mask, scale=self.scaling)
121
+ attn_output = attn_output.transpose(1, 2)
122
+ attn_weights = None
123
+ else:
124
+ attn_output, attn_weights = attention_interface(
125
+ self,
126
+ query_states,
127
+ key_states,
128
+ value_states,
129
+ attention_mask,
130
+ dropout=0.0 if not self.training else self.attention_dropout,
131
+ scaling=self.scaling,
132
+ **kwargs,
133
+ )
134
+
135
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
136
+ attn_output = self.o_proj(attn_output)
137
+ return attn_output, attn_weights
138
+
139
+
140
+ class PanguEmbeddedDecoderLayer(LlamaDecoderLayer):
141
+ pass
142
+
143
+
144
+ class PanguEmbeddedModel(LlamaModel):
145
+ pass
146
+
147
+
148
+ class PanguEmbeddedForCausalLM(LlamaForCausalLM):
149
+ pass
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "[unused10]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenization_openpangu.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ import os
23
+ from shutil import copyfile
24
+ from typing import Any, Dict, List, Optional, Tuple
25
+
26
+ import sentencepiece as spm
27
+
28
+ from transformers.tokenization_utils import PreTrainedTokenizer
29
+ from transformers.utils import logging
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
35
+
36
+ PRETRAINED_VOCAB_FILES_MAP = {}
37
+
38
+
39
+ def convert_bool(string):
40
+ if isinstance(string, str):
41
+ if string.lower() == "true":
42
+ return True
43
+ elif string.lower() == "false":
44
+ return False
45
+ else:
46
+ return string
47
+ else:
48
+ return string
49
+
50
+
51
+ class PanguTokenizer(PreTrainedTokenizer):
52
+ """
53
+ Construct a tokenizer. Based on byte-level Byte-Pair-Encoding.
54
+
55
+ Args:
56
+ vocab_file (`str`):
57
+ Path to the vocabulary file.
58
+ """
59
+
60
+ vocab_files_names = VOCAB_FILES_NAMES
61
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
62
+ model_input_names = ["input_ids", "attention_mask"]
63
+ _auto_class = "AutoTokenizer"
64
+
65
+ def __init__(
66
+ self,
67
+ vocab_file,
68
+ unk_token="<unk>",
69
+ bos_token="<s>",
70
+ eos_token="</s>",
71
+ pad_token="</s>",
72
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
73
+ add_bos_token=True,
74
+ add_eos_token=False,
75
+ decode_with_prefix_space=False,
76
+ clean_up_tokenization_spaces=False,
77
+ **kwargs,
78
+ ):
79
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
80
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
81
+ self.sp_model.Load(vocab_file)
82
+ super().__init__(
83
+ bos_token=bos_token,
84
+ eos_token=eos_token,
85
+ unk_token=unk_token,
86
+ pad_token=pad_token,
87
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
88
+ **kwargs,
89
+ )
90
+ self.vocab_file = vocab_file
91
+ self.add_bos_token = convert_bool(add_bos_token)
92
+ self.add_eos_token = add_eos_token
93
+ self.decode_with_prefix_space = decode_with_prefix_space
94
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
95
+ self.sp_model.Load(vocab_file)
96
+ self._no_prefix_space_tokens = None
97
+
98
+ """ Initialisation"""
99
+
100
+ @property
101
+ def no_prefix_space_tokens(self):
102
+ if self._no_prefix_space_tokens is None:
103
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
104
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
105
+ return self._no_prefix_space_tokens
106
+
107
+ @property
108
+ def vocab_size(self):
109
+ """Returns vocab size"""
110
+ return self.sp_model.get_piece_size()
111
+
112
+ @property
113
+ def bos_token_id(self) -> Optional[int]:
114
+ return self.sp_model.bos_id()
115
+
116
+ @property
117
+ def eos_token_id(self) -> Optional[int]:
118
+ return super().eos_token_id
119
+
120
+ def get_vocab(self):
121
+ """Returns vocab as a dict"""
122
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
123
+ vocab.update(self.added_tokens_encoder)
124
+ return vocab
125
+
126
+ def _tokenize(self, text):
127
+ """Returns a tokenized string."""
128
+ return self.sp_model.encode(text, out_type=str)
129
+
130
+ def _convert_token_to_id(self, token):
131
+ """Converts a token (str) in an id using the vocab."""
132
+ return self.sp_model.piece_to_id(token)
133
+
134
+ def _convert_id_to_token(self, index):
135
+ """Converts an index (integer) in a token (str) using the vocab."""
136
+ token = self.sp_model.IdToPiece(index)
137
+ return token
138
+
139
+ def _maybe_add_prefix_space(self, tokens, decoded):
140
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
141
+ return " " + decoded
142
+ else:
143
+ return decoded
144
+
145
+ def convert_tokens_to_string(self, tokens):
146
+ """Converts a sequence of tokens (string) in a single string."""
147
+ current_sub_tokens = []
148
+ out_string = ""
149
+ prev_is_special = False
150
+ for token in tokens:
151
+ # make sure that special tokens are not decoded using sentencepiece model
152
+ if token in self.all_special_tokens:
153
+ # Decode the current sub-tokens first
154
+ if current_sub_tokens:
155
+ out_string += self.sp_model.decode(current_sub_tokens)
156
+ current_sub_tokens = []
157
+ # Append the special token without adding extra spaces
158
+ out_string += token
159
+ prev_is_special = True
160
+ else:
161
+ current_sub_tokens.append(token)
162
+ prev_is_special = False
163
+ # Decode any remaining sub-tokens
164
+ if current_sub_tokens:
165
+ out_string += self.sp_model.decode(current_sub_tokens)
166
+ # Clean up leading and trailing spaces
167
+ if self.clean_up_tokenization_spaces:
168
+ out_string = self.clean_up_tokenization(out_string)
169
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
170
+ return out_string[1:]
171
+
172
+ # Override decode to set spaces_between_special_tokens to True as default
173
+ def decode(self,
174
+ token_ids,
175
+ spaces_between_special_tokens: bool = False,
176
+ **kwargs):
177
+ return super().decode(
178
+ token_ids=token_ids,
179
+ spaces_between_special_tokens=spaces_between_special_tokens,
180
+ **kwargs,
181
+ )
182
+
183
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
184
+ """
185
+ Save the vocabulary and special tokens file to a directory.
186
+
187
+ Args:
188
+ save_directory (`str`):
189
+ The directory in which to save the vocabulary.
190
+
191
+ Returns:
192
+ `Tuple(str)`: Paths to the files saved.
193
+ """
194
+ if not os.path.isdir(save_directory):
195
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
196
+ return ("",)
197
+ out_vocab_file = os.path.join(
198
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
199
+ )
200
+
201
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
202
+ copyfile(self.vocab_file, out_vocab_file)
203
+ elif not os.path.isfile(self.vocab_file):
204
+ with open(out_vocab_file, "wb") as fi:
205
+ content_spiece_model = self.sp_model.serialized_model_proto()
206
+ fi.write(content_spiece_model)
207
+
208
+ return (out_vocab_file,)
209
+
210
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
211
+ if self.add_bos_token:
212
+ bos_token_ids = [self.bos_token_id]
213
+ else:
214
+ bos_token_ids = []
215
+
216
+ output = bos_token_ids + token_ids_0
217
+
218
+ if token_ids_1 is not None:
219
+ output = output + token_ids_1
220
+
221
+ if self.add_eos_token:
222
+ output = output + [self.eos_token_id]
223
+
224
+ return output
225
+
226
+ def get_special_tokens_mask(
227
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
228
+ ) -> List[int]:
229
+ """
230
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
231
+ special tokens using the tokenizer `prepare_for_model` method.
232
+
233
+ Args:
234
+ token_ids_0 (`List[int]`):
235
+ List of IDs.
236
+ token_ids_1 (`List[int]`, *optional*):
237
+ Optional second list of IDs for sequence pairs.
238
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
239
+ Whether or not the token list is already formatted with special tokens for the model.
240
+
241
+ Returns:
242
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
243
+ """
244
+ if already_has_special_tokens:
245
+ return super().get_special_tokens_mask(
246
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
247
+ )
248
+
249
+ if token_ids_1 is None:
250
+ return [1] + ([0] * len(token_ids_0)) + [1]
251
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
252
+
253
+ def create_token_type_ids_from_sequences(
254
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
255
+ ) -> List[int]:
256
+ """
257
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
258
+ use of token type ids, therefore a list of zeros is returned.
259
+
260
+ Args:
261
+ token_ids_0 (`List[int]`):
262
+ List of IDs.
263
+ token_ids_1 (`List[int]`, *optional*):
264
+ Optional second list of IDs for sequence pairs.
265
+
266
+ Returns:
267
+ `List[int]`: List of zeros.
268
+ """
269
+ eos = [self.eos_token_id]
270
+
271
+ if token_ids_1 is None:
272
+ return len(token_ids_0 + eos) * [0]
273
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b16f1558c0cd4ae6ef1a2c605713be0a514f50e1ce2d2c878979ce988c148ec
3
+ size 2477809
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"add_bos_token": true, "add_eos_token": false, "add_prefix_space": true, "added_tokens_decoder": {"0": {"content": "<unk>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "1": {"content": "<s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "2": {"content": "</s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45806": {"content": "<|User|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45813": {"content": "<|Bot|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45830": {"content": "[unused0]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45840": {"content": "[unused1]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45846": {"content": "[unused2]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45849": {"content": "[unused3]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45861": {"content": "[unused4]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45866": {"content": "[unused5]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45874": {"content": "[unused6]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45883": {"content": "[unused7]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45884": {"content": "[unused8]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45887": {"content": "[unused9]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45892": {"content": "[unused10]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45920": {"content": "[unused11]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45932": {"content": "[unused12]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45938": {"content": "[unused13]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45953": {"content": "[unused14]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45968": {"content": "[unused15]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45974": {"content": "[unused16]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45982": {"content": "[unused17]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45986": {"content": "[unused18]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46005": {"content": "[unused19]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46007": {"content": "[unused20]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46014": {"content": "[unused21]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46017": {"content": "[unused22]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46028": {"content": "[unused23]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46032": {"content": "[unused24]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46081": {"content": "[unused25]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46086": {"content": "[unused26]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46101": {"content": "[unused27]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46183": {"content": "[unused28]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46230": {"content": "[unused29]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46245": {"content": "[unused30]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46257": {"content": "[unused31]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144208": {"content": "[unused32]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144209": {"content": "[unused33]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}}, "auto_map": {"AutoTokenizer": ["tokenization_openpangu.PanguTokenizer", null]}, "bos_token": "<s>", "clean_up_tokenization_spaces": false, "eos_token": "[unused10]", "legacy": true, "model_max_length": 1000000000000000019884624838656, "pad_token": "<unk>", "sp_model_kwargs": {}, "spaces_between_special_tokens": false, "tokenizer_class": "PanguTokenizer", "unk_token": "<unk>", "use_default_system_prompt": false, "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{ '[unused9]系统:' + message['content'] + '[unused10]' }}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}"}