Upload folder using huggingface_hub
Browse files- .gitattributes +5 -0
- LICENSE +34 -0
- Open Source Software Notice +218 -0
- README.md +554 -6
- README_EN.md +554 -0
- checklist.chk +14 -0
- config.json +31 -0
- configuration_openpangu_dense.py +56 -0
- deepdiver_v2/cli/README.md +238 -0
- deepdiver_v2/cli/demo.py +668 -0
- deepdiver_v2/cli/run_demo.sh +171 -0
- deepdiver_v2/config/config.py +239 -0
- deepdiver_v2/env.template +44 -0
- deepdiver_v2/requirements.txt +8 -0
- deepdiver_v2/src/__init__.py +11 -0
- deepdiver_v2/src/agents/__init__.py +62 -0
- deepdiver_v2/src/agents/base_agent.py +692 -0
- deepdiver_v2/src/agents/objective_information_seeker.py +428 -0
- deepdiver_v2/src/agents/planner_agent.py +1203 -0
- deepdiver_v2/src/agents/subjective_information_seeker.py +417 -0
- deepdiver_v2/src/agents/writer_agent.py +477 -0
- deepdiver_v2/src/tools/__init__.py +36 -0
- deepdiver_v2/src/tools/mcp_client.py +814 -0
- deepdiver_v2/src/tools/mcp_server_standard.py +1751 -0
- deepdiver_v2/src/tools/mcp_tools.py +0 -0
- deepdiver_v2/src/tools/server_config.yaml +73 -0
- deepdiver_v2/src/utils/__init__.py +8 -0
- deepdiver_v2/src/utils/status_codes.py +12 -0
- deepdiver_v2/src/workspace/__init__.py +26 -0
- deepdiver_v2/src/workspace/local_workspace_manager.py +420 -0
- docs/openpangu-deepdiver-v2-tech-report.pdf +3 -0
- generation_config.json +11 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +486 -0
- modeling_openpangu_dense.py +585 -0
- modular_openpangu_dense.py +149 -0
- special_tokens_map.json +30 -0
- tokenization_openpangu.py +273 -0
- tokenizer.model +3 -0
- tokenizer_config.json +1 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
model-00004-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
model-00001-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
model-00002-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
model-00003-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0
|
| 2 |
+
|
| 3 |
+
This OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0 (the "Agreement") is a legal agreement between You and Huawei Technologies Co., Ltd. ("Huawei", "We" or "Us"), and it governs Your reproducing, use, modification, and distribution of openPangu as made available by Huawei under this Agreement.
|
| 4 |
+
|
| 5 |
+
By using, reproducing, modifying, distributing, performing or displaying any portion or element of openPangu, or otherwise accepting the terms of this Agreement, You agree to be bound by this Agreement.
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
1.1. “openPangu” or “Model” means openPangu large language models and software, including trained model weights, parameters (including optimizer states), accompanying source code and scripts released under this Agreement.
|
| 9 |
+
1.2. “Derivative Model” means all (1) modifications to the Model, (2) works based on the Model, and (3) any other derivative works of the Model. For clarity, information or content results from operating or otherwise using the Model is not a Derivative Model.
|
| 10 |
+
1.3. “You” or “Your” means an individual or Legal Entity exercising permissions granted by this Agreement and/or using the Model for any purpose.
|
| 11 |
+
1.4. “Third Party” or “Third Parties” means individuals or legal entities that are not under common control with Us or You.
|
| 12 |
+
|
| 13 |
+
2. License Grant. Subject to Your full compliance with the terms and conditions of this Agreement, We hereby grant to You a perpetual, worldwide, non-exclusive, non-transferable, no-charge, royalty-free license (except as stated in Section 3) to use, reproduce, modify, and distribute the Model.
|
| 14 |
+
|
| 15 |
+
3. Conditions for License Grant. You represent and warrant that You will not, access, download, install, run, deploy, integrate, modify, or otherwise use the Model, directly or indirectly, within the European Union.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
4. Redistribution.
|
| 19 |
+
4.1. If You distribute the Model or Derivative Model, You shall retain in Your distribution (1) a copy of this agreement, and (2) all copyright notices and other notices of origin included in the Model that are applicable to Your distribution.
|
| 20 |
+
4.2. Further, if You distribute or make available to Third Parties a product or service (including another AI model) based on the Model, You are required to (1) display the acknowledgement “Powered by openPangu” and (2) include a trademark notice “openPangu is a trademark of Huawei Technologies Co., Ltd.” on related webpages, user manuals, product documentations or other advertising materials mentioning features of the Model.
|
| 21 |
+
4.3. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for Derivative Model made by You as a whole, provided Your use, reproduction, and distribution of the Model otherwise complies with the terms and conditions of this Agreement.
|
| 22 |
+
|
| 23 |
+
5. Ownership. We do not claim ownership to any information or content generated using the Model or Derivative Model that are made by You. You are solely responsible for evaluating the accuracy and appropriateness of such information or content for Your use case.
|
| 24 |
+
|
| 25 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of Huawei, except as required for complying with Section 4.2.
|
| 26 |
+
|
| 27 |
+
7. Indemnity. You will indemnify and hold harmless Huawei from and against any claim by any third party arising out of or related to Your use or distribution of the Model or Derivative Model made by You (e.g. a violation against Section 3). For avoidance of doubt, “third party” in this clause include supervisory authorities.
|
| 28 |
+
|
| 29 |
+
8. THE MODEL IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, NONINFRINGEMENT, ACCURACY, OR THE ABSENCE OF LATENT OR OTHER DEFECTS OR ERRORS, WHETHER OR NOT DISCOVERABLE, ALL TO THE GREATEST EXTENT PERMISSIBLE UNDER APPLICABLE LAW.
|
| 30 |
+
|
| 31 |
+
9. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MODEL, IN WHOLE OR IN PART, NO MATTER HOW IT’S CAUSED OR THE LEGAL THEORY IT IS BASED ON, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
END OF THE TERMS AND CONDITIONS
|
Open Source Software Notice
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPEN SOURCE SOFTWARE NOTICE
|
| 2 |
+
|
| 3 |
+
Please note we provide an open source software notice along with this product and/or this product firmware (in the following just “this product”). The open source software licenses are granted by the respective right holders. And the open source licenses prevail all other license information with regard to the respective open source software contained in the product, including but not limited to End User Software Licensing Agreement. This notice is provided on behalf of Huawei Technologies Co. Ltd. and any of its local subsidiaries which may have provided this product to you in your local country.
|
| 4 |
+
|
| 5 |
+
Warranty Disclaimer
|
| 6 |
+
THE OPEN SOURCE SOFTWARE IN THIS PRODUCT IS DISTRIBUTED IN THE HOPE THAT IT WILL BE USEFUL, BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. SEE THE APPLICABLE LICENSES FOR MORE DETAILS.
|
| 7 |
+
|
| 8 |
+
Copyright Notice and License Texts
|
| 9 |
+
|
| 10 |
+
Software: transformers 4.53.2
|
| 11 |
+
Copyright notice:
|
| 12 |
+
Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 13 |
+
|
| 14 |
+
License Text:
|
| 15 |
+
----------------------------------------
|
| 16 |
+
|
| 17 |
+
Apache License
|
| 18 |
+
Version 2.0, January 2004
|
| 19 |
+
http://www.apache.org/licenses/
|
| 20 |
+
|
| 21 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 22 |
+
|
| 23 |
+
1. Definitions.
|
| 24 |
+
|
| 25 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 26 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 27 |
+
|
| 28 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 29 |
+
the copyright owner that is granting the License.
|
| 30 |
+
|
| 31 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 32 |
+
other entities that control, are controlled by, or are under common
|
| 33 |
+
control with that entity. For the purposes of this definition,
|
| 34 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 35 |
+
direction or management of such entity, whether by contract or
|
| 36 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 37 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 38 |
+
|
| 39 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 40 |
+
exercising permissions granted by this License.
|
| 41 |
+
|
| 42 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 43 |
+
including but not limited to software source code, documentation
|
| 44 |
+
source, and configuration files.
|
| 45 |
+
|
| 46 |
+
"Object" form shall mean any form resulting from mechanical
|
| 47 |
+
transformation or translation of a Source form, including but
|
| 48 |
+
not limited to compiled object code, generated documentation,
|
| 49 |
+
and conversions to other media types.
|
| 50 |
+
|
| 51 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 52 |
+
Object form, made available under the License, as indicated by a
|
| 53 |
+
copyright notice that is included in or attached to the work
|
| 54 |
+
(an example is provided in the Appendix below).
|
| 55 |
+
|
| 56 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 57 |
+
form, that is based on (or derived from) the Work and for which the
|
| 58 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 59 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 60 |
+
of this License, Derivative Works shall not include works that remain
|
| 61 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 62 |
+
the Work and Derivative Works thereof.
|
| 63 |
+
|
| 64 |
+
"Contribution" shall mean any work of authorship, including
|
| 65 |
+
the original version of the Work and any modifications or additions
|
| 66 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 67 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 68 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 69 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 70 |
+
means any form of electronic, verbal, or written communication sent
|
| 71 |
+
to the Licensor or its representatives, including but not limited to
|
| 72 |
+
communication on electronic mailing lists, source code control systems,
|
| 73 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 74 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 75 |
+
excluding communication that is conspicuously marked or otherwise
|
| 76 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 77 |
+
|
| 78 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 79 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 80 |
+
subsequently incorporated within the Work.
|
| 81 |
+
|
| 82 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 83 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 84 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 85 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 86 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 87 |
+
Work and such Derivative Works in Source or Object form.
|
| 88 |
+
|
| 89 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 90 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 91 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 92 |
+
(except as stated in this section) patent license to make, have made,
|
| 93 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 94 |
+
where such license applies only to those patent claims licensable
|
| 95 |
+
by such Contributor that are necessarily infringed by their
|
| 96 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 97 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 98 |
+
institute patent litigation against any entity (including a
|
| 99 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 100 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 101 |
+
or contributory patent infringement, then any patent licenses
|
| 102 |
+
granted to You under this License for that Work shall terminate
|
| 103 |
+
as of the date such litigation is filed.
|
| 104 |
+
|
| 105 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 106 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 107 |
+
modifications, and in Source or Object form, provided that You
|
| 108 |
+
meet the following conditions:
|
| 109 |
+
|
| 110 |
+
(a) You must give any other recipients of the Work or
|
| 111 |
+
Derivative Works a copy of this License; and
|
| 112 |
+
|
| 113 |
+
(b) You must cause any modified files to carry prominent notices
|
| 114 |
+
stating that You changed the files; and
|
| 115 |
+
|
| 116 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 117 |
+
that You distribute, all copyright, patent, trademark, and
|
| 118 |
+
attribution notices from the Source form of the Work,
|
| 119 |
+
excluding those notices that do not pertain to any part of
|
| 120 |
+
the Derivative Works; and
|
| 121 |
+
|
| 122 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 123 |
+
distribution, then any Derivative Works that You distribute must
|
| 124 |
+
include a readable copy of the attribution notices contained
|
| 125 |
+
within such NOTICE file, excluding those notices that do not
|
| 126 |
+
pertain to any part of the Derivative Works, in at least one
|
| 127 |
+
of the following places: within a NOTICE text file distributed
|
| 128 |
+
as part of the Derivative Works; within the Source form or
|
| 129 |
+
documentation, if provided along with the Derivative Works; or,
|
| 130 |
+
within a display generated by the Derivative Works, if and
|
| 131 |
+
wherever such third-party notices normally appear. The contents
|
| 132 |
+
of the NOTICE file are for informational purposes only and
|
| 133 |
+
do not modify the License. You may add Your own attribution
|
| 134 |
+
notices within Derivative Works that You distribute, alongside
|
| 135 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 136 |
+
that such additional attribution notices cannot be construed
|
| 137 |
+
as modifying the License.
|
| 138 |
+
|
| 139 |
+
You may add Your own copyright statement to Your modifications and
|
| 140 |
+
may provide additional or different license terms and conditions
|
| 141 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 142 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 143 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 144 |
+
the conditions stated in this License.
|
| 145 |
+
|
| 146 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 147 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 148 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 149 |
+
this License, without any additional terms or conditions.
|
| 150 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 151 |
+
the terms of any separate license agreement you may have executed
|
| 152 |
+
with Licensor regarding such Contributions.
|
| 153 |
+
|
| 154 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 155 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 156 |
+
except as required for reasonable and customary use in describing the
|
| 157 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 158 |
+
|
| 159 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 160 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 161 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 162 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 163 |
+
implied, including, without limitation, any warranties or conditions
|
| 164 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 165 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 166 |
+
appropriateness of using or redistributing the Work and assume any
|
| 167 |
+
risks associated with Your exercise of permissions under this License.
|
| 168 |
+
|
| 169 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 170 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 171 |
+
unless required by applicable law (such as deliberate and grossly
|
| 172 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 173 |
+
liable to You for damages, including any direct, indirect, special,
|
| 174 |
+
incidental, or consequential damages of any character arising as a
|
| 175 |
+
result of this License or out of the use or inability to use the
|
| 176 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 177 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 178 |
+
other commercial damages or losses), even if such Contributor
|
| 179 |
+
has been advised of the possibility of such damages.
|
| 180 |
+
|
| 181 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 182 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 183 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 184 |
+
or other liability obligations and/or rights consistent with this
|
| 185 |
+
License. However, in accepting such obligations, You may act only
|
| 186 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 187 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 188 |
+
defend, and hold each Contributor harmless for any liability
|
| 189 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 190 |
+
of your accepting any such warranty or additional liability.
|
| 191 |
+
|
| 192 |
+
END OF TERMS AND CONDITIONS
|
| 193 |
+
|
| 194 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 195 |
+
|
| 196 |
+
To apply the Apache License to your work, attach the following
|
| 197 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 198 |
+
replaced with your own identifying information. (Don't include
|
| 199 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 200 |
+
comment syntax for the file format. We also recommend that a
|
| 201 |
+
file or class name and description of purpose be included on the
|
| 202 |
+
same "printed page" as the copyright notice for easier
|
| 203 |
+
identification within third-party archives.
|
| 204 |
+
|
| 205 |
+
Copyright [yyyy] [name of copyright owner]
|
| 206 |
+
|
| 207 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 208 |
+
you may not use this file except in compliance with the License.
|
| 209 |
+
You may obtain a copy of the License at
|
| 210 |
+
|
| 211 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 212 |
+
|
| 213 |
+
Unless required by applicable law or agreed to in writing, software
|
| 214 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 215 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 216 |
+
See the License for the specific language governing permissions and
|
| 217 |
+
limitations under the License.
|
| 218 |
+
|
README.md
CHANGED
|
@@ -1,6 +1,554 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 开源盘古 Embedded-7B-DeepDiver
|
| 2 |
+
中文 | [English](README_EN.md)
|
| 3 |
+
📑[技术报告](https://ai.gitcode.com/ascend-tribe/openPangu-Embedded-7B-DeepDiver/blob/main/docs/openpangu-deepdiver-v2-tech-report.pdf)
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
## 1. 简介
|
| 7 |
+
DeepDiver是openPangu系列中定位深度信息获取与处理的Agent,支持原生 Multi-Agent System(MAS),用于复杂知识问答与长文调研报告写作。
|
| 8 |
+
|
| 9 |
+
### 特性
|
| 10 |
+
- 🔍 支持QA模式:回答100步+复杂知识性问题
|
| 11 |
+
- ✍️ 支持长文写作模式:撰写3w+字文章与报告
|
| 12 |
+
- 🔄 支持自适应模式:根据用户问题自动选择知识问答模式或长文写作模式
|
| 13 |
+
|
| 14 |
+
## 2. 评测结果
|
| 15 |
+
|
| 16 |
+
| 测评集 | 测评指标 | openPangu-7B-DeepDiver|
|
| 17 |
+
| :------------: | :-----------------: | :--------: |
|
| 18 |
+
| **BrowseComp-zh** | Acc | 18.3 |
|
| 19 |
+
| **BrowseComp-en** | Acc | 8.3 |
|
| 20 |
+
|**XBench-DeepSearch** | Acc | 39.0 |
|
| 21 |
+
|
| 22 |
+
注:上表仅展示复杂问答的结果,长文调研的评测结果请参考[技术报告](https://ai.gitcode.com/ascend-tribe/openPangu-Embedded-7B-DeepDiver/blob/main/docs/openpangu-deepdiver-v2-tech-report.pdf)
|
| 23 |
+
|
| 24 |
+
## 3. 快速部署
|
| 25 |
+
|
| 26 |
+
### 3.1 环境准备
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
# 克隆并安装
|
| 30 |
+
git clone <repository-url>
|
| 31 |
+
cd deepdiver_v2
|
| 32 |
+
pip install -r requirements.txt
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### 3.2 部署推理服务
|
| 36 |
+
|
| 37 |
+
#### 拉取镜像
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
docker pull quay.io/ascend/vllm-ascend:v0.9.2rc1
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
或按照[官方文档](https://vllm-ascend.readthedocs.io/en/stable/installation.html)手动构建 docker 容器。
|
| 44 |
+
|
| 45 |
+
#### 运行容器
|
| 46 |
+
|
| 47 |
+
```
|
| 48 |
+
docker run -itd --name vllm-deepdiver \
|
| 49 |
+
--network host \
|
| 50 |
+
--device /dev/davinci0 \
|
| 51 |
+
--device /dev/davinci1 \
|
| 52 |
+
--device /dev/davinci2 \
|
| 53 |
+
--device /dev/davinci3 \
|
| 54 |
+
--device /dev/davinci4 \
|
| 55 |
+
--device /dev/davinci5 \
|
| 56 |
+
--device /dev/davinci6 \
|
| 57 |
+
--device /dev/davinci7 \
|
| 58 |
+
-u root \
|
| 59 |
+
--device /dev/davinci_manager \
|
| 60 |
+
--device /dev/devmm_svm \
|
| 61 |
+
--device /dev/hisi_hdc \
|
| 62 |
+
-v /usr/local/dcmi:/usr/local/dcmi:ro \
|
| 63 |
+
-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool:ro \
|
| 64 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi:ro \
|
| 65 |
+
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/:ro \
|
| 66 |
+
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info:ro \
|
| 67 |
+
-v /etc/ascend_install.info:/etc/ascend_install.info:ro \
|
| 68 |
+
-v /usr/local/Ascend/firmware:/usr/local/Ascend/firmware:ro \
|
| 69 |
+
-v /data:/data:ro \
|
| 70 |
+
-v /home/work:/home/work \ # 配置一个可读写的工作目录
|
| 71 |
+
quay.io/ascend/vllm-ascend:v0.9.2rc1
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
#### 进入容器
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
```
|
| 78 |
+
docker exec -itu root vllm-deepdiver bash
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
注意:必须使用 `-itu root`。
|
| 82 |
+
|
| 83 |
+
#### 复制 Pangu 的 modeling 文件
|
| 84 |
+
|
| 85 |
+
`open_pangu.py` 和 `__init__.py` 可以在[这里](https://ai.gitcode.com/ascend-tribe/openpangu-embedded-7b-model/tree/main/inference/vllm_ascend/models)找到。
|
| 86 |
+
|
| 87 |
+
```
|
| 88 |
+
cp ./vllm_ascend/open_pangu.py /vllm-workspace/vllm-ascend/vllm_ascend/models/
|
| 89 |
+
cp ./vllm_ascend/__init__.py /vllm-workspace/vllm-ascend/vllm_ascend/models/
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
#### 启动部署
|
| 93 |
+
|
| 94 |
+
```
|
| 95 |
+
PRECHECKPOINT_PATH="path/to/deepdiver_model"
|
| 96 |
+
|
| 97 |
+
export VLLM_USE_V1=1
|
| 98 |
+
|
| 99 |
+
export VLLM_WORKER_MULTIPROC_METHOD=fork
|
| 100 |
+
# export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
| 101 |
+
|
| 102 |
+
vllm serve $PRECHECKPOINT_PATH \
|
| 103 |
+
--served-model-name ${SERVED_MODEL_NAME:=pangu_auto} \
|
| 104 |
+
--tensor-parallel-size ${tensor_parallel_size:=8} \
|
| 105 |
+
--trust-remote-code \
|
| 106 |
+
--host 127.0.0.1 \
|
| 107 |
+
--port 8888 \
|
| 108 |
+
--max-num-seqs 256 \
|
| 109 |
+
--max-model-len ${MAX_MODEL_LEN:=131072} \
|
| 110 |
+
--max-num-batched-tokens ${MAX_NUM_BATCHED_TOKENS:=4096} \
|
| 111 |
+
--tokenizer-mode "slow" \
|
| 112 |
+
--dtype bfloat16 \
|
| 113 |
+
--distributed-executor-backend mp \
|
| 114 |
+
--gpu-memory-utilization 0.93 \
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
#### 测试部署
|
| 118 |
+
|
| 119 |
+
```
|
| 120 |
+
curl -X POST http://127.0.0.1:8888/v1/completions -H "Content-Type: application/json" -d '{
|
| 121 |
+
"model": "pangu_auto",
|
| 122 |
+
"prompt": ["Tell me who you are?"],
|
| 123 |
+
"max_tokens": 50
|
| 124 |
+
}'
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### 3.3 实现所需工具
|
| 128 |
+
|
| 129 |
+
在启动服务器前,你需要为 web search 与 URL 抓取工具实现自定义逻辑。
|
| 130 |
+
|
| 131 |
+
#### Web Search(`_generic_search`)
|
| 132 |
+
|
| 133 |
+
位置:`src/tools/mcp_tools.py` - `_generic_search` 方法
|
| 134 |
+
|
| 135 |
+
将 `NotImplementedError` 替换为你的搜索工具实现:
|
| 136 |
+
|
| 137 |
+
```python
|
| 138 |
+
def _generic_search(self, query: str, max_results: int, config: Dict[str, Any]) -> MCPToolResult:
|
| 139 |
+
"""Your custom search implementation - based on the commented code example"""
|
| 140 |
+
try:
|
| 141 |
+
# Example implementation for search API:
|
| 142 |
+
url = config.get('base_url', 'https://api.search-provider.com/search')
|
| 143 |
+
payload = json.dumps({"q": query, "num": max_results})
|
| 144 |
+
api_keys = config.get('api_keys', [])
|
| 145 |
+
headers = {
|
| 146 |
+
'X-API-KEY': random.choice(api_keys),
|
| 147 |
+
'Content-Type': 'application/json'
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
response = requests.post(url, data=payload, headers=headers)
|
| 151 |
+
response.raise_for_status()
|
| 152 |
+
|
| 153 |
+
# Transform your API response to required format
|
| 154 |
+
search_results = {
|
| 155 |
+
"organic": [
|
| 156 |
+
{
|
| 157 |
+
"title": result["title"],
|
| 158 |
+
"link": result["link"],
|
| 159 |
+
"snippet": result["snippet"],
|
| 160 |
+
"date": result.get("date", "unknown")
|
| 161 |
+
}
|
| 162 |
+
for result in response.json().get("organic", [])
|
| 163 |
+
]
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
return MCPToolResult(success=True, data=search_results)
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
return MCPToolResult(success=False, error=f"Generic search failed: {e}")
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
#### URL Crawler(`url_crawler` 与 `_content_extractor`)
|
| 173 |
+
|
| 174 |
+
位置:`src/tools/mcp_tools.py` - `_content_extractor`
|
| 175 |
+
|
| 176 |
+
将 `NotImplementedError` 部分替换为你的网页抓取工具实现:
|
| 177 |
+
|
| 178 |
+
```python
|
| 179 |
+
# Example implementation for content extractor:
|
| 180 |
+
crawler_url = f"{crawler_config.get('base_url', 'https://api.content-extractor.com')}/{url}"
|
| 181 |
+
response = requests.get(crawler_url, headers=headers, timeout=crawler_config.get('timeout', 30))
|
| 182 |
+
response.raise_for_status()
|
| 183 |
+
|
| 184 |
+
content = response.text
|
| 185 |
+
|
| 186 |
+
# Truncate if needed
|
| 187 |
+
if max_tokens and len(content.split()) > max_tokens:
|
| 188 |
+
words = content.split()[:max_tokens]
|
| 189 |
+
content = ' '.join(words) + '...'
|
| 190 |
+
|
| 191 |
+
return MCPToolResult(success=True, data=content)
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
#### ⚠️ 第三方服务提示
|
| 195 |
+
|
| 196 |
+
重要:搜索与抓取工具使用外部 API 由用户自行选择和实现。我们不对以下情况负责:
|
| 197 |
+
- 与第三方服务相关的隐私/安全问题
|
| 198 |
+
- 搜索/抓取活动的合规性
|
| 199 |
+
- 内容准确性或版权问题
|
| 200 |
+
- API 停机或变更
|
| 201 |
+
|
| 202 |
+
使用这些服务需自担风险。请查看其条款与隐私政策。
|
| 203 |
+
|
| 204 |
+
### 3.4 必要配置
|
| 205 |
+
|
| 206 |
+
#### 配置 .env 文件
|
| 207 |
+
复制 `env.template` 到 `config/.env` 并配置如下选项:
|
| 208 |
+
|
| 209 |
+
```bash
|
| 210 |
+
# LLM Service
|
| 211 |
+
MODEL_REQUEST_URL=http://localhost:8888/v1/chat/completions # 你的 LLM endpoint
|
| 212 |
+
|
| 213 |
+
# Agent 限制
|
| 214 |
+
PLANNER_MODE=auto # 在 auto、writing 或 qa 模式间切换
|
| 215 |
+
|
| 216 |
+
# 外部 API(先实现函数)
|
| 217 |
+
SEARCH_ENGINE_BASE_URL= # 搜索 API endpoint
|
| 218 |
+
SEARCH_ENGINE_API_KEYS= # 搜索 API keys
|
| 219 |
+
URL_CRAWLER_BASE_URL= # URL Crawler API endpoint
|
| 220 |
+
URL_CRAWLER_API_KEYS= # URL Crawler API keys
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
⚠️ 注意:
|
| 224 |
+
- 请将上一步部署的推理服务 URL 配置到 `MODEL_REQUEST_URL`
|
| 225 |
+
- 在 `PLANNER_MODE` 中指定模式。`auto` 会自动决策回答复杂问题或生成长文;若希望优先长文写作,可设置为 `writing`;若希望专注解决高难度问题,可设置为 `qa`
|
| 226 |
+
|
| 227 |
+
### 3.5 启动工具服务
|
| 228 |
+
|
| 229 |
+
```bash
|
| 230 |
+
python src/tools/mcp_server_standard.py
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
### 3.6 运行Demo
|
| 234 |
+
|
| 235 |
+
```bash
|
| 236 |
+
# 交互模式
|
| 237 |
+
python cli/demo.py
|
| 238 |
+
|
| 239 |
+
# 单次查询
|
| 240 |
+
python cli/demo.py -q "$your_query"
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
基于上述步骤可以快速运行DeepDiver,如果需要二次开发,可以参考[章节4](#4-自定义工具开发指南)和[5](#5-个性化配置)
|
| 244 |
+
|
| 245 |
+
## 4. 自定义工具开发指南
|
| 246 |
+
|
| 247 |
+
当前工具主要分为内置工具和外部MCP工具,内部工具主要包括分发任务,思考/反思等,外部MCP工具则是一些延伸LLM能力的工具,如搜索互联网,爬取链接,下载和读写文件等。
|
| 248 |
+
|
| 249 |
+
### 4.1 已实现的工具类别
|
| 250 |
+
|
| 251 |
+
#### A. 外部MCP工具
|
| 252 |
+
Web Search 与数据采集:
|
| 253 |
+
- `batch_web_search`:多查询 web 搜索
|
| 254 |
+
- `url_crawler`:从 URL 抽取内容
|
| 255 |
+
- `download_files`:从 URL 下载文件
|
| 256 |
+
|
| 257 |
+
文件操作:
|
| 258 |
+
- `file_read`、`file_write`:基础文件 I/O
|
| 259 |
+
- `list_workspace`:目录列表
|
| 260 |
+
|
| 261 |
+
文档处理与内容创作:
|
| 262 |
+
- `document_qa`:针对特定文档问答
|
| 263 |
+
- `document_extract`:多格式文本抽取
|
| 264 |
+
- `section_writer`:结构化内容生成
|
| 265 |
+
|
| 266 |
+
#### B. 内置工具
|
| 267 |
+
- `think`、`reflect`:推理与规划
|
| 268 |
+
- `task_done`:任务完成汇报
|
| 269 |
+
- `assign_task_xxx`: 分发任务并创建子智能体
|
| 270 |
+
|
| 271 |
+
### 4.2 开发并集成新的外部MCP工具
|
| 272 |
+
|
| 273 |
+
#### A. 实现新的MCP工具
|
| 274 |
+
位置:`src/tools/mcp_tools.py` - 在 `MCPTools` 类中添加方法
|
| 275 |
+
|
| 276 |
+
```python
|
| 277 |
+
def your_new_tool(self, param1: str, param2: int) -> MCPToolResult:
|
| 278 |
+
"""
|
| 279 |
+
Description of what your tool does.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
param1: Description of parameter 1
|
| 283 |
+
param2: Description of parameter 2
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
MCPToolResult: Standardized result format
|
| 287 |
+
"""
|
| 288 |
+
try:
|
| 289 |
+
# Your tool implementation here
|
| 290 |
+
result_data = {
|
| 291 |
+
"output": "Tool result",
|
| 292 |
+
"processed_items": param2
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
return MCPToolResult(
|
| 296 |
+
success=True,
|
| 297 |
+
data=result_data,
|
| 298 |
+
metadata={"tool_name": "your_new_tool"}
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
except Exception as e:
|
| 302 |
+
logger.error(f"Tool execution failed: {e}")
|
| 303 |
+
return MCPToolResult(
|
| 304 |
+
success=False,
|
| 305 |
+
error=f"Tool failed: {str(e)}"
|
| 306 |
+
)
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
#### B. 在服务器中注册工具
|
| 310 |
+
|
| 311 |
+
##### 添加工具 Schema
|
| 312 |
+
位置:`src/tools/mcp_tools.py` - 添加到 `MCP_TOOL_SCHEMAS` 字典
|
| 313 |
+
|
| 314 |
+
```python
|
| 315 |
+
MCP_TOOL_SCHEMAS = {
|
| 316 |
+
# ... existing tools ...
|
| 317 |
+
|
| 318 |
+
"your_new_tool": {
|
| 319 |
+
"name": "your_new_tool",
|
| 320 |
+
"description": "Brief description of what your tool does",
|
| 321 |
+
"inputSchema": {
|
| 322 |
+
"type": "object",
|
| 323 |
+
"properties": {
|
| 324 |
+
"param1": {
|
| 325 |
+
"type": "string",
|
| 326 |
+
"description": "Description of parameter 1"
|
| 327 |
+
},
|
| 328 |
+
"param2": {
|
| 329 |
+
"type": "integer",
|
| 330 |
+
"default": 10,
|
| 331 |
+
"description": "Description of parameter 2"
|
| 332 |
+
}
|
| 333 |
+
},
|
| 334 |
+
"required": ["param1"]
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
}
|
| 338 |
+
```
|
| 339 |
+
|
| 340 |
+
##### 注册工具函数
|
| 341 |
+
位置:`src/tools/mcp_server_standard.py` - 添加到 `get_tool_function()`
|
| 342 |
+
|
| 343 |
+
```python
|
| 344 |
+
def get_tool_function(tool_name: str):
|
| 345 |
+
"""Get the actual function for a tool"""
|
| 346 |
+
tool_map = {
|
| 347 |
+
# ... existing tools ...
|
| 348 |
+
"your_new_tool": lambda tools, **kwargs: tools.your_new_tool(**kwargs),
|
| 349 |
+
}
|
| 350 |
+
return tool_map.get(tool_name)
|
| 351 |
+
```
|
| 352 |
+
|
| 353 |
+
#### C. 让特定智能体可使用工具
|
| 354 |
+
工具对各智能体的可见性由 MCP client 中的预定义工具集控制。
|
| 355 |
+
|
| 356 |
+
位置:`src/tools/mcp_client.py` - 修改各智能体的工具集
|
| 357 |
+
|
| 358 |
+
```python
|
| 359 |
+
# Define which MCP server tools each agent can access
|
| 360 |
+
PLANNER_AGENT_TOOLS = [
|
| 361 |
+
"download_files",
|
| 362 |
+
"document_qa",
|
| 363 |
+
"file_read",
|
| 364 |
+
"file_write",
|
| 365 |
+
"str_replace_based_edit_tool",
|
| 366 |
+
"list_workspace",
|
| 367 |
+
"file_find_by_name",
|
| 368 |
+
"your_new_tool", # Add your new tool here
|
| 369 |
+
]
|
| 370 |
+
|
| 371 |
+
INFORMATION_SEEKER_TOOLS = [
|
| 372 |
+
"batch_web_search",
|
| 373 |
+
"url_crawler",
|
| 374 |
+
"document_extract",
|
| 375 |
+
"document_qa",
|
| 376 |
+
"download_files",
|
| 377 |
+
"file_read",
|
| 378 |
+
"file_write",
|
| 379 |
+
"str_replace_based_edit_tool",
|
| 380 |
+
"list_workspace",
|
| 381 |
+
"file_find_by_name",
|
| 382 |
+
"your_new_tool", # Add your new tool here if needed
|
| 383 |
+
]
|
| 384 |
+
|
| 385 |
+
WRITER_AGENT_TOOLS = [
|
| 386 |
+
"file_read",
|
| 387 |
+
"list_workspace",
|
| 388 |
+
"file_find_by_name",
|
| 389 |
+
"search_result_classifier",
|
| 390 |
+
"section_writer",
|
| 391 |
+
"concat_section_files",
|
| 392 |
+
# Add your tool if the writer agent needs it
|
| 393 |
+
]
|
| 394 |
+
```
|
| 395 |
+
|
| 396 |
+
### 4.3 添加内置智能体工具/函数
|
| 397 |
+
|
| 398 |
+
#### A. 带有真实返回的工具/函数
|
| 399 |
+
DeepDiver中的agent,如planner,集成了`assign_subjective_task_to_writer`, `assign_multi_objective_tasks_to_info_seeker` 等内置函数作为工具, 这类函数除了具体实现之外,还需要使用`_build_agent_specific_tool_schemas()` 添加专属的tool schema。
|
| 400 |
+
|
| 401 |
+
位置:`src/agents/your_agent.py`
|
| 402 |
+
|
| 403 |
+
```python
|
| 404 |
+
def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
|
| 405 |
+
"""Add built-in agent functions (not MCP server tools)"""
|
| 406 |
+
|
| 407 |
+
# Get base schemas from MCP server via client
|
| 408 |
+
schemas = super()._build_agent_specific_tool_schemas()
|
| 409 |
+
|
| 410 |
+
# Add agent-specific built-in functions like task assignment, completion reporting
|
| 411 |
+
builtin_functions = [
|
| 412 |
+
{
|
| 413 |
+
"type": "function",
|
| 414 |
+
"function": {
|
| 415 |
+
"name": "agent_specific_task_done",
|
| 416 |
+
"description": "Report task completion for this agent",
|
| 417 |
+
"parameters": {
|
| 418 |
+
"type": "object",
|
| 419 |
+
"properties": {
|
| 420 |
+
"result": {"type": "string", "description": "Task result"},
|
| 421 |
+
"status": {"type": "string", "description": "Completion status"}
|
| 422 |
+
},
|
| 423 |
+
"required": ["result", "status"]
|
| 424 |
+
}
|
| 425 |
+
}
|
| 426 |
+
}
|
| 427 |
+
]
|
| 428 |
+
|
| 429 |
+
schemas.extend(builtin_functions)
|
| 430 |
+
return schemas
|
| 431 |
+
```
|
| 432 |
+
|
| 433 |
+
#### B. 带有伪返回的内置工具
|
| 434 |
+
DeepDiver中的cognitive tools,比如think和reflect等,这些工具实际没有具体实现,agent在调用这些工具时通过生成工具入参,就已经完成了工具的调用。可以直接在模型生成完入参后,使用类似以下方法进行返回,继续让模型完成后续工作 (参考`planner_agent.py` 中`_execute_react_loop()`的实现):
|
| 435 |
+
|
| 436 |
+
```python
|
| 437 |
+
if tool_call["name"] in ["think", "reflect"]:
|
| 438 |
+
tool_result = {"tool_results": "You can proceed to invoke other tools if needed. "}
|
| 439 |
+
```
|
| 440 |
+
|
| 441 |
+
同理,这种内置工具也需要使用`_build_agent_specific_tool_schemas()` 添加专属的tool schema。
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
## 5. 个性化配置
|
| 445 |
+
|
| 446 |
+
### 5.1 Client 配置
|
| 447 |
+
|
| 448 |
+
复制 `env.template` 到 `config/.env` 并配置如下选项:
|
| 449 |
+
|
| 450 |
+
```bash
|
| 451 |
+
# LLM Service
|
| 452 |
+
MODEL_REQUEST_URL=http://localhost:8000 # 你的 LLM endpoint
|
| 453 |
+
MODEL_REQUEST_TOKEN=your-token # LLM auth token
|
| 454 |
+
MODEL_NAME=pangu_auto # 模型名
|
| 455 |
+
MODEL_TEMPERATURE=0.3 # 随机度(0.0-1.0)
|
| 456 |
+
MODEL_MAX_TOKENS=8192 # 最大回复长度
|
| 457 |
+
MODEL_REQUEST_TIMEOUT=60 # 请求超时(秒)
|
| 458 |
+
|
| 459 |
+
# Agent 限制
|
| 460 |
+
PLANNER_MAX_ITERATION=40 # Planner 最大 ReAct 步数
|
| 461 |
+
INFORMATION_SEEKER_MAX_ITERATION=30 # 信息搜集最大 ReAct 步数
|
| 462 |
+
WRITER_MAX_ITERATION=40 # Writer 最大 ReAct 步数
|
| 463 |
+
PLANNER_MODE=auto # auto / 长文优先 / qa 优先
|
| 464 |
+
|
| 465 |
+
# MCP Server
|
| 466 |
+
MCP_SERVER_URL=http://localhost:6274/mcp # MCP server endpoint
|
| 467 |
+
MCP_USE_STDIO=false # 使用 stdio 或 HTTP
|
| 468 |
+
|
| 469 |
+
# 外部 API(先实现函数)
|
| 470 |
+
SEARCH_ENGINE_BASE_URL= # 搜索 API endpoint
|
| 471 |
+
SEARCH_ENGINE_API_KEYS= # 搜索 API keys
|
| 472 |
+
URL_CRAWLER_BASE_URL= # URL Crawler API endpoint
|
| 473 |
+
URL_CRAWLER_API_KEYS= # URL Crawler API keys
|
| 474 |
+
URL_CRAWLER_MAX_TOKENS=100000 # URL Crawler 内容最大长度
|
| 475 |
+
|
| 476 |
+
# 存储路径
|
| 477 |
+
TRAJECTORY_STORAGE_PATH=./workspace # Agent工作目录
|
| 478 |
+
REPORT_OUTPUT_PATH=./report # 报告输出目录
|
| 479 |
+
DOCUMENT_ANALYSIS_PATH=./doc_analysis # 文档分析目录
|
| 480 |
+
|
| 481 |
+
# 系统
|
| 482 |
+
DEBUG_MODE=false # 是否开启调试日志
|
| 483 |
+
MAX_RETRIES=3 # API 重试次数
|
| 484 |
+
TIMEOUT=30 # 通用超时(秒)
|
| 485 |
+
```
|
| 486 |
+
|
| 487 |
+
### 5.2 Server 配置(server_config.yaml)
|
| 488 |
+
|
| 489 |
+
`server_config.yaml` 控制服务器行为、工具限流与运行设置:
|
| 490 |
+
|
| 491 |
+
#### 核心服务器设置
|
| 492 |
+
|
| 493 |
+
```yaml
|
| 494 |
+
server:
|
| 495 |
+
host: "127.0.0.1" # 服务器绑定地址
|
| 496 |
+
port: 6274 # 端口
|
| 497 |
+
debug_mode: false # 调试日志
|
| 498 |
+
session_ttl_seconds: 21600 # 会话过期(6小时)
|
| 499 |
+
max_sessions: 1000 # 并发会话上限
|
| 500 |
+
```
|
| 501 |
+
|
| 502 |
+
#### 工具限流
|
| 503 |
+
|
| 504 |
+
对所有会话的外部 API 使用进行控制:
|
| 505 |
+
|
| 506 |
+
```yaml
|
| 507 |
+
tool_rate_limits:
|
| 508 |
+
batch_web_search:
|
| 509 |
+
requests_per_minute: 9000 # 每分钟限制
|
| 510 |
+
burst_limit: 35 # 短时突发
|
| 511 |
+
|
| 512 |
+
url_crawler:
|
| 513 |
+
requests_per_minute: 9000
|
| 514 |
+
burst_limit: 60
|
| 515 |
+
```
|
| 516 |
+
|
| 517 |
+
#### 会话管理
|
| 518 |
+
|
| 519 |
+
```yaml
|
| 520 |
+
server:
|
| 521 |
+
cleanup_interval_seconds: 600 # 清理过期会话(5分钟)
|
| 522 |
+
enable_session_keepalive: true # 长时操作期间保活
|
| 523 |
+
keepalive_touch_interval: 300 # 保活触发间隔(秒)
|
| 524 |
+
```
|
| 525 |
+
|
| 526 |
+
#### 安全与性能
|
| 527 |
+
|
| 528 |
+
```yaml
|
| 529 |
+
server:
|
| 530 |
+
request_timeout_seconds: 1800 # 请求超时
|
| 531 |
+
max_request_size_mb: 1000 # 最大请求体
|
| 532 |
+
rate_limit_requests_per_minute: 300000 # 每 IP 限流
|
| 533 |
+
```
|
| 534 |
+
|
| 535 |
+
配置文件包含对每项设置的详细注释。请根据你的部署需求与外部 API 限额进行调整。
|
| 536 |
+
|
| 537 |
+
## 6. 模型许可证
|
| 538 |
+
|
| 539 |
+
除文件中对开源许可证另有约定外,openPangu-Embedded-7B-DeepDiver 模型根据 OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0 授权,旨在允许使用并促进人工智能技术的进一步发展。有关详细信息,请参阅模型存储库根目录中的 [LICENSE](LICENSE) 文件。
|
| 540 |
+
|
| 541 |
+
## 7. 安全提示与免责声明
|
| 542 |
+
由于 openPangu-Embedded-7B-DeepDiver 模型和框架所依赖的技术固有的技术限制,以及人工智能生成的内容是由盘古自动生成的,华为无法对以下事项做出任何保证:
|
| 543 |
+
|
| 544 |
+
- 尽管该模型的输出由 AI 算法生成,但不能排除某些信息可能存在缺陷、不合理或引起不适的可能性,生成的内容不代表华为的态度或立场;
|
| 545 |
+
- 无法保证该模型 100% 准确、可靠、功能齐全、及时、安全、无错误、不间断、持续稳定或无任何故障;
|
| 546 |
+
- 该模型的输出内容不构成任何建议或决策,也不保证生成的内容的真实性、完整性、准确性、及时性、合法性、功能性或实用性。生成的内容不能替代医疗、法律等领域的专业人士回答您的问题。生成的内容仅供参考,不代表华为的任何态度、立场或观点。您需要根据实际情况做出独立判断,华为不承担任何责任;
|
| 547 |
+
- DeepDiver MAS系统的组件间通信不包含内置的数据加密或认证(如 tokens、签名)。你需要自行评估安全需求并实施相应防护(例如运行在加密网络中、加入 SSL/TLS、强制组件身份校验);
|
| 548 |
+
- 由于缺乏加密/认证导致的任何安全事件(数据泄露、未授权访问、业务损失)由使用方自行承担。项目开发者不承担责任。
|
| 549 |
+
|
| 550 |
+
## 8. 反馈
|
| 551 |
+
|
| 552 |
+
如果有任何意见和建议,请提交issue或联系 openPangu@huawei.com。
|
| 553 |
+
|
| 554 |
+
---
|
README_EN.md
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# openPangu-Embedded-7B-DeepDiver
|
| 2 |
+
[中文](README.md) | English
|
| 3 |
+
📑[Technical Report](https://ai.gitcode.com/ascend-tribe/openPangu-Embedded-7B-DeepDiver/blob/main/docs/openpangu-deepdiver-v2-tech-report.pdf)
|
| 4 |
+
|
| 5 |
+
## 1. Introduction
|
| 6 |
+
DeepDiver is an agentic solution within openPangu series aimed at deep information seeking and processing, which natively supports the Multi-Agent System (MAS) and is designed for complex question answering and long-form report writing.
|
| 7 |
+
|
| 8 |
+
### Features
|
| 9 |
+
- 🔍 Supports QA Mode: Capable of answering 100+ steps of complex knowledge-based questions.
|
| 10 |
+
- ✍️ Supports Long-form Writing Mode: Enables the creation of articles and reports with over 3w+ words.
|
| 11 |
+
- 🔄 Supports Adaptive Mode: Automatically selects between QA Mode and Long-form Writing Mode based on user queries.
|
| 12 |
+
|
| 13 |
+
## 2. Results
|
| 14 |
+
|
| 15 |
+
| Benchmark | Metric | openPangu-7B-DeepDiver|
|
| 16 |
+
| :------------: | :-----------------: | :--------: |
|
| 17 |
+
| **BrowseComp-zh** | Acc | 18.3 |
|
| 18 |
+
| **BrowseComp-en** | Acc | 8.3 |
|
| 19 |
+
| **XBench-DeepSearch** | Acc | 39.0 |
|
| 20 |
+
|
| 21 |
+
Note: The table above only displays the results of complex QA. For the evaluation results of long-form report writing, please refer to the [technical report](https://ai.gitcode.com/ascend-tribe/openPangu-Embedded-7B-DeepDiver/blob/main/docs/openpangu-deepdiver-v2-tech-report.pdf)
|
| 22 |
+
|
| 23 |
+
## 3. Quick Start
|
| 24 |
+
|
| 25 |
+
### 3.1 Setup
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
# Clone and install
|
| 29 |
+
git clone <repository-url>
|
| 30 |
+
cd deepdiver_v2
|
| 31 |
+
pip install -r requirements.txt
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### 3.2 Deployment of the Inference Service
|
| 35 |
+
|
| 36 |
+
#### Pull Images
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
docker pull quay.io/ascend/vllm-ascend:v0.9.2rc1
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Or follow the [official documentation](https://vllm-ascend.readthedocs.io/en/stable/installation.html) to build the docker container manually.
|
| 43 |
+
|
| 44 |
+
#### Run Docker Container
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
docker run -itd --name vllm-deepdiver \
|
| 48 |
+
--network host \
|
| 49 |
+
--device /dev/davinci0 \
|
| 50 |
+
--device /dev/davinci1 \
|
| 51 |
+
--device /dev/davinci2 \
|
| 52 |
+
--device /dev/davinci3 \
|
| 53 |
+
--device /dev/davinci4 \
|
| 54 |
+
--device /dev/davinci5 \
|
| 55 |
+
--device /dev/davinci6 \
|
| 56 |
+
--device /dev/davinci7 \
|
| 57 |
+
-u root \
|
| 58 |
+
--device /dev/davinci_manager \
|
| 59 |
+
--device /dev/devmm_svm \
|
| 60 |
+
--device /dev/hisi_hdc \
|
| 61 |
+
-v /usr/local/dcmi:/usr/local/dcmi:ro \
|
| 62 |
+
-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool:ro \
|
| 63 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi:ro \
|
| 64 |
+
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/:ro \
|
| 65 |
+
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info:ro \
|
| 66 |
+
-v /etc/ascend_install.info:/etc/ascend_install.info:ro \
|
| 67 |
+
-v /usr/local/Ascend/firmware:/usr/local/Ascend/firmware:ro \
|
| 68 |
+
-v /data:/data:ro \
|
| 69 |
+
-v /home/work:/home/work \ # set a working dir
|
| 70 |
+
quay.io/ascend/vllm-ascend:v0.9.2rc1
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
#### Enter the Container
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
docker exec -itu root vllm-deepdiver bash
|
| 77 |
+
```
|
| 78 |
+
Note that `-itu root` is necessary.
|
| 79 |
+
|
| 80 |
+
#### Copy Pangu's Modeling Files
|
| 81 |
+
|
| 82 |
+
`open_pangu.py` and `__init__.py` can be found at [here](https://ai.gitcode.com/ascend-tribe/openpangu-embedded-7b-model/tree/main/inference/vllm_ascend/models)
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
cp ./vllm_ascend/open_pangu.py /vllm-workspace/vllm-ascend/vllm_ascend/models/
|
| 86 |
+
cp ./vllm_ascend/__init__.py /vllm-workspace/vllm-ascend/vllm_ascend/models/
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
#### Start Deployment
|
| 90 |
+
|
| 91 |
+
```
|
| 92 |
+
PRECHECKPOINT_PATH="path/to/deepdiver_model"
|
| 93 |
+
|
| 94 |
+
export VLLM_USE_V1=1
|
| 95 |
+
|
| 96 |
+
export VLLM_WORKER_MULTIPROC_METHOD=fork
|
| 97 |
+
# export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
| 98 |
+
|
| 99 |
+
vllm serve $PRECHECKPOINT_PATH \
|
| 100 |
+
--served-model-name ${SERVED_MODEL_NAME:=pangu_auto} \
|
| 101 |
+
--tensor-parallel-size ${tensor_parallel_size:=8} \
|
| 102 |
+
--trust-remote-code \
|
| 103 |
+
--host 127.0.0.1 \
|
| 104 |
+
--port 8888 \
|
| 105 |
+
--max-num-seqs 256 \
|
| 106 |
+
--max-model-len ${MAX_MODEL_LEN:=131072} \
|
| 107 |
+
--max-num-batched-tokens ${MAX_NUM_BATCHED_TOKENS:=4096} \
|
| 108 |
+
--tokenizer-mode "slow" \
|
| 109 |
+
--dtype bfloat16 \
|
| 110 |
+
--distributed-executor-backend mp \
|
| 111 |
+
--gpu-memory-utilization 0.93 \
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
#### Test Deployment
|
| 115 |
+
|
| 116 |
+
```
|
| 117 |
+
curl -X POST http://127.0.0.1:8888/v1/completions -H "Content-Type: application/json" -d '{
|
| 118 |
+
"model": "pangu_auto",
|
| 119 |
+
"prompt": ["Tell me who you are?"],
|
| 120 |
+
"max_tokens": 50
|
| 121 |
+
}'
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### 3.3 Implement Required Tools
|
| 125 |
+
|
| 126 |
+
Before starting the server, you must implement custom logic for web search and URL crawling tools.
|
| 127 |
+
|
| 128 |
+
#### Web Search (`_generic_search`)
|
| 129 |
+
|
| 130 |
+
**Location**: `src/tools/mcp_tools.py` - `_generic_search` method
|
| 131 |
+
|
| 132 |
+
Replace the `NotImplementedError` with your search API integration:
|
| 133 |
+
|
| 134 |
+
```python
|
| 135 |
+
def _generic_search(self, query: str, max_results: int, config: Dict[str, Any]) -> MCPToolResult:
|
| 136 |
+
"""Your custom search implementation - based on the commented code example"""
|
| 137 |
+
try:
|
| 138 |
+
# Example implementation for search API:
|
| 139 |
+
url = config.get('base_url', 'https://api.search-provider.com/search')
|
| 140 |
+
payload = json.dumps({"q": query, "num": max_results})
|
| 141 |
+
api_keys = config.get('api_keys', [])
|
| 142 |
+
headers = {
|
| 143 |
+
'X-API-KEY': random.choice(api_keys),
|
| 144 |
+
'Content-Type': 'application/json'
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
response = requests.post(url, data=payload, headers=headers)
|
| 148 |
+
response.raise_for_status()
|
| 149 |
+
|
| 150 |
+
# Transform your API response to required format
|
| 151 |
+
search_results = {
|
| 152 |
+
"organic": [
|
| 153 |
+
{
|
| 154 |
+
"title": result["title"],
|
| 155 |
+
"link": result["link"],
|
| 156 |
+
"snippet": result["snippet"],
|
| 157 |
+
"date": result.get("date", "unknown")
|
| 158 |
+
}
|
| 159 |
+
for result in response.json().get("organic", [])
|
| 160 |
+
]
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
return MCPToolResult(success=True, data=search_results)
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
return MCPToolResult(success=False, error=f"Generic search failed: {e}")
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
#### URL Crawler (`url_crawler` and `_content_extractor`)
|
| 170 |
+
|
| 171 |
+
**Location**: `src/tools/mcp_tools.py` - `_content_extractor`
|
| 172 |
+
|
| 173 |
+
Replace the `NotImplementedError` section with your crawler API integration:
|
| 174 |
+
|
| 175 |
+
```python
|
| 176 |
+
# Example implementation for content extractor:
|
| 177 |
+
crawler_url = f"{crawler_config.get('base_url', 'https://api.content-extractor.com')}/{url}"
|
| 178 |
+
response = requests.get(crawler_url, headers=headers, timeout=crawler_config.get('timeout', 30))
|
| 179 |
+
response.raise_for_status()
|
| 180 |
+
|
| 181 |
+
content = response.text
|
| 182 |
+
|
| 183 |
+
# Truncate if needed
|
| 184 |
+
if max_tokens and len(content.split()) > max_tokens:
|
| 185 |
+
words = content.split()[:max_tokens]
|
| 186 |
+
content = ' '.join(words) + '...'
|
| 187 |
+
|
| 188 |
+
return MCPToolResult(success=True, data=content)
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
#### ⚠️ **Third-Party Service Notice**
|
| 192 |
+
|
| 193 |
+
**Important**: Search and crawler tools use external APIs (your choice). We're not responsible for:
|
| 194 |
+
- Privacy/security issues with third-party services
|
| 195 |
+
- Legal compliance with search/crawling activities
|
| 196 |
+
- Content accuracy or copyright issues
|
| 197 |
+
- API downtime or changes
|
| 198 |
+
|
| 199 |
+
Use these services at your own risk. Check their terms and privacy policies.
|
| 200 |
+
|
| 201 |
+
### 3.4 Mandatory Configuration
|
| 202 |
+
|
| 203 |
+
#### Configure the.env file
|
| 204 |
+
Copy `env.template` to `config/.env` and configure these options:
|
| 205 |
+
|
| 206 |
+
```bash
|
| 207 |
+
# LLM Service
|
| 208 |
+
MODEL_REQUEST_URL=http://localhost:8888/v1/chat/completions # Your LLM endpoint
|
| 209 |
+
|
| 210 |
+
# Agent Limits
|
| 211 |
+
PLANNER_MODE=auto # Switching between the auto mode, writing mode, or qa mode.
|
| 212 |
+
|
| 213 |
+
# External APIs (implement functions first)
|
| 214 |
+
SEARCH_ENGINE_BASE_URL= # Search API endpoint
|
| 215 |
+
SEARCH_ENGINE_API_KEYS= # Search API keys
|
| 216 |
+
URL_CRAWLER_BASE_URL= # Crawler API endpoint
|
| 217 |
+
URL_CRAWLER_API_KEYS= # Crawler API keys
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
**⚠️ Important:**
|
| 221 |
+
- Please configure the URL for deploying the inference service from the previous step in `MODEL_REQUEST_URL`
|
| 222 |
+
- Specify the mode in `PLANNER_MODE`. The `auto` mode is designed to automatically determine whether to answer complex questions or generate long-form reports. However, if you wish to prioritize long-form writing, you can set the PLANNER_MODE to ```writing```. Alternatively, if you want to focus solely on solving highly complex problems, configure the mode as ```qa```
|
| 223 |
+
|
| 224 |
+
### 3.5 Start the Tool Server
|
| 225 |
+
|
| 226 |
+
```bash
|
| 227 |
+
python src/tools/mcp_server_standard.py
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
### 3.6 Run the Demo
|
| 231 |
+
|
| 232 |
+
```bash
|
| 233 |
+
# Interactive mode
|
| 234 |
+
python cli/demo.py
|
| 235 |
+
|
| 236 |
+
# Single query
|
| 237 |
+
python cli/demo.py -q "$your_query"
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
Based on the above steps, DeepDiver can be quickly executed. If further development is required, you can refer to [Section 4](#4-customized-tool-development-guide) and [5](#5-customized-configuration).
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
## 4. Customized Tool Development Guide
|
| 244 |
+
Currently, tools are mainly categorized into Built-in Tools and External MCP Tools. Built-in Tools primarily include task assignment, think/reflect, etc. External MCP Tools are extensions that enhance LLM capabilities, such as web search, url crawl, file download, read, and write.
|
| 245 |
+
|
| 246 |
+
### 4.1 Implemented Tool Categories
|
| 247 |
+
|
| 248 |
+
#### A. External MCP Tools
|
| 249 |
+
Web Search and Data Collection:
|
| 250 |
+
- `batch_web_search`: Multi-query web search
|
| 251 |
+
- `url_crawler`: Extract content from URLs
|
| 252 |
+
- `download_files`: Download files from URLs
|
| 253 |
+
|
| 254 |
+
File Operations:
|
| 255 |
+
- `file_read`, `file_write`: Basic file I/O
|
| 256 |
+
- `list_workspace`: Directory listing
|
| 257 |
+
|
| 258 |
+
Document Processing and Content Creation:
|
| 259 |
+
- `document_qa`: Question-answering on documents
|
| 260 |
+
- `document_extract`: Extract text from various formats
|
| 261 |
+
- `section_writer`: Structured content generation
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
#### B. Built-in Tools
|
| 265 |
+
- `think`, `reflect`: Reasoning and planning
|
| 266 |
+
- `task_done`: Task completion reporting
|
| 267 |
+
- `assign_task_xxx`: Assign tasks and create sub-agents
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
### 4.2 Develop and Integrate New External MCP Tools
|
| 271 |
+
|
| 272 |
+
#### A. Implementing a New MCP Tool
|
| 273 |
+
Location: `src/tools/mcp_tools.py` - Add a method to the `MCPTools` class
|
| 274 |
+
|
| 275 |
+
```python
|
| 276 |
+
def your_new_tool(self, param1: str, param2: int) -> MCPToolResult:
|
| 277 |
+
"""
|
| 278 |
+
Description of what your tool does.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
param1: Description of parameter 1
|
| 282 |
+
param2: Description of parameter 2
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
MCPToolResult: Standardized result format
|
| 286 |
+
"""
|
| 287 |
+
try:
|
| 288 |
+
# Your tool implementation here
|
| 289 |
+
result_data = {
|
| 290 |
+
"output": "Tool result",
|
| 291 |
+
"processed_items": param2
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
return MCPToolResult(
|
| 295 |
+
success=True,
|
| 296 |
+
data=result_data,
|
| 297 |
+
metadata={"tool_name": "your_new_tool"}
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
except Exception as e:
|
| 301 |
+
logger.error(f"Tool execution failed: {e}")
|
| 302 |
+
return MCPToolResult(
|
| 303 |
+
success=False,
|
| 304 |
+
error=f"Tool failed: {str(e)}"
|
| 305 |
+
)
|
| 306 |
+
```
|
| 307 |
+
|
| 308 |
+
#### B. Registering the Tool on the Server
|
| 309 |
+
|
| 310 |
+
##### Adding Tool Schema
|
| 311 |
+
Location: `src/tools/mcp_tools.py` - Add to the `MCP_TOOL_SCHEMAS` dictionary
|
| 312 |
+
|
| 313 |
+
```python
|
| 314 |
+
MCP_TOOL_SCHEMAS = {
|
| 315 |
+
# ... existing tools ...
|
| 316 |
+
|
| 317 |
+
"your_new_tool": {
|
| 318 |
+
"name": "your_new_tool",
|
| 319 |
+
"description": "Brief description of what your tool does",
|
| 320 |
+
"inputSchema": {
|
| 321 |
+
"type": "object",
|
| 322 |
+
"properties": {
|
| 323 |
+
"param1": {
|
| 324 |
+
"type": "string",
|
| 325 |
+
"description": "Description of parameter 1"
|
| 326 |
+
},
|
| 327 |
+
"param2": {
|
| 328 |
+
"type": "integer",
|
| 329 |
+
"default": 10,
|
| 330 |
+
"description": "Description of parameter 2"
|
| 331 |
+
}
|
| 332 |
+
},
|
| 333 |
+
"required": ["param1"]
|
| 334 |
+
}
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
```
|
| 338 |
+
|
| 339 |
+
##### Registering the Tool Function
|
| 340 |
+
Location: `src/tools/mcp_server_standard.py` - Add to `get_tool_function()`
|
| 341 |
+
|
| 342 |
+
```python
|
| 343 |
+
def get_tool_function(tool_name: str):
|
| 344 |
+
"""Get the actual function for a tool"""
|
| 345 |
+
tool_map = {
|
| 346 |
+
# ... existing tools ...
|
| 347 |
+
"your_new_tool": lambda tools, **kwargs: tools.your_new_tool(**kwargs),
|
| 348 |
+
}
|
| 349 |
+
return tool_map.get(tool_name)
|
| 350 |
+
```
|
| 351 |
+
|
| 352 |
+
#### C. Making the Tool Accessible to Specific Agents
|
| 353 |
+
The visibility of tools to each agent is controlled by the predefined tool sets in the MCP client.
|
| 354 |
+
|
| 355 |
+
Location: `src/tools/mcp_client.py` - Modify the tool sets for each agent
|
| 356 |
+
|
| 357 |
+
```python
|
| 358 |
+
# Define which MCP server tools each agent can access
|
| 359 |
+
PLANNER_AGENT_TOOLS = [
|
| 360 |
+
"download_files",
|
| 361 |
+
"document_qa",
|
| 362 |
+
"file_read",
|
| 363 |
+
"file_write",
|
| 364 |
+
"str_replace_based_edit_tool",
|
| 365 |
+
"list_workspace",
|
| 366 |
+
"file_find_by_name",
|
| 367 |
+
"your_new_tool", # Add your new tool here
|
| 368 |
+
]
|
| 369 |
+
|
| 370 |
+
INFORMATION_SEEKER_TOOLS = [
|
| 371 |
+
"batch_web_search",
|
| 372 |
+
"url_crawler",
|
| 373 |
+
"document_extract",
|
| 374 |
+
"document_qa",
|
| 375 |
+
"download_files",
|
| 376 |
+
"file_read",
|
| 377 |
+
"file_write",
|
| 378 |
+
"str_replace_based_edit_tool",
|
| 379 |
+
"list_workspace",
|
| 380 |
+
"file_find_by_name",
|
| 381 |
+
"your_new_tool", # Add your new tool here if needed
|
| 382 |
+
]
|
| 383 |
+
|
| 384 |
+
WRITER_AGENT_TOOLS = [
|
| 385 |
+
"file_read",
|
| 386 |
+
"list_workspace",
|
| 387 |
+
"file_find_by_name",
|
| 388 |
+
"search_result_classifier",
|
| 389 |
+
"section_writer",
|
| 390 |
+
"concat_section_files",
|
| 391 |
+
# Add your tool if the writer agent needs it
|
| 392 |
+
]
|
| 393 |
+
```
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
### 4.3 Adding Built-in Agent Tools/Functions
|
| 397 |
+
|
| 398 |
+
#### A. Tools/Functions with Actual Return Values
|
| 399 |
+
Agents in DeepDiver (e.g., the Planner) integrate built-in functions as tools, such as `assign_subjective_task_to_writer` and `assign_multi_objective_tasks_to_info_seeker`. In addition to their specific implementations, these functions require adding **agent-specific tool schemas** using `_build_agent_specific_tool_schemas()`.
|
| 400 |
+
|
| 401 |
+
Location: `src/agents/your_agent.py`
|
| 402 |
+
|
| 403 |
+
```python
|
| 404 |
+
def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
|
| 405 |
+
"""Add built-in agent functions (not MCP server tools)"""
|
| 406 |
+
|
| 407 |
+
# Get base schemas from MCP server via client
|
| 408 |
+
schemas = super()._build_agent_specific_tool_schemas()
|
| 409 |
+
|
| 410 |
+
# Add agent-specific built-in functions like task assignment, completion reporting
|
| 411 |
+
builtin_functions = [
|
| 412 |
+
{
|
| 413 |
+
"type": "function",
|
| 414 |
+
"function": {
|
| 415 |
+
"name": "agent_specific_task_done",
|
| 416 |
+
"description": "Report task completion for this agent",
|
| 417 |
+
"parameters": {
|
| 418 |
+
"type": "object",
|
| 419 |
+
"properties": {
|
| 420 |
+
"result": {"type": "string", "description": "Task result"},
|
| 421 |
+
"status": {"type": "string", "description": "Completion status"}
|
| 422 |
+
},
|
| 423 |
+
"required": ["result", "status"]
|
| 424 |
+
}
|
| 425 |
+
}
|
| 426 |
+
}
|
| 427 |
+
]
|
| 428 |
+
|
| 429 |
+
schemas.extend(builtin_functions)
|
| 430 |
+
return schemas
|
| 431 |
+
```
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
#### B. Built-in Tools with Pseudo Return Values
|
| 435 |
+
Cognitive tools in DeepDiver (e.g., `think` and `reflect`) have no specific implementation. When an agent calls these tools, the tool invocation is considered complete once the agent generates the tool's input parameters. You can directly return a result after the model generates the input parameters, allowing the model to continue with subsequent tasks (refer to the implementation of `_execute_react_loop()` in `planner_agent.py`):
|
| 436 |
+
|
| 437 |
+
```python
|
| 438 |
+
if tool_call["name"] in ["think", "reflect"]:
|
| 439 |
+
tool_result = {"tool_results": "You can proceed to invoke other tools if needed. "}
|
| 440 |
+
```
|
| 441 |
+
|
| 442 |
+
Similarly, such built-in tools also require adding their exclusive tool schemas using `_build_agent_specific_tool_schemas()`.
|
| 443 |
+
|
| 444 |
+
## 5. Customized Configuration
|
| 445 |
+
|
| 446 |
+
### 5.1 Client Configuration
|
| 447 |
+
|
| 448 |
+
Copy `env.template` to `config/.env` and configure these options:
|
| 449 |
+
|
| 450 |
+
```bash
|
| 451 |
+
# LLM Service
|
| 452 |
+
MODEL_REQUEST_URL=http://localhost:8000 # Your LLM endpoint
|
| 453 |
+
MODEL_REQUEST_TOKEN=your-token # Auth token
|
| 454 |
+
MODEL_NAME=pangu_auto # Model name
|
| 455 |
+
MODEL_TEMPERATURE=0.3 # Response randomness (0.0-1.0)
|
| 456 |
+
MODEL_MAX_TOKENS=8192 # Max response length
|
| 457 |
+
MODEL_REQUEST_TIMEOUT=60 # Request timeout (seconds)
|
| 458 |
+
|
| 459 |
+
# Agent Limits
|
| 460 |
+
PLANNER_MAX_ITERATION=40 # Planner maximum ReAct steps
|
| 461 |
+
INFORMATION_SEEKER_MAX_ITERATION=30 # Info seeker maximum ReAct steps
|
| 462 |
+
WRITER_MAX_ITERATION=40 # Writer maximum ReAct steps
|
| 463 |
+
PLANNER_MODE=auto # Switching between the auto mode, long-form writing - priority mode, or the qa - priority mode.
|
| 464 |
+
|
| 465 |
+
# MCP Server
|
| 466 |
+
MCP_SERVER_URL=http://localhost:6274/mcp # MCP server endpoint
|
| 467 |
+
MCP_USE_STDIO=false # Use stdio vs HTTP
|
| 468 |
+
|
| 469 |
+
# External APIs (implement functions first)
|
| 470 |
+
SEARCH_ENGINE_BASE_URL= # Search API endpoint
|
| 471 |
+
SEARCH_ENGINE_API_KEYS= # Search API keys
|
| 472 |
+
URL_CRAWLER_BASE_URL= # Crawler API endpoint
|
| 473 |
+
URL_CRAWLER_API_KEYS= # Crawler API keys
|
| 474 |
+
URL_CRAWLER_MAX_TOKENS=100000 # Max crawled content length
|
| 475 |
+
|
| 476 |
+
# Storage Paths
|
| 477 |
+
TRAJECTORY_STORAGE_PATH=./workspace # Agent work directory
|
| 478 |
+
REPORT_OUTPUT_PATH=./report # Report output directory
|
| 479 |
+
DOCUMENT_ANALYSIS_PATH=./doc_analysis # Document analysis directory
|
| 480 |
+
|
| 481 |
+
# System
|
| 482 |
+
DEBUG_MODE=false # Enable debug logging
|
| 483 |
+
MAX_RETRIES=3 # API retry attempts
|
| 484 |
+
TIMEOUT=30 # General timeout (seconds)
|
| 485 |
+
```
|
| 486 |
+
|
| 487 |
+
### 5.2 Server Configuration (server_config.yaml)
|
| 488 |
+
|
| 489 |
+
The `server_config.yaml` file controls server behavior, tool rate limiting, and operational settings:
|
| 490 |
+
|
| 491 |
+
#### Core Server Settings
|
| 492 |
+
|
| 493 |
+
```yaml
|
| 494 |
+
server:
|
| 495 |
+
host: "127.0.0.1" # Server bind address
|
| 496 |
+
port: 6274 # Server port
|
| 497 |
+
debug_mode: false # Enable debug logging
|
| 498 |
+
session_ttl_seconds: 21600 # Session timeout (6 hours)
|
| 499 |
+
max_sessions: 1000 # Max concurrent sessions
|
| 500 |
+
```
|
| 501 |
+
|
| 502 |
+
#### Tool Rate Limiting
|
| 503 |
+
|
| 504 |
+
Controls external API usage across all sessions:
|
| 505 |
+
|
| 506 |
+
```yaml
|
| 507 |
+
tool_rate_limits:
|
| 508 |
+
batch_web_search:
|
| 509 |
+
requests_per_minute: 9000 # Per-minute limit
|
| 510 |
+
burst_limit: 35 # Short-term burst allowance
|
| 511 |
+
|
| 512 |
+
url_crawler:
|
| 513 |
+
requests_per_minute: 9000
|
| 514 |
+
burst_limit: 60
|
| 515 |
+
```
|
| 516 |
+
|
| 517 |
+
#### Session Management
|
| 518 |
+
|
| 519 |
+
```yaml
|
| 520 |
+
server:
|
| 521 |
+
cleanup_interval_seconds: 600 # Clean expired sessions (5 min)
|
| 522 |
+
enable_session_keepalive: true # Keep sessions alive during long operations
|
| 523 |
+
keepalive_touch_interval: 300 # Touch session every N seconds
|
| 524 |
+
```
|
| 525 |
+
|
| 526 |
+
#### Security & Performance
|
| 527 |
+
|
| 528 |
+
```yaml
|
| 529 |
+
server:
|
| 530 |
+
request_timeout_seconds: 1800 # Request timeout
|
| 531 |
+
max_request_size_mb: 1000 # Maximum request size
|
| 532 |
+
rate_limit_requests_per_minute: 300000 # Requests per IP
|
| 533 |
+
```
|
| 534 |
+
|
| 535 |
+
The configuration file includes detailed comments explaining each setting. Modify values based on your deployment requirements and external API limits.
|
| 536 |
+
|
| 537 |
+
## 6. Model License
|
| 538 |
+
|
| 539 |
+
Unless otherwise noted, openPangu-Embedded-7B-DeepDiver model is licensed under the terms and conditions of OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0, which is intended to be used permissively and enable the further development of artificial intelligence technologies. Please refer to the [LICENSE](LICENSE) file located in the root directory of the model repository for details.
|
| 540 |
+
|
| 541 |
+
## 7. Security Notice and Disclaimer
|
| 542 |
+
|
| 543 |
+
Due to the inherent technical limitations of the technologies relied upon by the openPangu-Embedded-7B-DeepDiver model and its framework, as well as the fact that AI-generated content is automatically produced by Pangu, Huawei cannot make any warranties regarding the following matters:
|
| 544 |
+
|
| 545 |
+
- The output of this Model is automatically generated via AI algorithms, it does not rule out the possibility that some of the information may be flawed, unreasonable, or cause discomfort, and the generated content does not represent Huawei's attitude or standpoint;
|
| 546 |
+
- There is no guarantee that this Model is 100% accurate, reliable, functional, timely, secure and safety, error-free, uninterrupted, continuously stable, or free of any faults;
|
| 547 |
+
- The output of this Model does not constitute any advices or decisions for you, and it does not guarantee the authenticity, completeness, accuracy, timeliness, legality, functionality, or practicality of the generated content. The generated content cannot replace professionals in medical, legal, and other fields in answering your questions. The generated content is for your reference only and does not represent any attitude, standpoint, or position of Huawei. You need to make independent judgments based on your actual situation, and Huawei does not assume any responsibilities;
|
| 548 |
+
- The inter-component communication of the DeepDiver MAS system does not include built-in data encryption or authentication mechanisms (e.g., tokens, signatures). You shall independently assess your security requirements and implement corresponding protective measures (such as deploying the system in an encrypted network, integrating SSL/TLS protocols, and enforcing component identity verification);
|
| 549 |
+
- Any security incidents (including but not limited to data leakage, unauthorized access, and business losses) arising from the lack of encryption/authentication mechanisms shall be borne by the user of the system. Huawei shall bear no responsibility therefor.
|
| 550 |
+
|
| 551 |
+
## 8. Contact Us
|
| 552 |
+
If you have any comments or suggestions, please submit an issue or contact openPangu@huawei.com.
|
| 553 |
+
|
| 554 |
+
---
|
checklist.chk
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
880418a2e195ea5221f700fabc446de4af401c777ea66288eedcfe3ca7861a58 *./config.json
|
| 2 |
+
7302a0fdc386e723b16ad5860787b50a7bff39d30549dae986e073b193f2beb4 *./configuration_openpangu_dense.py
|
| 3 |
+
5cbfc09f10ae85f0e9bebc1281541dcc7107d86e34282839277782cbb146117d *./generation_config.json
|
| 4 |
+
9bf645e8399be6d99000eae64bd172b5c457d6d2c44d2257b47eb97a3c41aeda *./model.safetensors.index.json
|
| 5 |
+
f15eaf322af8a0b0f16b26795eb68af836179413d3dbfa4dc44505db6c8b0d6f *./modeling_openpangu_dense.py
|
| 6 |
+
7b8ec6cd94b1921560d37755c7c0c08280c1f9123195d14d352ad0607788f7f6 *./model-00001-of-00004.safetensors
|
| 7 |
+
fc05d80f52ce44d1433a942e867bf61ea49eb1eebb0700312f76d6b3a3dee917 *./model-00002-of-00004.safetensors
|
| 8 |
+
1ed37f38214c755b51bea06a71e154c9ea27670eb3b8506c06addcfbea2066f2 *./model-00003-of-00004.safetensors
|
| 9 |
+
0145e255ba965ed0e75164a037b9a0137c5e5c12ffc42463ff82568054fe0186 *./model-00004-of-00004.safetensors
|
| 10 |
+
c1f2d87f855b994039c52b1e83c8a7f3d71a2d1eb52946c4a2e862e99f19d8b3 *./modular_openpangu_dense.py
|
| 11 |
+
b34cf5e7c7660889303b6e2d0a346c440356385c9db551d06f6615cf9fc600d1 *./special_tokens_map.json
|
| 12 |
+
6b16f1558c0cd4ae6ef1a2c605713be0a514f50e1ce2d2c878979ce988c148ec *./tokenizer.model
|
| 13 |
+
acb88eac57f8765fedf34e9c10bc16d55c46f0902b0fea74fbf041daca2667ae *./tokenizer_config.json
|
| 14 |
+
c98602d6d1f61792a8bd3393972bbbe7409a205c0bb6299394c74287c26bd723 *./tokenization_openpangu.py
|
config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"PanguEmbeddedForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_openpangu_dense.PanguEmbeddedConfig",
|
| 7 |
+
"AutoModel": "modeling_openpangu_dense.PanguEmbeddedModel",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_openpangu_dense.PanguEmbeddedForCausalLM"
|
| 9 |
+
},
|
| 10 |
+
"bias": true,
|
| 11 |
+
"attention_dropout": 0.0,
|
| 12 |
+
"bos_token_id": 1,
|
| 13 |
+
"pad_token_id": 0,
|
| 14 |
+
"eos_token_id": 45892,
|
| 15 |
+
"hidden_act": "silu",
|
| 16 |
+
"hidden_size": 4096,
|
| 17 |
+
"initializer_range": 0.02,
|
| 18 |
+
"intermediate_size": 12800,
|
| 19 |
+
"max_position_embeddings": 147456,
|
| 20 |
+
"model_type": "PanguEmbedded",
|
| 21 |
+
"num_attention_heads": 32,
|
| 22 |
+
"num_hidden_layers": 34,
|
| 23 |
+
"num_key_value_heads": 8,
|
| 24 |
+
"rms_norm_eps": 1e-05,
|
| 25 |
+
"rope_theta": 16000000.0,
|
| 26 |
+
"tie_word_embeddings": false,
|
| 27 |
+
"torch_dtype": "bfloat16",
|
| 28 |
+
"transformers_version": "4.53.2",
|
| 29 |
+
"use_cache": true,
|
| 30 |
+
"vocab_size": 153376
|
| 31 |
+
}
|
configuration_openpangu_dense.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
|
| 4 |
+
from transformers.utils import logging
|
| 5 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
logger = logging.get_logger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PanguEmbeddedConfig(PretrainedConfig):
|
| 12 |
+
|
| 13 |
+
model_type = "PanguEmbedded"
|
| 14 |
+
_auto_class = "AutoConfig"
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
vocab_size=153376,
|
| 19 |
+
hidden_size=4096,
|
| 20 |
+
intermediate_size=12800,
|
| 21 |
+
num_hidden_layers=34,
|
| 22 |
+
num_attention_heads=32,
|
| 23 |
+
num_key_value_heads=8,
|
| 24 |
+
hidden_act="silu",
|
| 25 |
+
max_position_embeddings=147456,
|
| 26 |
+
initializer_range=0.02,
|
| 27 |
+
rms_norm_eps=1e-5,
|
| 28 |
+
use_cache=True,
|
| 29 |
+
pad_token_id=0,
|
| 30 |
+
bos_token_id=1,
|
| 31 |
+
eos_token_id=45892,
|
| 32 |
+
tie_word_embeddings=False,
|
| 33 |
+
rope_theta=16000000.0,
|
| 34 |
+
bias=True,
|
| 35 |
+
**kwargs,
|
| 36 |
+
):
|
| 37 |
+
self.vocab_size = vocab_size
|
| 38 |
+
self.max_position_embeddings = max_position_embeddings
|
| 39 |
+
self.hidden_size = hidden_size
|
| 40 |
+
self.intermediate_size = intermediate_size
|
| 41 |
+
self.num_hidden_layers = num_hidden_layers
|
| 42 |
+
self.num_attention_heads = num_attention_heads
|
| 43 |
+
self.num_key_value_heads = num_key_value_heads
|
| 44 |
+
self.hidden_act = hidden_act
|
| 45 |
+
self.initializer_range = initializer_range
|
| 46 |
+
self.rms_norm_eps = rms_norm_eps
|
| 47 |
+
self.use_cache = use_cache
|
| 48 |
+
self.rope_theta = rope_theta
|
| 49 |
+
self.bias = bias
|
| 50 |
+
super().__init__(
|
| 51 |
+
pad_token_id=pad_token_id,
|
| 52 |
+
bos_token_id=bos_token_id,
|
| 53 |
+
eos_token_id=eos_token_id,
|
| 54 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 55 |
+
**kwargs,
|
| 56 |
+
)
|
deepdiver_v2/cli/README.md
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLI Demo for DeepDiver Long Writer Multi-Agent System
|
| 2 |
+
|
| 3 |
+
This CLI demo showcases the multi-agent system that coordinates between PlannerAgent, InformationSeekerAgent, and WriterAgent to handle complex queries and generate comprehensive long-form content.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- 🧠 **PlannerAgent**: Orchestrates the entire process and coordinates sub-agents
|
| 8 |
+
- 🔍 **InformationSeekerAgent**: Performs web research and gathers information
|
| 9 |
+
- ✍️ **WriterAgent**: Creates comprehensive long-form content
|
| 10 |
+
- 📊 **Real-time Visualization**: Shows tool calls, reasoning traces, and sub-agent responses
|
| 11 |
+
- ⚙️ **Configuration Management**: Loads settings from .env files
|
| 12 |
+
|
| 13 |
+
## Setup
|
| 14 |
+
|
| 15 |
+
### 1. Install Dependencies
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
cd deepdiver_v2
|
| 19 |
+
pip install -r requirements.txt
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### 2. Configure Environment
|
| 23 |
+
|
| 24 |
+
Create a `.env` file in the `config/` directory:
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
# From the project root
|
| 28 |
+
cp env.template config/.env
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
Then edit `config/.env` with your settings:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
# Custom LLM Service Configuration
|
| 35 |
+
MODEL_REQUEST_URL=http://your-llm-service-endpoint/v1/chat/completions
|
| 36 |
+
MODEL_REQUEST_TOKEN=your-service-token
|
| 37 |
+
MODEL_NAME=pangu_auto
|
| 38 |
+
|
| 39 |
+
# MCP Server Configuration
|
| 40 |
+
MCP_SERVER_URL=http://localhost:6274/mcp
|
| 41 |
+
MCP_AUTH_TOKEN=
|
| 42 |
+
MCP_USE_STDIO=true
|
| 43 |
+
|
| 44 |
+
# Agent Iteration Limits
|
| 45 |
+
PLANNER_MAX_ITERATION=20
|
| 46 |
+
INFORMATION_SEEKER_MAX_ITERATION=30
|
| 47 |
+
WRITER_MAX_ITERATION=20
|
| 48 |
+
|
| 49 |
+
# Mode
|
| 50 |
+
PLANNER_MODE=auto # auto, writing, qa
|
| 51 |
+
|
| 52 |
+
# Other settings...
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
### 3. Start Required Services
|
| 56 |
+
|
| 57 |
+
Make sure your MCP server is running:
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
# Start MCP server (if needed)
|
| 61 |
+
python src/tools/mcp_server_standard.py
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
## Usage
|
| 65 |
+
|
| 66 |
+
### Interactive Mode (Recommended)
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
python cli/demo.py
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
This will start an interactive session where you can enter queries and see the full execution flow.
|
| 73 |
+
|
| 74 |
+
### Single Query Mode
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
python cli/demo.py -q "Write a comprehensive analysis of artificial intelligence trends in 2024"
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
### Configuration Only
|
| 81 |
+
|
| 82 |
+
```bash
|
| 83 |
+
python cli/demo.py --config-only
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### Debug Mode (Verbose Logging)
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
python cli/demo.py --debug -q "Debug a specific query"
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### Quiet Mode (Clean Output)
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
python cli/demo.py --quiet -q "Run with minimal output"
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
### Create Sample Configuration
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
python cli/demo.py --create-env
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
## Example Queries
|
| 105 |
+
|
| 106 |
+
### For Information Seeking Tasks:
|
| 107 |
+
- "What are the latest developments in quantum computing?"
|
| 108 |
+
- "Research the current state of renewable energy adoption globally"
|
| 109 |
+
- "Find information about recent AI breakthroughs in healthcare"
|
| 110 |
+
|
| 111 |
+
### For Long-form Writing Tasks:
|
| 112 |
+
- "Write a comprehensive report on the impact of AI on education"
|
| 113 |
+
- "Create an in-depth analysis of climate change mitigation strategies"
|
| 114 |
+
- "Generate a detailed guide on sustainable business practices"
|
| 115 |
+
|
| 116 |
+
## Demo Flow Visualization
|
| 117 |
+
|
| 118 |
+
The demo provides rich visual feedback showing:
|
| 119 |
+
|
| 120 |
+
1. **🚀 Task Initiation**: Shows the user query and planner startup
|
| 121 |
+
2. **🧠 Agent Reasoning**: Displays the planner's reasoning at each step
|
| 122 |
+
3. **🔧 Tool Calls**: Shows what tools are being called with their arguments
|
| 123 |
+
4. **📋 Tool Results**: Displays the results from each tool execution
|
| 124 |
+
5. **🤝 Sub-Agent Execution**: Shows when sub-agents (InformationSeeker, Writer) are invoked
|
| 125 |
+
6. **📊 Sub-Agent Results**: Displays results from sub-agent executions
|
| 126 |
+
7. **🏁 Final Result**: Shows the complete execution summary
|
| 127 |
+
8. **🔍 Execution Trace**: Detailed step-by-step trace of the entire process
|
| 128 |
+
|
| 129 |
+
## Output Modes
|
| 130 |
+
|
| 131 |
+
The CLI demo supports different output modes for different use cases:
|
| 132 |
+
|
| 133 |
+
### Default Mode
|
| 134 |
+
Shows the full rich interface with welcome screen, progress bars, and detailed visualization of all agent interactions.
|
| 135 |
+
|
| 136 |
+
### Quiet Mode (`--quiet`)
|
| 137 |
+
Suppresses all non-essential output, showing only final results. Useful for:
|
| 138 |
+
- Integration with scripts or automation
|
| 139 |
+
- Focusing on results without process details
|
| 140 |
+
- Running in environments where rich output isn't needed
|
| 141 |
+
|
| 142 |
+
### Debug Mode (`--debug`)
|
| 143 |
+
Enables verbose logging with timestamps, showing all internal system messages. Useful for:
|
| 144 |
+
- Troubleshooting configuration issues
|
| 145 |
+
- Understanding detailed agent behavior
|
| 146 |
+
- Development and debugging
|
| 147 |
+
|
| 148 |
+
```bash
|
| 149 |
+
# Examples of different modes
|
| 150 |
+
python cli/demo.py --query "Test query" # Default rich mode
|
| 151 |
+
python cli/demo.py --quiet --query "Test query" # Minimal output
|
| 152 |
+
python cli/demo.py --debug --query "Test query" # Verbose debugging
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
## Troubleshooting
|
| 156 |
+
|
| 157 |
+
### Configuration Issues
|
| 158 |
+
|
| 159 |
+
If you see configuration errors:
|
| 160 |
+
|
| 161 |
+
1. Ensure `config/.env` exists and is properly formatted
|
| 162 |
+
2. Check that all required environment variables are set
|
| 163 |
+
3. Verify your LLM service endpoint is accessible
|
| 164 |
+
4. Confirm MCP server is running and reachable
|
| 165 |
+
5. Use `--debug` mode to see detailed error messages
|
| 166 |
+
|
| 167 |
+
### Agent Initialization Issues
|
| 168 |
+
|
| 169 |
+
If agent initialization fails:
|
| 170 |
+
|
| 171 |
+
1. Check MCP server connectivity
|
| 172 |
+
2. Verify model configuration is correct
|
| 173 |
+
3. Ensure required permissions for workspace directories
|
| 174 |
+
4. Check log output for specific error messages
|
| 175 |
+
|
| 176 |
+
### Tool Execution Issues
|
| 177 |
+
|
| 178 |
+
If tool calls fail:
|
| 179 |
+
|
| 180 |
+
1. Verify MCP server is running and has the required tools
|
| 181 |
+
2. Check network connectivity for web search/crawler tools
|
| 182 |
+
3. Ensure workspace directories exist and are writable
|
| 183 |
+
4. Review tool arguments for correctness
|
| 184 |
+
|
| 185 |
+
## Advanced Usage
|
| 186 |
+
|
| 187 |
+
### Custom Sub-Agent Configurations
|
| 188 |
+
|
| 189 |
+
You can customize sub-agent behavior by modifying the configurations in the demo script:
|
| 190 |
+
|
| 191 |
+
```python
|
| 192 |
+
sub_agent_configs = {
|
| 193 |
+
"information_seeker": {
|
| 194 |
+
"model": "your-model",
|
| 195 |
+
"max_iterations": 30,
|
| 196 |
+
},
|
| 197 |
+
"writer": {
|
| 198 |
+
"model": "your-model",
|
| 199 |
+
"max_iterations": 20,
|
| 200 |
+
"temperature": 0.3,
|
| 201 |
+
"max_tokens": 16384
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
### Monitoring and Debugging
|
| 207 |
+
|
| 208 |
+
Enable debug mode in your `.env` file:
|
| 209 |
+
|
| 210 |
+
```bash
|
| 211 |
+
DEBUG_MODE=true
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
This will provide more detailed logging and error information.
|
| 215 |
+
|
| 216 |
+
## Architecture Overview
|
| 217 |
+
|
| 218 |
+
The demo showcases a sophisticated multi-agent architecture:
|
| 219 |
+
|
| 220 |
+
```
|
| 221 |
+
User Query
|
| 222 |
+
↓
|
| 223 |
+
PlannerAgent (Coordinator)
|
| 224 |
+
↓
|
| 225 |
+
├── InformationSeekerAgent (Research)
|
| 226 |
+
│ ├── Web Search Tools
|
| 227 |
+
│ ├── URL Crawling Tools
|
| 228 |
+
│ ├── Document Analysis Tools
|
| 229 |
+
│ └── File Management Tools
|
| 230 |
+
│
|
| 231 |
+
└── WriterAgent (Content Generation)
|
| 232 |
+
├── File Reading Tools
|
| 233 |
+
├── Document QA Tools
|
| 234 |
+
├── Content Synthesis
|
| 235 |
+
└── Long-form Writing
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
Each agent follows the ReAct pattern (Reasoning + Acting) with iterative refinement until task completion.
|
deepdiver_v2/cli/demo.py
ADDED
|
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
#!/usr/bin/env python3
|
| 3 |
+
"""
|
| 4 |
+
CLI Demo for DeepDiver Long Writer Multi-Agent System
|
| 5 |
+
|
| 6 |
+
This demo showcases the multi-agent system that includes:
|
| 7 |
+
- PlannerAgent: Coordinates and orchestrates the entire process
|
| 8 |
+
- InformationSeekerAgent: Gathers and researches information
|
| 9 |
+
- WriterAgent: Creates long-form content
|
| 10 |
+
|
| 11 |
+
Features:
|
| 12 |
+
- Loads configuration from config/.env file
|
| 13 |
+
- Shows real-time tool calls and reasoning traces
|
| 14 |
+
- Displays sub-agent responses and interactions
|
| 15 |
+
- Visualizes the complete execution flow
|
| 16 |
+
- Query preprocessing for safety and task suitability check
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import json
|
| 22 |
+
import time
|
| 23 |
+
import logging
|
| 24 |
+
import argparse
|
| 25 |
+
import requests
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Dict, Any, List, Optional
|
| 28 |
+
from rich.console import Console
|
| 29 |
+
from rich.table import Table
|
| 30 |
+
from rich.panel import Panel
|
| 31 |
+
from rich.syntax import Syntax
|
| 32 |
+
from rich.markdown import Markdown
|
| 33 |
+
|
| 34 |
+
# Add project root to Python path
|
| 35 |
+
project_root = Path(__file__).parent.parent
|
| 36 |
+
sys.path.insert(0, str(project_root))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Configure logging to keep the CLI clean
|
| 40 |
+
def setup_clean_logging(debug_mode: bool = False):
|
| 41 |
+
"""Configure logging to show only relevant information for the demo"""
|
| 42 |
+
if debug_mode:
|
| 43 |
+
# Debug mode: show all logs with timestamps
|
| 44 |
+
logging.basicConfig(
|
| 45 |
+
level=logging.DEBUG,
|
| 46 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 47 |
+
datefmt='%H:%M:%S'
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
# Clean demo mode: suppress verbose logs
|
| 51 |
+
|
| 52 |
+
# Suppress specific noisy loggers
|
| 53 |
+
noisy_loggers = [
|
| 54 |
+
'httpx',
|
| 55 |
+
'httpcore',
|
| 56 |
+
'urllib3',
|
| 57 |
+
'src.tools.mcp_client',
|
| 58 |
+
# 'src.agents.base_agent',
|
| 59 |
+
'config.config' # Also suppress config messages in quiet mode
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
for logger_name in noisy_loggers:
|
| 63 |
+
logging.getLogger(logger_name).setLevel(logging.ERROR)
|
| 64 |
+
|
| 65 |
+
# Set up default clean logging before any imports
|
| 66 |
+
setup_clean_logging(debug_mode=False)
|
| 67 |
+
|
| 68 |
+
# Import the multi-agent system components
|
| 69 |
+
from config.config import get_config, reload_config
|
| 70 |
+
from src.agents.planner_agent import create_planner_agent
|
| 71 |
+
from src.agents.base_agent import AgentResponse
|
| 72 |
+
|
| 73 |
+
console = Console()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class DemoVisualizer:
|
| 77 |
+
"""Visualizes the execution of the multi-agent system"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, quiet_mode: bool = False):
|
| 80 |
+
self.console = console
|
| 81 |
+
self.execution_log = []
|
| 82 |
+
self.quiet_mode = quiet_mode
|
| 83 |
+
|
| 84 |
+
def _should_display(self, force: bool = False) -> bool:
|
| 85 |
+
"""Check if output should be displayed based on quiet mode"""
|
| 86 |
+
return not self.quiet_mode or force
|
| 87 |
+
|
| 88 |
+
def show_welcome(self):
|
| 89 |
+
"""Display welcome message and system info"""
|
| 90 |
+
if not self._should_display():
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
welcome_text = """
|
| 94 |
+
# 🤖 DeepDiver Long Writer Multi-Agent System Demo
|
| 95 |
+
|
| 96 |
+
This demo showcases an advanced multi-agent system for research and long-form content generation.
|
| 97 |
+
|
| 98 |
+
## System Components:
|
| 99 |
+
- **🧠 PlannerAgent**: Orchestrates the entire process and coordinates sub-agents
|
| 100 |
+
- **🔍 InformationSeekerAgent**: Performs web research and gathers information
|
| 101 |
+
- **✍️ WriterAgent**: Creates comprehensive long-form content
|
| 102 |
+
|
| 103 |
+
## Features:
|
| 104 |
+
- Real-time tool execution visualization
|
| 105 |
+
- Sub-agent response tracking
|
| 106 |
+
- Complete reasoning trace display
|
| 107 |
+
- Configuration management
|
| 108 |
+
- Query safety and suitability pre-check
|
| 109 |
+
"""
|
| 110 |
+
self.console.print(Panel(Markdown(welcome_text), title="[bold blue]Welcome", border_style="blue"))
|
| 111 |
+
|
| 112 |
+
def show_config(self, config):
|
| 113 |
+
"""Display current configuration"""
|
| 114 |
+
if not self._should_display():
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
config_table = Table(title="📋 System Configuration", show_header=True, header_style="bold magenta")
|
| 118 |
+
config_table.add_column("Setting", style="cyan", no_wrap=True)
|
| 119 |
+
config_table.add_column("Value", style="green")
|
| 120 |
+
|
| 121 |
+
# Safe config display (hide sensitive values)
|
| 122 |
+
safe_config = config.to_dict()
|
| 123 |
+
for key, value in safe_config.items():
|
| 124 |
+
if value is not None and str(value) != "None":
|
| 125 |
+
display_value = str(value)
|
| 126 |
+
if len(display_value) > 60:
|
| 127 |
+
display_value = display_value[:57] + "..."
|
| 128 |
+
config_table.add_row(key, display_value)
|
| 129 |
+
|
| 130 |
+
self.console.print(config_table)
|
| 131 |
+
|
| 132 |
+
def show_planner_start(self, query: str):
|
| 133 |
+
"""Show planner starting execution"""
|
| 134 |
+
self.console.print(Panel(
|
| 135 |
+
f"[bold yellow]User Query:[/bold yellow] {query}\n\n"
|
| 136 |
+
f"[bold green]🚀 Starting PlannerAgent execution...[/bold green]",
|
| 137 |
+
title="[bold blue]Task Initiation",
|
| 138 |
+
border_style="green"
|
| 139 |
+
))
|
| 140 |
+
|
| 141 |
+
def show_reasoning_step(self, iteration: int, reasoning: str):
|
| 142 |
+
"""Display reasoning step"""
|
| 143 |
+
self.console.print(Panel(
|
| 144 |
+
Markdown(f"**Iteration {iteration} - Reasoning:**\n\n{reasoning}"),
|
| 145 |
+
title=f"[bold yellow]🧠 Agent Reasoning (Step {iteration})",
|
| 146 |
+
border_style="yellow"
|
| 147 |
+
))
|
| 148 |
+
|
| 149 |
+
def show_tool_call(self, iteration: int, tool_name: str, arguments: Dict[str, Any]):
|
| 150 |
+
"""Display tool call"""
|
| 151 |
+
args_json = json.dumps(arguments, indent=2, ensure_ascii=False)
|
| 152 |
+
|
| 153 |
+
self.console.print(Panel(
|
| 154 |
+
f"[bold cyan]Tool:[/bold cyan] {tool_name}\n\n"
|
| 155 |
+
f"[bold cyan]Arguments:[/bold cyan]\n{Syntax(args_json, 'json', theme='monokai', line_numbers=True)}",
|
| 156 |
+
title=f"[bold cyan]🔧 Tool Call (Step {iteration})",
|
| 157 |
+
border_style="cyan"
|
| 158 |
+
))
|
| 159 |
+
|
| 160 |
+
def show_tool_result(self, iteration: int, tool_name: str, result: Dict[str, Any]):
|
| 161 |
+
"""Display tool result"""
|
| 162 |
+
success = result.get("success", True)
|
| 163 |
+
status_icon = "✅" if success else "❌"
|
| 164 |
+
status_color = "green" if success else "red"
|
| 165 |
+
|
| 166 |
+
# Format result for display
|
| 167 |
+
if success and "data" in result:
|
| 168 |
+
display_result = result["data"]
|
| 169 |
+
elif "error" in result:
|
| 170 |
+
display_result = {"error": result["error"]}
|
| 171 |
+
else:
|
| 172 |
+
display_result = result
|
| 173 |
+
|
| 174 |
+
result_text = json.dumps(display_result, indent=2, ensure_ascii=False)
|
| 175 |
+
if len(result_text) > 1000:
|
| 176 |
+
result_text = result_text[:997] + "..."
|
| 177 |
+
|
| 178 |
+
self.console.print(Panel(
|
| 179 |
+
f"[bold {status_color}]Status:[/bold {status_color}] {status_icon} {'Success' if success else 'Failed'}\n\n"
|
| 180 |
+
f"[bold {status_color}]Result:[/bold {status_color}]\n{Syntax(result_text, 'json', theme='monokai', line_numbers=True)}",
|
| 181 |
+
title=f"[bold {status_color}]📋 Tool Result: {tool_name} (Step {iteration})",
|
| 182 |
+
border_style=status_color
|
| 183 |
+
))
|
| 184 |
+
|
| 185 |
+
def show_sub_agent_execution(self, agent_name: str, task_content: str):
|
| 186 |
+
"""Show sub-agent starting execution"""
|
| 187 |
+
self.console.print(Panel(
|
| 188 |
+
f"[bold magenta]Agent:[/bold magenta] {agent_name}\n\n"
|
| 189 |
+
f"[bold magenta]Task:[/bold magenta] {task_content[:500]}{'...' if len(task_content) > 500 else ''}",
|
| 190 |
+
title="[bold magenta]🤝 Sub-Agent Execution",
|
| 191 |
+
border_style="magenta"
|
| 192 |
+
))
|
| 193 |
+
|
| 194 |
+
def show_sub_agent_result(self, agent_name: str, result: Dict[str, Any]):
|
| 195 |
+
"""Show sub-agent execution result"""
|
| 196 |
+
success = result.get("success", True)
|
| 197 |
+
status_icon = "✅" if success else "❌"
|
| 198 |
+
status_color = "green" if success else "red"
|
| 199 |
+
|
| 200 |
+
# Extract key information
|
| 201 |
+
iterations = result.get("iterations", 0)
|
| 202 |
+
execution_time = result.get("execution_time", 0)
|
| 203 |
+
|
| 204 |
+
summary = f"[bold {status_color}]Status:[/bold {status_color}] {status_icon} {'Success' if success else 'Failed'}\n"
|
| 205 |
+
summary += f"[bold blue]Iterations:[/bold blue] {iterations}\n"
|
| 206 |
+
summary += f"[bold blue]Execution Time:[/bold blue] {execution_time:.2f}s\n\n"
|
| 207 |
+
|
| 208 |
+
if success and "data" in result:
|
| 209 |
+
data = result["data"]
|
| 210 |
+
if isinstance(data, dict):
|
| 211 |
+
for key, value in data.items():
|
| 212 |
+
if isinstance(value, str) and len(value) > 200:
|
| 213 |
+
summary += f"[bold blue]{key}:[/bold blue] {value[:197]}...\n"
|
| 214 |
+
else:
|
| 215 |
+
summary += f"[bold blue]{key}:[/bold blue] {value}\n"
|
| 216 |
+
elif "error" in result:
|
| 217 |
+
summary += f"[bold red]Error:[/bold red] {result['error']}\n"
|
| 218 |
+
|
| 219 |
+
self.console.print(Panel(
|
| 220 |
+
summary,
|
| 221 |
+
title=f"[bold {status_color}]📊 Sub-Agent Result: {agent_name}",
|
| 222 |
+
border_style=status_color
|
| 223 |
+
))
|
| 224 |
+
|
| 225 |
+
def show_final_result(self, response: AgentResponse):
|
| 226 |
+
"""Display final execution result"""
|
| 227 |
+
# Always show final results, even in quiet mode
|
| 228 |
+
if not self._should_display(force=True):
|
| 229 |
+
return
|
| 230 |
+
|
| 231 |
+
success = response.success
|
| 232 |
+
status_icon = "✅" if success else "❌"
|
| 233 |
+
status_color = "green" if success else "red"
|
| 234 |
+
|
| 235 |
+
summary = f"[bold {status_color}]Final Status:[/bold {status_color}] {status_icon} {'Completed Successfully' if success else 'Failed'}\n"
|
| 236 |
+
summary += f"[bold blue]Total Iterations:[/bold blue] {response.iterations}\n"
|
| 237 |
+
summary += f"[bold blue]Total Execution Time:[/bold blue] {response.execution_time:.2f}s\n"
|
| 238 |
+
summary += f"[bold blue]Agent:[/bold blue] {response.agent_name}\n\n"
|
| 239 |
+
|
| 240 |
+
if success and response.result:
|
| 241 |
+
if isinstance(response.result, dict):
|
| 242 |
+
for key, value in response.result.items():
|
| 243 |
+
if isinstance(value, str) and len(value) > 3000:
|
| 244 |
+
summary += f"[bold blue]{key}:[/bold blue] {value[:2997]}...\n\n"
|
| 245 |
+
else:
|
| 246 |
+
summary += f"[bold blue]{key}:[/bold blue] {value}\n\n"
|
| 247 |
+
elif response.error:
|
| 248 |
+
summary += f"[bold red]Error:[/bold red] {response.error}\n"
|
| 249 |
+
|
| 250 |
+
self.console.print(Panel(
|
| 251 |
+
summary,
|
| 252 |
+
title=f"[bold {status_color}]🏁 Final Result",
|
| 253 |
+
border_style=status_color
|
| 254 |
+
))
|
| 255 |
+
|
| 256 |
+
def show_reasoning_trace(self, trace: List[Dict[str, Any]]):
|
| 257 |
+
"""Display detailed reasoning trace"""
|
| 258 |
+
if not trace:
|
| 259 |
+
return
|
| 260 |
+
|
| 261 |
+
trace_table = Table(title="🔍 Detailed Execution Trace", show_header=True, header_style="bold cyan")
|
| 262 |
+
trace_table.add_column("Step", style="cyan", width=8)
|
| 263 |
+
trace_table.add_column("Type", style="magenta", width=12)
|
| 264 |
+
trace_table.add_column("Details", style="white")
|
| 265 |
+
|
| 266 |
+
for i, step in enumerate(trace, 1):
|
| 267 |
+
step_type = step.get("type", "unknown")
|
| 268 |
+
|
| 269 |
+
if step_type == "reasoning":
|
| 270 |
+
content = step.get("content", "")[:100] + ("..." if len(step.get("content", "")) > 100 else "")
|
| 271 |
+
trace_table.add_row(str(i), "🧠 Reasoning", content)
|
| 272 |
+
|
| 273 |
+
elif step_type == "action":
|
| 274 |
+
tool = step.get("tool", "")
|
| 275 |
+
result_status = "✅" if step.get("result", {}).get("success", True) else "❌"
|
| 276 |
+
trace_table.add_row(str(i), "🔧 Tool Call", f"{result_status} {tool}")
|
| 277 |
+
|
| 278 |
+
elif step_type == "error":
|
| 279 |
+
error = step.get("error", "")[:100] + ("..." if len(step.get("error", "")) > 100 else "")
|
| 280 |
+
trace_table.add_row(str(i), "❌ Error", error)
|
| 281 |
+
|
| 282 |
+
self.console.print(trace_table)
|
| 283 |
+
|
| 284 |
+
def show_unsupported_response(self):
|
| 285 |
+
"""Display the fixed response for unsupported queries"""
|
| 286 |
+
# Always show this response, even in quiet mode
|
| 287 |
+
self.console.print(Panel(
|
| 288 |
+
"Sorry, your question is not within the current scope of tasks for DeepDiver-V2. Please try asking a question related to long-form writing or complex knowledge Q&A instead.",
|
| 289 |
+
title="[bold yellow]❌ Unsupported Query",
|
| 290 |
+
border_style="yellow"
|
| 291 |
+
))
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class AgentExecutionMonitor:
|
| 295 |
+
"""Monitors agent execution and provides real-time feedback"""
|
| 296 |
+
|
| 297 |
+
def __init__(self, visualizer: DemoVisualizer):
|
| 298 |
+
self.visualizer = visualizer
|
| 299 |
+
self.current_iteration = 0
|
| 300 |
+
|
| 301 |
+
def on_reasoning_step(self, iteration: int, reasoning: str):
|
| 302 |
+
"""Called when agent performs reasoning"""
|
| 303 |
+
self.visualizer.show_reasoning_step(iteration, reasoning)
|
| 304 |
+
|
| 305 |
+
def on_tool_call(self, iteration: int, tool_name: str, arguments: Dict[str, Any]):
|
| 306 |
+
"""Called when agent makes a tool call"""
|
| 307 |
+
self.visualizer.show_tool_call(iteration, tool_name, arguments)
|
| 308 |
+
|
| 309 |
+
# Check for sub-agent assignments
|
| 310 |
+
if "assign_" in tool_name and "task" in tool_name:
|
| 311 |
+
if "tasks" in arguments:
|
| 312 |
+
for task in arguments.get("tasks", []):
|
| 313 |
+
task_content = task.get("task_content", "")
|
| 314 |
+
self.visualizer.show_sub_agent_execution("InformationSeeker", task_content)
|
| 315 |
+
elif "task_content" in arguments:
|
| 316 |
+
task_content = arguments.get("task_content", "")
|
| 317 |
+
self.visualizer.show_sub_agent_execution("Writer", task_content)
|
| 318 |
+
|
| 319 |
+
def on_tool_result(self, iteration: int, tool_name: str, result: Dict[str, Any]):
|
| 320 |
+
"""Called when tool execution completes"""
|
| 321 |
+
self.visualizer.show_tool_result(iteration, tool_name, result)
|
| 322 |
+
|
| 323 |
+
# Show sub-agent results if this was an assignment
|
| 324 |
+
if "assign_" in tool_name and "task" in tool_name:
|
| 325 |
+
if "data" in result and "tasks" in result["data"]:
|
| 326 |
+
for task_result in result["data"]["tasks"]:
|
| 327 |
+
agent_name = task_result.get("agent_name", "InformationSeeker")
|
| 328 |
+
self.visualizer.show_sub_agent_result(agent_name, task_result)
|
| 329 |
+
elif "data" in result:
|
| 330 |
+
agent_name = result["data"].get("agent_name", "Writer")
|
| 331 |
+
self.visualizer.show_sub_agent_result(agent_name, result["data"])
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def classify_query(query: str, config) -> Dict[str, Any]:
|
| 335 |
+
"""
|
| 336 |
+
Classify user query into one of three categories using LLM:
|
| 337 |
+
1. SAFE_SENSITIVE: Contains unsafe content (insults, political risks, etc.)
|
| 338 |
+
2. NON_KNOWLEDGE: Non-knowledge intensive (no need for research, e.g., greetings, simple calculations)
|
| 339 |
+
3. NORMAL: Requires processing (long-form writing or complex knowledge Q&A)
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
Dict with 'category' (str) and 'reasoning' (str)
|
| 343 |
+
"""
|
| 344 |
+
logger = logging.getLogger(__name__)
|
| 345 |
+
|
| 346 |
+
# Get model configuration
|
| 347 |
+
model_config = config.get_custom_llm_config()
|
| 348 |
+
pangu_url = model_config.get('url') or os.getenv('MODEL_REQUEST_URL', '')
|
| 349 |
+
model_token = model_config.get('token') or os.getenv('MODEL_REQUEST_TOKEN', '')
|
| 350 |
+
|
| 351 |
+
# Validate model configuration
|
| 352 |
+
if not pangu_url:
|
| 353 |
+
logger.error("Model URL not configured for query classification")
|
| 354 |
+
# Fallback to NORMAL category if model config is missing
|
| 355 |
+
return {
|
| 356 |
+
"category": "NORMAL",
|
| 357 |
+
"reasoning": "模型配置不完整,跳过分类检查,默认按正常任务处理"
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
headers = {'Content-Type': 'application/json', 'csb-token': model_token}
|
| 361 |
+
|
| 362 |
+
# Classification prompt (detailed instructions for accurate categorization)
|
| 363 |
+
prompt_template = """
|
| 364 |
+
你是一个Query分类器,需要将用户输入的查询分为以下三类,并给出明确的分类理由:
|
| 365 |
+
|
| 366 |
+
1. 【SAFE_SENSITIVE - 安全敏感内容】:包含以下任何一种情况的查询
|
| 367 |
+
- 辱骂、侮辱性语言(如脏话、人身攻击)
|
| 368 |
+
- 涉及政治敏感内容(如国家领导人、敏感政治事件、舆情风险话题)
|
| 369 |
+
- 违法违规内容(如暴力、色情、恐怖主义相关)
|
| 370 |
+
- 歧视性言论(种族、性别、宗教等歧视)
|
| 371 |
+
|
| 372 |
+
2. 【NON_KNOWLEDGE - 非知识密集型任务】:不需要进行信息搜索的简单查询
|
| 373 |
+
- 问候语(如"你好"、"早上好"、"嗨")
|
| 374 |
+
- 简单计算(如"1+1等于几"、"25乘以4是多少")
|
| 375 |
+
- 基础闲聊(如"你是谁")
|
| 376 |
+
- 指令性语句(如"退出"、"帮助"、"开始")
|
| 377 |
+
- 不需要信息收集的简单问题
|
| 378 |
+
|
| 379 |
+
3. 【NORMAL - 正常任务】:不包含安全敏感内容,需要进行信息搜索或长文写作的任务
|
| 380 |
+
- 简单的信息收集任务 (如"华为成立时间是什么时候")
|
| 381 |
+
- 复杂知识问答(如"ACL2025举办地有什么美食推荐")
|
| 382 |
+
- 长文写作任务(如"写一篇关于气候变化影响的5000字报告")
|
| 383 |
+
- 需要数据支持的分析(如"2023年全球经济增长数据及分析")
|
| 384 |
+
- 专业领域研究(如"机器学习在医疗诊断中的应用案例")
|
| 385 |
+
|
| 386 |
+
分类要求:
|
| 387 |
+
- 严格按照上述定义进行分类,不要遗漏任何关键特征
|
| 388 |
+
- 优先判断是否为SAFE_SENSITIVE,其次判断是否为NON_KNOWLEDGE,最后才是NORMAL
|
| 389 |
+
- 必须提供清晰的分类理由,说明为什么属于该类别
|
| 390 |
+
- 输出格式必须严格遵循:先输出分类理由的思考,然后换行输出分类结果(SAFE_SENSITIVE/NON_KNOWLEDGE/NORMAL)
|
| 391 |
+
|
| 392 |
+
示例1(SAFE_SENSITIVE):
|
| 393 |
+
该查询包含辱骂性语言"XXX",符合安全敏感内容的定义,属于需要拦截的内容
|
| 394 |
+
SAFE_SENSITIVE
|
| 395 |
+
|
| 396 |
+
示例2(NON_KNOWLEDGE):
|
| 397 |
+
该查询是简单的问候语"你好",不需要进行信息搜索,属于非知识密集型任务
|
| 398 |
+
NON_KNOWLEDGE
|
| 399 |
+
|
| 400 |
+
示例3(NORMAL):
|
| 401 |
+
该查询不包含安全敏感内容,要求撰写关于"区块链技术在金融领域的应用"的长文,需要进行信息收集、案例研究和深度分析,属于正常的长文写作任务
|
| 402 |
+
NORMAL
|
| 403 |
+
|
| 404 |
+
用户输入query:$query"""
|
| 405 |
+
|
| 406 |
+
# Prepare conversation history
|
| 407 |
+
conversation_history = [
|
| 408 |
+
{"role": "user", "content": prompt_template.replace("$query", query) + " /no_think"}
|
| 409 |
+
]
|
| 410 |
+
|
| 411 |
+
try:
|
| 412 |
+
# Call LLM with retry logic
|
| 413 |
+
retry_num = 1
|
| 414 |
+
max_retry_num = 3
|
| 415 |
+
while retry_num <= max_retry_num:
|
| 416 |
+
try:
|
| 417 |
+
response = requests.post(
|
| 418 |
+
url=pangu_url,
|
| 419 |
+
headers=headers,
|
| 420 |
+
json={
|
| 421 |
+
"model": config.model_name,
|
| 422 |
+
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<s>[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{'<s>[unused9]系统:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}",
|
| 423 |
+
"spaces_between_special_tokens": False,
|
| 424 |
+
"messages": conversation_history,
|
| 425 |
+
"temperature": 0.1, # Low temperature for deterministic classification
|
| 426 |
+
"max_tokens": 5000,
|
| 427 |
+
},
|
| 428 |
+
timeout=model_config.get("timeout", 60)
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
response_json = response.json()
|
| 432 |
+
logger.debug(f"Classification API response: {json.dumps(response_json, indent=2)}")
|
| 433 |
+
|
| 434 |
+
# Extract and parse result
|
| 435 |
+
assistant_message = response_json["choices"][0]["message"]["content"].strip()
|
| 436 |
+
lines = assistant_message.split('\n', 1)
|
| 437 |
+
|
| 438 |
+
if len(lines) < 2:
|
| 439 |
+
raise ValueError(f"Invalid response format: {assistant_message}")
|
| 440 |
+
|
| 441 |
+
reasoning = lines[0].strip()
|
| 442 |
+
category = lines[1].strip() if len(lines) > 1 else "NORMAL"
|
| 443 |
+
|
| 444 |
+
# Validate category
|
| 445 |
+
valid_categories = ["SAFE_SENSITIVE", "NON_KNOWLEDGE", "NORMAL"]
|
| 446 |
+
if category not in valid_categories:
|
| 447 |
+
logger.warning(f"Invalid category '{category}', using fallback NORMAL")
|
| 448 |
+
category = "NORMAL"
|
| 449 |
+
reasoning = f"模型返回无效分类 '{category}',默认按正常任务处理。原始理由:{reasoning}"
|
| 450 |
+
|
| 451 |
+
return {
|
| 452 |
+
"category": category,
|
| 453 |
+
"reasoning": reasoning
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
except Exception as e:
|
| 457 |
+
logger.error(f"Classification attempt {retry_num} failed: {str(e)}")
|
| 458 |
+
if retry_num == max_retry_num:
|
| 459 |
+
raise
|
| 460 |
+
time.sleep(2) # Wait before retry
|
| 461 |
+
retry_num += 1
|
| 462 |
+
|
| 463 |
+
except Exception as e:
|
| 464 |
+
logger.error(f"Query classification failed: {str(e)}")
|
| 465 |
+
# Fallback to NORMAL category if classification fails
|
| 466 |
+
return {
|
| 467 |
+
"category": "NORMAL",
|
| 468 |
+
"reasoning": f"分类服务暂时不可用(错误:{str(e)[:100]}...),默认按正常任务处理"
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def load_environment_config(quiet: bool = False):
|
| 473 |
+
"""Load configuration from .env file"""
|
| 474 |
+
try:
|
| 475 |
+
# Check for .env file in config directory
|
| 476 |
+
config_dir = Path(__file__).parent.parent / "config"
|
| 477 |
+
env_file = config_dir / ".env"
|
| 478 |
+
|
| 479 |
+
if not env_file.exists():
|
| 480 |
+
if not quiet:
|
| 481 |
+
console.print(f"[yellow]⚠️ No .env file found at {env_file}[/yellow]")
|
| 482 |
+
console.print(f"[yellow]💡 Please copy env.template to config/.env and configure your settings[/yellow]")
|
| 483 |
+
return None
|
| 484 |
+
|
| 485 |
+
# Reload configuration to pick up .env file
|
| 486 |
+
reload_config()
|
| 487 |
+
config = get_config()
|
| 488 |
+
|
| 489 |
+
if not quiet:
|
| 490 |
+
console.print("[green]✅ Configuration loaded successfully[/green]")
|
| 491 |
+
return config
|
| 492 |
+
|
| 493 |
+
except Exception as e:
|
| 494 |
+
if not quiet:
|
| 495 |
+
console.print(f"[red]❌ Failed to load configuration: {e}[/red]")
|
| 496 |
+
return None
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def create_sample_env_file():
|
| 500 |
+
"""Create a sample .env file for demo purposes"""
|
| 501 |
+
config_dir = Path(__file__).parent.parent / "config"
|
| 502 |
+
env_file = config_dir / ".env"
|
| 503 |
+
|
| 504 |
+
if env_file.exists():
|
| 505 |
+
return
|
| 506 |
+
|
| 507 |
+
# Copy from template
|
| 508 |
+
template_file = Path(__file__).parent.parent / "env.template"
|
| 509 |
+
if template_file.exists():
|
| 510 |
+
import shutil
|
| 511 |
+
shutil.copy2(template_file, env_file)
|
| 512 |
+
console.print(f"[green]✅ Created .env file from template at {env_file}[/green]")
|
| 513 |
+
console.print("[yellow]⚠️ Please edit the .env file with your actual configuration values[/yellow]")
|
| 514 |
+
else:
|
| 515 |
+
console.print(f"[red]❌ Could not find env.template to copy[/red]")
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def run_demo_query(planner, query: str, visualizer: DemoVisualizer, config) -> Optional[AgentResponse]:
|
| 519 |
+
"""Run a demo query through the planner with preprocessing"""
|
| 520 |
+
|
| 521 |
+
# Step 1: Show query information
|
| 522 |
+
visualizer.show_planner_start(query)
|
| 523 |
+
|
| 524 |
+
# Step 2: Query classification (preprocessing)
|
| 525 |
+
classification_result = classify_query(query, config)
|
| 526 |
+
|
| 527 |
+
# Step 3: Branch processing based on classification
|
| 528 |
+
unsupported_categories = ["SAFE_SENSITIVE", "NON_KNOWLEDGE"]
|
| 529 |
+
if classification_result["category"] in unsupported_categories:
|
| 530 |
+
# Show fixed response for unsupported queries
|
| 531 |
+
visualizer.show_unsupported_response()
|
| 532 |
+
return None
|
| 533 |
+
|
| 534 |
+
# Step 4: Process normal query (original flow)
|
| 535 |
+
try:
|
| 536 |
+
# Execute the query
|
| 537 |
+
with console.status("[bold green]Executing planner task...", spinner="dots"):
|
| 538 |
+
response = planner.execute_task(query)
|
| 539 |
+
|
| 540 |
+
# Show final results
|
| 541 |
+
visualizer.show_final_result(response)
|
| 542 |
+
|
| 543 |
+
# Show detailed trace if available
|
| 544 |
+
if hasattr(response, 'reasoning_trace') and response.reasoning_trace:
|
| 545 |
+
visualizer.show_reasoning_trace(response.reasoning_trace)
|
| 546 |
+
|
| 547 |
+
return response
|
| 548 |
+
|
| 549 |
+
except Exception as e:
|
| 550 |
+
console.print(f"[red]❌ Error during execution: {e}[/red]")
|
| 551 |
+
return None
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def main():
|
| 555 |
+
"""Main CLI demo function"""
|
| 556 |
+
parser = argparse.ArgumentParser(description="DeepDiver Multi-Agent System Demo")
|
| 557 |
+
parser.add_argument("--query", "-q", type=str, help="Query to execute (interactive mode if not provided)")
|
| 558 |
+
parser.add_argument("--config-only", "-c", action="store_true", help="Only show configuration and exit")
|
| 559 |
+
parser.add_argument("--create-env", "-e", action="store_true", help="Create sample .env file from template")
|
| 560 |
+
parser.add_argument("--debug", "-d", action="store_true", help="Enable debug mode with verbose logging")
|
| 561 |
+
parser.add_argument("--quiet", help="Suppress all non-essential output")
|
| 562 |
+
|
| 563 |
+
args = parser.parse_args()
|
| 564 |
+
|
| 565 |
+
# Setup logging based on arguments (re-configure if debug mode is requested)
|
| 566 |
+
if args.debug:
|
| 567 |
+
setup_clean_logging(debug_mode=True)
|
| 568 |
+
|
| 569 |
+
# Initialize visualizer
|
| 570 |
+
visualizer = DemoVisualizer(quiet_mode=args.quiet)
|
| 571 |
+
if not args.quiet:
|
| 572 |
+
visualizer.show_welcome()
|
| 573 |
+
|
| 574 |
+
# Create sample .env file if requested
|
| 575 |
+
if args.create_env:
|
| 576 |
+
create_sample_env_file()
|
| 577 |
+
return 0
|
| 578 |
+
|
| 579 |
+
# Load configuration
|
| 580 |
+
config = load_environment_config(quiet=args.quiet)
|
| 581 |
+
if not config:
|
| 582 |
+
if not args.quiet:
|
| 583 |
+
console.print("[red]❌ Cannot proceed without valid configuration[/red]")
|
| 584 |
+
console.print("[yellow]💡 Use --create-env to create a sample configuration file[/yellow]")
|
| 585 |
+
return 1
|
| 586 |
+
|
| 587 |
+
# Show configuration
|
| 588 |
+
visualizer.show_config(config)
|
| 589 |
+
|
| 590 |
+
if args.config_only:
|
| 591 |
+
return 0
|
| 592 |
+
|
| 593 |
+
# Initialize planner agent
|
| 594 |
+
try:
|
| 595 |
+
if not args.quiet:
|
| 596 |
+
console.print("[blue]🔄 Initializing PlannerAgent...[/blue]")
|
| 597 |
+
|
| 598 |
+
# Create planner with sub-agent configurations
|
| 599 |
+
sub_agent_configs = {
|
| 600 |
+
"information_seeker": {
|
| 601 |
+
"model": config.model_name,
|
| 602 |
+
"max_iterations": config.information_seeker_max_iterations or 30,
|
| 603 |
+
},
|
| 604 |
+
"writer": {
|
| 605 |
+
"model": config.model_name,
|
| 606 |
+
"max_iterations": config.writer_max_iterations or 30,
|
| 607 |
+
"temperature": config.model_temperature,
|
| 608 |
+
"max_tokens": config.model_max_tokens
|
| 609 |
+
}
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
planner = create_planner_agent(
|
| 613 |
+
model=config.model_name,
|
| 614 |
+
max_iterations=config.planner_max_iterations or 40,
|
| 615 |
+
sub_agent_configs=sub_agent_configs
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
if not args.quiet:
|
| 619 |
+
console.print("[green]✅ PlannerAgent initialized successfully[/green]")
|
| 620 |
+
|
| 621 |
+
except Exception as e:
|
| 622 |
+
if not args.quiet:
|
| 623 |
+
console.print(f"[red]❌ Failed to initialize PlannerAgent: {e}[/red]")
|
| 624 |
+
return 1
|
| 625 |
+
|
| 626 |
+
# Handle query execution
|
| 627 |
+
if args.query:
|
| 628 |
+
# Single query mode
|
| 629 |
+
run_demo_query(planner, args.query, visualizer, config)
|
| 630 |
+
else:
|
| 631 |
+
# Interactive mode
|
| 632 |
+
if not args.quiet:
|
| 633 |
+
console.print("\n[bold blue]🎯 Interactive Mode[/bold blue]")
|
| 634 |
+
console.print("Enter your queries below. Type 'quit' or 'exit' to leave.")
|
| 635 |
+
|
| 636 |
+
while True:
|
| 637 |
+
try:
|
| 638 |
+
prompt_text = "\n[bold cyan]Enter your query:[/bold cyan] " if not args.quiet else "Query: "
|
| 639 |
+
query = console.input(prompt_text).strip()
|
| 640 |
+
|
| 641 |
+
if query.lower() in ['quit', 'exit', 'q']:
|
| 642 |
+
if not args.quiet:
|
| 643 |
+
console.print("[green]👋 Goodbye![/green]")
|
| 644 |
+
break
|
| 645 |
+
|
| 646 |
+
if not query:
|
| 647 |
+
continue
|
| 648 |
+
|
| 649 |
+
if not args.quiet:
|
| 650 |
+
console.print("\n" + "="*80 + "\n")
|
| 651 |
+
run_demo_query(planner, query, visualizer, config)
|
| 652 |
+
if not args.quiet:
|
| 653 |
+
console.print("\n" + "="*80 + "\n")
|
| 654 |
+
|
| 655 |
+
except KeyboardInterrupt:
|
| 656 |
+
if not args.quiet:
|
| 657 |
+
console.print("\n[yellow]⚠️ Interrupted by user[/yellow]")
|
| 658 |
+
break
|
| 659 |
+
except EOFError:
|
| 660 |
+
if not args.quiet:
|
| 661 |
+
console.print("\n[green]👋 Goodbye![/green]")
|
| 662 |
+
break
|
| 663 |
+
|
| 664 |
+
return 0
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
if __name__ == "__main__":
|
| 668 |
+
sys.exit(main())
|
deepdiver_v2/cli/run_demo.sh
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# DeepDiver Multi-Agent System CLI Demo Runner
|
| 4 |
+
# This script makes it easier to run the CLI demo with different options
|
| 5 |
+
|
| 6 |
+
set -e
|
| 7 |
+
|
| 8 |
+
# Colors for output
|
| 9 |
+
RED='\033[0;31m'
|
| 10 |
+
GREEN='\033[0;32m'
|
| 11 |
+
YELLOW='\033[1;33m'
|
| 12 |
+
BLUE='\033[0;34m'
|
| 13 |
+
NC='\033[0m' # No Color
|
| 14 |
+
|
| 15 |
+
# Function to print colored output
|
| 16 |
+
print_status() {
|
| 17 |
+
echo -e "${GREEN}[INFO]${NC} $1"
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
print_warning() {
|
| 21 |
+
echo -e "${YELLOW}[WARNING]${NC} $1"
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
print_error() {
|
| 25 |
+
echo -e "${RED}[ERROR]${NC} $1"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
# Get script directory
|
| 29 |
+
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
| 30 |
+
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
|
| 31 |
+
|
| 32 |
+
# Function to show help
|
| 33 |
+
show_help() {
|
| 34 |
+
echo "DeepDiver Multi-Agent System CLI Demo Runner"
|
| 35 |
+
echo ""
|
| 36 |
+
echo "Usage: $0 [OPTIONS] [QUERY]"
|
| 37 |
+
echo ""
|
| 38 |
+
echo "Options:"
|
| 39 |
+
echo " -h, --help Show this help message"
|
| 40 |
+
echo " -i, --interactive Start interactive mode (default)"
|
| 41 |
+
echo " -c, --config-only Show configuration and exit"
|
| 42 |
+
echo " -e, --create-env Create sample .env file from template"
|
| 43 |
+
echo " -q, --query \"QUERY\" Execute a specific query"
|
| 44 |
+
echo " -d, --debug Enable debug mode with verbose logging"
|
| 45 |
+
echo " --quiet Suppress all non-essential output"
|
| 46 |
+
echo " --setup Install dependencies and setup"
|
| 47 |
+
echo ""
|
| 48 |
+
echo "Examples:"
|
| 49 |
+
echo " $0 --interactive"
|
| 50 |
+
echo " $0 --query \"Research the latest trends in AI\""
|
| 51 |
+
echo " $0 --config-only"
|
| 52 |
+
echo " $0 --debug --query \"Debug a specific query\""
|
| 53 |
+
echo " $0 --quiet --query \"Run quietly\""
|
| 54 |
+
echo " $0 --setup"
|
| 55 |
+
echo ""
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# Function to setup the demo
|
| 59 |
+
setup_demo() {
|
| 60 |
+
print_status "Setting up DeepDiver CLI Demo..."
|
| 61 |
+
|
| 62 |
+
# Check if we're in the right directory
|
| 63 |
+
if [ ! -f "$PROJECT_ROOT/cli/demo.py" ]; then
|
| 64 |
+
print_error "Cannot find demo.py. Please run this script from the CLI directory or project root."
|
| 65 |
+
exit 1
|
| 66 |
+
fi
|
| 67 |
+
|
| 68 |
+
# Install dependencies
|
| 69 |
+
print_status "Installing Python dependencies..."
|
| 70 |
+
cd "$PROJECT_ROOT"
|
| 71 |
+
|
| 72 |
+
if [ -f "cli/requirements.txt" ]; then
|
| 73 |
+
pip install -r cli/requirements.txt
|
| 74 |
+
print_status "Dependencies installed successfully"
|
| 75 |
+
else
|
| 76 |
+
print_warning "requirements.txt not found, skipping dependency installation"
|
| 77 |
+
fi
|
| 78 |
+
|
| 79 |
+
# Check for .env file
|
| 80 |
+
if [ ! -f "config/.env" ]; then
|
| 81 |
+
print_warning "No .env file found in config/ directory"
|
| 82 |
+
print_status "Creating sample .env file from template..."
|
| 83 |
+
|
| 84 |
+
if [ -f "env.template" ]; then
|
| 85 |
+
cp env.template config/.env
|
| 86 |
+
print_status "Sample .env file created at config/.env"
|
| 87 |
+
print_warning "Please edit config/.env with your actual configuration values"
|
| 88 |
+
else
|
| 89 |
+
print_error "No env.template found. Please create config/.env manually"
|
| 90 |
+
fi
|
| 91 |
+
else
|
| 92 |
+
print_status ".env file found at config/.env"
|
| 93 |
+
fi
|
| 94 |
+
|
| 95 |
+
# Make demo script executable
|
| 96 |
+
chmod +x "$PROJECT_ROOT/cli/demo.py"
|
| 97 |
+
print_status "Made demo.py executable"
|
| 98 |
+
|
| 99 |
+
print_status "Setup complete! You can now run the demo with:"
|
| 100 |
+
echo " $0 --interactive"
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# Function to run the demo
|
| 104 |
+
run_demo() {
|
| 105 |
+
local args=("$@")
|
| 106 |
+
|
| 107 |
+
# Change to project root
|
| 108 |
+
cd "$PROJECT_ROOT"
|
| 109 |
+
|
| 110 |
+
print_status "Starting DeepDiver CLI Demo..."
|
| 111 |
+
python cli/demo.py "${args[@]}"
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
# Parse command line arguments
|
| 115 |
+
DEMO_ARGS=()
|
| 116 |
+
|
| 117 |
+
while [[ $# -gt 0 ]]; do
|
| 118 |
+
case $1 in
|
| 119 |
+
-h|--help)
|
| 120 |
+
show_help
|
| 121 |
+
exit 0
|
| 122 |
+
;;
|
| 123 |
+
--setup)
|
| 124 |
+
setup_demo
|
| 125 |
+
exit 0
|
| 126 |
+
;;
|
| 127 |
+
-c|--config-only)
|
| 128 |
+
DEMO_ARGS+=("--config-only")
|
| 129 |
+
shift
|
| 130 |
+
;;
|
| 131 |
+
-e|--create-env)
|
| 132 |
+
DEMO_ARGS+=("--create-env")
|
| 133 |
+
shift
|
| 134 |
+
;;
|
| 135 |
+
-q|--query)
|
| 136 |
+
if [ -z "${2:-}" ]; then
|
| 137 |
+
print_error "Query argument is required with --query option"
|
| 138 |
+
show_help
|
| 139 |
+
exit 1
|
| 140 |
+
fi
|
| 141 |
+
DEMO_ARGS+=("--query" "$2")
|
| 142 |
+
shift 2
|
| 143 |
+
;;
|
| 144 |
+
-d|--debug)
|
| 145 |
+
DEMO_ARGS+=("--debug")
|
| 146 |
+
shift
|
| 147 |
+
;;
|
| 148 |
+
--quiet)
|
| 149 |
+
DEMO_ARGS+=("--quiet")
|
| 150 |
+
shift
|
| 151 |
+
;;
|
| 152 |
+
-i|--interactive)
|
| 153 |
+
# Interactive is default, no need to add args
|
| 154 |
+
shift
|
| 155 |
+
;;
|
| 156 |
+
*)
|
| 157 |
+
# If it's not a flag, treat it as a query
|
| 158 |
+
if [[ "$1" != -* ]]; then
|
| 159 |
+
DEMO_ARGS+=("--query" "$1")
|
| 160 |
+
shift
|
| 161 |
+
else
|
| 162 |
+
print_error "Unknown option: $1"
|
| 163 |
+
show_help
|
| 164 |
+
exit 1
|
| 165 |
+
fi
|
| 166 |
+
;;
|
| 167 |
+
esac
|
| 168 |
+
done
|
| 169 |
+
|
| 170 |
+
# Run the demo with collected arguments
|
| 171 |
+
run_demo "${DEMO_ARGS[@]}"
|
deepdiver_v2/config/config.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional, Dict, Any
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
import logging
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
logging.basicConfig(level=logging.INFO)
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class APIConfig:
|
| 18 |
+
"""Configuration class for API keys and settings"""
|
| 19 |
+
|
| 20 |
+
# Custom LLM Service Configuration
|
| 21 |
+
# Your own deployed LLM service accessed via requests
|
| 22 |
+
model_request_url: Optional[str] = None
|
| 23 |
+
model_request_token: Optional[str] = None
|
| 24 |
+
model_name: str = "pangu_auto" # Default model name
|
| 25 |
+
|
| 26 |
+
# Custom Planner Mode
|
| 27 |
+
planner_mode: str = "auto" # Default planner mode
|
| 28 |
+
|
| 29 |
+
# MCP Server Configuration
|
| 30 |
+
mcp_server_url: Optional[str] = None
|
| 31 |
+
mcp_auth_token: Optional[str] = None
|
| 32 |
+
mcp_use_stdio: bool = True # Default to stdio for backward compatibility
|
| 33 |
+
|
| 34 |
+
# Search Engine Configuration (Generic)
|
| 35 |
+
search_engine_base_url: Optional[str] = None
|
| 36 |
+
search_engine_api_keys: Optional[str] = None # Can be comma-separated for rotation
|
| 37 |
+
|
| 38 |
+
# URL Crawler Configuration (Generic)
|
| 39 |
+
url_crawler_base_url: Optional[str] = None
|
| 40 |
+
url_crawler_api_keys: Optional[str] = None # Can be comma-separated for rotation
|
| 41 |
+
url_crawler_max_tokens: int = 100000
|
| 42 |
+
|
| 43 |
+
# Model Interaction Configuration
|
| 44 |
+
model_temperature: float = 0.3
|
| 45 |
+
model_max_tokens: int = 8192
|
| 46 |
+
model_request_timeout: int = 180
|
| 47 |
+
|
| 48 |
+
# Tool Trajectory and Output Configuration
|
| 49 |
+
trajectory_storage_path: str = "./workspace"
|
| 50 |
+
report_output_path: str = "./report"
|
| 51 |
+
document_analysis_path: str = "./doc_analysis"
|
| 52 |
+
|
| 53 |
+
# Per-agent iteration controls (optional; resolved by agent factories)
|
| 54 |
+
planner_max_iterations: Optional[int] = None
|
| 55 |
+
information_seeker_max_iterations: Optional[int] = None
|
| 56 |
+
writer_max_iterations: Optional[int] = None
|
| 57 |
+
|
| 58 |
+
# General Settings
|
| 59 |
+
debug_mode: bool = False
|
| 60 |
+
max_retries: int = 3
|
| 61 |
+
timeout: int = 30
|
| 62 |
+
|
| 63 |
+
def __post_init__(self):
|
| 64 |
+
"""Load configuration from environment variables"""
|
| 65 |
+
self.load_from_env()
|
| 66 |
+
|
| 67 |
+
def load_from_env(self):
|
| 68 |
+
"""Load API keys and settings from environment variables"""
|
| 69 |
+
# Custom LLM Service
|
| 70 |
+
self.model_request_url = os.getenv('MODEL_REQUEST_URL')
|
| 71 |
+
self.model_request_token = os.getenv('MODEL_REQUEST_TOKEN')
|
| 72 |
+
self.model_name = os.getenv('MODEL_NAME', 'pangu-auto')
|
| 73 |
+
|
| 74 |
+
# Custom Planner Mode
|
| 75 |
+
self.planner_mode = os.getenv("PLANNER_MODE", self.planner_mode)
|
| 76 |
+
|
| 77 |
+
# MCP Server
|
| 78 |
+
self.mcp_server_url = os.getenv("MCP_SERVER_URL")
|
| 79 |
+
self.mcp_auth_token = os.getenv("MCP_AUTH_TOKEN")
|
| 80 |
+
self.mcp_use_stdio = os.getenv("MCP_USE_STDIO", "true").lower() == "true"
|
| 81 |
+
|
| 82 |
+
# Search Engine Configuration
|
| 83 |
+
self.search_engine_base_url = os.getenv("SEARCH_ENGINE_BASE_URL")
|
| 84 |
+
self.search_engine_api_keys = os.getenv("SEARCH_ENGINE_API_KEYS")
|
| 85 |
+
|
| 86 |
+
# URL Crawler Configuration
|
| 87 |
+
self.url_crawler_base_url = os.getenv("URL_CRAWLER_BASE_URL")
|
| 88 |
+
self.url_crawler_api_keys = os.getenv("URL_CRAWLER_API_KEYS")
|
| 89 |
+
self.url_crawler_max_tokens = int(os.getenv("URL_CRAWLER_MAX_TOKENS", self.url_crawler_max_tokens))
|
| 90 |
+
|
| 91 |
+
# Model Interaction Configuration
|
| 92 |
+
self.model_temperature = float(os.getenv("MODEL_TEMPERATURE", self.model_temperature))
|
| 93 |
+
self.model_max_tokens = int(os.getenv("MODEL_MAX_TOKENS", self.model_max_tokens))
|
| 94 |
+
self.model_request_timeout = int(os.getenv("MODEL_REQUEST_TIMEOUT", self.model_request_timeout))
|
| 95 |
+
|
| 96 |
+
# Tool Trajectory and Output Configuration
|
| 97 |
+
self.trajectory_storage_path = os.getenv("TRAJECTORY_STORAGE_PATH", self.trajectory_storage_path)
|
| 98 |
+
self.report_output_path = os.getenv("REPORT_OUTPUT_PATH", self.report_output_path)
|
| 99 |
+
self.document_analysis_path = os.getenv("DOCUMENT_ANALYSIS_PATH", self.document_analysis_path)
|
| 100 |
+
|
| 101 |
+
# Per-agent iteration controls
|
| 102 |
+
self.planner_max_iterations = (
|
| 103 |
+
int(os.getenv("PLANNER_MAX_ITERATION")) if os.getenv("PLANNER_MAX_ITERATION") else None
|
| 104 |
+
)
|
| 105 |
+
self.information_seeker_max_iterations = (
|
| 106 |
+
int(os.getenv("INFORMATION_SEEKER_MAX_ITERATION")) if os.getenv("INFORMATION_SEEKER_MAX_ITERATION") else None
|
| 107 |
+
)
|
| 108 |
+
self.writer_max_iterations = (
|
| 109 |
+
int(os.getenv("WRITER_MAX_ITERATION")) if os.getenv("WRITER_MAX_ITERATION") else None
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# General Settings
|
| 113 |
+
self.debug_mode = os.getenv("DEBUG_MODE", "false").lower() == "true"
|
| 114 |
+
self.max_retries = int(os.getenv("MAX_RETRIES", self.max_retries))
|
| 115 |
+
self.timeout = int(os.getenv("TIMEOUT", self.timeout))
|
| 116 |
+
|
| 117 |
+
def get_custom_llm_config(self) -> Dict[str, Any]:
|
| 118 |
+
"""Get configuration for custom LLM service"""
|
| 119 |
+
return {
|
| 120 |
+
"url": self.model_request_url,
|
| 121 |
+
"token": self.model_request_token,
|
| 122 |
+
"model": self.model_name,
|
| 123 |
+
"temperature": self.model_temperature,
|
| 124 |
+
"max_tokens": self.model_max_tokens,
|
| 125 |
+
"timeout": self.model_request_timeout,
|
| 126 |
+
"base_url": self.model_request_url # For backward compatibility with model_config.get('base_url')
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
def get_available_search_providers(self) -> list:
|
| 130 |
+
"""Get list of available search providers based on API keys"""
|
| 131 |
+
providers = []
|
| 132 |
+
if self.search_engine_api_keys:
|
| 133 |
+
providers.append("custom")
|
| 134 |
+
return providers
|
| 135 |
+
|
| 136 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 137 |
+
"""Convert config to dictionary (excluding sensitive data)"""
|
| 138 |
+
config_dict = {}
|
| 139 |
+
for key, value in self.__dict__.items():
|
| 140 |
+
if "api_key" in key.lower() or "password" in key.lower():
|
| 141 |
+
config_dict[key] = "***" if value else None
|
| 142 |
+
else:
|
| 143 |
+
config_dict[key] = value
|
| 144 |
+
return config_dict
|
| 145 |
+
|
| 146 |
+
# Global configuration instance
|
| 147 |
+
config = APIConfig()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def get_config() -> APIConfig:
|
| 151 |
+
"""Get the global configuration instance"""
|
| 152 |
+
return config
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def reload_config():
|
| 156 |
+
"""Reload configuration from environment variables"""
|
| 157 |
+
global config
|
| 158 |
+
config = APIConfig()
|
| 159 |
+
logger.info("Configuration reloaded")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def validate_api_key(api_key: Optional[str], service_name: str) -> bool:
|
| 163 |
+
"""Validate that an API key is present and not empty"""
|
| 164 |
+
if not api_key or api_key.strip() == "":
|
| 165 |
+
logger.error(f"Missing or empty API key for {service_name}")
|
| 166 |
+
return False
|
| 167 |
+
return True
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def get_url_crawler_config() -> Dict[str, Any]:
|
| 171 |
+
"""Get generic URL crawler configuration"""
|
| 172 |
+
api_keys = config.url_crawler_api_keys
|
| 173 |
+
base_url = config.url_crawler_base_url
|
| 174 |
+
|
| 175 |
+
if not api_keys:
|
| 176 |
+
return {}
|
| 177 |
+
|
| 178 |
+
# Parse comma-separated API keys for rotation
|
| 179 |
+
api_key_list = [key.strip() for key in api_keys.split(",")] if isinstance(api_keys, str) else [api_keys]
|
| 180 |
+
|
| 181 |
+
return {
|
| 182 |
+
"api_keys": api_key_list,
|
| 183 |
+
"base_url": base_url,
|
| 184 |
+
"max_tokens": config.url_crawler_max_tokens,
|
| 185 |
+
"timeout": config.timeout
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def get_search_engine_config() -> Dict[str, Any]:
|
| 190 |
+
"""Get generic search engine configuration"""
|
| 191 |
+
api_keys = config.search_engine_api_keys
|
| 192 |
+
base_url = config.search_engine_base_url
|
| 193 |
+
|
| 194 |
+
if not api_keys:
|
| 195 |
+
return {}
|
| 196 |
+
|
| 197 |
+
# Parse comma-separated API keys for rotation
|
| 198 |
+
api_key_list = [key.strip() for key in api_keys.split(",")] if isinstance(api_keys, str) else [api_keys]
|
| 199 |
+
|
| 200 |
+
return {
|
| 201 |
+
"api_keys": api_key_list,
|
| 202 |
+
"base_url": base_url,
|
| 203 |
+
"timeout": config.timeout
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_model_config() -> Dict[str, Any]:
|
| 208 |
+
"""Get model interaction configuration for custom LLM service"""
|
| 209 |
+
return config.get_custom_llm_config()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def get_storage_config() -> Dict[str, Any]:
|
| 213 |
+
"""Get storage and trajectory configuration"""
|
| 214 |
+
return {
|
| 215 |
+
"trajectory_storage_path": config.trajectory_storage_path,
|
| 216 |
+
"report_output_path": config.report_output_path,
|
| 217 |
+
"document_analysis_path": config.document_analysis_path
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def get_mcp_config() -> Dict[str, Any]:
|
| 222 |
+
"""Get MCP server specific configuration"""
|
| 223 |
+
return {
|
| 224 |
+
"server_url": config.mcp_server_url,
|
| 225 |
+
"auth_token": config.mcp_auth_token,
|
| 226 |
+
"use_stdio": config.mcp_use_stdio,
|
| 227 |
+
"timeout": config.timeout
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# Example usage and testing
|
| 232 |
+
if __name__ == "__main__":
|
| 233 |
+
print("=== Multi Agent System Configuration ===")
|
| 234 |
+
print(f"Debug Mode: {config.debug_mode}")
|
| 235 |
+
print(f"Custom LLM Service URL: {config.model_request_url}")
|
| 236 |
+
print(f"Available Search Providers: {config.get_available_search_providers()}")
|
| 237 |
+
print("\nConfiguration Summary:")
|
| 238 |
+
for key, value in config.to_dict().items():
|
| 239 |
+
print(f" {key}: {value}")
|
deepdiver_v2/env.template
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ===================================
|
| 2 |
+
# DeepDiver Configuration Template
|
| 3 |
+
# ===================================
|
| 4 |
+
# Copy this file to .env and fill in your values
|
| 5 |
+
|
| 6 |
+
# Custom LLM Service Configuration
|
| 7 |
+
# Your own deployed LLM service endpoint
|
| 8 |
+
|
| 9 |
+
MODEL_REQUEST_URL=
|
| 10 |
+
MODEL_REQUEST_TOKEN=
|
| 11 |
+
MODEL_NAME=pangu_auto
|
| 12 |
+
PLANNER_MAX_ITERATION=40
|
| 13 |
+
INFORMATION_SEEKER_MAX_ITERATION=30
|
| 14 |
+
WRITER_MAX_ITERATION=40
|
| 15 |
+
PLANNER_MODE=auto # auto, writing, qa
|
| 16 |
+
|
| 17 |
+
# Model Interaction Settings
|
| 18 |
+
MODEL_TEMPERATURE=0.6
|
| 19 |
+
MODEL_MAX_TOKENS=8192
|
| 20 |
+
MODEL_REQUEST_TIMEOUT=180
|
| 21 |
+
|
| 22 |
+
# MCP Server Configuration
|
| 23 |
+
MCP_SERVER_URL=http://localhost:6274/mcp
|
| 24 |
+
MCP_AUTH_TOKEN=
|
| 25 |
+
MCP_USE_STDIO=false
|
| 26 |
+
|
| 27 |
+
# Search Engine Configuration
|
| 28 |
+
SEARCH_ENGINE_BASE_URL=
|
| 29 |
+
SEARCH_ENGINE_API_KEYS=
|
| 30 |
+
|
| 31 |
+
# URL Crawler Configuration
|
| 32 |
+
URL_CRAWLER_BASE_URL=
|
| 33 |
+
URL_CRAWLER_API_KEYS=
|
| 34 |
+
URL_CRAWLER_MAX_TOKENS=100000
|
| 35 |
+
|
| 36 |
+
# Tool Trajectory and Output Paths
|
| 37 |
+
TRAJECTORY_STORAGE_PATH=./workspace
|
| 38 |
+
REPORT_OUTPUT_PATH=./report
|
| 39 |
+
DOCUMENT_ANALYSIS_PATH=./doc_analysis
|
| 40 |
+
|
| 41 |
+
# General Settings
|
| 42 |
+
DEBUG_MODE=false
|
| 43 |
+
MAX_RETRIES=3
|
| 44 |
+
TIMEOUT=30
|
deepdiver_v2/requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
beautifulsoup4==4.13.5
|
| 2 |
+
httpx[http2]==0.28.1
|
| 3 |
+
python-dotenv==1.1.1
|
| 4 |
+
python_dateutil==2.9.0.post0
|
| 5 |
+
Requests==2.32.5
|
| 6 |
+
rich==14.1.0
|
| 7 |
+
starlette==0.47.3
|
| 8 |
+
uvicorn==0.35.0
|
deepdiver_v2/src/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
"""
|
| 3 |
+
DeepDiver Multi-Agent System
|
| 4 |
+
|
| 5 |
+
A comprehensive multi-agent system with MCP integration, local workspace management,
|
| 6 |
+
and advanced knowledge management capabilities.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
__version__ = "2.0.0"
|
| 10 |
+
__author__ = "DeepDiver Team"
|
| 11 |
+
__description__ = "Multi-Agent System with MCP and Local Workspace Integration"
|
deepdiver_v2/src/agents/__init__.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
"""
|
| 3 |
+
Multi-Agent System - Agent Module
|
| 4 |
+
|
| 5 |
+
This module provides the core agents for the multi-agent system:
|
| 6 |
+
- BaseAgent: Abstract base class with common functionality
|
| 7 |
+
- InformationSeekerAgent: Research and information gathering
|
| 8 |
+
- WriterAgent: Content creation and writing
|
| 9 |
+
- PlannerAgent: Top-level orchestrator
|
| 10 |
+
|
| 11 |
+
All agents follow the ReAct pattern and use standardized TaskInput format.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from .base_agent import (
|
| 15 |
+
BaseAgent,
|
| 16 |
+
AgentConfig,
|
| 17 |
+
AgentResponse,
|
| 18 |
+
TaskInput,
|
| 19 |
+
create_agent_config
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from .subjective_information_seeker import (
|
| 23 |
+
InformationSeekerAgent,
|
| 24 |
+
create_subjective_information_seeker
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
from .objective_information_seeker import (
|
| 28 |
+
InformationSeekerAgent,
|
| 29 |
+
create_objective_information_seeker
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
from .writer_agent import (
|
| 33 |
+
WriterAgent,
|
| 34 |
+
create_writer_agent
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
from .planner_agent import (
|
| 38 |
+
PlannerAgent,
|
| 39 |
+
create_planner_agent
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
__all__ = [
|
| 43 |
+
# Base classes
|
| 44 |
+
"BaseAgent",
|
| 45 |
+
"AgentConfig",
|
| 46 |
+
"AgentResponse",
|
| 47 |
+
"TaskInput",
|
| 48 |
+
"create_agent_config",
|
| 49 |
+
|
| 50 |
+
# Specific agents
|
| 51 |
+
"InformationSeekerAgent",
|
| 52 |
+
"create_subjective_information_seeker",
|
| 53 |
+
"create_objective_information_seeker",
|
| 54 |
+
"WriterAgent",
|
| 55 |
+
"create_writer_agent",
|
| 56 |
+
"PlannerAgent",
|
| 57 |
+
"create_planner_agent"
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
# Version info
|
| 61 |
+
__version__ = "0.1.0"
|
| 62 |
+
__author__ = "DeepDiver Multi-Agent System"
|
deepdiver_v2/src/agents/base_agent.py
ADDED
|
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import Dict, Any, List, Optional
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
|
| 8 |
+
# Import MCP client availability flag without binding unused symbols
|
| 9 |
+
try:
|
| 10 |
+
from ..tools import mcp_client as _mcp_client_module # noqa: F401
|
| 11 |
+
MCP_CLIENT_AVAILABLE = True
|
| 12 |
+
except ImportError:
|
| 13 |
+
MCP_CLIENT_AVAILABLE = False
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class AgentConfig:
|
| 18 |
+
"""Configuration for agents - session management handled entirely by MCP server"""
|
| 19 |
+
agent_name: str = "base_agent"
|
| 20 |
+
planner_mode: str = "auto"
|
| 21 |
+
model: Optional[str] = None
|
| 22 |
+
max_iterations: int = 10
|
| 23 |
+
temperature: Optional[float] = None
|
| 24 |
+
max_tokens: Optional[int] = None
|
| 25 |
+
# Paths used by writer and other agents
|
| 26 |
+
trajectory_storage_path: Optional[str] = None
|
| 27 |
+
report_output_path: Optional[str] = None
|
| 28 |
+
document_analysis_path: Optional[str] = None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class AgentResponse:
|
| 33 |
+
"""Standardized response format for all agents"""
|
| 34 |
+
success: bool
|
| 35 |
+
result: Optional[Dict[str, Any]] = None
|
| 36 |
+
error: Optional[str] = None
|
| 37 |
+
iterations: int = 0
|
| 38 |
+
reasoning_trace: List[Dict[str, Any]] = field(default_factory=list)
|
| 39 |
+
agent_name: str = ""
|
| 40 |
+
execution_time: float = 0.0
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class TaskInput:
|
| 45 |
+
"""Standardized task input format for all agents"""
|
| 46 |
+
task_content: str # The specific task content
|
| 47 |
+
task_steps_for_reference: Optional[str] = None # Reference steps for execution
|
| 48 |
+
deliverable_contents: Optional[str] = None # Format of final deliverable
|
| 49 |
+
current_task_status: Optional[str] = None # Description of current task status
|
| 50 |
+
task_executor: str = "info_seeker" # Name of task executor (info_seeker, writer)
|
| 51 |
+
workspace_id: Optional[str] = None # Workspace ID for stored files and memory
|
| 52 |
+
acceptance_checking_criteria: Optional[str] = None # Criteria for determining task completion and quality
|
| 53 |
+
|
| 54 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 55 |
+
"""Convert TaskInput to dictionary format"""
|
| 56 |
+
return {
|
| 57 |
+
"task_content": self.task_content,
|
| 58 |
+
"task_steps_for_reference": self.task_steps_for_reference,
|
| 59 |
+
"deliverable_contents": self.deliverable_contents,
|
| 60 |
+
"current_task_status": self.current_task_status,
|
| 61 |
+
"task_executor": self.task_executor,
|
| 62 |
+
"workspace_id": self.workspace_id,
|
| 63 |
+
"acceptance_checking_criteria": self.acceptance_checking_criteria
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def from_dict(cls, data: Dict[str, Any]) -> 'TaskInput':
|
| 68 |
+
"""Create TaskInput from dictionary"""
|
| 69 |
+
return cls(
|
| 70 |
+
task_content=data.get("task_content", ""),
|
| 71 |
+
task_steps_for_reference=data.get("task_steps_for_reference"),
|
| 72 |
+
deliverable_contents=data.get("deliverable_contents"),
|
| 73 |
+
current_task_status=data.get("current_task_status"),
|
| 74 |
+
task_executor=data.get("task_executor", "info_seeker"),
|
| 75 |
+
workspace_id=data.get("workspace_id"),
|
| 76 |
+
acceptance_checking_criteria=data.get("acceptance_checking_criteria")
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def format_for_prompt(self) -> str:
|
| 80 |
+
"""Format the task input for use in prompts"""
|
| 81 |
+
prompt = f"Task Content:\n{self.task_content}\n\n"
|
| 82 |
+
|
| 83 |
+
if self.task_steps_for_reference:
|
| 84 |
+
prompt += f"Task Steps for Reference:\n{self.task_steps_for_reference}\n\n"
|
| 85 |
+
|
| 86 |
+
if self.deliverable_contents:
|
| 87 |
+
prompt += f"Deliverable Contents:\n{self.deliverable_contents}\n\n"
|
| 88 |
+
|
| 89 |
+
if self.current_task_status:
|
| 90 |
+
prompt += f"Current Task Status:\n{self.current_task_status}\n\n"
|
| 91 |
+
|
| 92 |
+
if self.acceptance_checking_criteria:
|
| 93 |
+
prompt += f"Acceptance Checking Criteria:\n{self.acceptance_checking_criteria}\n\n"
|
| 94 |
+
|
| 95 |
+
prompt += f"Task Executor: {self.task_executor}\n"
|
| 96 |
+
|
| 97 |
+
if self.workspace_id:
|
| 98 |
+
prompt += f"Workspace ID: {self.workspace_id}\n"
|
| 99 |
+
|
| 100 |
+
return prompt
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class SectionWriterTaskInput(TaskInput):
|
| 104 |
+
"""
|
| 105 |
+
Specialized TaskInput for section writing tasks
|
| 106 |
+
|
| 107 |
+
Only stores the essential parameters. The section_writer agent
|
| 108 |
+
will handle prompt assembly internally.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
task_content: str,
|
| 114 |
+
user_query: str,
|
| 115 |
+
write_file_path: str,
|
| 116 |
+
overall_outline: str,
|
| 117 |
+
current_chapter_outline: str,
|
| 118 |
+
key_files: List[Dict[str, Any]],
|
| 119 |
+
written_chapters: str = "",
|
| 120 |
+
workspace_id: Optional[str] = None
|
| 121 |
+
):
|
| 122 |
+
# Store the section writer specific parameters
|
| 123 |
+
self.write_file_path = write_file_path
|
| 124 |
+
self.user_query = user_query
|
| 125 |
+
self.current_chapter_outline = current_chapter_outline
|
| 126 |
+
self.key_files = key_files
|
| 127 |
+
self.written_chapters = written_chapters
|
| 128 |
+
self.overall_outline = overall_outline
|
| 129 |
+
|
| 130 |
+
# Initialize parent TaskInput with minimal required fields
|
| 131 |
+
super().__init__(
|
| 132 |
+
task_content=task_content,
|
| 133 |
+
task_executor="section_writer",
|
| 134 |
+
workspace_id=workspace_id,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class WriterAgentTaskInput(TaskInput):
|
| 139 |
+
"""
|
| 140 |
+
Specialized TaskInput for section writing tasks
|
| 141 |
+
|
| 142 |
+
Only stores the 4 essential parameters. The section_writer agent
|
| 143 |
+
will handle prompt assembly internally.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
task_content: str,
|
| 149 |
+
user_query: str,
|
| 150 |
+
key_files: List[Dict[str, Any]],
|
| 151 |
+
workspace_id: Optional[str] = None
|
| 152 |
+
):
|
| 153 |
+
# Store the section writer specific parameters
|
| 154 |
+
self.user_query = user_query
|
| 155 |
+
self.key_files = key_files
|
| 156 |
+
|
| 157 |
+
# Initialize parent TaskInput with minimal required fields
|
| 158 |
+
super().__init__(
|
| 159 |
+
task_content=task_content,
|
| 160 |
+
task_executor="writer_agent",
|
| 161 |
+
workspace_id=workspace_id,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class BaseAgent(ABC):
|
| 166 |
+
"""
|
| 167 |
+
Base class for all agents with MCP server-managed sessions.
|
| 168 |
+
|
| 169 |
+
Session management is now entirely handled by the MCP server:
|
| 170 |
+
- Server assigns session IDs on connection
|
| 171 |
+
- Server creates workspace folders with UUID names
|
| 172 |
+
- All tool operations are performed in server-managed workspaces
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self, config: AgentConfig, shared_mcp_client=None):
|
| 176 |
+
self.execution_stats = None
|
| 177 |
+
self.reasoning_trace = None
|
| 178 |
+
self.config = config
|
| 179 |
+
self.logger = logging.getLogger(f"{__name__}.{config.agent_name}")
|
| 180 |
+
|
| 181 |
+
# Session info is populated by the MCP server
|
| 182 |
+
self.session_info = None
|
| 183 |
+
|
| 184 |
+
# Tool management
|
| 185 |
+
self.mcp_tools = None
|
| 186 |
+
self.available_tools = {}
|
| 187 |
+
|
| 188 |
+
self.reset_trace()
|
| 189 |
+
|
| 190 |
+
# Initialize MCP tools (server will handle session creation or use shared client)
|
| 191 |
+
self._initialize(shared_mcp_client)
|
| 192 |
+
|
| 193 |
+
def _initialize(self, shared_mcp_client=None):
|
| 194 |
+
"""Initialize agent with MCP server connection or shared client"""
|
| 195 |
+
try:
|
| 196 |
+
self.logger.info(f"Initializing agent {self.config.agent_name}")
|
| 197 |
+
|
| 198 |
+
if shared_mcp_client:
|
| 199 |
+
# Use shared MCP client with agent-specific tool filtering
|
| 200 |
+
agent_type = self._get_agent_type()
|
| 201 |
+
self.mcp_tools = self._create_filtered_mcp_tools(shared_mcp_client, agent_type)
|
| 202 |
+
self.logger.info(f"Agent {self.config.agent_name} using shared MCP client with {agent_type} tools")
|
| 203 |
+
else:
|
| 204 |
+
# Create MCP tools with agent-specific filtering (no more unfiltered access)
|
| 205 |
+
self.mcp_tools = self._create_filtered_mcp_tools_standalone()
|
| 206 |
+
|
| 207 |
+
# Discover available tools
|
| 208 |
+
self.available_tools = self._discover_mcp_tools()
|
| 209 |
+
|
| 210 |
+
# Build tool schemas for function calling
|
| 211 |
+
self.tool_schemas = self._build_tool_schemas()
|
| 212 |
+
|
| 213 |
+
self.logger.info(f"Agent {self.config.agent_name} initialized successfully")
|
| 214 |
+
self.logger.info(f"Available tools: {list(self.available_tools.keys())}")
|
| 215 |
+
|
| 216 |
+
except Exception as e:
|
| 217 |
+
self.logger.error(f"Failed to initialize agent {self.config.agent_name}: {e}")
|
| 218 |
+
raise
|
| 219 |
+
|
| 220 |
+
def _discover_mcp_tools(self) -> Dict[str, Any]:
|
| 221 |
+
"""Discover available tools from MCP server or fallback tools"""
|
| 222 |
+
available_tools = {}
|
| 223 |
+
|
| 224 |
+
# Try to get tools from MCP client first
|
| 225 |
+
if hasattr(self.mcp_tools, 'get_available_tools'):
|
| 226 |
+
try:
|
| 227 |
+
mcp_tools_dict = self.mcp_tools.get_available_tools()
|
| 228 |
+
for tool_name, tool_info in mcp_tools_dict.items():
|
| 229 |
+
# For proper MCP architecture, store tool info for direct client calls
|
| 230 |
+
# instead of creating wrapper lambda functions
|
| 231 |
+
available_tools[tool_name] = tool_info
|
| 232 |
+
|
| 233 |
+
if available_tools:
|
| 234 |
+
self.logger.info(f"Discovered {len(available_tools)} tools from MCP server")
|
| 235 |
+
return available_tools
|
| 236 |
+
except Exception as e:
|
| 237 |
+
self.logger.warning(f"Failed to discover MCP tools: {e}")
|
| 238 |
+
|
| 239 |
+
# Fallback: if MCP client not available, use direct method access
|
| 240 |
+
# This should rarely be needed with proper MCP setup
|
| 241 |
+
if hasattr(self.mcp_tools, '__dict__'):
|
| 242 |
+
for attr_name in dir(self.mcp_tools):
|
| 243 |
+
if not attr_name.startswith('_') and callable(getattr(self.mcp_tools, attr_name)):
|
| 244 |
+
available_tools[attr_name] = getattr(self.mcp_tools, attr_name)
|
| 245 |
+
|
| 246 |
+
return available_tools
|
| 247 |
+
|
| 248 |
+
def _get_agent_type(self) -> str:
|
| 249 |
+
"""Get agent type for tool filtering"""
|
| 250 |
+
agent_name = self.config.agent_name.lower()
|
| 251 |
+
if "planner" in agent_name:
|
| 252 |
+
return "planner"
|
| 253 |
+
elif "information" in agent_name or "seeker" in agent_name:
|
| 254 |
+
return "information_seeker"
|
| 255 |
+
elif "writer" in agent_name:
|
| 256 |
+
return "writer"
|
| 257 |
+
else:
|
| 258 |
+
# Default to planner tools for unknown agent types
|
| 259 |
+
return "planner"
|
| 260 |
+
|
| 261 |
+
def _create_filtered_mcp_tools(self, shared_client, agent_type: str):
|
| 262 |
+
"""Create filtered MCP tools adapter using shared client"""
|
| 263 |
+
try:
|
| 264 |
+
from src.tools.mcp_client import create_filtered_mcp_tools_adapter
|
| 265 |
+
return create_filtered_mcp_tools_adapter(shared_client, agent_type)
|
| 266 |
+
except ImportError:
|
| 267 |
+
# Fallback if FilteredMCPToolsAdapter not available
|
| 268 |
+
self.logger.warning("FilteredMCPToolsAdapter not available, using regular adapter")
|
| 269 |
+
from src.tools.mcp_client import MCPToolsAdapter
|
| 270 |
+
adapter = MCPToolsAdapter.__new__(MCPToolsAdapter)
|
| 271 |
+
adapter.client = shared_client
|
| 272 |
+
return adapter
|
| 273 |
+
|
| 274 |
+
def _create_filtered_mcp_tools_standalone(self):
|
| 275 |
+
"""Create filtered MCP tools adapter with its own client connection"""
|
| 276 |
+
try:
|
| 277 |
+
# Get agent type for filtering
|
| 278 |
+
agent_type = self._get_agent_type()
|
| 279 |
+
|
| 280 |
+
# Create a new MCP client
|
| 281 |
+
client = self._create_new_mcp_client()
|
| 282 |
+
|
| 283 |
+
# Apply filtering based on agent type
|
| 284 |
+
from src.tools.mcp_client import create_filtered_mcp_tools_adapter
|
| 285 |
+
filtered_adapter = create_filtered_mcp_tools_adapter(client, agent_type)
|
| 286 |
+
|
| 287 |
+
self.logger.info(f"Agent {self.config.agent_name} created filtered MCP adapter with {agent_type} tools")
|
| 288 |
+
return filtered_adapter
|
| 289 |
+
|
| 290 |
+
except Exception as e:
|
| 291 |
+
self.logger.error(f"Failed to create filtered MCP tools: {e}")
|
| 292 |
+
raise RuntimeError(f"Failed to create filtered MCP client for {self.config.agent_name}: {e}")
|
| 293 |
+
|
| 294 |
+
def _create_new_mcp_client(self):
|
| 295 |
+
"""Create a new MCP client connection"""
|
| 296 |
+
try:
|
| 297 |
+
# Get MCP configuration
|
| 298 |
+
from config.config import get_mcp_config
|
| 299 |
+
mcp_config = get_mcp_config()
|
| 300 |
+
|
| 301 |
+
# Create MCP client
|
| 302 |
+
from src.tools.mcp_client import MCPClient
|
| 303 |
+
|
| 304 |
+
if mcp_config.get("server_url") and not mcp_config.get("use_stdio", True):
|
| 305 |
+
# HTTP-based MCP server
|
| 306 |
+
client = MCPClient(server_url=mcp_config["server_url"])
|
| 307 |
+
self.logger.info(
|
| 308 |
+
f"Agent {self.config.agent_name} connected to HTTP MCP server: {mcp_config['server_url']}")
|
| 309 |
+
else:
|
| 310 |
+
# Default to the expected HTTP MCP server on port 6274
|
| 311 |
+
client = MCPClient(server_url="http://localhost:6274/mcp")
|
| 312 |
+
self.logger.info(
|
| 313 |
+
f"Agent {self.config.agent_name} connected to default HTTP MCP server: http://localhost:6274/mcp")
|
| 314 |
+
|
| 315 |
+
return client
|
| 316 |
+
|
| 317 |
+
except Exception as e:
|
| 318 |
+
self.logger.error(f"Failed to create MCP client: {e}")
|
| 319 |
+
raise RuntimeError(f"MCP client creation failed for {self.config.agent_name}: {e}")
|
| 320 |
+
|
| 321 |
+
# NOTE: _create_mcp_tools() method removed to prevent unfiltered tool access.
|
| 322 |
+
# All agents now use _create_filtered_mcp_tools_standalone() or _create_filtered_mcp_tools()
|
| 323 |
+
# to ensure proper tool isolation and security.
|
| 324 |
+
|
| 325 |
+
def get_session_info(self) -> Optional[Dict[str, Any]]:
|
| 326 |
+
"""Get information about the current server-managed session"""
|
| 327 |
+
try:
|
| 328 |
+
# First try the adapter's get_session_info method if available
|
| 329 |
+
if hasattr(self.mcp_tools, 'get_session_info'):
|
| 330 |
+
session_info = self.mcp_tools.get_session_info()
|
| 331 |
+
if session_info:
|
| 332 |
+
# Add agent-specific information
|
| 333 |
+
session_info.update({
|
| 334 |
+
"server_managed": True,
|
| 335 |
+
"agent_name": self.config.agent_name
|
| 336 |
+
})
|
| 337 |
+
return session_info
|
| 338 |
+
|
| 339 |
+
# Fallback: Check if we have an MCP tools adapter with a client
|
| 340 |
+
if hasattr(self.mcp_tools, 'client'):
|
| 341 |
+
client = self.mcp_tools.client
|
| 342 |
+
|
| 343 |
+
# Check if client has session ID and connection status
|
| 344 |
+
if hasattr(client, '_session_id') and hasattr(client, 'is_connected'):
|
| 345 |
+
return {
|
| 346 |
+
"session_id": client._session_id,
|
| 347 |
+
"server_managed": True,
|
| 348 |
+
"agent_name": self.config.agent_name,
|
| 349 |
+
"connected": client.is_connected()
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
# Fallback: check if mcp_tools has session info directly
|
| 353 |
+
if hasattr(self.mcp_tools, '_session_id'):
|
| 354 |
+
return {
|
| 355 |
+
"session_id": self.mcp_tools._session_id,
|
| 356 |
+
"server_managed": True,
|
| 357 |
+
"agent_name": self.config.agent_name,
|
| 358 |
+
"connected": getattr(self.mcp_tools, 'is_connected', lambda: True)()
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
# If no session info available, return basic info
|
| 362 |
+
return {
|
| 363 |
+
"session_id": None,
|
| 364 |
+
"server_managed": True,
|
| 365 |
+
"agent_name": self.config.agent_name,
|
| 366 |
+
"connected": hasattr(self.mcp_tools, 'client') and getattr(self.mcp_tools.client, 'is_connected',
|
| 367 |
+
lambda: False)()
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
except Exception as e:
|
| 371 |
+
self.logger.warning(f"Failed to get session info: {e}")
|
| 372 |
+
return {
|
| 373 |
+
"session_id": None,
|
| 374 |
+
"server_managed": True,
|
| 375 |
+
"agent_name": self.config.agent_name,
|
| 376 |
+
"connected": False,
|
| 377 |
+
"error": str(e)
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
def _build_tool_schemas(self) -> List[Dict[str, Any]]:
|
| 381 |
+
"""Build tool schemas for function calling"""
|
| 382 |
+
schemas = []
|
| 383 |
+
|
| 384 |
+
# Get agent-specific tool schemas
|
| 385 |
+
agent_schemas = self._build_agent_specific_tool_schemas()
|
| 386 |
+
schemas.extend(agent_schemas)
|
| 387 |
+
|
| 388 |
+
return schemas
|
| 389 |
+
|
| 390 |
+
def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
|
| 391 |
+
"""
|
| 392 |
+
Build agent-specific tool schemas using proper MCP architecture.
|
| 393 |
+
Schemas come from MCP server via client, not direct imports.
|
| 394 |
+
"""
|
| 395 |
+
schemas = []
|
| 396 |
+
|
| 397 |
+
# Proper MCP way: Get schemas from MCP client (which got them from server)
|
| 398 |
+
try:
|
| 399 |
+
if hasattr(self.mcp_tools, 'get_tool_schemas'):
|
| 400 |
+
# Use the MCP client to get schemas (proper MCP architecture)
|
| 401 |
+
schemas = self.mcp_tools.get_tool_schemas()
|
| 402 |
+
self.logger.info(f"Retrieved {len(schemas)} tool schemas from MCP server")
|
| 403 |
+
else:
|
| 404 |
+
# Fallback for adapters that don't have the new method yet
|
| 405 |
+
self.logger.warning("MCP adapter doesn't support get_tool_schemas, using fallback")
|
| 406 |
+
schemas = self._build_fallback_schemas()
|
| 407 |
+
except Exception as e:
|
| 408 |
+
self.logger.warning(f"Failed to get schemas from MCP client: {e}, using fallback")
|
| 409 |
+
schemas = self._build_fallback_schemas()
|
| 410 |
+
|
| 411 |
+
return schemas
|
| 412 |
+
|
| 413 |
+
def _build_fallback_schemas(self) -> List[Dict[str, Any]]:
|
| 414 |
+
"""Fallback schema building if MCP client method fails"""
|
| 415 |
+
schemas = []
|
| 416 |
+
|
| 417 |
+
# Try to get tool info from MCP client
|
| 418 |
+
if hasattr(self.mcp_tools, 'get_available_tools'):
|
| 419 |
+
try:
|
| 420 |
+
available_tools = self.mcp_tools.get_available_tools()
|
| 421 |
+
for tool_name, tool_info in available_tools.items():
|
| 422 |
+
schema = {
|
| 423 |
+
"type": "function",
|
| 424 |
+
"function": {
|
| 425 |
+
"name": tool_name,
|
| 426 |
+
"description": getattr(tool_info, 'description', f"Tool: {tool_name}"),
|
| 427 |
+
"parameters": getattr(tool_info, 'input_schema', {"type": "object", "properties": {}, "required": []})
|
| 428 |
+
}
|
| 429 |
+
}
|
| 430 |
+
schemas.append(schema)
|
| 431 |
+
self.logger.info(f"Built {len(schemas)} schemas using fallback method")
|
| 432 |
+
except Exception as e:
|
| 433 |
+
self.logger.warning(f"Fallback schema building failed: {e}")
|
| 434 |
+
|
| 435 |
+
return schemas
|
| 436 |
+
|
| 437 |
+
def execute_tool_call(self, tool_call) -> Dict[str, Any]:
|
| 438 |
+
"""Execute a tool call and return results using proper MCP architecture"""
|
| 439 |
+
tool_name = tool_call["name"]
|
| 440 |
+
|
| 441 |
+
try:
|
| 442 |
+
# Parse arguments
|
| 443 |
+
arguments = tool_call["arguments"]
|
| 444 |
+
|
| 445 |
+
# Check if tool is available
|
| 446 |
+
if tool_name not in self.available_tools:
|
| 447 |
+
return {
|
| 448 |
+
"success": False,
|
| 449 |
+
"error": f"Tool '{tool_name}' not available for this agent"
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
# Route tool execution based on tool type
|
| 453 |
+
# Built-in tools (like assign_task_to_*) are callable methods, not MCP server tools
|
| 454 |
+
if callable(self.available_tools[tool_name]):
|
| 455 |
+
# Built-in tool: execute locally
|
| 456 |
+
tool_function = self.available_tools[tool_name]
|
| 457 |
+
result = tool_function(**arguments)
|
| 458 |
+
|
| 459 |
+
# Convert result to standard format
|
| 460 |
+
if hasattr(result, 'to_dict'):
|
| 461 |
+
return result.to_dict()
|
| 462 |
+
elif isinstance(result, dict):
|
| 463 |
+
return result
|
| 464 |
+
else:
|
| 465 |
+
return {
|
| 466 |
+
"success": True,
|
| 467 |
+
"data": result,
|
| 468 |
+
"error": None,
|
| 469 |
+
"metadata": {}
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
elif hasattr(self.mcp_tools, 'client') and hasattr(self.mcp_tools.client, 'call_tool'):
|
| 473 |
+
# MCP server tool: execute via client
|
| 474 |
+
result = self.mcp_tools.client.call_tool(tool_name, arguments)
|
| 475 |
+
|
| 476 |
+
# Convert MCPClientResult to standard format
|
| 477 |
+
if hasattr(result, 'success'):
|
| 478 |
+
return {
|
| 479 |
+
"success": result.success,
|
| 480 |
+
"data": result.data,
|
| 481 |
+
"error": result.error,
|
| 482 |
+
"metadata": getattr(result, 'metadata', {})
|
| 483 |
+
}
|
| 484 |
+
else:
|
| 485 |
+
return result
|
| 486 |
+
else:
|
| 487 |
+
return {
|
| 488 |
+
"success": False,
|
| 489 |
+
"error": f"Tool '{tool_name}' is not executable (neither built-in nor MCP)"
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
except Exception as e:
|
| 493 |
+
self.logger.error(f"Error executing tool {tool_name}: {e}")
|
| 494 |
+
return {
|
| 495 |
+
"success": False,
|
| 496 |
+
"error": f"Tool execution failed: {str(e)}"
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
def log_reasoning(self, iteration: int, reasoning: str):
|
| 500 |
+
"""Log reasoning step in the trace"""
|
| 501 |
+
self.reasoning_trace.append({
|
| 502 |
+
"type": "reasoning",
|
| 503 |
+
"iteration": iteration,
|
| 504 |
+
"content": reasoning,
|
| 505 |
+
"timestamp": time.time()
|
| 506 |
+
})
|
| 507 |
+
self.execution_stats["reasoning_steps"] += 1
|
| 508 |
+
self.execution_stats["total_steps"] += 1
|
| 509 |
+
self.logger.info(f"Reasoning (Iter {iteration}): {reasoning[:100]}...")
|
| 510 |
+
|
| 511 |
+
def log_action(self, iteration: int, tool: str, arguments: Dict[str, Any], result: Dict[str, Any]):
|
| 512 |
+
"""Log action step in the trace"""
|
| 513 |
+
self.reasoning_trace.append({
|
| 514 |
+
"type": "action",
|
| 515 |
+
"iteration": iteration,
|
| 516 |
+
"tool": tool,
|
| 517 |
+
"arguments": arguments,
|
| 518 |
+
"result": result,
|
| 519 |
+
"timestamp": time.time()
|
| 520 |
+
})
|
| 521 |
+
self.execution_stats["action_steps"] += 1
|
| 522 |
+
self.execution_stats["total_steps"] += 1
|
| 523 |
+
|
| 524 |
+
# Log success/failure
|
| 525 |
+
success = result.get("success", True)
|
| 526 |
+
status = "Success" if success else "Failed"
|
| 527 |
+
self.logger.info(f"Action (Iter {iteration}): {tool} -> {status} -> {str(arguments)[:400]}...")
|
| 528 |
+
|
| 529 |
+
def log_error(self, iteration: int, error: str):
|
| 530 |
+
"""Log error in the trace"""
|
| 531 |
+
self.reasoning_trace.append({
|
| 532 |
+
"type": "error",
|
| 533 |
+
"iteration": iteration,
|
| 534 |
+
"error": error,
|
| 535 |
+
"timestamp": time.time()
|
| 536 |
+
})
|
| 537 |
+
self.execution_stats["error_steps"] += 1
|
| 538 |
+
self.execution_stats["total_steps"] += 1
|
| 539 |
+
self.logger.error(f"Error (Iter {iteration}): {error}")
|
| 540 |
+
|
| 541 |
+
def reset_trace(self):
|
| 542 |
+
"""Reset the reasoning trace for a new task"""
|
| 543 |
+
self.reasoning_trace = []
|
| 544 |
+
self.execution_stats = {
|
| 545 |
+
"total_steps": 0,
|
| 546 |
+
"reasoning_steps": 0,
|
| 547 |
+
"action_steps": 0,
|
| 548 |
+
"error_steps": 0,
|
| 549 |
+
"tool_usage": {},
|
| 550 |
+
"success_rate": 1.0
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
def get_execution_stats(self) -> Dict[str, Any]:
|
| 554 |
+
"""Get execution statistics"""
|
| 555 |
+
# Calculate success rate
|
| 556 |
+
if self.execution_stats["action_steps"] > 0:
|
| 557 |
+
failed_actions = sum(1 for step in self.reasoning_trace
|
| 558 |
+
if step.get("type") == "action"
|
| 559 |
+
and not step.get("result", {}).get("success", True))
|
| 560 |
+
self.execution_stats["success_rate"] = (
|
| 561 |
+
(self.execution_stats["action_steps"] - failed_actions) /
|
| 562 |
+
self.execution_stats["action_steps"]
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
return self.execution_stats.copy()
|
| 566 |
+
|
| 567 |
+
def create_response(self, success: bool, result: Dict[str, Any] = None,
|
| 568 |
+
error: str = None, iterations: int = 0,
|
| 569 |
+
execution_time: float = 0.0) -> AgentResponse:
|
| 570 |
+
"""Create a standardized agent response"""
|
| 571 |
+
return AgentResponse(
|
| 572 |
+
success=success,
|
| 573 |
+
result=result,
|
| 574 |
+
error=error,
|
| 575 |
+
iterations=iterations,
|
| 576 |
+
reasoning_trace=self.reasoning_trace.copy(),
|
| 577 |
+
agent_name=self.config.agent_name,
|
| 578 |
+
execution_time=execution_time
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
def validate_config(self) -> bool:
|
| 582 |
+
"""Validate agent configuration"""
|
| 583 |
+
try:
|
| 584 |
+
# Check required fields
|
| 585 |
+
if not self.config.agent_name:
|
| 586 |
+
return False
|
| 587 |
+
if not self.config.model:
|
| 588 |
+
return False
|
| 589 |
+
if self.config.max_iterations <= 0:
|
| 590 |
+
return False
|
| 591 |
+
if not (0.0 <= self.config.temperature <= 2.0):
|
| 592 |
+
return False
|
| 593 |
+
if self.config.max_tokens <= 0:
|
| 594 |
+
return False
|
| 595 |
+
|
| 596 |
+
return True
|
| 597 |
+
except Exception:
|
| 598 |
+
return False
|
| 599 |
+
|
| 600 |
+
@abstractmethod
|
| 601 |
+
def execute_task(self, task_input: TaskInput) -> AgentResponse:
|
| 602 |
+
"""
|
| 603 |
+
Execute a task using the standardized TaskInput format
|
| 604 |
+
|
| 605 |
+
Args:
|
| 606 |
+
task_input: TaskInput object with standardized task information
|
| 607 |
+
|
| 608 |
+
Returns:
|
| 609 |
+
AgentResponse with results and process trace
|
| 610 |
+
"""
|
| 611 |
+
pass
|
| 612 |
+
|
| 613 |
+
@abstractmethod
|
| 614 |
+
def _build_system_prompt(self) -> str:
|
| 615 |
+
"""Build the system prompt for this agent"""
|
| 616 |
+
pass
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
# Simple factory function for creating agent configurations
|
| 620 |
+
|
| 621 |
+
def create_agent_config(
|
| 622 |
+
agent_name: str,
|
| 623 |
+
model: Optional[str] = None,
|
| 624 |
+
max_iterations: Optional[int] = None,
|
| 625 |
+
temperature: Optional[float] = None,
|
| 626 |
+
max_tokens: Optional[int] = None
|
| 627 |
+
) -> AgentConfig:
|
| 628 |
+
"""
|
| 629 |
+
Create an AgentConfig instance for server-managed sessions.
|
| 630 |
+
|
| 631 |
+
Args:
|
| 632 |
+
agent_name: Name of the agent
|
| 633 |
+
model: LLM model to use
|
| 634 |
+
max_iterations: Maximum number of iterations
|
| 635 |
+
temperature: LLM temperature setting
|
| 636 |
+
max_tokens: Maximum tokens for LLM response
|
| 637 |
+
|
| 638 |
+
Returns:
|
| 639 |
+
Configured AgentConfig instance
|
| 640 |
+
"""
|
| 641 |
+
# Load env-backed defaults
|
| 642 |
+
try:
|
| 643 |
+
from config.config import get_config
|
| 644 |
+
api_cfg = get_config()
|
| 645 |
+
except Exception as e:
|
| 646 |
+
raise ValueError(f"Failed to load global configuration: {e}")
|
| 647 |
+
|
| 648 |
+
planner_mode = getattr(api_cfg, "planner_mode", "auto")
|
| 649 |
+
|
| 650 |
+
resolved_model = model if model is not None else getattr(api_cfg, "model_name", None)
|
| 651 |
+
if not resolved_model:
|
| 652 |
+
raise ValueError("Model is not specified and MODEL_NAME is not set in environment")
|
| 653 |
+
|
| 654 |
+
resolved_temperature = temperature if temperature is not None else getattr(api_cfg, "model_temperature", None)
|
| 655 |
+
if resolved_temperature is None:
|
| 656 |
+
raise ValueError("Temperature is not specified and MODEL_TEMPERATURE is not set in environment")
|
| 657 |
+
|
| 658 |
+
resolved_max_tokens = max_tokens if max_tokens is not None else getattr(api_cfg, "model_max_tokens", None)
|
| 659 |
+
if resolved_max_tokens is None:
|
| 660 |
+
raise ValueError("Max tokens is not specified and MODEL_MAX_TOKENS is not set in environment")
|
| 661 |
+
|
| 662 |
+
# Optional paths used by writer and others
|
| 663 |
+
trajectory_storage_path = getattr(api_cfg, "trajectory_storage_path", None)
|
| 664 |
+
report_output_path = getattr(api_cfg, "report_output_path", None)
|
| 665 |
+
document_analysis_path = getattr(api_cfg, "document_analysis_path", None)
|
| 666 |
+
|
| 667 |
+
# Resolve max_iterations per agent type
|
| 668 |
+
if max_iterations is None:
|
| 669 |
+
agent_lower = (agent_name or "").lower()
|
| 670 |
+
resolved_max_iterations = None
|
| 671 |
+
if "planner" in agent_lower:
|
| 672 |
+
resolved_max_iterations = getattr(api_cfg, "planner_max_iterations", None)
|
| 673 |
+
elif "writer" in agent_lower:
|
| 674 |
+
resolved_max_iterations = getattr(api_cfg, "writer_max_iterations", None)
|
| 675 |
+
elif "information" in agent_lower or "seeker" in agent_lower:
|
| 676 |
+
resolved_max_iterations = getattr(api_cfg, "information_seeker_max_iterations", None)
|
| 677 |
+
# if not found in env, raise
|
| 678 |
+
if resolved_max_iterations is None:
|
| 679 |
+
raise ValueError("Max iterations not specified and no env override (PLANNER_MAX_ITERATION/WRITER_MAX_ITERATION/INFORMATION_SEEKER_MAX_ITERATION)")
|
| 680 |
+
max_iterations = resolved_max_iterations
|
| 681 |
+
|
| 682 |
+
return AgentConfig(
|
| 683 |
+
agent_name=agent_name,
|
| 684 |
+
planner_mode=planner_mode,
|
| 685 |
+
model=resolved_model,
|
| 686 |
+
max_iterations=int(max_iterations),
|
| 687 |
+
temperature=resolved_temperature,
|
| 688 |
+
max_tokens=resolved_max_tokens,
|
| 689 |
+
trajectory_storage_path=trajectory_storage_path,
|
| 690 |
+
report_output_path=report_output_path,
|
| 691 |
+
document_analysis_path=document_analysis_path
|
| 692 |
+
)
|
deepdiver_v2/src/agents/objective_information_seeker.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
import json
|
| 3 |
+
from typing import Dict, Any, List
|
| 4 |
+
import time
|
| 5 |
+
import requests
|
| 6 |
+
import os
|
| 7 |
+
from .base_agent import BaseAgent, AgentConfig, AgentResponse, TaskInput
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class InformationSeekerAgent(BaseAgent):
|
| 12 |
+
"""
|
| 13 |
+
Information Seeker Agent that follows ReAct pattern (Reasoning + Acting)
|
| 14 |
+
|
| 15 |
+
This agent takes decomposed sub-questions or tasks from parent agents,
|
| 16 |
+
thinks interleaved (reasoning -> action -> reasoning -> action),
|
| 17 |
+
uses MCP tools to gather information, and returns structured results.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, config: AgentConfig = None, shared_mcp_client=None):
|
| 21 |
+
# Set default agent name if not specified
|
| 22 |
+
if config is None:
|
| 23 |
+
config = AgentConfig(agent_name="InformationSeekerAgent")
|
| 24 |
+
elif config.agent_name == "base_agent":
|
| 25 |
+
config.agent_name = "InformationSeekerAgent"
|
| 26 |
+
|
| 27 |
+
super().__init__(config, shared_mcp_client)
|
| 28 |
+
|
| 29 |
+
def _build_system_prompt(self) -> str:
|
| 30 |
+
"""Build the system prompt for the ReAct agent"""
|
| 31 |
+
tool_schemas_str = json.dumps(self.tool_schemas, ensure_ascii=False)
|
| 32 |
+
system_prompt_template = """You are an Information Seeker Agent that follows the ReAct pattern (Reasoning + Acting).
|
| 33 |
+
|
| 34 |
+
Your role is to:
|
| 35 |
+
1. Take decomposed sub-questions or tasks from parent agents
|
| 36 |
+
2. Think step-by-step through reasoning
|
| 37 |
+
3. Use available tools to gather information when needed
|
| 38 |
+
4. Continue reasoning based on tool results
|
| 39 |
+
5. Repeat this process until you have sufficient information
|
| 40 |
+
6. Call info_seeker_objective_task_done to provide a structured summary
|
| 41 |
+
|
| 42 |
+
### Optimized Workflow:
|
| 43 |
+
Follow this optimized workflow for information gathering:
|
| 44 |
+
|
| 45 |
+
1. INITIAL RESEARCH:
|
| 46 |
+
- Use `batch_web_search` to find relevant URLs for your queries. When calling the search statement, consider the language of the user's question. For example, for a Chinese question, generate a part of the search statement in Chinese.
|
| 47 |
+
- Analyze the search results (titles, snippets, URLs) to identify promising sources
|
| 48 |
+
|
| 49 |
+
2. CONTENT EXTRACTION:
|
| 50 |
+
- For important URLs, use `url_crawler` to:
|
| 51 |
+
a) Extract full content from the webpage
|
| 52 |
+
b) Save the content to a file in the workspace
|
| 53 |
+
- Store results with meaningful file paths (e.g., \"research/ai_trends_2024.txt\")
|
| 54 |
+
|
| 55 |
+
3. CONTENT ANALYSIS:
|
| 56 |
+
- Use `document_qa` to ask specific questions about the saved files:
|
| 57 |
+
a) Formulate focused questions to extract key insights
|
| 58 |
+
b) Use answers to deepen your understanding
|
| 59 |
+
- You can ask multiple questions about the same file
|
| 60 |
+
|
| 61 |
+
4. FILE MANAGEMENT:
|
| 62 |
+
- Use `file_write` to save important findings or summaries
|
| 63 |
+
- For reviewing saved content:
|
| 64 |
+
a) Prefer `document_qa` to ask specific questions about the content
|
| 65 |
+
b) Use `file_read` ONLY for small files (<1000 tokens) when you need the entire content
|
| 66 |
+
c) Avoid reading large files directly as it may exceed context limits
|
| 67 |
+
|
| 68 |
+
5. TASK COMPLETION:
|
| 69 |
+
- When ready to report, call `info_seeker_objective_task_done` with:
|
| 70 |
+
a) Comprehensive markdown summary of your process and findings
|
| 71 |
+
b) List of key files created with descriptions
|
| 72 |
+
|
| 73 |
+
### Usage of Systematic Tool:
|
| 74 |
+
- `think` is a systematic tool. After receiving the response from the complex tool or before invoking any other tools, you must **first invoke the `think` tool**: to deeply reflect on the results of previous tool invocations (if any), and to thoroughly consider and plan the user's task. The `think` tool does not acquire new information; it only saves your thoughts into memory.
|
| 75 |
+
- `reflect` is a systematic tool. When encountering a failure in tool execution, it is necessary to invoke the reflect tool to conduct a review and revise the task plan. It does not acquire new information; it only saves your thoughts into memory.
|
| 76 |
+
|
| 77 |
+
Always provide clear reasoning for your actions and synthesize information effectively.
|
| 78 |
+
|
| 79 |
+
Below, within the <tools></tools> tags, are the descriptions of each tool and the required fields for invocation:
|
| 80 |
+
<tools>
|
| 81 |
+
$tool_schemas
|
| 82 |
+
</tools>
|
| 83 |
+
For each function call, return a JSON object placed within the [unused11][unused12] tags, which includes the function name and the corresponding function arguments:
|
| 84 |
+
[unused11][{\"name\": <function name>, \"arguments\": <args json object>}][unused12]
|
| 85 |
+
"""
|
| 86 |
+
return system_prompt_template.replace("$tool_schemas", tool_schemas_str)
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def _build_initial_message_from_task_input(task_input: TaskInput) -> str:
|
| 90 |
+
"""Build the initial user message from TaskInput"""
|
| 91 |
+
message = task_input.format_for_prompt()
|
| 92 |
+
|
| 93 |
+
message += "\nPlease analyze this task and start your ReAct process:\n"
|
| 94 |
+
message += "1. Reason about what information you need to gather\n"
|
| 95 |
+
message += "2. Use appropriate tools to get that information\n"
|
| 96 |
+
message += "3. Continue reasoning and acting until you have sufficient information\n"
|
| 97 |
+
message += "4. Call task_done when ready to provide your complete findings\n\n"
|
| 98 |
+
message += "Begin with your initial reasoning about the task."
|
| 99 |
+
|
| 100 |
+
return message
|
| 101 |
+
|
| 102 |
+
def execute_task(self, task_input: TaskInput) -> AgentResponse:
|
| 103 |
+
"""
|
| 104 |
+
Execute a task using ReAct pattern (Reasoning + Acting)
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
task_input: TaskInput object with standardized task information
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
AgentResponse with results and process trace
|
| 111 |
+
"""
|
| 112 |
+
start_time = time.time()
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
self.logger.info(f"Starting information seeker task: {task_input.task_content}")
|
| 116 |
+
|
| 117 |
+
# Reset trace for new task
|
| 118 |
+
self.reset_trace()
|
| 119 |
+
|
| 120 |
+
# Initialize conversation history
|
| 121 |
+
conversation_history = []
|
| 122 |
+
|
| 123 |
+
# Build initial system prompt for ReAct
|
| 124 |
+
system_prompt = self._build_system_prompt()
|
| 125 |
+
|
| 126 |
+
# Build initial user message from TaskInput
|
| 127 |
+
user_message = self._build_initial_message_from_task_input(task_input)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Add to conversation
|
| 131 |
+
conversation_history.append({"role": "system", "content": system_prompt})
|
| 132 |
+
conversation_history.append({"role": "user", "content": user_message+" /no_think"})
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
iteration = 0
|
| 136 |
+
task_completed = False
|
| 137 |
+
# Get model endpoint configuration from env-backed config
|
| 138 |
+
from config.config import get_config
|
| 139 |
+
config = get_config()
|
| 140 |
+
model_config = config.get_custom_llm_config()
|
| 141 |
+
|
| 142 |
+
pangu_url = model_config.get('url') or os.getenv('MODEL_REQUEST_URL', '')
|
| 143 |
+
model_token = model_config.get('token') or os.getenv('MODEL_REQUEST_TOKEN', '')
|
| 144 |
+
headers = {'Content-Type': 'application/json', 'csb-token': model_token}
|
| 145 |
+
|
| 146 |
+
# ReAct Loop: Reasoning -> Acting -> Reasoning -> Acting...
|
| 147 |
+
while iteration < self.config.max_iterations and not task_completed:
|
| 148 |
+
iteration += 1
|
| 149 |
+
self.logger.info(f"Planning iteration {iteration}")
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
# Get LLM response (reasoning + potential tool calls)
|
| 153 |
+
retry_num = 1
|
| 154 |
+
max_retry_num = 10
|
| 155 |
+
while retry_num < max_retry_num:
|
| 156 |
+
try:
|
| 157 |
+
response = requests.post(
|
| 158 |
+
url=pangu_url,
|
| 159 |
+
headers=headers,
|
| 160 |
+
json={
|
| 161 |
+
"model": self.config.model,
|
| 162 |
+
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<s>[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{'<s>[unused9]系统:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}",
|
| 163 |
+
"messages": conversation_history,
|
| 164 |
+
"temperature": self.config.temperature,
|
| 165 |
+
"spaces_between_special_tokens": False,
|
| 166 |
+
"max_tokens": self.config.max_tokens,
|
| 167 |
+
},
|
| 168 |
+
timeout=model_config.get("timeout", 180)
|
| 169 |
+
)
|
| 170 |
+
response = response.json()
|
| 171 |
+
self.logger.debug(f"API response received")
|
| 172 |
+
break
|
| 173 |
+
except Exception as e:
|
| 174 |
+
time.sleep(3)
|
| 175 |
+
retry_num += 1
|
| 176 |
+
if retry_num == max_retry_num:
|
| 177 |
+
raise ValueError(str(e))
|
| 178 |
+
continue
|
| 179 |
+
|
| 180 |
+
assistant_message = response["choices"][0]["message"]
|
| 181 |
+
|
| 182 |
+
# Log the reasoning
|
| 183 |
+
try:
|
| 184 |
+
if assistant_message["content"]:
|
| 185 |
+
reasoning_content = assistant_message["content"].split("[unused16]")[-1].split("[unused17]")[0]
|
| 186 |
+
if len(reasoning_content) > 0:
|
| 187 |
+
self.log_reasoning(iteration, reasoning_content)
|
| 188 |
+
except Exception as e:
|
| 189 |
+
self.logger.warning(f"Tool call parsing error: {e}")
|
| 190 |
+
# Parse error, rerun
|
| 191 |
+
followup_prompt = f"There is a problem with the format of model generation: {e}. Please try again."
|
| 192 |
+
conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
def extract_tool_calls(content):
|
| 196 |
+
import re
|
| 197 |
+
if not content:
|
| 198 |
+
return []
|
| 199 |
+
tool_call_str = re.findall(r"\[unused11\]([\s\S]*?)\[unused12\]", content)
|
| 200 |
+
if len(tool_call_str) > 0:
|
| 201 |
+
try:
|
| 202 |
+
tool_calls = json.loads(tool_call_str[0].strip())
|
| 203 |
+
except:
|
| 204 |
+
return []
|
| 205 |
+
else:
|
| 206 |
+
return []
|
| 207 |
+
return tool_calls
|
| 208 |
+
|
| 209 |
+
# Add assistant message to conversation
|
| 210 |
+
conversation_history.append({
|
| 211 |
+
"role": "assistant",
|
| 212 |
+
"content": assistant_message["content"]
|
| 213 |
+
})
|
| 214 |
+
|
| 215 |
+
tool_calls = extract_tool_calls(assistant_message["content"])
|
| 216 |
+
|
| 217 |
+
# Execute tool calls if any (Acting phase)
|
| 218 |
+
|
| 219 |
+
for tool_call in tool_calls:
|
| 220 |
+
arguments = tool_call["arguments"]
|
| 221 |
+
|
| 222 |
+
# Check if planning is complete
|
| 223 |
+
if tool_call["name"] in ["info_seeker_objective_task_done"]:
|
| 224 |
+
task_completed = True
|
| 225 |
+
self.log_action(iteration, tool_call["name"], arguments, arguments)
|
| 226 |
+
break
|
| 227 |
+
if tool_call["name"] in ["think", "reflect"]:
|
| 228 |
+
tool_result = {"tool_results": "You can proceed to invoke other tools if needed."}
|
| 229 |
+
else:
|
| 230 |
+
tool_result = self.execute_tool_call(tool_call)
|
| 231 |
+
|
| 232 |
+
# Log the action using base class method
|
| 233 |
+
self.log_action(iteration, tool_call["name"], arguments, tool_result)
|
| 234 |
+
|
| 235 |
+
# Add tool result to conversation
|
| 236 |
+
conversation_history.append({
|
| 237 |
+
"role": "tool",
|
| 238 |
+
"content": json.dumps(tool_result, ensure_ascii=False, indent=2) + " /no_think"
|
| 239 |
+
})
|
| 240 |
+
|
| 241 |
+
# If no tool calls, encourage continued planning
|
| 242 |
+
if len(tool_calls) == 0:
|
| 243 |
+
# Add follow-up prompt to encourage action or completion
|
| 244 |
+
followup_prompt = (
|
| 245 |
+
"Continue your planning process. Use available tools to assign tasks to agents, "
|
| 246 |
+
"search for information, or coordinate work. When you have a complete answer, "
|
| 247 |
+
"call info_seeker_objective_task_done. /no_think"
|
| 248 |
+
)
|
| 249 |
+
conversation_history.append({"role": "user", "content": followup_prompt})
|
| 250 |
+
if iteration == self.config.max_iterations-3:
|
| 251 |
+
followup_prompt = "Due to length and number of rounds restrictions, you must now call the `info_seeker_objective_task_done` tool to report the completion of your task. /no_think"
|
| 252 |
+
conversation_history.append({"role": "user", "content": followup_prompt})
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
except Exception as e:
|
| 256 |
+
error_msg = f"Error in planning iteration {iteration}: {e}"
|
| 257 |
+
self.log_error(iteration, error_msg)
|
| 258 |
+
break
|
| 259 |
+
|
| 260 |
+
execution_time = time.time() - start_time
|
| 261 |
+
# Extract final result
|
| 262 |
+
if task_completed:
|
| 263 |
+
# Find the info_seeker_objective_task_done result in the trace
|
| 264 |
+
task_done_result = None
|
| 265 |
+
for step in reversed(self.reasoning_trace):
|
| 266 |
+
if step.get("type") == "action" and step.get("tool") == "info_seeker_objective_task_done":
|
| 267 |
+
task_done_result = step.get("result")
|
| 268 |
+
break
|
| 269 |
+
|
| 270 |
+
return self.create_response(
|
| 271 |
+
success=True,
|
| 272 |
+
result=task_done_result,
|
| 273 |
+
iterations=iteration,
|
| 274 |
+
execution_time=execution_time
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
return self.create_response(
|
| 278 |
+
success=False,
|
| 279 |
+
error=f"Task not completed within {self.config.max_iterations} iterations",
|
| 280 |
+
iterations=iteration,
|
| 281 |
+
execution_time=execution_time
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
except Exception as e:
|
| 285 |
+
execution_time = time.time() - start_time
|
| 286 |
+
self.logger.error(f"Error in execute_task: {e}")
|
| 287 |
+
return self.create_response(
|
| 288 |
+
success=False,
|
| 289 |
+
error=str(e),
|
| 290 |
+
iterations=iteration if 'iteration' in locals() else 0,
|
| 291 |
+
execution_time=execution_time
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
|
| 295 |
+
"""
|
| 296 |
+
Build tool schemas for InformationSeekerAgent using proper MCP architecture.
|
| 297 |
+
Schemas come from MCP server via client, not direct imports.
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
# Get MCP tool schemas from server via client (proper MCP architecture)
|
| 301 |
+
schemas = super()._build_agent_specific_tool_schemas()
|
| 302 |
+
|
| 303 |
+
# Add schemas for built-in task assignment tools
|
| 304 |
+
builtin_assignment_schemas = [
|
| 305 |
+
{
|
| 306 |
+
"type": "function",
|
| 307 |
+
"function": {
|
| 308 |
+
"name": "think",
|
| 309 |
+
"description": "Use the tool to think about something. It will not obtain new information or make any changes to the repository, but just log the thought. Use it when complex reasoning or brainstorming is needed.",
|
| 310 |
+
"parameters": {
|
| 311 |
+
"type": "object",
|
| 312 |
+
"properties": {
|
| 313 |
+
"thought": {
|
| 314 |
+
"type": "string",
|
| 315 |
+
"description": "Your thoughts."
|
| 316 |
+
}
|
| 317 |
+
},
|
| 318 |
+
"required": ["thought"]
|
| 319 |
+
}
|
| 320 |
+
}
|
| 321 |
+
},
|
| 322 |
+
{
|
| 323 |
+
"type": "function",
|
| 324 |
+
"function": {
|
| 325 |
+
"name": "reflect",
|
| 326 |
+
"description": "When multiple attempts yield no progress, use this tool to reflect on previous reasoning and planning, considering possible overlooked clues and exploring more possibilities. It will not obtain new information or make any changes to the repository.",
|
| 327 |
+
"parameters": {
|
| 328 |
+
"type": "object",
|
| 329 |
+
"properties": {
|
| 330 |
+
"reflect": {
|
| 331 |
+
"type": "string",
|
| 332 |
+
"description": "The specific content of your reflection"
|
| 333 |
+
}
|
| 334 |
+
},
|
| 335 |
+
"required": ["reflect"]
|
| 336 |
+
}
|
| 337 |
+
}
|
| 338 |
+
},
|
| 339 |
+
{
|
| 340 |
+
"type": "function",
|
| 341 |
+
"function": {
|
| 342 |
+
"name": "info_seeker_objective_task_done",
|
| 343 |
+
"description": "Structured reporting of task completion details including summary, decisions, outputs, and status",
|
| 344 |
+
"inputSchema": {
|
| 345 |
+
"type": "object",
|
| 346 |
+
"properties": {
|
| 347 |
+
"task_summary": {
|
| 348 |
+
"type": "string",
|
| 349 |
+
"description": "Comprehensive markdown covering what the agent was asked to do, steps taken, tools used, key findings, files created, challenges, and final deliverables.",
|
| 350 |
+
"format": "markdown"
|
| 351 |
+
},
|
| 352 |
+
"task_name": {
|
| 353 |
+
"type": "string",
|
| 354 |
+
"description": "The name of the task currently assigned to the agent, usually with underscores (e.g., 'web_research_ai_trends')"
|
| 355 |
+
},
|
| 356 |
+
"key_files": {
|
| 357 |
+
"type": "array",
|
| 358 |
+
"items": {
|
| 359 |
+
"type": "object",
|
| 360 |
+
"properties": {
|
| 361 |
+
"file_path": {
|
| 362 |
+
"type": "string",
|
| 363 |
+
"description": "Relative path to created/modified file"
|
| 364 |
+
},
|
| 365 |
+
"desc": {
|
| 366 |
+
"type": "string",
|
| 367 |
+
"description": "File contents and creation purpose"
|
| 368 |
+
},
|
| 369 |
+
"is_final_output_file": {
|
| 370 |
+
"type": "boolean",
|
| 371 |
+
"description": "Whether file is primary deliverable"
|
| 372 |
+
}
|
| 373 |
+
},
|
| 374 |
+
"required": ["file_path", "desc", "is_final_output_file"]
|
| 375 |
+
},
|
| 376 |
+
"description": "List of key files generated or modified during the task, with their details."
|
| 377 |
+
},
|
| 378 |
+
"completion_status": {
|
| 379 |
+
"type": "string",
|
| 380 |
+
"enum": ["completed", "partial", "failed"],
|
| 381 |
+
"description": "Final task status"
|
| 382 |
+
}
|
| 383 |
+
},
|
| 384 |
+
"required": ["task_summary", "task_name", "key_files", "completion_status"]
|
| 385 |
+
}
|
| 386 |
+
}
|
| 387 |
+
},
|
| 388 |
+
]
|
| 389 |
+
|
| 390 |
+
schemas.extend(builtin_assignment_schemas)
|
| 391 |
+
|
| 392 |
+
return schemas
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
# Factory function for creating the agent
|
| 396 |
+
def create_objective_information_seeker(
|
| 397 |
+
model: Any = None,
|
| 398 |
+
max_iterations: Any = None,
|
| 399 |
+
shared_mcp_client=None,
|
| 400 |
+
**kwargs
|
| 401 |
+
) -> InformationSeekerAgent:
|
| 402 |
+
"""
|
| 403 |
+
Create an InformationSeekerAgent instance with server-managed sessions.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
model: The LLM model to use
|
| 407 |
+
max_iterations: Maximum number of iterations
|
| 408 |
+
shared_mcp_client: Optional shared MCP client from parent agent (prevents extra sessions)
|
| 409 |
+
**kwargs: Additional configuration options
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
Configured InformationSeekerAgent instance with appropriate tools
|
| 413 |
+
"""
|
| 414 |
+
# Import the enhanced config function
|
| 415 |
+
from ..agents.base_agent import create_agent_config
|
| 416 |
+
|
| 417 |
+
# Create agent configuration (session managed by MCP server)
|
| 418 |
+
config = create_agent_config(
|
| 419 |
+
agent_name="InformationSeekerAgent",
|
| 420 |
+
model=model,
|
| 421 |
+
max_iterations=max_iterations,
|
| 422 |
+
**kwargs
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# Create agent instance with shared MCP client (filtered tools for information seeking)
|
| 426 |
+
agent = InformationSeekerAgent(config=config, shared_mcp_client=shared_mcp_client)
|
| 427 |
+
|
| 428 |
+
return agent
|
deepdiver_v2/src/agents/planner_agent.py
ADDED
|
@@ -0,0 +1,1203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
"""
|
| 3 |
+
Planner Agent for Multi-Agent Task Coordination
|
| 4 |
+
|
| 5 |
+
This agent serves as a coordinator for complex tasks that require multiple agents
|
| 6 |
+
working together. It implements the ReAct pattern for reasoning and action.
|
| 7 |
+
"""
|
| 8 |
+
import time
|
| 9 |
+
import json
|
| 10 |
+
import requests
|
| 11 |
+
import os
|
| 12 |
+
from typing import Dict, Any, List
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 14 |
+
|
| 15 |
+
# Base imports
|
| 16 |
+
from .base_agent import BaseAgent, AgentConfig, AgentResponse, WriterAgentTaskInput
|
| 17 |
+
# Import agent creators for built-in task assignment
|
| 18 |
+
from .writer_agent import create_writer_agent
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PlannerAgent(BaseAgent):
|
| 22 |
+
"""
|
| 23 |
+
PlannerAgent coordinates multiple agents to handle complex user queries.
|
| 24 |
+
|
| 25 |
+
The agent uses the ReAct pattern (Reasoning + Acting) to analyze user requests,
|
| 26 |
+
break them down into manageable tasks, and coordinate the appropriate agents
|
| 27 |
+
to complete the work.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, config: AgentConfig = None, shared_mcp_client=None):
|
| 31 |
+
# Set default agent name if not specified
|
| 32 |
+
if config and not config.agent_name:
|
| 33 |
+
config.agent_name = "PlannerAgent"
|
| 34 |
+
elif not config:
|
| 35 |
+
config = AgentConfig(agent_name="PlannerAgent")
|
| 36 |
+
|
| 37 |
+
super().__init__(config, shared_mcp_client)
|
| 38 |
+
|
| 39 |
+
# Planner-specific state
|
| 40 |
+
self.execution_plan = []
|
| 41 |
+
self.task_queue = []
|
| 42 |
+
|
| 43 |
+
# Add built-in task assignment methods to available tools
|
| 44 |
+
self._add_builtin_assignment_tools()
|
| 45 |
+
|
| 46 |
+
# Regenerate tool schemas with built-in assignment tools
|
| 47 |
+
self.tool_schemas = self._build_tool_schemas()
|
| 48 |
+
|
| 49 |
+
self.sub_agent_configs = {}
|
| 50 |
+
|
| 51 |
+
def _add_builtin_assignment_tools(self):
|
| 52 |
+
"""Add built-in task assignment methods as available tools"""
|
| 53 |
+
# Add assignment methods that share the MCP client connection
|
| 54 |
+
self.available_tools.update({
|
| 55 |
+
"assign_subjective_task_to_writer": self.assign_subjective_task_to_writer, # assign_subjective_task_to_writer
|
| 56 |
+
"assign_multi_objective_tasks_to_info_seeker": self.assign_multi_objective_tasks_to_info_seeker,
|
| 57 |
+
"assign_multi_subjective_tasks_to_info_seeker": self.assign_multi_subjective_tasks_to_info_seeker
|
| 58 |
+
})
|
| 59 |
+
|
| 60 |
+
def assign_multi_objective_tasks_to_info_seeker(
|
| 61 |
+
self,
|
| 62 |
+
tasks: List[Dict[str, str]],
|
| 63 |
+
max_workers: int = 5
|
| 64 |
+
) -> Dict[str, Any]:
|
| 65 |
+
"""
|
| 66 |
+
Creates multiple TaskInput objects and routes them to info_seeker agents for concurrent execution.
|
| 67 |
+
This tool enables the PlannerAgent to assign multiple research tasks through the MCP tool interface.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
tasks: List of task dictionaries with the following keys:
|
| 71 |
+
- task_content (required): The specific task content
|
| 72 |
+
- task_steps_for_reference: Optional reference steps for execution
|
| 73 |
+
- deliverable_contents: Format of expected deliverable
|
| 74 |
+
- acceptance_checking_criteria: Criteria for task completion and quality
|
| 75 |
+
- workspace_id: Workspace ID for stored files and memory
|
| 76 |
+
- current_task_status: Description of current task status
|
| 77 |
+
|
| 78 |
+
max_workers: Maximum concurrent threads (default=4)
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
MCPToolResult with execution results for all tasks
|
| 82 |
+
"""
|
| 83 |
+
try:
|
| 84 |
+
# Validate task count (1-4 tasks)
|
| 85 |
+
if not (1 <= len(tasks) <= 5):
|
| 86 |
+
return {
|
| 87 |
+
"success": False,
|
| 88 |
+
"error": f"Invalid task count ({len(tasks)}). Must assign 1~5 tasks. Please re-plan the task execution schedule or re-decompose the task."
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# Import here to avoid circular imports
|
| 92 |
+
try:
|
| 93 |
+
from agents import TaskInput, create_objective_information_seeker
|
| 94 |
+
except ImportError:
|
| 95 |
+
from ..agents import TaskInput, create_objective_information_seeker
|
| 96 |
+
|
| 97 |
+
results = []
|
| 98 |
+
import threading
|
| 99 |
+
lock = threading.Lock()
|
| 100 |
+
|
| 101 |
+
def process_task(task: Dict[str, str]):
|
| 102 |
+
"""Process a single task with thread-safe result collection"""
|
| 103 |
+
try:
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# Create TaskInput object
|
| 107 |
+
task_input = TaskInput(
|
| 108 |
+
task_content=task["task_content"],
|
| 109 |
+
task_steps_for_reference=task.get("task_steps_for_reference"),
|
| 110 |
+
deliverable_contents=task.get("deliverable_contents"),
|
| 111 |
+
current_task_status=task.get("current_task_status"),
|
| 112 |
+
workspace_id=None, # Session/workspace is managed by the server; no need to set explicitly
|
| 113 |
+
acceptance_checking_criteria=task.get("acceptance_checking_criteria")
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Create and execute with info seeker agent - use shared MCP client for session consistency
|
| 117 |
+
info_seeker_config = getattr(self, 'sub_agent_configs', {}).get('information_seeker', {})
|
| 118 |
+
info_seeker = create_objective_information_seeker(
|
| 119 |
+
model=info_seeker_config.get('model', self.config.model),
|
| 120 |
+
max_iterations=info_seeker_config.get('max_iterations', 30),
|
| 121 |
+
shared_mcp_client=self.mcp_tools.client if hasattr(self.mcp_tools, 'client') else self.mcp_tools
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
self.logger.info(f"Assigning task to InformationSeekerAgent: {task['task_content'][:8000]}...")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# Execute the task
|
| 128 |
+
response = info_seeker.execute_task(task_input)
|
| 129 |
+
|
| 130 |
+
if response.success:
|
| 131 |
+
response_data = {
|
| 132 |
+
"task_content": task.get("task_content", "Unknown task"),
|
| 133 |
+
"success": True,
|
| 134 |
+
"data": response.result,
|
| 135 |
+
"agent_name": response.agent_name,
|
| 136 |
+
"iterations": response.iterations,
|
| 137 |
+
"execution_time": response.execution_time,
|
| 138 |
+
# "reasoning_trace": response.reasoning_trace
|
| 139 |
+
}
|
| 140 |
+
else:
|
| 141 |
+
response_data = {
|
| 142 |
+
"task_content": task.get("task_content", "Unknown task"),
|
| 143 |
+
"success": False,
|
| 144 |
+
"error": response.error,
|
| 145 |
+
"agent_name": response.agent_name
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
# Thread-safe result collection
|
| 149 |
+
with lock:
|
| 150 |
+
results.append(response_data)
|
| 151 |
+
|
| 152 |
+
return response_data
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
error_msg = f"Task processing failed: {str(e)}"
|
| 156 |
+
self.logger.error(error_msg)
|
| 157 |
+
with lock:
|
| 158 |
+
results.append({
|
| 159 |
+
"task_content": task.get("task_content", "Unknown task"),
|
| 160 |
+
"success": False,
|
| 161 |
+
"error": error_msg
|
| 162 |
+
})
|
| 163 |
+
return None
|
| 164 |
+
|
| 165 |
+
# Execute tasks in parallel with thread pool
|
| 166 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 167 |
+
futures = [executor.submit(process_task, task) for task in tasks]
|
| 168 |
+
# Wait for all tasks to complete
|
| 169 |
+
for future in futures:
|
| 170 |
+
future.result() # Raise exceptions if any
|
| 171 |
+
|
| 172 |
+
# Check overall success
|
| 173 |
+
all_success = all(task_result.get("success", False) for task_result in results)
|
| 174 |
+
|
| 175 |
+
return {
|
| 176 |
+
"success": all_success,
|
| 177 |
+
"data": {"tasks": results},
|
| 178 |
+
"error": None if all_success else "Some tasks failed",
|
| 179 |
+
"metadata": {
|
| 180 |
+
"tool_name": "assign_multi_objective_tasks_to_info_seeker",
|
| 181 |
+
"task_count": len(tasks),
|
| 182 |
+
"success_count": sum(1 for r in results if r.get("success")),
|
| 183 |
+
"failure_count": sum(1 for r in results if not r.get("success"))
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
except Exception as e:
|
| 188 |
+
self.logger.error(f"Multi-task assignment failed: {e}")
|
| 189 |
+
return {
|
| 190 |
+
"success": False,
|
| 191 |
+
"error": f"Multi-task assignment failed: {str(e)}"
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def assign_multi_subjective_tasks_to_info_seeker(
|
| 196 |
+
self,
|
| 197 |
+
tasks: List[Dict[str, str]],
|
| 198 |
+
max_workers: int = 5
|
| 199 |
+
) -> Dict[str, Any]:
|
| 200 |
+
"""
|
| 201 |
+
Creates multiple TaskInput objects and routes them to info_seeker agents for concurrent execution.
|
| 202 |
+
This tool enables the PlannerAgent to assign multiple research tasks through the MCP tool interface.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
tasks: List of task dictionaries with the following keys:
|
| 206 |
+
- task_content (required): The specific task content
|
| 207 |
+
- task_steps_for_reference: Optional reference steps for execution
|
| 208 |
+
- deliverable_contents: Format of expected deliverable
|
| 209 |
+
- acceptance_checking_criteria: Criteria for task completion and quality
|
| 210 |
+
- workspace_id: Workspace ID for stored files and memory
|
| 211 |
+
- current_task_status: Description of current task status
|
| 212 |
+
|
| 213 |
+
max_workers: Maximum concurrent threads (default=4)
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
MCPToolResult with execution results for all tasks
|
| 217 |
+
"""
|
| 218 |
+
try:
|
| 219 |
+
# Validate task count (1-4 tasks)
|
| 220 |
+
if not (1 <= len(tasks) <= 6):
|
| 221 |
+
return {
|
| 222 |
+
"success": False,
|
| 223 |
+
"error": f"Invalid task count ({len(tasks)}). Must assign 1-6 tasks."
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
# Import here to avoid circular imports
|
| 227 |
+
try:
|
| 228 |
+
from agents import TaskInput, create_subjective_information_seeker
|
| 229 |
+
except ImportError:
|
| 230 |
+
from ..agents import TaskInput, create_subjective_information_seeker
|
| 231 |
+
|
| 232 |
+
results = []
|
| 233 |
+
import threading
|
| 234 |
+
lock = threading.Lock()
|
| 235 |
+
|
| 236 |
+
def process_task(task: Dict[str, str]):
|
| 237 |
+
"""Process a single task with thread-safe result collection"""
|
| 238 |
+
try:
|
| 239 |
+
# Create TaskInput object
|
| 240 |
+
task_input = TaskInput(
|
| 241 |
+
task_content=task["task_content"],
|
| 242 |
+
task_steps_for_reference=task.get("task_steps_for_reference"),
|
| 243 |
+
deliverable_contents=task.get("deliverable_contents"),
|
| 244 |
+
current_task_status=task.get("current_task_status"),
|
| 245 |
+
workspace_id=self.get_session_info()["session_id"], # Session/workspace is managed by the server; no need to set explicitly
|
| 246 |
+
acceptance_checking_criteria=task.get("acceptance_checking_criteria")
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Create and execute with info seeker agent - use shared MCP client for session consistency
|
| 250 |
+
info_seeker_config = getattr(self, 'sub_agent_configs', {}).get('information_seeker', {})
|
| 251 |
+
info_seeker = create_subjective_information_seeker(
|
| 252 |
+
model=info_seeker_config.get('model', self.config.model),
|
| 253 |
+
max_iterations=info_seeker_config.get('max_iterations', 30),
|
| 254 |
+
shared_mcp_client=self.mcp_tools.client if hasattr(self.mcp_tools, 'client') else self.mcp_tools
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
self.logger.info(f"Assigning task to InformationSeekerAgent: {task['task_content'][:8000]}...")
|
| 258 |
+
|
| 259 |
+
# Execute the task
|
| 260 |
+
response = info_seeker.execute_task(task_input)
|
| 261 |
+
|
| 262 |
+
if response.success:
|
| 263 |
+
response_data = {
|
| 264 |
+
"task_content": task.get("task_content", "Unknown task"),
|
| 265 |
+
"success": True,
|
| 266 |
+
"data": response.result,
|
| 267 |
+
"agent_name": response.agent_name,
|
| 268 |
+
"iterations": response.iterations,
|
| 269 |
+
"execution_time": response.execution_time,
|
| 270 |
+
# "reasoning_trace": response.reasoning_trace
|
| 271 |
+
}
|
| 272 |
+
else:
|
| 273 |
+
response_data = {
|
| 274 |
+
"task_content": task.get("task_content", "Unknown task"),
|
| 275 |
+
"success": False,
|
| 276 |
+
"error": response.error,
|
| 277 |
+
"agent_name": response.agent_name
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
# Thread-safe result collection
|
| 281 |
+
with lock:
|
| 282 |
+
results.append(response_data)
|
| 283 |
+
|
| 284 |
+
return response_data
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
error_msg = f"Task processing failed: {str(e)}"
|
| 288 |
+
self.logger.error(error_msg)
|
| 289 |
+
with lock:
|
| 290 |
+
results.append({
|
| 291 |
+
"task_content": task.get("task_content", "Unknown task"),
|
| 292 |
+
"success": False,
|
| 293 |
+
"error": error_msg
|
| 294 |
+
})
|
| 295 |
+
return None
|
| 296 |
+
|
| 297 |
+
# Execute tasks in parallel with thread pool
|
| 298 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 299 |
+
futures = [executor.submit(process_task, task) for task in tasks]
|
| 300 |
+
# Wait for all tasks to complete
|
| 301 |
+
for future in futures:
|
| 302 |
+
future.result() # Raise exceptions if any
|
| 303 |
+
|
| 304 |
+
# Check overall success
|
| 305 |
+
all_success = all(task_result.get("success", False) for task_result in results)
|
| 306 |
+
|
| 307 |
+
return {
|
| 308 |
+
"success": all_success,
|
| 309 |
+
"data": {"tasks": results},
|
| 310 |
+
"error": None if all_success else "Some tasks failed",
|
| 311 |
+
"metadata": {
|
| 312 |
+
"tool_name": "assign_multi_subjective_tasks_to_info_seeker",
|
| 313 |
+
"task_count": len(tasks),
|
| 314 |
+
"success_count": sum(1 for r in results if r.get("success")),
|
| 315 |
+
"failure_count": sum(1 for r in results if not r.get("success"))
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
except Exception as e:
|
| 320 |
+
self.logger.error(f"Multi-task assignment failed: {e}")
|
| 321 |
+
return {
|
| 322 |
+
"success": False,
|
| 323 |
+
"error": f"Multi-task assignment failed: {str(e)}"
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
def assign_subjective_task_to_writer(
|
| 327 |
+
self,
|
| 328 |
+
task_content: str,
|
| 329 |
+
user_query: str,
|
| 330 |
+
key_files: List[Dict[str, str]]
|
| 331 |
+
) -> Dict[str, Any]:
|
| 332 |
+
"""
|
| 333 |
+
Assign a writing or content creation task to the WriterAgent
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
task_content: Detailed description of the writing task to be performed
|
| 337 |
+
user_query: List storing previous information seeker subtask summaries intact to preserve information from each completed research task
|
| 338 |
+
key_files: Curated list of relevant files with file_path and desc for each file
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
Dictionary with task assignment results
|
| 342 |
+
"""
|
| 343 |
+
try:
|
| 344 |
+
|
| 345 |
+
self.logger.info("Assigning task to WriterAgent")
|
| 346 |
+
|
| 347 |
+
# Create task input
|
| 348 |
+
task_input = WriterAgentTaskInput(
|
| 349 |
+
task_content=task_content,
|
| 350 |
+
user_query=user_query,
|
| 351 |
+
key_files=key_files,
|
| 352 |
+
workspace_id=self.get_session_info()["session_id"],
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Create writer agent with shared MCP client and sub-agent configuration
|
| 356 |
+
writer_config = getattr(self, 'sub_agent_configs', {}).get('writer', {})
|
| 357 |
+
writer = create_writer_agent(
|
| 358 |
+
shared_mcp_client=self.mcp_tools.client,
|
| 359 |
+
model=writer_config.get('model', self.config.model),
|
| 360 |
+
max_iterations=writer_config.get('max_iterations', 20),
|
| 361 |
+
temperature=writer_config.get('temperature', 0.3),
|
| 362 |
+
max_tokens=writer_config.get('max_tokens', 16384)
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
self.logger.info(f"Assigning task to WriterAgent: {task_content[:800]}...")
|
| 366 |
+
|
| 367 |
+
# Execute the task with shared connection
|
| 368 |
+
response = writer.execute_task(task_input)
|
| 369 |
+
|
| 370 |
+
if response.success:
|
| 371 |
+
return {
|
| 372 |
+
"success": True,
|
| 373 |
+
"data": response.result,
|
| 374 |
+
"agent_name": response.agent_name,
|
| 375 |
+
"iterations": response.iterations,
|
| 376 |
+
"execution_time": response.execution_time,
|
| 377 |
+
# "reasoning_trace": response.reasoning_trace
|
| 378 |
+
}
|
| 379 |
+
else:
|
| 380 |
+
return {
|
| 381 |
+
"success": False,
|
| 382 |
+
"error": response.error,
|
| 383 |
+
"agent_name": response.agent_name
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
except Exception as e:
|
| 387 |
+
self.logger.error(f"Failed to assign task to WriterAgent: {e}")
|
| 388 |
+
return {
|
| 389 |
+
"success": False,
|
| 390 |
+
"error": f"Task assignment failed: {str(e)}"
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
def _build_system_prompt(self) -> str:
|
| 394 |
+
"""Build the system prompt for the planner agent"""
|
| 395 |
+
tool_schemas_str = json.dumps(self.tool_schemas, ensure_ascii=False)
|
| 396 |
+
|
| 397 |
+
auto_system_prompt_template = """# PlannerAgent: Multi-Agent Task Coordinator
|
| 398 |
+
**Role:** Analyze complex queries, first distinguish query type (long-form writing type/objective question type), then create structured plans, and coordinate specialized agents to deliver comprehensive solutions—call corresponding tools based on query type, and only invoke writer for long-form writing type queries.
|
| 399 |
+
|
| 400 |
+
#### Available Sub-Agents:
|
| 401 |
+
- **`information_seeker`**: Research, data gathering, web search (supports single/parallel multi-task; long-form writing type uses assign_multi_subjective_tasks_to_info_seeker, other types use assign_multi_objective_tasks_to_info_seeker)
|
| 402 |
+
- **`writer`**: Only invoke this sub-agent when long-form writing is required.
|
| 403 |
+
|
| 404 |
+
---
|
| 405 |
+
|
| 406 |
+
## Optimized Workflow
|
| 407 |
+
### 1. Query Type Judgment & Analysis & Planning Phase
|
| 408 |
+
**Goal:** Use the `think` tool to analyze the problem and determine whether it is a simple task (refers to tasks that do not require calling the information search agent or tool) or a complex task (requires calling info seeker). If it is a complex task, it is necessary to further analyze whether it is a objective question(do not require calling the writer agent)or a long-form writing question (requires long-form expression and need to call the writer agent later).
|
| 409 |
+
- **Simple Tasks:** For simple tasks that do not require info seeker invocation, you can directly call the `planner_objective_task_done` tool and write the answer in `final_answer` field without creating a todo.md file.
|
| 410 |
+
- **Complex Tasks:**
|
| 411 |
+
- For objective tasks, must use `assign_multi_objective_tasks_to_info_seeker`
|
| 412 |
+
- For long-form writing tasks, must use `assign_multi_subjective_tasks_to_info_seeker`, and call the writer agent to integrate the collected information to generate a very long text
|
| 413 |
+
- **Task Decomposition Rules:**
|
| 414 |
+
- Construct a task tree with a tree-like structure, where the root node represents the user's input query. Each subtask is marked with its depth in the task tree, and the entire task tree is executed from shallow to deep. Tasks at the same depth in the task tree must be independent and can be executed in parallel (via `assign_multi_xxx_tasks_to_info_seeker`) without mutual dependencies.
|
| 415 |
+
- At the first level of the task tree, it is essential to thoroughly design subtasks that can be executed in parallel to explore various potential background information, thereby providing more specific clues for the next step of planning.
|
| 416 |
+
- Competitive Redundancy Mechanism:
|
| 417 |
+
- For key subtasks that have a significant impact on subsequent reasoning and planning, a redundancy mechanism should be established. This involves duplicating the task at the same depth level in the task tree, enabling the parallel execution of nearly identical tasks to enhance the completion rate and robustness of the task execution.
|
| 418 |
+
- **Task Parallel Sending Requirements:**
|
| 419 |
+
- When using `assign_multi_xxx_tasks_to_info_seeker`, all parallel-sent subtasks must be independent of each other; the description of each subtask must not contain any mutual references or dependency requirements for other subtasks.
|
| 420 |
+
- There is no sequential execution relationship among all parallel-sent subtasks.
|
| 421 |
+
|
| 422 |
+
- **Mandatory Documentation:** Create and write `todo.md` (e.g., `todo_v1.md`) with fields:
|
| 423 |
+
```markdown
|
| 424 |
+
# Task Planning Document
|
| 425 |
+
## task_name: [Clear identifier]
|
| 426 |
+
## task_desc: [Detailed requirements - focus on WHAT not HOW]
|
| 427 |
+
## deliverable_contents: [Exact output format specs]
|
| 428 |
+
## success_criteria: [Measurable 100% completion metrics]
|
| 429 |
+
## context: [Background, constraints, prior results]
|
| 430 |
+
## task_steps_for_reference: [Tree-structured preliminary execution plan, tag tasks with the depth in task tree `[DEPTH:xx]`]
|
| 431 |
+
```
|
| 432 |
+
|
| 433 |
+
### 2. Execution & Iteration Phase
|
| 434 |
+
#### A. Unified Iteration Triggers (Shared by Both Types)
|
| 435 |
+
- Based on upper-layer task results, refine the next layer of planning and document it in a new version of `todo.md` (e.g., `todo_v2.md`).
|
| 436 |
+
- If upper-layer tasks fail/encounter challenges: Invoke the `reflect` tool for introspection (no new information acquired, only saves thoughts), adjust the plan, and re-invoke the corresponding `information_seeker` method (objective: `assign_multi_objective_tasks_to_info_seeker`; long-form writing: `assign_multi_subjective_tasks_to_info_seeker`).
|
| 437 |
+
- If current tasks require prior round information: Clearly specify the context of each task and referenced files (e.g., `./data/agent_output_v1.json`) when calling `information_seeker`.
|
| 438 |
+
- Decompose and refine clues from upper-layer results, then execute verification in parallel.
|
| 439 |
+
|
| 440 |
+
#### B. Query-Type-Specific Operations
|
| 441 |
+
- **Objective tasks**: No additional operations (strictly no writer invocation). Continue iterating until information meets `success_criteria`.
|
| 442 |
+
- **Long-form writing tasks**: Add **information sufficiency check before writer invocation**:
|
| 443 |
+
1. Evaluate collected information from two dimensions: quantity (e.g., "Enough case studies for 3 chapters") and comprehensiveness (e.g., "Covers both positive and negative impacts of AI on education").
|
| 444 |
+
2. If information is insufficient: Adjust subtask directions (e.g., "Supplement AI education failure cases") and re-invoke `assign_multi_subjective_tasks_to_info_seeker` for targeted collection.
|
| 445 |
+
3. If information is sufficient: Invoke the writer via `assign_subjective_task_to_writer` (provide all collected materials and `todo.md` as context).
|
| 446 |
+
4. If the writer returns an incomplete result: Do not assist in completing it; only feed back the current completion status to the user.
|
| 447 |
+
|
| 448 |
+
### 3. Completion & Synthesis Phase
|
| 449 |
+
#### A. Unified Validation & Integration (Shared by Both Types)
|
| 450 |
+
- **Validation**: Cross-check multi-source `information_seeker` outputs for consistency (e.g., "NBS and World Bank GDP data differ by ≤1%").
|
| 451 |
+
- **Integration**: Combine parallel outputs into a unified deliverable (e.g., "Merge two GDP data sources into a single table" or "Integrate writer’s report with supplementary case studies").
|
| 452 |
+
- **Delivery**: Output language must match the user’s query language (e.g., Chinese query → Chinese deliverable).
|
| 453 |
+
|
| 454 |
+
#### B. Query-Type-Specific Task Completion (Critical)
|
| 455 |
+
- **Objective tasks**: Call the `planner_objective_task_done` tool **only when** all planned tasks are completed and the final deliverable (e.g., verified data, clear answers) is ready for user delivery.
|
| 456 |
+
- **Long-form writing tasks**: Call the `planner_subjective_task_done` tool **only when** the writer has finished executing and the final long-form content meets the `success_criteria` in `todo.md`.
|
| 457 |
+
|
| 458 |
+
---
|
| 459 |
+
|
| 460 |
+
## Critical Protocols
|
| 461 |
+
1. **Dependency Management:**
|
| 462 |
+
- Prohibit parallel dispatch for sequential dependent tasks unless using competitive redundancy mechanism
|
| 463 |
+
- Convert sequential chains to parallel where possible (e.g., Hypothesis_A vs Hypothesis_B testing)
|
| 464 |
+
2. **File Traceability:**
|
| 465 |
+
- All output references use relative paths (`./data/agent_output_1.json`)
|
| 466 |
+
- Version `todo.md` after each iteration (e.g., `todo_v2.md`)
|
| 467 |
+
3. **Local File Reading Recommendations:**
|
| 468 |
+
- For files crawled natively, it is not recommended to directly use the `file_read` tool to read the entire content (maybe too long). Instead, the `document_qa` tool should be used to extract and verify the required information.
|
| 469 |
+
- For task deliverables and summary documents from sub-agents, the `file_read` tool can be used to read them.
|
| 470 |
+
4. The final deliverable presented to the user should be consistent with the language used in the user's question.
|
| 471 |
+
5. **Writer invocation**: Strictly prohibit calling the writer for objective tasks; for long-form writing tasks, **never directly answer based on collected information**—must invoke the writer to generate the final long-form content.
|
| 472 |
+
|
| 473 |
+
Below, within the <tools></tools> tags, are the descriptions of each tool and the required fields for invocation:
|
| 474 |
+
<tools>
|
| 475 |
+
$tool_schemas
|
| 476 |
+
</tools>
|
| 477 |
+
For each function call, return a JSON object placed within the [unused11][unused12] tags, which includes the function name and the corresponding function arguments:
|
| 478 |
+
[unused11][{\"name\": <function name>, \"arguments\": <args json object>}][unused12]"""
|
| 479 |
+
|
| 480 |
+
writing_system_prompt_template = """### PlannerAgent: Multi-Agent Task Coordinator
|
| 481 |
+
**Role:** Analyze complex queries, create structured plans, and coordinate specialized agents to deliver comprehensive solutions.
|
| 482 |
+
|
| 483 |
+
#### Available Sub-Agents:
|
| 484 |
+
- **`information_seeker`**: Research, data gathering, web search (supports single/parallel multi-task)
|
| 485 |
+
- **`writer`**: Creates content (e.g., reports, analysis, etc.), and synthesizes from existing materials
|
| 486 |
+
|
| 487 |
+
---
|
| 488 |
+
|
| 489 |
+
### Optimized Workflow
|
| 490 |
+
#### 1. Analysis & Planning Phase
|
| 491 |
+
**Goal:** Analyze the problem and determine whether it is a simple task or a complex task. If it is a complex task, it is necessary to further analyze whether it is a subject-driven question or an objective-driven question, so as to decompose the problem into multiple clear and executable subtasks according to the specific problem type. The main characteristic of objective-driven questions is that their answers are clear and verifiable entities, otherwise they are subject-driven questions.
|
| 492 |
+
- **Simple Tasks:** For simple tasks that do not require sub-agent invocation, you can directly answer without creating a todo.md file
|
| 493 |
+
- **Complex Tasks:**
|
| 494 |
+
- For Objective-driven tasks, Adopt *diverge-converge* strategy:
|
| 495 |
+
1. Use `assign_multi_subjective_tasks_to_info_seeker` call for divergent background research
|
| 496 |
+
2. Converge findings to define specific sub-problems
|
| 497 |
+
- For Subject-driven tasks, Adopt *multi-perspective* strategy:
|
| 498 |
+
1. Use assign_multi_subjective_tasks_to_info_seeker call for divergent multi-source exploration (each task targets independent dimensions)
|
| 499 |
+
2. Converge findings to define focused sub-problems addressing distinct knowledge gaps
|
| 500 |
+
3. When the information seeker collects information, start to call the writer agent to integrate the collected information to generate a very long text
|
| 501 |
+
- **Task Decomposition Rules:**
|
| 502 |
+
- Construct a task tree with a tree-like structure, where the root node represents the user's input query. Each subtask is marked with its depth in the task tree, and the entire task tree is executed from shallow to deep. Tasks at the same depth in the task tree must be independent and can be executed in parallel (via `assign_multi_subjective_tasks_to_info_seeker`) without mutual dependencies.
|
| 503 |
+
- At the first level of the task tree, it is essential to thoroughly design subtasks that can be executed in parallel to explore various potential background information, thereby providing more specific clues for the next step of planning.
|
| 504 |
+
- Competitive Redundancy Mechanism:
|
| 505 |
+
- For key subtasks that have a significant impact on subsequent reasoning and planning, a redundancy mechanism should be established. This involves duplicating the task at the same depth level in the task tree, enabling the parallel execution of nearly identical tasks to enhance the completion rate and robustness of the task execution.
|
| 506 |
+
- **Task Parallel Sending Requirements:**
|
| 507 |
+
- When using `assign_multi_subjective_tasks_to_info_seeker`, all parallel-sent subtasks must be independent of each other; the description of each subtask must not contain any mutual references or dependency requirements for other subtasks.
|
| 508 |
+
- There is no sequential execution relationship among all parallel-sent subtasks.
|
| 509 |
+
|
| 510 |
+
- **Mandatory Documentation:** Create and write `todo.md` (e.g., `todo_v1.md`) with fields:
|
| 511 |
+
```markdown
|
| 512 |
+
# Task Planning Document
|
| 513 |
+
## task_name: [Clear identifier]
|
| 514 |
+
## task_desc: [Detailed requirements - focus on WHAT not HOW]
|
| 515 |
+
## deliverable_contents: [Exact output format specs]
|
| 516 |
+
## success_criteria: [Measurable 100% completion metrics]
|
| 517 |
+
## context: [Background, constraints, prior results]
|
| 518 |
+
## task_steps_for_reference: [Tree-structured preliminary execution plan, tag tasks with the depth in task tree `[DEPTH:xx]`]
|
| 519 |
+
```
|
| 520 |
+
|
| 521 |
+
#### 2. Execution & Iteration Phase
|
| 522 |
+
- **Iteration Triggers:**
|
| 523 |
+
- Based on the execution results of the upper layer of the task tree, specify and refine the next layer and subsequent task planning, and document them in a new `todo.md` file (e.g., `todo_v2.md`).
|
| 524 |
+
- If there are tasks in the previous layer that have failed or encountered challenges, it is necessary to invoke `reflect` for introspection, consider more possibilities, and make new task planning and invoke `assign_multi_subjective_tasks_to_info_seeker` again.
|
| 525 |
+
- If the tasks sent in the current round require reference to task information from previous rounds, it is essential to clearly specify the context of each task and the files that may need to be used or referenced when calling `assign_multi_subjective_tasks_to_info_seeker`.
|
| 526 |
+
- For the multiple clues of the execution results from the previous layer, they should be decomposed and refined, and executed in parallel for verification.
|
| 527 |
+
- **Information check required before calling writer:**
|
| 528 |
+
- Before invoking writer, analyze collected information for sufficiency: evaluate both quantity and comprehensiveness to ensure adequate material for long article generation
|
| 529 |
+
- If information is insufficient, adjust subtask direction and initiate additional targeted information collection
|
| 530 |
+
- **When information is sufficient, invoke writer agent** via `assign_subjective_task_to_writer`
|
| 531 |
+
|
| 532 |
+
#### 3. Completion & Synthesis Phase
|
| 533 |
+
- **Validation:** Cross-check multi-source outputs for consistency, and Check whether the information source is sufficient
|
| 534 |
+
- **Integration:** Combine parallel outputs into unified deliverable
|
| 535 |
+
- **Delivery:** Output language must match user's query language
|
| 536 |
+
- When the writer agent is finished executing, planner_subjective_task_done tool needs to be called to end the current task
|
| 537 |
+
|
| 538 |
+
---
|
| 539 |
+
|
| 540 |
+
### Critical Protocols
|
| 541 |
+
1. **Dependency Management:**
|
| 542 |
+
- Prohibit parallel dispatch for sequential dependent tasks unless using competitive redundancy mechanism
|
| 543 |
+
- Convert sequential chains to parallel where possible (e.g., Hypothesis_A vs Hypothesis_B testing)
|
| 544 |
+
2. **File Traceability:**
|
| 545 |
+
- All output references use relative paths (`./data/agent_output_1.json`)
|
| 546 |
+
- Version `todo.md` after each iteration (e.g., `todo_v2.md`)
|
| 547 |
+
3. **Iteration Discipline:**
|
| 548 |
+
- Minimum 2 parallel agents for critical hypothesis-validation tasks
|
| 549 |
+
- Terminate only when ALL success criteria are met at 100%
|
| 550 |
+
5. **Usage of Think Tool:**
|
| 551 |
+
- `think` is a systematic tool. After receiving the response from the complex tool or before invoking any other tools, you must **first invoke the `think` tool**: to deeply reflect on the results of previous tool invocations (if any), and to thoroughly consider and plan the user's task. The `think` tool does not acquire new information; it only saves your thoughts into memory.
|
| 552 |
+
6. **Usage of Reflect Tool:**
|
| 553 |
+
`reflect` is a systematic tool. When encountering a failure in tool execution, it is necessary to invoke the reflect tool to conduct a review and revise the task plan. It does not acquire new information; it only saves your thoughts into memory.
|
| 554 |
+
7. Always prioritize complete solutions over partial delivery. Use parallel redundancy for critical path tasks, and convert agent disagreements into new parallel investigation branches.
|
| 555 |
+
8. **CRITICAL:** When you determine that the information_seeker has gathered sufficient information, you must invoke the writer agent to draft the final article in response to the user's query. You are not allowed to reply directly based on the collected information!
|
| 556 |
+
9.Also note that when the writing agent returns a result that shows it is not completed, you do not need to help it complete it further. You only need to feedback the current completion status to the user.
|
| 557 |
+
|
| 558 |
+
Below, within the <tools></tools> tags, are the descriptions of each tool and the required fields for invocation:
|
| 559 |
+
<tools>
|
| 560 |
+
$tool_schemas
|
| 561 |
+
</tools>
|
| 562 |
+
For each function call, return a JSON object placed within the [unused11][unused12] tags, which includes the function name and the corresponding function arguments:
|
| 563 |
+
[unused11][{\"name\": <function name>, \"arguments\": <args json object>}][unused12]"""
|
| 564 |
+
|
| 565 |
+
qa_system_prompt_template = """### PlannerAgent: Multi-Agent Task Coordinator
|
| 566 |
+
**Role:** Analyze complex queries, create structured plans, and coordinate specialized agents to deliver comprehensive solutions.
|
| 567 |
+
|
| 568 |
+
#### Available Sub-Agents:
|
| 569 |
+
- **`information_seeker`**: Research, data gathering, web search (supports single/parallel multi-task)
|
| 570 |
+
|
| 571 |
+
---
|
| 572 |
+
|
| 573 |
+
### Optimized Workflow
|
| 574 |
+
#### 1. Analysis & Planning Phase
|
| 575 |
+
**Goal:** Decompose problems into executable units with clear dependencies
|
| 576 |
+
- **Simple Tasks:** For simple tasks that do not require sub-agent invocation, you can directly answer and call `planner_objective_task_done` without creating a todo.md file
|
| 577 |
+
- **Complex Tasks:**
|
| 578 |
+
- **Task Decomposition Rules:**
|
| 579 |
+
- Construct a task tree with a tree-like structure, where the root node represents the user\'s input query. Each subtask is marked with its depth in the task tree, and the entire task tree is executed from shallow to deep. Tasks at the same depth in the task tree must be independent and can be executed in parallel (via `assign_multi_objective_tasks_to_info_seeker`) without mutual dependencies.
|
| 580 |
+
- At the first level of the task tree, it is essential to thoroughly design subtasks that can be executed in parallel to explore various potential background information, thereby providing more specific clues for the next step of planning.
|
| 581 |
+
- Competitive Redundancy Mechanism:
|
| 582 |
+
- For key subtasks that have a significant impact on subsequent reasoning and planning, a redundancy mechanism should be established. This involves duplicating the task at the same depth level in the task tree, enabling the parallel execution of nearly identical tasks to enhance the completion rate and robustness of the task execution.
|
| 583 |
+
- **Task Parallel Sending Requirements:**
|
| 584 |
+
- When using `assign_multi_objective_tasks_to_info_seeker`, all parallel-sent subtasks must be independent of each other; the description of each subtask must not contain any mutual references or dependency requirements for other subtasks.
|
| 585 |
+
- There is no sequential execution relationship among all parallel-sent subtasks.
|
| 586 |
+
|
| 587 |
+
- **Mandatory Documentation:** Create and write `todo.md` (e.g., `todo_v1.md`) with fields:
|
| 588 |
+
```markdown
|
| 589 |
+
# Task Planning Document
|
| 590 |
+
## task_name: [Clear identifier]
|
| 591 |
+
## task_desc: [Detailed requirements - focus on WHAT not HOW]
|
| 592 |
+
## deliverable_contents: [Exact output format specs]
|
| 593 |
+
## success_criteria: [Measurable 100% completion metrics]
|
| 594 |
+
## context: [Background, constraints, prior results]
|
| 595 |
+
## task_steps_for_reference: [Tree-structured preliminary execution plan, tag tasks with the depth in task tree `[DEPTH:xx]`]
|
| 596 |
+
```
|
| 597 |
+
|
| 598 |
+
#### 2. Execution & Iteration Phase
|
| 599 |
+
- **Iteration Triggers:**
|
| 600 |
+
- Based on the execution results of the upper layer of the task tree, specify and refine the next layer and subsequent task planning, and document them in a new `todo.md` file (e.g., `todo_v2.md`).
|
| 601 |
+
- If there are tasks in the previous layer that have failed or encountered challenges, it is necessary to invoke `reflect` for introspection, consider more possibilities, and make new task planning and invoke `assign_multi_objective_tasks_to_info_seeker` again.
|
| 602 |
+
- If the tasks sent in the current round require reference to task information from previous rounds, it is essential to clearly specify the context of each task and the files that may need to be used or referenced when calling `assign_multi_objective_tasks_to_info_seeker`.
|
| 603 |
+
- For the multiple clues of the execution results from the previous layer, they should be decomposed and refined, and executed in parallel for verification.
|
| 604 |
+
|
| 605 |
+
#### 3. Completion & Synthesis Phase
|
| 606 |
+
- **Validation:** Cross-check multi-source outputs for consistency
|
| 607 |
+
- **Integration:** Combine parallel outputs into unified deliverable
|
| 608 |
+
- **Delivery:** Output language must match user\'s query language
|
| 609 |
+
- **Task Completed:** The `planner_objective_task_done` can only be called when all planned tasks have been completed and the final results are ready to be delivered to the user.
|
| 610 |
+
|
| 611 |
+
---
|
| 612 |
+
|
| 613 |
+
### Critical Protocols
|
| 614 |
+
1. **Dependency Management:**
|
| 615 |
+
- Prohibit parallel dispatch for sequential dependent tasks unless using competitive redundancy mechanism
|
| 616 |
+
- Convert sequential chains to parallel where possible (e.g., Hypothesis_A vs Hypothesis_B testing)
|
| 617 |
+
2. **File Traceability:**
|
| 618 |
+
- All output references use relative paths (`./data/agent_output_1.json`)
|
| 619 |
+
- Version `todo.md` after each iteration (e.g., `todo_v2.md`)
|
| 620 |
+
3. **Local File Reading Recommendations:**
|
| 621 |
+
- For files crawled natively, it is not recommended to directly use the `file_read` tool to read the entire content (maybe too long). Instead, the `document_qa` tool should be used to extract and verify the required information.
|
| 622 |
+
- For task deliverables and summary documents from sub-agents, the `file_read` tool can be used to read them.
|
| 623 |
+
4. The final deliverable presented to the user should be consistent with the language used in the user\'s question.
|
| 624 |
+
|
| 625 |
+
Below, within the <tools></tools> tags, are the descriptions of each tool and the required fields for invocation:
|
| 626 |
+
<tools>
|
| 627 |
+
$tool_schemas
|
| 628 |
+
</tools>
|
| 629 |
+
For each function call, return a JSON object placed within the [unused11][unused12] tags, which includes the function name and the corresponding function arguments:
|
| 630 |
+
[unused11][{\"name\": <function name>, \"arguments\": <args json object>}][unused12]"""
|
| 631 |
+
|
| 632 |
+
planner_mode_system_prompt_map = {
|
| 633 |
+
"auto": auto_system_prompt_template,
|
| 634 |
+
"writing": writing_system_prompt_template,
|
| 635 |
+
"qa": qa_system_prompt_template
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
system_prompt = planner_mode_system_prompt_map[self.config.planner_mode].replace("$tool_schemas", tool_schemas_str)
|
| 639 |
+
|
| 640 |
+
return system_prompt
|
| 641 |
+
|
| 642 |
+
def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
|
| 643 |
+
"""
|
| 644 |
+
Build tool schemas for PlannerAgent using proper MCP architecture.
|
| 645 |
+
Schemas come from MCP server via client, not direct imports.
|
| 646 |
+
"""
|
| 647 |
+
|
| 648 |
+
# Get MCP tool schemas from server via client (proper MCP architecture)
|
| 649 |
+
schemas = super()._build_agent_specific_tool_schemas()
|
| 650 |
+
|
| 651 |
+
# Add schemas for built-in task assignment tools
|
| 652 |
+
planner_mode_builtin_tools_map = {
|
| 653 |
+
"auto": ["think", "reflect", "assign_multi_subjective_tasks_to_info_seeker", "assign_multi_objective_tasks_to_info_seeker", "assign_subjective_task_to_writer", "writer_subjective_task_done", "planner_subjective_task_done", "planner_objective_task_done"],
|
| 654 |
+
"writing": ["think", "reflect", "assign_multi_subjective_tasks_to_info_seeker", "assign_subjective_task_to_writer", "writer_subjective_task_done", "planner_subjective_task_done"],
|
| 655 |
+
"qa": ["think", "reflect", "assign_multi_objective_tasks_to_info_seeker", "planner_objective_task_done"],
|
| 656 |
+
}
|
| 657 |
+
builtin_assignment_schemas = [
|
| 658 |
+
{
|
| 659 |
+
"type": "function",
|
| 660 |
+
"function": {
|
| 661 |
+
"name": "think",
|
| 662 |
+
"description": "Use the tool to think about something. It will not obtain new information or make any changes to the repository, but just log the thought. Use it when complex reasoning or brainstorming is needed.",
|
| 663 |
+
"parameters": {
|
| 664 |
+
"type": "object",
|
| 665 |
+
"properties": {
|
| 666 |
+
"thought": {
|
| 667 |
+
"type": "string",
|
| 668 |
+
"description": "Your thoughts."
|
| 669 |
+
}
|
| 670 |
+
},
|
| 671 |
+
"required": ["thought"]
|
| 672 |
+
}
|
| 673 |
+
}
|
| 674 |
+
},
|
| 675 |
+
{
|
| 676 |
+
"type": "function",
|
| 677 |
+
"function": {
|
| 678 |
+
"name": "reflect",
|
| 679 |
+
"description": "When multiple attempts yield no progress, use this tool to reflect on previous reasoning and planning, considering possible overlooked clues and exploring more possibilities. It will not obtain new information or make any changes to the repository.",
|
| 680 |
+
"parameters": {
|
| 681 |
+
"type": "object",
|
| 682 |
+
"properties": {
|
| 683 |
+
"reflect": {
|
| 684 |
+
"type": "string",
|
| 685 |
+
"description": "The specific content of your reflection"
|
| 686 |
+
}
|
| 687 |
+
},
|
| 688 |
+
"required": ["reflect"]
|
| 689 |
+
}
|
| 690 |
+
}
|
| 691 |
+
},
|
| 692 |
+
{
|
| 693 |
+
"type": "function",
|
| 694 |
+
"function": {
|
| 695 |
+
"name": "assign_multi_subjective_tasks_to_info_seeker",
|
| 696 |
+
"description": "Assign 1~6 research or information gathering tasks to different InformationSeekerAgents for parallel execution, each task descriptions must be semantically complete and clearly provide contextual information and potentially important reference documents.",
|
| 697 |
+
"parameters": {
|
| 698 |
+
"type": "object",
|
| 699 |
+
"properties": {
|
| 700 |
+
"tasks": {
|
| 701 |
+
"type": "array",
|
| 702 |
+
"description": "List of tasks to be assigned to multiple InformationSeekerAgents",
|
| 703 |
+
"items": {
|
| 704 |
+
"type": "object",
|
| 705 |
+
"properties": {
|
| 706 |
+
"task_content": {
|
| 707 |
+
"type": "string",
|
| 708 |
+
"description": "Detailed description of the task to be performed"
|
| 709 |
+
},
|
| 710 |
+
"task_steps_for_reference": {
|
| 711 |
+
"type": "string",
|
| 712 |
+
"description": "Optional reference steps for task execution"
|
| 713 |
+
},
|
| 714 |
+
"deliverable_contents": {
|
| 715 |
+
"type": "string",
|
| 716 |
+
"description": "Expected format and content of deliverables"
|
| 717 |
+
},
|
| 718 |
+
"current_task_status": {
|
| 719 |
+
"type": "string",
|
| 720 |
+
"description": "Current status and context of the task, important documents that may be used and referenced"
|
| 721 |
+
},
|
| 722 |
+
"acceptance_checking_criteria": {
|
| 723 |
+
"type": "string",
|
| 724 |
+
"description": "Criteria for determining task completion and quality"
|
| 725 |
+
},
|
| 726 |
+
},
|
| 727 |
+
"required": ["task_content"]
|
| 728 |
+
}
|
| 729 |
+
}
|
| 730 |
+
},
|
| 731 |
+
"required": ["tasks"]
|
| 732 |
+
}
|
| 733 |
+
}
|
| 734 |
+
},
|
| 735 |
+
{
|
| 736 |
+
"type": "function",
|
| 737 |
+
"function": {
|
| 738 |
+
"name": "assign_multi_objective_tasks_to_info_seeker",
|
| 739 |
+
"description": "Assign 1~5 research or information gathering tasks to different InformationSeekerAgents for parallel execution, each task descriptions must be semantically complete and clearly provide contextual information and potentially important reference documents.",
|
| 740 |
+
"parameters": {
|
| 741 |
+
"type": "object",
|
| 742 |
+
"properties": {
|
| 743 |
+
"tasks": {
|
| 744 |
+
"type": "array",
|
| 745 |
+
"description": "List of tasks to be assigned to multiple InformationSeekerAgents",
|
| 746 |
+
"items": {
|
| 747 |
+
"type": "object",
|
| 748 |
+
"properties": {
|
| 749 |
+
"task_content": {
|
| 750 |
+
"type": "string",
|
| 751 |
+
"description": "Detailed description of the task to be performed, the task description must be semantically complete"
|
| 752 |
+
},
|
| 753 |
+
"task_steps_for_reference": {
|
| 754 |
+
"type": "string",
|
| 755 |
+
"description": "Optional reference steps for task execution"
|
| 756 |
+
},
|
| 757 |
+
"deliverable_contents": {
|
| 758 |
+
"type": "string",
|
| 759 |
+
"description": "Expected format and content of deliverables"
|
| 760 |
+
},
|
| 761 |
+
"current_task_status": {
|
| 762 |
+
"type": "string",
|
| 763 |
+
"description": "Current status and context of the task, important documents that may be used and referenced"
|
| 764 |
+
},
|
| 765 |
+
"acceptance_checking_criteria": {
|
| 766 |
+
"type": "string",
|
| 767 |
+
"description": "Criteria for determining task completion and quality, and the requirements in the event of task completion failure"
|
| 768 |
+
},
|
| 769 |
+
},
|
| 770 |
+
"required": ["task_content"]
|
| 771 |
+
}
|
| 772 |
+
}
|
| 773 |
+
},
|
| 774 |
+
"required": ["tasks"]
|
| 775 |
+
}
|
| 776 |
+
}
|
| 777 |
+
},
|
| 778 |
+
{
|
| 779 |
+
"type": "function",
|
| 780 |
+
"function": {
|
| 781 |
+
"name": "assign_subjective_task_to_writer",
|
| 782 |
+
"description": "Assign a writing or content creation task to the WriterAgent",
|
| 783 |
+
"parameters": {
|
| 784 |
+
"type": "object",
|
| 785 |
+
"properties": {
|
| 786 |
+
"user_query": {
|
| 787 |
+
"type": "string",
|
| 788 |
+
"description": "Pass in the original user question."
|
| 789 |
+
},
|
| 790 |
+
"task_content": {
|
| 791 |
+
"type": "string",
|
| 792 |
+
"description": "Integrate and synthesize provided materials to generate comprehensive long-form content exceeding 10,000 words, especially careful not to give specific details, such as an outline plan, you are only providing the writer with a general description of the task."
|
| 793 |
+
},
|
| 794 |
+
"key_files": {
|
| 795 |
+
"type": "array",
|
| 796 |
+
"items": {
|
| 797 |
+
"type": "object",
|
| 798 |
+
"properties": {
|
| 799 |
+
"file_path": {
|
| 800 |
+
"type": "string",
|
| 801 |
+
"description": "Relative path to the file containing research content"
|
| 802 |
+
}
|
| 803 |
+
},
|
| 804 |
+
"required": ["file_path"]
|
| 805 |
+
},
|
| 806 |
+
"description": "Collect all key_files returned by the information seeker for long-form content creation."
|
| 807 |
+
}
|
| 808 |
+
},
|
| 809 |
+
"required": ["user_query", "task_content", "key_files"]
|
| 810 |
+
}
|
| 811 |
+
}
|
| 812 |
+
},
|
| 813 |
+
{
|
| 814 |
+
"type": "function",
|
| 815 |
+
"function": {
|
| 816 |
+
"name": "writer_subjective_task_done",
|
| 817 |
+
"description": "Writer Agent task completion reporting for complete long-form content. Called after all chapters/sections are written to provide a summary of the complete long article, final completion status and analysis, and the storage path of the final consolidated article.",
|
| 818 |
+
"parameters": {
|
| 819 |
+
"type": "object",
|
| 820 |
+
"properties": {
|
| 821 |
+
"final_article_path": {
|
| 822 |
+
"type": "string",
|
| 823 |
+
"description": "The file path where the final article is saved."
|
| 824 |
+
},
|
| 825 |
+
"article_summary": {
|
| 826 |
+
"type": "string",
|
| 827 |
+
"description": "Comprehensive summary of the complete long-form article, including main themes, key points covered, and overall narrative structure.",
|
| 828 |
+
"format": "markdown"
|
| 829 |
+
},
|
| 830 |
+
"completion_status": {
|
| 831 |
+
"type": "string",
|
| 832 |
+
"enum": ["completed", "partial", "failed"],
|
| 833 |
+
"description": "Final status of the complete long-form writing task"
|
| 834 |
+
},
|
| 835 |
+
"completion_analysis": {
|
| 836 |
+
"type": "string",
|
| 837 |
+
"description": "Analysis of the overall writing project completion including: assessment of article coherence and quality, evaluation of content organization and flow, identification of any challenges in the writing process, and overall evaluation of the long-form content creation success."
|
| 838 |
+
}
|
| 839 |
+
},
|
| 840 |
+
"required": ["final_article_path", "article_summary", "completion_status", "completion_analysis"]
|
| 841 |
+
}
|
| 842 |
+
}
|
| 843 |
+
},
|
| 844 |
+
{
|
| 845 |
+
"type": "function",
|
| 846 |
+
"function": {
|
| 847 |
+
"name": "planner_subjective_task_done",
|
| 848 |
+
"description": "When the writer agent is executed, the task done tool is called to end the planner's task.",
|
| 849 |
+
"parameters": {
|
| 850 |
+
"type": "object",
|
| 851 |
+
"properties": {
|
| 852 |
+
"final_article_path": {
|
| 853 |
+
"type": "string",
|
| 854 |
+
"description": "The file path where the final article is saved."
|
| 855 |
+
},
|
| 856 |
+
"task_summary": {
|
| 857 |
+
"type": "string",
|
| 858 |
+
"description": "This field is mainly used to describe the main content of the article, briefly summarize it, and finally indicate the path where the final article is saved.",
|
| 859 |
+
"format": "markdown"
|
| 860 |
+
},
|
| 861 |
+
"task_name": {
|
| 862 |
+
"type": "string",
|
| 863 |
+
"description": "The name of the task currently assigned to the agent, usually with underscores (e.g., 'web_research_ai_trends')"
|
| 864 |
+
},
|
| 865 |
+
"completion_status": {
|
| 866 |
+
"type": "string",
|
| 867 |
+
"enum": ["completed", "partial", "failed"],
|
| 868 |
+
"description": "Final task status"
|
| 869 |
+
}
|
| 870 |
+
},
|
| 871 |
+
"required": ["final_article_path", "task_summary", "task_name", "completion_status"]
|
| 872 |
+
}
|
| 873 |
+
}
|
| 874 |
+
},
|
| 875 |
+
{
|
| 876 |
+
"type": "function",
|
| 877 |
+
"function": {
|
| 878 |
+
"name": "planner_objective_task_done",
|
| 879 |
+
"description": "Structured reporting of task completion details including summary, decisions, and final answer",
|
| 880 |
+
"parameters": {
|
| 881 |
+
"type": "object",
|
| 882 |
+
"properties": {
|
| 883 |
+
"task_summary": {
|
| 884 |
+
"type": "string",
|
| 885 |
+
"description": "Comprehensive markdown covering what the agent was asked to do, steps taken, tools used, key findings, files created, challenges",
|
| 886 |
+
"format": "markdown"
|
| 887 |
+
},
|
| 888 |
+
"task_name": {
|
| 889 |
+
"type": "string",
|
| 890 |
+
"description": "The name of the task currently assigned to the agent, usually with underscores (e.g., 'web_research_ai_trends')"
|
| 891 |
+
},
|
| 892 |
+
"key_files": {
|
| 893 |
+
"type": "array",
|
| 894 |
+
"items": {
|
| 895 |
+
"type": "object",
|
| 896 |
+
"properties": {
|
| 897 |
+
"file_path": {
|
| 898 |
+
"type": "string",
|
| 899 |
+
"description": "Relative path to created/modified file"
|
| 900 |
+
},
|
| 901 |
+
"desc": {
|
| 902 |
+
"type": "string",
|
| 903 |
+
"description": "File contents and creation purpose"
|
| 904 |
+
},
|
| 905 |
+
"is_final_output_file": {
|
| 906 |
+
"type": "boolean",
|
| 907 |
+
"description": "Whether file is primary deliverable"
|
| 908 |
+
}
|
| 909 |
+
},
|
| 910 |
+
"required": ["file_path", "desc", "is_final_output_file"]
|
| 911 |
+
},
|
| 912 |
+
"description": "List of key files generated or modified during the task, with their details."
|
| 913 |
+
},
|
| 914 |
+
"completion_status": {
|
| 915 |
+
"type": "string",
|
| 916 |
+
"enum": ["completed", "partial", "failed"],
|
| 917 |
+
"description": "Final task status"
|
| 918 |
+
},
|
| 919 |
+
"final_answer": {
|
| 920 |
+
"type": "string",
|
| 921 |
+
"description": "The final response displayed to the user",
|
| 922 |
+
}
|
| 923 |
+
},
|
| 924 |
+
"required": ["task_summary", "task_name", "key_files", "completion_status", "final_answer"]
|
| 925 |
+
}
|
| 926 |
+
}
|
| 927 |
+
},
|
| 928 |
+
]
|
| 929 |
+
|
| 930 |
+
used_builtin_schemas = [schema for schema in builtin_assignment_schemas if schema["function"]["name"] in planner_mode_builtin_tools_map[self.config.planner_mode]]
|
| 931 |
+
schemas.extend(used_builtin_schemas)
|
| 932 |
+
|
| 933 |
+
return schemas
|
| 934 |
+
|
| 935 |
+
def _execute_react_loop(self, initial_message: str, max_iterations: int = 20) -> Dict[str, Any]:
|
| 936 |
+
"""
|
| 937 |
+
Execute the ReAct loop for planning tasks
|
| 938 |
+
|
| 939 |
+
Args:
|
| 940 |
+
initial_message: Initial message to start the planning process
|
| 941 |
+
max_iterations: Maximum number of iterations to perform
|
| 942 |
+
|
| 943 |
+
Returns:
|
| 944 |
+
Dictionary with execution results and trace
|
| 945 |
+
"""
|
| 946 |
+
start_time = time.time()
|
| 947 |
+
try:
|
| 948 |
+
# Reset trace for new task
|
| 949 |
+
self.reset_trace()
|
| 950 |
+
# Initialize conversation history
|
| 951 |
+
conversation_history = []
|
| 952 |
+
|
| 953 |
+
# Build system prompt for planning
|
| 954 |
+
system_prompt = self._build_system_prompt()
|
| 955 |
+
# Add to conversation
|
| 956 |
+
conversation_history.append({"role": "system", "content": system_prompt})
|
| 957 |
+
conversation_history.append({"role": "user", "content": initial_message + " /no_think"})
|
| 958 |
+
|
| 959 |
+
iteration = 0
|
| 960 |
+
task_completed = False
|
| 961 |
+
|
| 962 |
+
# Get model endpoint configuration from env-backed config
|
| 963 |
+
from config.config import get_config
|
| 964 |
+
config = get_config()
|
| 965 |
+
model_config = config.get_custom_llm_config()
|
| 966 |
+
|
| 967 |
+
pangu_url = model_config.get('url') or os.getenv('MODEL_REQUEST_URL', '')
|
| 968 |
+
model_token = model_config.get('token') or os.getenv('MODEL_REQUEST_TOKEN', '')
|
| 969 |
+
headers = {'Content-Type': 'application/json', 'csb-token': model_token}
|
| 970 |
+
# ReAct Loop: Reasoning -> Acting -> Reasoning -> Acting...
|
| 971 |
+
while iteration < self.config.max_iterations and not task_completed:
|
| 972 |
+
iteration += 1
|
| 973 |
+
self.logger.info(f"Planning iteration {iteration}")
|
| 974 |
+
|
| 975 |
+
try:
|
| 976 |
+
# Get LLM response (reasoning + potential tool calls)
|
| 977 |
+
retry_num = 1
|
| 978 |
+
max_retry_num = 10
|
| 979 |
+
while retry_num < max_retry_num:
|
| 980 |
+
try:
|
| 981 |
+
response = requests.post(
|
| 982 |
+
url=pangu_url,
|
| 983 |
+
headers=headers,
|
| 984 |
+
json={
|
| 985 |
+
"model": self.config.model,
|
| 986 |
+
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<s>[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{'<s>[unused9]系统:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}",
|
| 987 |
+
"spaces_between_special_tokens": False,
|
| 988 |
+
"messages": conversation_history,
|
| 989 |
+
"temperature": self.config.temperature,
|
| 990 |
+
"max_tokens": self.config.max_tokens,
|
| 991 |
+
},
|
| 992 |
+
timeout=model_config.get("timeout", 180)
|
| 993 |
+
)
|
| 994 |
+
response = response.json()
|
| 995 |
+
self.logger.debug(f"API response received")
|
| 996 |
+
break
|
| 997 |
+
except Exception as e:
|
| 998 |
+
time.sleep(3)
|
| 999 |
+
retry_num += 1
|
| 1000 |
+
if retry_num == max_retry_num:
|
| 1001 |
+
raise ValueError(str(e))
|
| 1002 |
+
continue
|
| 1003 |
+
assistant_message = response["choices"][0]["message"]
|
| 1004 |
+
|
| 1005 |
+
# Log the reasoning
|
| 1006 |
+
try:
|
| 1007 |
+
if assistant_message["content"]:
|
| 1008 |
+
reasoning_content = assistant_message["content"].split("[unused16]")[-1].split("[unused17]")[0]
|
| 1009 |
+
if len(reasoning_content) > 0:
|
| 1010 |
+
self.log_reasoning(iteration, reasoning_content)
|
| 1011 |
+
except Exception as e:
|
| 1012 |
+
self.logger.warning(f"Tool call parsing error: {e}")
|
| 1013 |
+
# Parse error, rerun
|
| 1014 |
+
followup_prompt = f"There is a problem with the format of model generation: {e}. Please try again."
|
| 1015 |
+
conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
|
| 1016 |
+
continue
|
| 1017 |
+
|
| 1018 |
+
def extract_tool_calls(content):
|
| 1019 |
+
import re
|
| 1020 |
+
if not content:
|
| 1021 |
+
return []
|
| 1022 |
+
tool_call_str = re.findall(r"\[unused11\]([\s\S]*?)\[unused12\]", content)
|
| 1023 |
+
if len(tool_call_str) > 0:
|
| 1024 |
+
try:
|
| 1025 |
+
tool_calls = json.loads(tool_call_str[0].strip())
|
| 1026 |
+
except:
|
| 1027 |
+
return []
|
| 1028 |
+
else:
|
| 1029 |
+
return []
|
| 1030 |
+
return tool_calls
|
| 1031 |
+
|
| 1032 |
+
# Add assistant message to conversation
|
| 1033 |
+
conversation_history.append({
|
| 1034 |
+
"role": "assistant",
|
| 1035 |
+
"content": assistant_message["content"]
|
| 1036 |
+
})
|
| 1037 |
+
|
| 1038 |
+
tool_calls = extract_tool_calls(assistant_message["content"])
|
| 1039 |
+
|
| 1040 |
+
# Execute tool calls if any (Acting phase)
|
| 1041 |
+
|
| 1042 |
+
for tool_call in tool_calls:
|
| 1043 |
+
arguments = tool_call["arguments"]
|
| 1044 |
+
self.logger.debug(f"Arguments is string: {isinstance(arguments, str)}")
|
| 1045 |
+
|
| 1046 |
+
# Check if planning is complete
|
| 1047 |
+
if tool_call["name"] in ["planner_subjective_task_done", "planner_objective_task_done", "writer_subjective_task_done"]:
|
| 1048 |
+
task_completed = True
|
| 1049 |
+
self.log_action(iteration, tool_call["name"], arguments, arguments)
|
| 1050 |
+
break
|
| 1051 |
+
if tool_call["name"] in ["think", "reflect"]:
|
| 1052 |
+
tool_result = {"tool_results": "You can proceed to invoke other tools if needed. "}
|
| 1053 |
+
else:
|
| 1054 |
+
tool_result = self.execute_tool_call(tool_call)
|
| 1055 |
+
|
| 1056 |
+
# Log the action using base class method
|
| 1057 |
+
self.log_action(iteration, tool_call["name"], arguments, tool_result)
|
| 1058 |
+
|
| 1059 |
+
# Add tool result to conversation
|
| 1060 |
+
conversation_history.append({
|
| 1061 |
+
"role": "tool",
|
| 1062 |
+
"content": json.dumps(tool_result, ensure_ascii=False, indent=2) + " /no_think"
|
| 1063 |
+
})
|
| 1064 |
+
|
| 1065 |
+
# If no tool calls, encourage continued planning
|
| 1066 |
+
if len(tool_calls) == 0:
|
| 1067 |
+
# Add follow-up prompt to encourage action or completion
|
| 1068 |
+
followup_prompt = (
|
| 1069 |
+
"Continue your planning process. Use available tools to assign tasks to agents, "
|
| 1070 |
+
"search for information, or coordinate work. When you have a complete answer, "
|
| 1071 |
+
"call planner_subjective_task_done or planner_objective_task_done. /no_think"
|
| 1072 |
+
)
|
| 1073 |
+
conversation_history.append({"role": "user", "content": followup_prompt})
|
| 1074 |
+
|
| 1075 |
+
except Exception as e:
|
| 1076 |
+
error_msg = f"Error in planning iteration {iteration}: {e}"
|
| 1077 |
+
self.log_error(iteration, error_msg)
|
| 1078 |
+
break
|
| 1079 |
+
|
| 1080 |
+
execution_time = time.time() - start_time
|
| 1081 |
+
|
| 1082 |
+
# Extract final result
|
| 1083 |
+
if task_completed:
|
| 1084 |
+
# Find the completion result in the trace
|
| 1085 |
+
completion_result = None
|
| 1086 |
+
for step in reversed(self.reasoning_trace):
|
| 1087 |
+
if step.get("type") == "action" and step.get("tool") in ["planner_subjective_task_done",
|
| 1088 |
+
"planner_objective_task_done"]:
|
| 1089 |
+
completion_result = step.get("result")
|
| 1090 |
+
break
|
| 1091 |
+
|
| 1092 |
+
return {
|
| 1093 |
+
"success": True,
|
| 1094 |
+
"data": completion_result,
|
| 1095 |
+
"reasoning_trace": self.reasoning_trace,
|
| 1096 |
+
"iterations": iteration,
|
| 1097 |
+
"execution_time": execution_time
|
| 1098 |
+
}
|
| 1099 |
+
else:
|
| 1100 |
+
return {
|
| 1101 |
+
"success": False,
|
| 1102 |
+
"error": f"Planning task not completed within {max_iterations} iterations",
|
| 1103 |
+
"reasoning_trace": self.reasoning_trace,
|
| 1104 |
+
"iterations": iteration,
|
| 1105 |
+
"execution_time": execution_time
|
| 1106 |
+
}
|
| 1107 |
+
except Exception as e:
|
| 1108 |
+
execution_time = time.time() - start_time if 'start_time' in locals() else 0
|
| 1109 |
+
self.logger.error(f"Error in execute_react_loop: {e}")
|
| 1110 |
+
return {
|
| 1111 |
+
"success": False,
|
| 1112 |
+
"error": str(e),
|
| 1113 |
+
"reasoning_trace": self.reasoning_trace,
|
| 1114 |
+
"iterations": iteration if 'iteration' in locals() else 0,
|
| 1115 |
+
"execution_time": execution_time
|
| 1116 |
+
}
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
def execute_task(self, user_query: str) -> AgentResponse:
|
| 1120 |
+
"""
|
| 1121 |
+
Execute a planning task for the given user query
|
| 1122 |
+
|
| 1123 |
+
Args:
|
| 1124 |
+
user_query: The user's query or request
|
| 1125 |
+
|
| 1126 |
+
Returns:
|
| 1127 |
+
AgentResponse with planning results and process trace
|
| 1128 |
+
"""
|
| 1129 |
+
start_time = time.time()
|
| 1130 |
+
|
| 1131 |
+
try:
|
| 1132 |
+
self.logger.info(f"Starting planner task: {user_query}")
|
| 1133 |
+
|
| 1134 |
+
# Execute the planning task using ReAct pattern
|
| 1135 |
+
result = self._execute_react_loop(
|
| 1136 |
+
initial_message=user_query,
|
| 1137 |
+
max_iterations=self.config.max_iterations # Reasonable limit for planning tasks
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
execution_time = time.time() - start_time
|
| 1141 |
+
|
| 1142 |
+
return AgentResponse(
|
| 1143 |
+
success=result.get("success", False),
|
| 1144 |
+
result=result.get("data"),
|
| 1145 |
+
error=result.get("error"),
|
| 1146 |
+
reasoning_trace=result.get("reasoning_trace", []),
|
| 1147 |
+
iterations=result.get("iterations", 0),
|
| 1148 |
+
execution_time=execution_time,
|
| 1149 |
+
agent_name=self.config.agent_name
|
| 1150 |
+
)
|
| 1151 |
+
|
| 1152 |
+
except Exception as e:
|
| 1153 |
+
execution_time = time.time() - start_time
|
| 1154 |
+
self.logger.error(f"Planner execution failed: {e}")
|
| 1155 |
+
|
| 1156 |
+
return AgentResponse(
|
| 1157 |
+
success=False,
|
| 1158 |
+
error=f"Planner execution failed: {str(e)}",
|
| 1159 |
+
reasoning_trace=[],
|
| 1160 |
+
iterations=0,
|
| 1161 |
+
execution_time=execution_time,
|
| 1162 |
+
agent_name=self.config.agent_name
|
| 1163 |
+
)
|
| 1164 |
+
|
| 1165 |
+
|
| 1166 |
+
def create_planner_agent(
|
| 1167 |
+
model: Any = None,
|
| 1168 |
+
sub_agent_configs: Dict[str, Dict[str, Any]] = None,
|
| 1169 |
+
shared_mcp_client=None,
|
| 1170 |
+
**kwargs
|
| 1171 |
+
) -> PlannerAgent:
|
| 1172 |
+
"""
|
| 1173 |
+
Create a PlannerAgent instance with server-managed sessions.
|
| 1174 |
+
|
| 1175 |
+
Args:
|
| 1176 |
+
model: The LLM model to use
|
| 1177 |
+
sub_agent_configs: Configuration for sub-agents (information_seeker, writer)
|
| 1178 |
+
shared_mcp_client: Optional shared MCP client to prevent duplicate connections
|
| 1179 |
+
**kwargs: Additional configuration options
|
| 1180 |
+
|
| 1181 |
+
Returns:
|
| 1182 |
+
Configured PlannerAgent instance
|
| 1183 |
+
"""
|
| 1184 |
+
# Import the enhanced config function
|
| 1185 |
+
from .base_agent import create_agent_config
|
| 1186 |
+
|
| 1187 |
+
# Create agent configuration (session managed by MCP server)
|
| 1188 |
+
config = create_agent_config(
|
| 1189 |
+
agent_name="PlannerAgent",
|
| 1190 |
+
model=model,
|
| 1191 |
+
**kwargs
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
# Create planner agent with optional shared MCP client
|
| 1195 |
+
planner = PlannerAgent(config=config, shared_mcp_client=shared_mcp_client)
|
| 1196 |
+
|
| 1197 |
+
# Store sub-agent configurations for use when creating sub-agents
|
| 1198 |
+
planner.sub_agent_configs = sub_agent_configs or {
|
| 1199 |
+
"information_seeker": {},
|
| 1200 |
+
"writer": {}
|
| 1201 |
+
}
|
| 1202 |
+
|
| 1203 |
+
return planner
|
deepdiver_v2/src/agents/subjective_information_seeker.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
import json
|
| 3 |
+
from typing import Dict, Any, List
|
| 4 |
+
import time
|
| 5 |
+
import requests
|
| 6 |
+
import os
|
| 7 |
+
from .base_agent import BaseAgent, AgentConfig, AgentResponse, TaskInput
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class InformationSeekerAgent(BaseAgent):
|
| 12 |
+
"""
|
| 13 |
+
Information Seeker Agent that follows ReAct pattern (Reasoning + Acting)
|
| 14 |
+
|
| 15 |
+
This agent takes decomposed sub-questions or tasks from parent agents,
|
| 16 |
+
thinks interleaved (reasoning -> action -> reasoning -> action),
|
| 17 |
+
uses MCP tools to gather information, and returns structured results.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, config: AgentConfig = None, shared_mcp_client=None):
|
| 21 |
+
# Set default agent name if not specified
|
| 22 |
+
if config is None:
|
| 23 |
+
config = AgentConfig(agent_name="InformationSeekerAgent")
|
| 24 |
+
elif config.agent_name == "base_agent":
|
| 25 |
+
config.agent_name = "InformationSeekerAgent"
|
| 26 |
+
|
| 27 |
+
super().__init__(config, shared_mcp_client)
|
| 28 |
+
|
| 29 |
+
def _build_system_prompt(self) -> str:
|
| 30 |
+
"""Build the system prompt for the ReAct agent"""
|
| 31 |
+
tool_schemas_str = json.dumps(self.tool_schemas, ensure_ascii=False)
|
| 32 |
+
system_prompt_template = """You are an Information Seeker Agent that follows the ReAct pattern (Reasoning + Acting).
|
| 33 |
+
|
| 34 |
+
Your role is to:
|
| 35 |
+
1. Take decomposed sub-questions or tasks from parent agents
|
| 36 |
+
2. Think step-by-step through reasoning
|
| 37 |
+
3. Use available tools to gather information when needed
|
| 38 |
+
4. Continue reasoning based on tool results
|
| 39 |
+
5. Repeat this process until you have sufficient information
|
| 40 |
+
6. Call info_seeker_subjective_task_done to provide a structured summary and key files
|
| 41 |
+
|
| 42 |
+
TOOL USAGE STRATEGY:
|
| 43 |
+
Follow this optimized workflow for information gathering:
|
| 44 |
+
|
| 45 |
+
1. INITIAL RESEARCH:
|
| 46 |
+
- Generate focused search queries (≤10): Limit to no more than 10 initial search queries to avoid increased failure rates from excessive decomposition.
|
| 47 |
+
- Use `batch_web_search` to find relevant URLs for your queries. When calling the search statement, consider the language of the user's question. For example, for a Chinese question, generate a part of the search statement in Chinese.
|
| 48 |
+
- Analyze the search results (titles, snippets, URLs) to identify promising sources
|
| 49 |
+
|
| 50 |
+
2. CONTENT EXTRACTION:
|
| 51 |
+
- For important URLs, use `url_crawler` to:
|
| 52 |
+
a) Extract full content from the webpage
|
| 53 |
+
b) Save the content to a file in the workspace **under the relative path `./url_crawler_save_files/`**
|
| 54 |
+
- Store results with meaningful file paths (e.g., `url_crawler_save_files/research/ai_trends_2024.txt`)
|
| 55 |
+
|
| 56 |
+
3. CONTENT ANALYSIS:
|
| 57 |
+
- Use `document_extract` for multi-dimensional analysis of saved files:
|
| 58 |
+
a) Provides structured analysis across five key dimensions: doc time source authority, core content and task relevance
|
| 59 |
+
|
| 60 |
+
4. FILE MANAGEMENT:
|
| 61 |
+
- For reviewing saved content:
|
| 62 |
+
a) Prefer `document_extract` to get comprehensive multi-dimensional analysis of saved files
|
| 63 |
+
b) Use `file_read` ONLY for small files (<1000 tokens) when you need the entire content
|
| 64 |
+
c) Avoid reading large files directly as it may exceed context limits
|
| 65 |
+
|
| 66 |
+
### Usage of Systematic Tool:
|
| 67 |
+
- `think` is a systematic tool. After receiving the response from the complex tool or before invoking any other tools, you must **first invoke the `think` tool**: to deeply reflect on the results of previous tool invocations (if any), and to thoroughly consider and plan the user's task. The `think` tool does not acquire new information; it only saves your thoughts into memory.
|
| 68 |
+
|
| 69 |
+
Always provide clear reasoning for your actions and synthesize information effectively.
|
| 70 |
+
|
| 71 |
+
Below, within the <tools></tools> tags, are the descriptions of each tool and the required fields for invocation:
|
| 72 |
+
<tools>
|
| 73 |
+
$tool_schemas
|
| 74 |
+
</tools>
|
| 75 |
+
For each function call, return a JSON object placed within the [unused11][unused12] tags, which includes the function name and the corresponding function arguments:
|
| 76 |
+
[unused11][{\"name\": <function name>, \"arguments\": <args json object>}][unused12]
|
| 77 |
+
"""
|
| 78 |
+
return system_prompt_template.replace("$tool_schemas", tool_schemas_str)
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def _build_initial_message_from_task_input(task_input: TaskInput) -> str:
|
| 82 |
+
"""Build the initial user message from TaskInput"""
|
| 83 |
+
message = task_input.format_for_prompt()
|
| 84 |
+
|
| 85 |
+
message += "\nPlease analyze this task and start your ReAct process:\n"
|
| 86 |
+
message += "1. Reason about what information you need to gather\n"
|
| 87 |
+
message += "2. Use appropriate tools to get that information\n"
|
| 88 |
+
message += "3. Continue reasoning and acting until you have sufficient information\n"
|
| 89 |
+
message += "4. Call info_seeker_subjective_task_done when ready to provide your complete findings\n\n"
|
| 90 |
+
message += "Begin with your initial reasoning about the task."
|
| 91 |
+
|
| 92 |
+
return message
|
| 93 |
+
|
| 94 |
+
def execute_task(self, task_input: TaskInput) -> AgentResponse:
|
| 95 |
+
"""
|
| 96 |
+
Execute a task using ReAct pattern (Reasoning + Acting)
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
task_input: TaskInput object with standardized task information
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
AgentResponse with results and process trace
|
| 103 |
+
"""
|
| 104 |
+
start_time = time.time()
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
self.logger.info(f"Starting information seeker task: {task_input.task_content}")
|
| 108 |
+
|
| 109 |
+
# Reset trace for new task
|
| 110 |
+
self.reset_trace()
|
| 111 |
+
|
| 112 |
+
# Initialize conversation history
|
| 113 |
+
conversation_history = []
|
| 114 |
+
|
| 115 |
+
# Build initial system prompt for ReAct
|
| 116 |
+
system_prompt = self._build_system_prompt()
|
| 117 |
+
|
| 118 |
+
# Build initial user message from TaskInput
|
| 119 |
+
user_message = self._build_initial_message_from_task_input(task_input)
|
| 120 |
+
|
| 121 |
+
# Add to conversation
|
| 122 |
+
conversation_history.append({"role": "system", "content": system_prompt})
|
| 123 |
+
conversation_history.append({"role": "user", "content": user_message + " /no_think"})
|
| 124 |
+
|
| 125 |
+
iteration = 0
|
| 126 |
+
task_completed = False
|
| 127 |
+
# Get model configuration from config
|
| 128 |
+
from config.config import get_config
|
| 129 |
+
config = get_config()
|
| 130 |
+
model_config = config.get_custom_llm_config()
|
| 131 |
+
|
| 132 |
+
pangu_url = model_config.get('url') or os.getenv('MODEL_REQUEST_URL', '')
|
| 133 |
+
model_token = model_config.get('token') or os.getenv('MODEL_REQUEST_TOKEN', '')
|
| 134 |
+
headers = {'Content-Type': 'application/json', 'csb-token': model_token}
|
| 135 |
+
|
| 136 |
+
# ReAct Loop: Reasoning -> Acting -> Reasoning -> Acting...
|
| 137 |
+
self.config.max_iterations = 30
|
| 138 |
+
while iteration < self.config.max_iterations and not task_completed:
|
| 139 |
+
iteration += 1
|
| 140 |
+
self.logger.info(f"Planning iteration {iteration}")
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
# Get LLM response (reasoning + potential tool calls)
|
| 144 |
+
retry_num = 1
|
| 145 |
+
max_retry_num = 10
|
| 146 |
+
while retry_num < max_retry_num:
|
| 147 |
+
try:
|
| 148 |
+
response = requests.post(
|
| 149 |
+
url=pangu_url,
|
| 150 |
+
headers=headers,
|
| 151 |
+
json={
|
| 152 |
+
"model": model_config.get('model', 'pangu_auto'),
|
| 153 |
+
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<s>[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{'<s>[unused9]系统:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}",
|
| 154 |
+
"messages": conversation_history,
|
| 155 |
+
"spaces_between_special_tokens": False,
|
| 156 |
+
"temperature": self.config.temperature,
|
| 157 |
+
},
|
| 158 |
+
timeout=model_config.get("timeout", 180)
|
| 159 |
+
)
|
| 160 |
+
response = response.json()
|
| 161 |
+
|
| 162 |
+
self.logger.debug(f"API response received")
|
| 163 |
+
break
|
| 164 |
+
except Exception as e:
|
| 165 |
+
time.sleep(3)
|
| 166 |
+
retry_num += 1
|
| 167 |
+
if retry_num == max_retry_num:
|
| 168 |
+
raise ValueError(str(e))
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
assistant_message = response["choices"][0]["message"]
|
| 172 |
+
# Log the reasoning
|
| 173 |
+
try:
|
| 174 |
+
if assistant_message["content"]:
|
| 175 |
+
reasoning_content = assistant_message["content"].split("[unused16]")[-1].split("[unused17]")[0]
|
| 176 |
+
if len(reasoning_content) > 0:
|
| 177 |
+
self.log_reasoning(iteration, reasoning_content)
|
| 178 |
+
except Exception as e:
|
| 179 |
+
self.logger.warning(f"Tool call parsing error: {e}")
|
| 180 |
+
# Parse error, rerun
|
| 181 |
+
followup_prompt = f"There is a problem with the format of model generation: {e}. Please try again."
|
| 182 |
+
conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def extract_tool_calls(content):
|
| 187 |
+
import re
|
| 188 |
+
if not content:
|
| 189 |
+
return []
|
| 190 |
+
tool_call_str = re.findall(r"\[unused11\]([\s\S]*?)\[unused12\]", content)
|
| 191 |
+
if len(tool_call_str) > 0:
|
| 192 |
+
try:
|
| 193 |
+
tool_calls = json.loads(tool_call_str[0].strip())
|
| 194 |
+
except Exception as ee:
|
| 195 |
+
return ["fail_tools_load", ee]
|
| 196 |
+
else:
|
| 197 |
+
return []
|
| 198 |
+
return tool_calls
|
| 199 |
+
|
| 200 |
+
# Add assistant message to conversation
|
| 201 |
+
conversation_history.append({
|
| 202 |
+
"role": "assistant",
|
| 203 |
+
"content": assistant_message["content"]
|
| 204 |
+
})
|
| 205 |
+
|
| 206 |
+
tool_calls = extract_tool_calls(assistant_message["content"])
|
| 207 |
+
|
| 208 |
+
if tool_calls[0] == "fail_tools_load":
|
| 209 |
+
# Parse error, rerun
|
| 210 |
+
followup_prompt = f"There was a parsing error in the format of the tool call" \
|
| 211 |
+
f" you generated:{tool_calls[1]} Please regenerate it."
|
| 212 |
+
conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# Execute tool calls if any (Acting phase)
|
| 217 |
+
|
| 218 |
+
for tool_call in tool_calls:
|
| 219 |
+
arguments = tool_call["arguments"]
|
| 220 |
+
|
| 221 |
+
# Check if planning is complete
|
| 222 |
+
if tool_call["name"] in ["info_seeker_subjective_task_done"]:
|
| 223 |
+
task_completed = True
|
| 224 |
+
self.log_action(iteration, tool_call["name"], arguments, arguments)
|
| 225 |
+
break
|
| 226 |
+
if tool_call["name"] in ["think", "reflect"]:
|
| 227 |
+
tool_result = {"tool_results": "You can proceed to invoke other tools if needed."}
|
| 228 |
+
else:
|
| 229 |
+
tool_result = self.execute_tool_call(tool_call)
|
| 230 |
+
|
| 231 |
+
# Log the action using base class method
|
| 232 |
+
self.log_action(iteration, tool_call["name"], arguments, tool_result)
|
| 233 |
+
|
| 234 |
+
# Add tool result to conversation
|
| 235 |
+
conversation_history.append({
|
| 236 |
+
"role": "tool",
|
| 237 |
+
"content": json.dumps(tool_result, ensure_ascii=False, indent=2) + " /no_think"
|
| 238 |
+
})
|
| 239 |
+
|
| 240 |
+
# If no tool calls, encourage continued planning
|
| 241 |
+
if len(tool_calls) == 0:
|
| 242 |
+
# Add follow-up prompt to encourage action or completion
|
| 243 |
+
followup_prompt = (
|
| 244 |
+
"Continue your analysis. If you need more information, use available tools. "
|
| 245 |
+
"If you have enough information to answer the question, call info_seeker_subjective_task_done with your complete context."
|
| 246 |
+
)
|
| 247 |
+
conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
|
| 248 |
+
if iteration == self.config.max_iterations-3:
|
| 249 |
+
followup_prompt = "Due to length and number of rounds restrictions, you must now call the `info_seeker_subjective_task_done` tool to report the completion of your task."
|
| 250 |
+
conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
except Exception as e:
|
| 254 |
+
error_msg = f"Error in planning iteration {iteration}: {e}"
|
| 255 |
+
self.log_error(iteration, error_msg)
|
| 256 |
+
break
|
| 257 |
+
|
| 258 |
+
execution_time = time.time() - start_time
|
| 259 |
+
# Extract final result
|
| 260 |
+
if task_completed:
|
| 261 |
+
# Find the task_done result in the trace
|
| 262 |
+
task_done_result = None
|
| 263 |
+
for step in reversed(self.reasoning_trace):
|
| 264 |
+
if step.get("type") == "action" and step.get("tool") == "info_seeker_subjective_task_done":
|
| 265 |
+
task_done_result = step.get("result")
|
| 266 |
+
break
|
| 267 |
+
|
| 268 |
+
return self.create_response(
|
| 269 |
+
success=True,
|
| 270 |
+
result=task_done_result,
|
| 271 |
+
iterations=iteration,
|
| 272 |
+
execution_time=execution_time
|
| 273 |
+
)
|
| 274 |
+
else:
|
| 275 |
+
return self.create_response(
|
| 276 |
+
success=False,
|
| 277 |
+
error=f"Task not completed within {self.config.max_iterations} iterations",
|
| 278 |
+
iterations=iteration,
|
| 279 |
+
execution_time=execution_time
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
except Exception as e:
|
| 283 |
+
execution_time = time.time() - start_time
|
| 284 |
+
self.logger.error(f"Error in execute_task: {e}")
|
| 285 |
+
return self.create_response(
|
| 286 |
+
success=False,
|
| 287 |
+
error=str(e),
|
| 288 |
+
iterations=iteration if 'iteration' in locals() else 0,
|
| 289 |
+
execution_time=execution_time
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
|
| 293 |
+
"""
|
| 294 |
+
Build tool schemas for InformationSeekerAgent using proper MCP architecture.
|
| 295 |
+
Schemas come from MCP server via client, not direct imports.
|
| 296 |
+
"""
|
| 297 |
+
# Get MCP tool schemas from server via client (proper MCP architecture)
|
| 298 |
+
schemas = super()._build_agent_specific_tool_schemas()
|
| 299 |
+
|
| 300 |
+
# Add schemas for built-in task assignment tools
|
| 301 |
+
builtin_assignment_schemas = [
|
| 302 |
+
{
|
| 303 |
+
"type": "function",
|
| 304 |
+
"function": {
|
| 305 |
+
"name": "think",
|
| 306 |
+
"description": "Use the tool to think about something. It will not obtain new information or make any changes to the repository, but just log the thought. Use it when complex reasoning or brainstorming is needed.",
|
| 307 |
+
"parameters": {
|
| 308 |
+
"type": "object",
|
| 309 |
+
"properties": {
|
| 310 |
+
"thought": {
|
| 311 |
+
"type": "string",
|
| 312 |
+
"description": "Your thoughts."
|
| 313 |
+
}
|
| 314 |
+
},
|
| 315 |
+
"required": ["thought"]
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
},
|
| 319 |
+
{
|
| 320 |
+
"type": "function",
|
| 321 |
+
"function": {
|
| 322 |
+
"name": "reflect",
|
| 323 |
+
"description": "When multiple attempts yield no progress, use this tool to reflect on previous reasoning and planning, considering possible overlooked clues and exploring more possibilities. It will not obtain new information or make any changes to the repository.",
|
| 324 |
+
"parameters": {
|
| 325 |
+
"type": "object",
|
| 326 |
+
"properties": {
|
| 327 |
+
"reflect": {
|
| 328 |
+
"type": "string",
|
| 329 |
+
"description": "The specific content of your reflection"
|
| 330 |
+
}
|
| 331 |
+
},
|
| 332 |
+
"required": ["reflect"]
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
},
|
| 336 |
+
{
|
| 337 |
+
"type": "function",
|
| 338 |
+
"function": {
|
| 339 |
+
"name": "info_seeker_subjective_task_done",
|
| 340 |
+
"description": "Information Seeker Agent task completion reporting with information collection summary and related files.",
|
| 341 |
+
"parameters": {
|
| 342 |
+
"type": "object",
|
| 343 |
+
"properties": {
|
| 344 |
+
"task_summary": {
|
| 345 |
+
"type": "string",
|
| 346 |
+
"description": "Simple summary of what information has been collected for the current task and what new discoveries have been made.",
|
| 347 |
+
"format": "markdown"
|
| 348 |
+
},
|
| 349 |
+
"key_files": {
|
| 350 |
+
"type": "array",
|
| 351 |
+
"items": {
|
| 352 |
+
"type": "object",
|
| 353 |
+
"properties": {
|
| 354 |
+
"file_path": {
|
| 355 |
+
"type": "string",
|
| 356 |
+
"description": "Relative path to the file with collected content"
|
| 357 |
+
},
|
| 358 |
+
},
|
| 359 |
+
"required": ["file_path"]
|
| 360 |
+
},
|
| 361 |
+
"description": "Collect files highly relevant to this task. "
|
| 362 |
+
},
|
| 363 |
+
"completion_status": {
|
| 364 |
+
"type": "string",
|
| 365 |
+
"enum": ["completed", "partial", "failed"],
|
| 366 |
+
"description": "Final status of the information gathering task"
|
| 367 |
+
},
|
| 368 |
+
"completion_analysis": {
|
| 369 |
+
"type": "string",
|
| 370 |
+
"description": "Brief analysis of task completion quality, information thoroughness, and any limitations or gaps."
|
| 371 |
+
}
|
| 372 |
+
},
|
| 373 |
+
"required": ["task_summary", "key_files", "completion_status", "completion_analysis"]
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
},
|
| 377 |
+
]
|
| 378 |
+
|
| 379 |
+
schemas.extend(builtin_assignment_schemas)
|
| 380 |
+
|
| 381 |
+
return schemas
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
# Factory function for creating the agent
|
| 385 |
+
def create_subjective_information_seeker(
|
| 386 |
+
model: str = "pangu_auto",
|
| 387 |
+
max_iterations: int = 10,
|
| 388 |
+
shared_mcp_client=None,
|
| 389 |
+
**kwargs
|
| 390 |
+
) -> InformationSeekerAgent:
|
| 391 |
+
"""
|
| 392 |
+
Create an InformationSeekerAgent instance with server-managed sessions.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
model: The LLM model to use
|
| 396 |
+
max_iterations: Maximum number of iterations
|
| 397 |
+
shared_mcp_client: Optional shared MCP client from parent agent (prevents extra sessions)
|
| 398 |
+
**kwargs: Additional configuration options
|
| 399 |
+
|
| 400 |
+
Returns:
|
| 401 |
+
Configured InformationSeekerAgent instance with appropriate tools
|
| 402 |
+
"""
|
| 403 |
+
# Import the enhanced config function
|
| 404 |
+
from .base_agent import create_agent_config
|
| 405 |
+
|
| 406 |
+
# Create agent configuration (session managed by MCP server)
|
| 407 |
+
config = create_agent_config(
|
| 408 |
+
agent_name="InformationSeekerAgent",
|
| 409 |
+
model=model,
|
| 410 |
+
max_iterations=max_iterations,
|
| 411 |
+
**kwargs
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
# Create agent instance with shared MCP client (filtered tools for information seeking)
|
| 415 |
+
agent = InformationSeekerAgent(config=config, shared_mcp_client=shared_mcp_client)
|
| 416 |
+
|
| 417 |
+
return agent
|
deepdiver_v2/src/agents/writer_agent.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
import json
|
| 3 |
+
from typing import Dict, Any, List
|
| 4 |
+
import time
|
| 5 |
+
import requests
|
| 6 |
+
import os
|
| 7 |
+
from .base_agent import BaseAgent, AgentConfig, AgentResponse, WriterAgentTaskInput
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class WriterAgent(BaseAgent):
|
| 12 |
+
"""
|
| 13 |
+
Writer Agent that follows ReAct pattern for content synthesis and generation
|
| 14 |
+
|
| 15 |
+
This agent takes writing tasks from parent agents, searches through existing
|
| 16 |
+
files and knowledge base, and creates long-form content through iterative
|
| 17 |
+
reasoning and refinement. It does NOT access internet resources, only
|
| 18 |
+
local files and memories.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, config: AgentConfig = None, shared_mcp_client=None):
|
| 22 |
+
# Set default agent name if not specified
|
| 23 |
+
if config is None:
|
| 24 |
+
config = AgentConfig(agent_name="WriterAgent")
|
| 25 |
+
elif config.agent_name == "base_agent":
|
| 26 |
+
config.agent_name = "WriterAgent"
|
| 27 |
+
|
| 28 |
+
super().__init__(config, shared_mcp_client)
|
| 29 |
+
|
| 30 |
+
# Rebuild tool schemas with writer-specific tools only
|
| 31 |
+
self.tool_schemas = self._build_tool_schemas()
|
| 32 |
+
|
| 33 |
+
def _build_agent_specific_tool_schemas(self) -> List[Dict[str, Any]]:
|
| 34 |
+
"""
|
| 35 |
+
Build tool schemas for WriterAgent using proper MCP architecture.
|
| 36 |
+
Schemas come from MCP server via client, not direct imports.
|
| 37 |
+
"""
|
| 38 |
+
# Get MCP tool schemas from server via client (proper MCP architecture)
|
| 39 |
+
schemas = super()._build_agent_specific_tool_schemas()
|
| 40 |
+
|
| 41 |
+
# Add schemas for built-in task assignment tools
|
| 42 |
+
builtin_assignment_schemas = [
|
| 43 |
+
{
|
| 44 |
+
"type": "function",
|
| 45 |
+
"function": {
|
| 46 |
+
"name": "think",
|
| 47 |
+
"description": "Use the tool to think about something. It will not obtain new information or make any changes to the repository, but just log the thought. Use it when complex reasoning or brainstorming is needed.",
|
| 48 |
+
"parameters": {
|
| 49 |
+
"type": "object",
|
| 50 |
+
"properties": {
|
| 51 |
+
"thought": {
|
| 52 |
+
"type": "string",
|
| 53 |
+
"description": "Your thoughts."
|
| 54 |
+
}
|
| 55 |
+
},
|
| 56 |
+
"required": ["thought"]
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"type": "function",
|
| 62 |
+
"function": {
|
| 63 |
+
"name": "reflect",
|
| 64 |
+
"description": "When multiple attempts yield no progress, use this tool to reflect on previous reasoning and planning, considering possible overlooked clues and exploring more possibilities. It will not obtain new information or make any changes to the repository.",
|
| 65 |
+
"parameters": {
|
| 66 |
+
"type": "object",
|
| 67 |
+
"properties": {
|
| 68 |
+
"reflect": {
|
| 69 |
+
"type": "string",
|
| 70 |
+
"description": "The specific content of your reflection"
|
| 71 |
+
}
|
| 72 |
+
},
|
| 73 |
+
"required": ["reflect"]
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"type": "function",
|
| 79 |
+
"function": {
|
| 80 |
+
"name": "writer_subjective_task_done",
|
| 81 |
+
"description": "Writer Agent task completion reporting for complete long-form content. Called after all chapters/sections are written to provide a summary of the complete long article, final completion status and analysis, and the storage path of the final consolidated article.",
|
| 82 |
+
"parameters": {
|
| 83 |
+
"type": "object",
|
| 84 |
+
"properties": {
|
| 85 |
+
"final_article_path": {
|
| 86 |
+
"type": "string",
|
| 87 |
+
"description": "The file path where the final article is saved."
|
| 88 |
+
},
|
| 89 |
+
"article_summary": {
|
| 90 |
+
"type": "string",
|
| 91 |
+
"description": "Comprehensive summary of the complete long-form article, including main themes, key points covered, and overall narrative structure.",
|
| 92 |
+
"format": "markdown"
|
| 93 |
+
},
|
| 94 |
+
"completion_status": {
|
| 95 |
+
"type": "string",
|
| 96 |
+
"enum": ["completed", "partial", "failed"],
|
| 97 |
+
"description": "Final status of the complete long-form writing task"
|
| 98 |
+
},
|
| 99 |
+
"completion_analysis": {
|
| 100 |
+
"type": "string",
|
| 101 |
+
"description": "Analysis of the overall writing project completion including: assessment of article coherence and quality, evaluation of content organization and flow, identification of any challenges in the writing process, and overall evaluation of the long-form content creation success."
|
| 102 |
+
}
|
| 103 |
+
},
|
| 104 |
+
"required": ["final_article_path", "article_summary", "completion_status",
|
| 105 |
+
"completion_analysis"]
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
},
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
schemas.extend(builtin_assignment_schemas)
|
| 112 |
+
|
| 113 |
+
return schemas
|
| 114 |
+
|
| 115 |
+
def _build_system_prompt(self) -> str:
|
| 116 |
+
"""Build the system prompt for the writer agent"""
|
| 117 |
+
tool_schemas_str = json.dumps(self.tool_schemas, ensure_ascii=False)
|
| 118 |
+
system_prompt_template = """You are a professional writing master. You will receive key files and user problems. Your task is to generate an outline highly consistent with the user problem, classify files into sections, and iteratively call section_writer tool to create comprehensive content. Then you strictly follow the steps given below:
|
| 119 |
+
|
| 120 |
+
MANDATORY WORKFLOW:
|
| 121 |
+
|
| 122 |
+
1. OUTLINE GENERATION
|
| 123 |
+
Based on the core content of the provided key files collection(file_core_content), generate a high-quality outline suitable for long-form writing. Strictly adhere to the following requirements during generation:
|
| 124 |
+
- Before generating the outline, carefully review the provided **file_core_content**, prioritizing sections with:
|
| 125 |
+
1.**Higher authority** (credible sources)
|
| 126 |
+
2.**Greater information richness** (substantive, detailed content)
|
| 127 |
+
3.**Stronger relevance** (direct alignment with user query)
|
| 128 |
+
4.**Timeliness** (if user’s query is time-sensitive, prioritize recent/updated content)
|
| 129 |
+
Select these segments as the basis for outline generation. Note that we only focus on relevance to the question, so when generating the outline, do not add unrelated sections just for the sake of length. Additionally, the sections should flow logically and not be too disjointed, as this would harm the readability of the final output.
|
| 130 |
+
- The overall structure must be **logically clear**, with **no repetition or redundancy** between chapters.
|
| 131 |
+
- **Note1:** The generated outline must not only have chapter-level headings (Level 1) highly relevant to the user’s question, but the subheadings (Level 2) must also be highly relevant to the user’s question. It is not permitted to generate chapter titles with weak relevance, whether Level 1 or Level 2.
|
| 132 |
+
- **Note2:** The number of chapters must not exceed 7, dynamic evaluation can be performed based on the collected content. For example, if there is a lot of content, more chapters can be generated, and vice versa. But each chapter should only include Level 1 and Level 2 headings. Also, be careful not to generate too many Level 2 headings, limit them to 4. However, if the first chapter is an abstract or introduction, do not generate subheadings (level-2 headings)—only include the main heading (level-1). Additionally, tailor the outline style based on the type of document. For example, in a research report, the first chapter should preferably be titled \"Abstract\" or \"Introduction.\"
|
| 133 |
+
|
| 134 |
+
2. FILE CLASSIFICATION
|
| 135 |
+
- Use the search_result_classifier tool to reasonably split the outline generated above and accurately assign key files to each chapter of the outline.
|
| 136 |
+
- Ensure optimal distribution of reference materials across chapters based on content relevance.
|
| 137 |
+
|
| 138 |
+
3. ITERATIVE SECTION WRITING
|
| 139 |
+
- Call section_writer tool sequentially for each chapter
|
| 140 |
+
- CRITICAL: Must wait for previous chapter completion before starting the next chapter
|
| 141 |
+
- Pass only the specific chapter outline , target file path and corresponding classified files to each section writer
|
| 142 |
+
- Generate save path for each chapter using \"./report/part_X.md\" format (e.g., \"./report/part_1.md\" for first chapter)
|
| 143 |
+
- Check section writer results after completion; retry up to 2 times per chapter if quality is insufficient based on returned fields (do not read saved files)
|
| 144 |
+
- When you call the section_writer tool, pay special attention to the fact that the parameter value of written_chapters_summary is a summary of the content returned by all previously completed chapters. Be careful not to make any changes to the summary content, including compressing the content.
|
| 145 |
+
|
| 146 |
+
4. TASK COMPLETION
|
| 147 |
+
- After all chapters are written, you must first call the concat_section_files tool to merge the saved chapter files into one file, then call writer_subjective_task_done to finalize and return.
|
| 148 |
+
|
| 149 |
+
CRITICAL REQUIREMENTS:
|
| 150 |
+
- The creation of the outline is crucial! Therefore, you must strictly adhere to the above requirements for generating the outline.
|
| 151 |
+
- No parallel writing - strictly sequential chapter execution
|
| 152 |
+
- Wait for each section writer completion before proceeding to next chapter
|
| 153 |
+
- Classify files appropriately to support each chapter's content needs
|
| 154 |
+
- Note again that to merge all the written chapter files, you must use the concat_section_files tool!!! You are not allowed to call any other tools for merging!!!
|
| 155 |
+
|
| 156 |
+
FORBIDDEN CONTENT PATTERNS:
|
| 157 |
+
- NEVER generate meta-structural chapters that describe how the article is organized
|
| 158 |
+
- AVOID introductory sections that outline \"Chapter 1 will cover..., Chapter 2 will discuss...\"
|
| 159 |
+
- DO NOT create chapters that explain the report structure or methodology
|
| 160 |
+
- Each chapter must contain SUBSTANTIVE CONTENT, not descriptions of what other chapters contain
|
| 161 |
+
- When generating an outline, if it is not a professional term, the language should remain consistent with the user's question.\"
|
| 162 |
+
|
| 163 |
+
Usage of TOOLS:
|
| 164 |
+
- search_result_classifier: Classify key files into outline sections
|
| 165 |
+
- section_writer: Write individual chapters sequentially
|
| 166 |
+
- writer_subjective_task_done: Complete the writing task
|
| 167 |
+
- concat_section_files: Concatenate the content of the saved section files into a single file
|
| 168 |
+
- think tool: \"Think\" is a systematic tool requiring its use during key steps. Before executing actions like generating an outline, you must first call this tool to deeply consider the given content and key requirements, ensuring the output meets specifications. Similarly, during iterative chapter generation, after receiving feedback and before writing the next chapter, call \"think\" to reflect on the current chapter. This provides guidance to avoid content repetition and ensure smooth transitions between chapters.
|
| 169 |
+
|
| 170 |
+
Execute workflow systematically to produce high-quality, coherent long-form content with substantive chapters.
|
| 171 |
+
|
| 172 |
+
Below, within the <tools></tools> tags, are the descriptions of each tool and the required fields for invocation:
|
| 173 |
+
<tools>
|
| 174 |
+
$tool_schemas
|
| 175 |
+
</tools>
|
| 176 |
+
For each function call, return a JSON object placed within the [unused11][unused12] tags, which includes the function name and the corresponding function arguments:
|
| 177 |
+
[unused11][{\"name\": <function name>, \"arguments\": <args json object>}][unused12]
|
| 178 |
+
"""
|
| 179 |
+
return system_prompt_template.replace("$tool_schemas", tool_schemas_str)
|
| 180 |
+
|
| 181 |
+
def _build_initial_message_from_task_input(self, task_input: WriterAgentTaskInput) -> str:
|
| 182 |
+
"""Build the initial user message from TaskInput"""
|
| 183 |
+
message = ""
|
| 184 |
+
|
| 185 |
+
# Add key files information with reliability dimensions
|
| 186 |
+
def load_json_from_server(file_path):
|
| 187 |
+
"""Load JSONL file from MCP server using unlimited internal tool"""
|
| 188 |
+
res = []
|
| 189 |
+
try:
|
| 190 |
+
# Use json read tool directly through raw MCP client
|
| 191 |
+
raw_result = self.mcp_tools.client.call_tool("load_json", {"file_path": file_path})
|
| 192 |
+
|
| 193 |
+
if not raw_result.success:
|
| 194 |
+
self.logger.error(f"Failed to read file from server: {raw_result.error}")
|
| 195 |
+
return res
|
| 196 |
+
|
| 197 |
+
res = json.loads(raw_result.data["content"][0]["text"])["data"]
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
self.logger.error(f"Error loading file {file_path} from MCP server: {e}")
|
| 201 |
+
import traceback
|
| 202 |
+
self.logger.debug(f"Full traceback: {traceback.format_exc()}")
|
| 203 |
+
|
| 204 |
+
return res
|
| 205 |
+
|
| 206 |
+
key_files_dict = {}
|
| 207 |
+
|
| 208 |
+
server_analysis_path = f"doc_analysis/file_analysis.jsonl"
|
| 209 |
+
self.logger.debug(f"Loading analysis from MCP server: {server_analysis_path}")
|
| 210 |
+
file_analysis_list = load_json_from_server(server_analysis_path)
|
| 211 |
+
|
| 212 |
+
for file_info in file_analysis_list:
|
| 213 |
+
if file_info.get('file_path'):
|
| 214 |
+
key_files_dict[file_info.get('file_path')] = file_info
|
| 215 |
+
|
| 216 |
+
file_core_content = ""
|
| 217 |
+
if hasattr(task_input, 'key_files') and task_input.key_files:
|
| 218 |
+
message += "Key Files:\n"
|
| 219 |
+
for i, file_ in enumerate(task_input.key_files, 1):
|
| 220 |
+
file_path = file_.get('file_path')
|
| 221 |
+
if file_path in key_files_dict:
|
| 222 |
+
file_info = key_files_dict[file_path]
|
| 223 |
+
doc_time = file_info.get('doc_time', 'Not specified')
|
| 224 |
+
source_authority = file_info.get('source_authority', 'Not assessed')
|
| 225 |
+
task_relevance = file_info.get('task_relevance', 'Not assessed')
|
| 226 |
+
information_richness = file_info.get('information_richness', 'Not assessed')
|
| 227 |
+
message += f"{i}. File: {file_path}\n"
|
| 228 |
+
|
| 229 |
+
file_core_content += f"[{str(i)}]doc_time:{doc_time}|||source_authority:{source_authority}|||task_relevance:{task_relevance}|||information_richness:{information_richness}|||summary_content:{file_info.get('core_content', '')}\n"
|
| 230 |
+
message += "\n"
|
| 231 |
+
message += f"file_core_content: {file_core_content}\n"
|
| 232 |
+
else:
|
| 233 |
+
message += "Key Files: None provided\n"
|
| 234 |
+
|
| 235 |
+
message += "\n"
|
| 236 |
+
# Add user query
|
| 237 |
+
if hasattr(task_input, 'user_query') and task_input.user_query:
|
| 238 |
+
message += f"User Query: {task_input.user_query}\n"
|
| 239 |
+
else:
|
| 240 |
+
message += "User Query: Not provided\n"
|
| 241 |
+
|
| 242 |
+
return message
|
| 243 |
+
|
| 244 |
+
def execute_task(self, task_input: WriterAgentTaskInput) -> AgentResponse:
|
| 245 |
+
"""
|
| 246 |
+
Execute a writing task using ReAct pattern
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
task_input: TaskInput object with standardized task information
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
AgentResponse with writing results and process trace
|
| 253 |
+
"""
|
| 254 |
+
start_time = time.time()
|
| 255 |
+
|
| 256 |
+
try:
|
| 257 |
+
self.logger.info(f"Starting writing task: {task_input.task_content}")
|
| 258 |
+
|
| 259 |
+
# Reset trace for new task
|
| 260 |
+
self.reset_trace()
|
| 261 |
+
|
| 262 |
+
# Initialize conversation history
|
| 263 |
+
conversation_history = []
|
| 264 |
+
|
| 265 |
+
# Build system prompt for writing
|
| 266 |
+
system_prompt = self._build_system_prompt()
|
| 267 |
+
|
| 268 |
+
# Build initial user message from TaskInput
|
| 269 |
+
user_message = self._build_initial_message_from_task_input(task_input)
|
| 270 |
+
|
| 271 |
+
# Add to conversation
|
| 272 |
+
conversation_history.append({"role": "system", "content": system_prompt})
|
| 273 |
+
conversation_history.append({"role": "user", "content": user_message + " /no_think"})
|
| 274 |
+
|
| 275 |
+
iteration = 0
|
| 276 |
+
task_completed = False
|
| 277 |
+
|
| 278 |
+
self.logger.debug("Checking conversation history before model call")
|
| 279 |
+
self.logger.debug(f"Conversation history: {conversation_history}")
|
| 280 |
+
# ReAct Loop for Writing: Research → Plan → Write → Refine → Complete
|
| 281 |
+
# Get model configuration from config
|
| 282 |
+
from config.config import get_config
|
| 283 |
+
config = get_config()
|
| 284 |
+
model_config = config.get_custom_llm_config()
|
| 285 |
+
|
| 286 |
+
pangu_url = model_config.get('url') or os.getenv('MODEL_REQUEST_URL', '')
|
| 287 |
+
model_token = model_config.get('token') or os.getenv('MODEL_REQUEST_TOKEN', '')
|
| 288 |
+
headers = {'Content-Type': 'application/json', 'csb-token': model_token}
|
| 289 |
+
|
| 290 |
+
while iteration < self.config.max_iterations and not task_completed:
|
| 291 |
+
iteration += 1
|
| 292 |
+
self.logger.info(f"Writing iteration {iteration}")
|
| 293 |
+
|
| 294 |
+
try:
|
| 295 |
+
# Get LLM response (reasoning + potential tool calls) with retry
|
| 296 |
+
|
| 297 |
+
max_retries = 10
|
| 298 |
+
response = None
|
| 299 |
+
|
| 300 |
+
for attempt in range(max_retries):
|
| 301 |
+
try:
|
| 302 |
+
|
| 303 |
+
response = requests.post(
|
| 304 |
+
url=pangu_url,
|
| 305 |
+
headers=headers,
|
| 306 |
+
json={
|
| 307 |
+
"model": self.config.model,
|
| 308 |
+
"chat_template":"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<s>[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{'<s>[unused9]系统:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}",
|
| 309 |
+
"messages": conversation_history,
|
| 310 |
+
"temperature": self.config.temperature,
|
| 311 |
+
"max_tokens": self.config.max_tokens,
|
| 312 |
+
"spaces_between_special_tokens": False,
|
| 313 |
+
},
|
| 314 |
+
timeout=model_config.get("timeout", 180)
|
| 315 |
+
)
|
| 316 |
+
response = response.json()
|
| 317 |
+
|
| 318 |
+
self.logger.debug(f"API response received")
|
| 319 |
+
break # Success, exit retry loop
|
| 320 |
+
|
| 321 |
+
except Exception as e:
|
| 322 |
+
self.logger.warning(f"LLM API call attempt {attempt + 1} failed: {e}")
|
| 323 |
+
if attempt == max_retries - 1:
|
| 324 |
+
raise e # Last attempt, re-raise the exception
|
| 325 |
+
time.sleep(6) # Simple 1 second delay between retries
|
| 326 |
+
|
| 327 |
+
if response is None:
|
| 328 |
+
raise Exception("Failed to get response after all retries")
|
| 329 |
+
|
| 330 |
+
assistant_message = response["choices"][0]["message"]
|
| 331 |
+
|
| 332 |
+
try:
|
| 333 |
+
if assistant_message["content"]:
|
| 334 |
+
reasoning_content = assistant_message["content"].split("[unused16]")[-1].split("[unused17]")[0]
|
| 335 |
+
if len(reasoning_content) > 0:
|
| 336 |
+
self.log_reasoning(iteration, reasoning_content)
|
| 337 |
+
except Exception as e:
|
| 338 |
+
self.logger.warning(f"Tool call parsing error: {e}")
|
| 339 |
+
# Parse error, rerun
|
| 340 |
+
followup_prompt = f"There is a problem with the format of model generation: {e}. Please try again."
|
| 341 |
+
conversation_history.append({"role": "user", "content": followup_prompt + " /no_think"})
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
def extract_tool_calls(content):
|
| 345 |
+
import re
|
| 346 |
+
tool_call_str = re.findall(r"\[unused11\]([\s\S]*?)\[unused12\]", content)
|
| 347 |
+
if len(tool_call_str) > 0:
|
| 348 |
+
try:
|
| 349 |
+
tool_calls = json.loads(tool_call_str[0])
|
| 350 |
+
except:
|
| 351 |
+
return []
|
| 352 |
+
else:
|
| 353 |
+
return []
|
| 354 |
+
return tool_calls
|
| 355 |
+
|
| 356 |
+
# Add assistant message to conversation
|
| 357 |
+
conversation_history.append({
|
| 358 |
+
"role": "assistant",
|
| 359 |
+
"content": assistant_message["content"]
|
| 360 |
+
})
|
| 361 |
+
|
| 362 |
+
tool_calls = extract_tool_calls(assistant_message["content"])
|
| 363 |
+
|
| 364 |
+
# Execute tool calls if any (Acting phase)
|
| 365 |
+
for tool_call in tool_calls:
|
| 366 |
+
# Str
|
| 367 |
+
arguments = tool_call["arguments"]
|
| 368 |
+
self.logger.debug(f"Arguments is string: {isinstance(arguments, str)}")
|
| 369 |
+
|
| 370 |
+
# Check if planning is complete
|
| 371 |
+
if tool_call["name"] in ["writer_subjective_task_done"]:
|
| 372 |
+
task_completed = True
|
| 373 |
+
self.log_action(iteration, tool_call["name"], arguments, arguments)
|
| 374 |
+
break
|
| 375 |
+
if tool_call["name"] in ["think"]:
|
| 376 |
+
tool_result = {
|
| 377 |
+
"tool_results": "You can proceed to invoke other tools if needed. But the next step cannot call the reflect tool"}
|
| 378 |
+
else:
|
| 379 |
+
tool_result = self.execute_tool_call(tool_call)
|
| 380 |
+
|
| 381 |
+
# Log the action using base class method
|
| 382 |
+
self.log_action(iteration, tool_call["name"], arguments, tool_result)
|
| 383 |
+
|
| 384 |
+
# Add tool result to conversation
|
| 385 |
+
conversation_history.append({
|
| 386 |
+
"role": "tool",
|
| 387 |
+
"content": json.dumps(tool_result, ensure_ascii=False, indent=2) + " /no_think"
|
| 388 |
+
})
|
| 389 |
+
|
| 390 |
+
# If no tool calls, encourage continued writing
|
| 391 |
+
if len(tool_calls) == 0:
|
| 392 |
+
# Add follow-up prompt to encourage action or completion
|
| 393 |
+
followup_prompt = (
|
| 394 |
+
"Continue your writing process. If you need to research more, use available tools. "
|
| 395 |
+
"If you need to write or edit content, use file operations. "
|
| 396 |
+
"If your writing is complete and meets requirements, call writer_subjective_task_done. /no_think"
|
| 397 |
+
)
|
| 398 |
+
conversation_history.append({"role": "user", "content": followup_prompt})
|
| 399 |
+
|
| 400 |
+
except Exception as e:
|
| 401 |
+
error_msg = f"Error in writing iteration {iteration}: {e}"
|
| 402 |
+
self.log_error(iteration, error_msg)
|
| 403 |
+
break
|
| 404 |
+
|
| 405 |
+
execution_time = time.time() - start_time
|
| 406 |
+
# Extract final result
|
| 407 |
+
if task_completed:
|
| 408 |
+
# Find the completion result in the trace
|
| 409 |
+
completion_result = None
|
| 410 |
+
for step in reversed(self.reasoning_trace):
|
| 411 |
+
if step.get("type") == "action" and step.get("tool") in ["writer_subjective_task_done"]:
|
| 412 |
+
completion_result = step.get("result")
|
| 413 |
+
break
|
| 414 |
+
return self.create_response(
|
| 415 |
+
success=True,
|
| 416 |
+
result=completion_result,
|
| 417 |
+
iterations=iteration,
|
| 418 |
+
execution_time=execution_time
|
| 419 |
+
)
|
| 420 |
+
else:
|
| 421 |
+
|
| 422 |
+
return self.create_response(
|
| 423 |
+
success=False,
|
| 424 |
+
error=f"Writing task not completed within {self.config.max_iterations} iterations",
|
| 425 |
+
iterations=iteration,
|
| 426 |
+
execution_time=execution_time
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
except Exception as e:
|
| 430 |
+
execution_time = time.time() - start_time if 'start_time' in locals() else 0
|
| 431 |
+
self.logger.error(f"Error in execute_react_loop: {e}")
|
| 432 |
+
|
| 433 |
+
return self.create_response(
|
| 434 |
+
success=False,
|
| 435 |
+
error=str(e),
|
| 436 |
+
iterations=iteration if 'iteration' in locals() else 0,
|
| 437 |
+
execution_time=execution_time
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
# Factory function for creating the writer agent
|
| 442 |
+
def create_writer_agent(
|
| 443 |
+
model: Any = None,
|
| 444 |
+
max_iterations: int = 15, # More iterations for writing tasks
|
| 445 |
+
temperature: Any = None, # Resolved from env if not provided
|
| 446 |
+
max_tokens: Any = None,
|
| 447 |
+
shared_mcp_client=None
|
| 448 |
+
) -> WriterAgent:
|
| 449 |
+
"""
|
| 450 |
+
Create a WriterAgent instance with server-managed sessions.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
model: The LLM model to use
|
| 454 |
+
max_iterations: Maximum number of iterations for writing tasks
|
| 455 |
+
temperature: Temperature setting for creativity
|
| 456 |
+
max_tokens: Maximum tokens for the AI response
|
| 457 |
+
shared_mcp_client: Optional shared MCP client from parent agent (prevents extra sessions)
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
Configured WriterAgent instance with writing-focused tools
|
| 461 |
+
"""
|
| 462 |
+
# Import the enhanced config function
|
| 463 |
+
from .base_agent import create_agent_config
|
| 464 |
+
|
| 465 |
+
# Create agent configuration (session managed by MCP server)
|
| 466 |
+
config = create_agent_config(
|
| 467 |
+
agent_name="WriterAgent",
|
| 468 |
+
model=model,
|
| 469 |
+
max_iterations=max_iterations,
|
| 470 |
+
temperature=temperature,
|
| 471 |
+
max_tokens=max_tokens,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Create agent instance with shared MCP client (filtered tools for writing)
|
| 475 |
+
agent = WriterAgent(config=config, shared_mcp_client=shared_mcp_client)
|
| 476 |
+
|
| 477 |
+
return agent
|
deepdiver_v2/src/tools/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
"""
|
| 3 |
+
Model Context Protocol (MCP) Integration
|
| 4 |
+
|
| 5 |
+
This package contains MCP server implementations, tools, and integrations
|
| 6 |
+
for the DeepDiver multi-agent system.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from .mcp_tools import MCPTools
|
| 10 |
+
|
| 11 |
+
# Server imports
|
| 12 |
+
try:
|
| 13 |
+
from .mcp_server_standard import create_app as create_standard_app
|
| 14 |
+
from .mcp_server_simple import app as simple_app
|
| 15 |
+
MCP_STANDARD_AVAILABLE = True
|
| 16 |
+
except ImportError:
|
| 17 |
+
MCP_STANDARD_AVAILABLE = False
|
| 18 |
+
create_standard_app = None
|
| 19 |
+
simple_app = None
|
| 20 |
+
|
| 21 |
+
# For backward compatibility
|
| 22 |
+
try:
|
| 23 |
+
standard_app = simple_app # Keep simple app for basic compatibility
|
| 24 |
+
MCP_AVAILABLE = MCP_STANDARD_AVAILABLE
|
| 25 |
+
except Exception as e:
|
| 26 |
+
MCP_AVAILABLE = False
|
| 27 |
+
standard_app = None
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
'MCPTools',
|
| 31 |
+
'create_standard_app',
|
| 32 |
+
'simple_app',
|
| 33 |
+
'standard_app', # Backward compatibility
|
| 34 |
+
'MCP_AVAILABLE',
|
| 35 |
+
'MCP_STANDARD_AVAILABLE'
|
| 36 |
+
]
|
deepdiver_v2/src/tools/mcp_client.py
ADDED
|
@@ -0,0 +1,814 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
#!/usr/bin/env python3
|
| 3 |
+
"""
|
| 4 |
+
MCP Client for Agent-to-Server Communication
|
| 5 |
+
Provides a proper MCP client that uses the official MCP package
|
| 6 |
+
to connect to and communicate with MCP servers through the Model Context Protocol.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
import time
|
| 12 |
+
from typing import Dict, Any, List, Optional
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import sys
|
| 16 |
+
sys.path.append(str(Path(__file__).parent.parent.parent))
|
| 17 |
+
from ..utils.status_codes import JsonRpcErr
|
| 18 |
+
from http import HTTPStatus
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import httpx
|
| 22 |
+
MCP_AVAILABLE = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
MCP_AVAILABLE = False
|
| 25 |
+
logging.warning("HTTP client dependencies not available. Falling back to direct tools.")
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class MCPClientResult:
|
| 32 |
+
"""Standard result format for MCP client operations"""
|
| 33 |
+
success: bool
|
| 34 |
+
data: Any = None
|
| 35 |
+
error: str = None
|
| 36 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 37 |
+
|
| 38 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 39 |
+
return {
|
| 40 |
+
"success": self.success,
|
| 41 |
+
"data": self.data,
|
| 42 |
+
"error": self.error,
|
| 43 |
+
"metadata": self.metadata
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class MCPTool:
|
| 49 |
+
"""Simple representation of an MCP tool"""
|
| 50 |
+
name: str
|
| 51 |
+
description: str = ""
|
| 52 |
+
input_schema: Dict[str, Any] = field(default_factory=dict)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class RetryConfig:
|
| 57 |
+
"""Configuration for retry behavior on rate limiting"""
|
| 58 |
+
max_retries: int = 20 # Maximum number of retry attempts
|
| 59 |
+
base_delay: float = 2.0 # Base delay between retries (seconds)
|
| 60 |
+
max_delay: float = 60.0 # Maximum delay between retries (seconds)
|
| 61 |
+
exponential_backoff: bool = True # Use exponential backoff
|
| 62 |
+
respect_retry_after: bool = True # Respect server's Retry-After header
|
| 63 |
+
retry_on_rate_limit: bool = True # Enable automatic retry on rate limits
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class MCPClient:
|
| 67 |
+
"""
|
| 68 |
+
Simple HTTP-based MCP Client for dynamic tool discovery and execution.
|
| 69 |
+
|
| 70 |
+
This client makes direct HTTP JSON-RPC calls to the MCP server,
|
| 71 |
+
avoiding the complexity of streaming connections.
|
| 72 |
+
|
| 73 |
+
Session management is handled entirely by the server:
|
| 74 |
+
- Server assigns session IDs on connection
|
| 75 |
+
- Server manages workspace creation and isolation
|
| 76 |
+
- All tool operations use server-managed workspaces
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, server_url: str = "http://localhost:6274/mcp", retry_config: Optional[RetryConfig] = None):
|
| 80 |
+
self.server_url = server_url.rstrip('/')
|
| 81 |
+
self.retry_config = retry_config or RetryConfig()
|
| 82 |
+
self._tools: Dict[str, MCPTool] = {}
|
| 83 |
+
self._connected = False
|
| 84 |
+
self._request_id = 0
|
| 85 |
+
self._session_id = None
|
| 86 |
+
|
| 87 |
+
if not MCP_AVAILABLE:
|
| 88 |
+
logger.warning("HTTP client not available. Some functionality may be limited.")
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
# Initialize connection and discover tools
|
| 92 |
+
self._initialize_connection()
|
| 93 |
+
|
| 94 |
+
def _get_next_id(self) -> int:
|
| 95 |
+
"""Get next request ID"""
|
| 96 |
+
self._request_id += 1
|
| 97 |
+
return self._request_id
|
| 98 |
+
|
| 99 |
+
@staticmethod
|
| 100 |
+
def _parse_sse_response(sse_text: str) -> Dict[str, Any]:
|
| 101 |
+
"""Parse Server-Sent Events response and extract JSON data"""
|
| 102 |
+
try:
|
| 103 |
+
# SSE format: "event: message\ndata: {json}\n\n"
|
| 104 |
+
lines = sse_text.strip().split('\n')
|
| 105 |
+
|
| 106 |
+
for line in lines:
|
| 107 |
+
if line.startswith('data: '):
|
| 108 |
+
json_data = line[6:] # Remove "data: " prefix
|
| 109 |
+
return json.loads(json_data)
|
| 110 |
+
|
| 111 |
+
# If no data line found, try parsing entire response as JSON
|
| 112 |
+
return json.loads(sse_text)
|
| 113 |
+
|
| 114 |
+
except json.JSONDecodeError as e:
|
| 115 |
+
logger.error(f"Failed to parse SSE response: {e}")
|
| 116 |
+
logger.error(f"SSE text: {sse_text[:200]}...")
|
| 117 |
+
return {"error": {"code": JsonRpcErr.PARSE_ERROR, "message": f"Parse error: {e}"}}
|
| 118 |
+
|
| 119 |
+
def _make_request(self, method: str, params: Dict[str, Any] = None) -> MCPClientResult:
|
| 120 |
+
"""Make a JSON-RPC request to the MCP server with automatic retry on rate limits"""
|
| 121 |
+
if not MCP_AVAILABLE:
|
| 122 |
+
return MCPClientResult(success=False, error="HTTP client not available")
|
| 123 |
+
|
| 124 |
+
# Prepare JSON-RPC request
|
| 125 |
+
request_data = {
|
| 126 |
+
"jsonrpc": "2.0",
|
| 127 |
+
"id": self._get_next_id(),
|
| 128 |
+
"method": method,
|
| 129 |
+
"params": params or {}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
# Make HTTP request with proper MCP headers
|
| 133 |
+
headers = {
|
| 134 |
+
"Content-Type": "application/json",
|
| 135 |
+
"Accept": "application/json, text/event-stream"
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
# Add session ID if available
|
| 139 |
+
if self._session_id:
|
| 140 |
+
headers["X-Session-ID"] = self._session_id
|
| 141 |
+
|
| 142 |
+
last_error = None
|
| 143 |
+
retry_count = 0
|
| 144 |
+
|
| 145 |
+
while retry_count <= self.retry_config.max_retries:
|
| 146 |
+
try:
|
| 147 |
+
# Disable proxy for localhost/127.0.0.1 connections to avoid proxy interference
|
| 148 |
+
import os
|
| 149 |
+
from urllib.parse import urlparse
|
| 150 |
+
parsed_url = urlparse(self.server_url)
|
| 151 |
+
is_localhost = parsed_url.hostname in ['localhost', '127.0.0.1', '::1']
|
| 152 |
+
|
| 153 |
+
# Add localhost to NO_PROXY for localhost connections
|
| 154 |
+
original_no_proxy = None
|
| 155 |
+
if is_localhost:
|
| 156 |
+
original_no_proxy = os.environ.get('NO_PROXY', os.environ.get('no_proxy', ''))
|
| 157 |
+
# Add localhost and 127.0.0.1 to NO_PROXY
|
| 158 |
+
no_proxy_hosts = ['localhost', '127.0.0.1', '::1']
|
| 159 |
+
if original_no_proxy:
|
| 160 |
+
existing_hosts = [h.strip() for h in original_no_proxy.split(',')]
|
| 161 |
+
no_proxy_hosts.extend(existing_hosts)
|
| 162 |
+
os.environ['NO_PROXY'] = ','.join(no_proxy_hosts)
|
| 163 |
+
os.environ['no_proxy'] = ','.join(no_proxy_hosts)
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
# Create client with connection pooling for high-concurrency
|
| 167 |
+
limits = httpx.Limits(
|
| 168 |
+
max_keepalive_connections=3000, # Keep more connections alive
|
| 169 |
+
max_connections=3000, # Allow more concurrent connections
|
| 170 |
+
keepalive_expiry=1000.0 # Keep connections alive longer
|
| 171 |
+
)
|
| 172 |
+
timeout = httpx.Timeout(
|
| 173 |
+
connect=100.0,
|
| 174 |
+
read=None,
|
| 175 |
+
write=60.0,
|
| 176 |
+
pool=30.0
|
| 177 |
+
)
|
| 178 |
+
with httpx.Client(
|
| 179 |
+
timeout=timeout, # Higher timeout for high-concurrency scenarios
|
| 180 |
+
limits=limits, # Connection pooling for better performance
|
| 181 |
+
trust_env=False,
|
| 182 |
+
http2=True # Enable HTTP/2 for better multiplexing
|
| 183 |
+
) as client:
|
| 184 |
+
response = client.post(
|
| 185 |
+
self.server_url,
|
| 186 |
+
json=request_data,
|
| 187 |
+
headers=headers
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Check for rate limiting (HTTP 429)
|
| 191 |
+
if response.status_code == 429:
|
| 192 |
+
if not self.retry_config.retry_on_rate_limit:
|
| 193 |
+
return MCPClientResult(
|
| 194 |
+
success=False,
|
| 195 |
+
error=f"Rate limit exceeded (HTTP 429) - retries disabled",
|
| 196 |
+
metadata={"status_code": 429, "retry_count": retry_count}
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
if retry_count >= self.retry_config.max_retries:
|
| 200 |
+
return MCPClientResult(
|
| 201 |
+
success=False,
|
| 202 |
+
error=f"Rate limit exceeded (HTTP 429) - max retries ({self.retry_config.max_retries}) reached",
|
| 203 |
+
metadata={"status_code": 429, "retry_count": retry_count}
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Calculate retry delay
|
| 207 |
+
delay = self._calculate_retry_delay(response, retry_count)
|
| 208 |
+
|
| 209 |
+
logger.warning(f"Rate limit exceeded for {method} (attempt {retry_count + 1}/{self.retry_config.max_retries + 1}). Retrying in {delay:.1f}s...")
|
| 210 |
+
|
| 211 |
+
# Wait before retry
|
| 212 |
+
time.sleep(delay)
|
| 213 |
+
retry_count += 1
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
# Handle other HTTP errors
|
| 217 |
+
if response.status_code != HTTPStatus.OK:
|
| 218 |
+
return MCPClientResult(
|
| 219 |
+
success=False,
|
| 220 |
+
error=f"HTTP {response.status_code}: {response.text}",
|
| 221 |
+
metadata={"status_code": response.status_code, "retry_count": retry_count}
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Parse successful response (could be JSON or SSE format)
|
| 225 |
+
if response.headers.get("content-type", "").startswith("text/event-stream"):
|
| 226 |
+
# Parse SSE format
|
| 227 |
+
response_data = self._parse_sse_response(response.text)
|
| 228 |
+
else:
|
| 229 |
+
# Parse regular JSON
|
| 230 |
+
response_data = response.json()
|
| 231 |
+
|
| 232 |
+
if "error" in response_data:
|
| 233 |
+
return MCPClientResult(
|
| 234 |
+
success=False,
|
| 235 |
+
error=f"MCP Error: {response_data['error']}",
|
| 236 |
+
metadata={"retry_count": retry_count}
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Capture session ID from response data (for all methods, not just initialize)
|
| 240 |
+
if "session_id" in response_data:
|
| 241 |
+
self._session_id = response_data["session_id"]
|
| 242 |
+
logger.info(f"Captured session ID from response: {self._session_id}")
|
| 243 |
+
|
| 244 |
+
# Success! Log retry info if this wasn't the first attempt
|
| 245 |
+
if retry_count > 0:
|
| 246 |
+
logger.info(f"Request {method} succeeded after {retry_count} retries")
|
| 247 |
+
|
| 248 |
+
return MCPClientResult(
|
| 249 |
+
success=True,
|
| 250 |
+
data=response_data.get("result"),
|
| 251 |
+
metadata={
|
| 252 |
+
"method": method,
|
| 253 |
+
"server_url": self.server_url,
|
| 254 |
+
"session_id": self._session_id,
|
| 255 |
+
"retry_count": retry_count
|
| 256 |
+
}
|
| 257 |
+
)
|
| 258 |
+
finally:
|
| 259 |
+
# Restore original NO_PROXY environment variable
|
| 260 |
+
if is_localhost:
|
| 261 |
+
if original_no_proxy is not None:
|
| 262 |
+
if original_no_proxy:
|
| 263 |
+
os.environ['NO_PROXY'] = original_no_proxy
|
| 264 |
+
os.environ['no_proxy'] = original_no_proxy
|
| 265 |
+
else:
|
| 266 |
+
# Remove NO_PROXY if it wasn't set originally
|
| 267 |
+
os.environ.pop('NO_PROXY', None)
|
| 268 |
+
os.environ.pop('no_proxy', None)
|
| 269 |
+
|
| 270 |
+
except Exception as e:
|
| 271 |
+
last_error = str(e)
|
| 272 |
+
logger.error(f"MCP request failed for {method} (attempt {retry_count + 1}): {e}")
|
| 273 |
+
|
| 274 |
+
# Only retry on certain exceptions (network issues, timeouts)
|
| 275 |
+
if not self._should_retry_exception(e) or retry_count >= self.retry_config.max_retries:
|
| 276 |
+
break
|
| 277 |
+
|
| 278 |
+
# Calculate retry delay for exceptions
|
| 279 |
+
delay = self._calculate_exception_retry_delay(retry_count)
|
| 280 |
+
logger.warning(f"Request {method} failed, retrying in {delay:.1f}s... (attempt {retry_count + 1}/{self.retry_config.max_retries + 1})")
|
| 281 |
+
|
| 282 |
+
time.sleep(delay)
|
| 283 |
+
retry_count += 1
|
| 284 |
+
|
| 285 |
+
# All retries exhausted
|
| 286 |
+
return MCPClientResult(
|
| 287 |
+
success=False,
|
| 288 |
+
error=f"Request failed after {retry_count} retries. Last error: {last_error}",
|
| 289 |
+
metadata={"retry_count": retry_count}
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
def _calculate_retry_delay(self, response, retry_count: int) -> float:
|
| 293 |
+
"""Calculate delay before retry based on server response and retry count"""
|
| 294 |
+
delay = self.retry_config.base_delay
|
| 295 |
+
|
| 296 |
+
# Respect server's Retry-After header if available
|
| 297 |
+
if self.retry_config.respect_retry_after and "Retry-After" in response.headers:
|
| 298 |
+
try:
|
| 299 |
+
retry_after = float(response.headers["Retry-After"])
|
| 300 |
+
delay = min(retry_after, self.retry_config.max_delay)
|
| 301 |
+
logger.debug("Using server Retry-After: {%s}s", delay)
|
| 302 |
+
except (ValueError, TypeError):
|
| 303 |
+
logger.warning(f"Invalid Retry-After header: {response.headers.get('Retry-After')}")
|
| 304 |
+
|
| 305 |
+
# Apply exponential backoff if enabled
|
| 306 |
+
elif self.retry_config.exponential_backoff:
|
| 307 |
+
delay = min(
|
| 308 |
+
self.retry_config.base_delay * (2 ** retry_count),
|
| 309 |
+
self.retry_config.max_delay
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return delay
|
| 313 |
+
|
| 314 |
+
def _calculate_exception_retry_delay(self, retry_count: int) -> float:
|
| 315 |
+
"""Calculate delay for exception-based retries"""
|
| 316 |
+
if self.retry_config.exponential_backoff:
|
| 317 |
+
return min(
|
| 318 |
+
self.retry_config.base_delay * (2 ** retry_count),
|
| 319 |
+
self.retry_config.max_delay
|
| 320 |
+
)
|
| 321 |
+
return self.retry_config.base_delay
|
| 322 |
+
|
| 323 |
+
@staticmethod
|
| 324 |
+
def _should_retry_exception(exception: Exception) -> bool:
|
| 325 |
+
"""Determine if an exception warrants a retry"""
|
| 326 |
+
# Retry on network-related exceptions
|
| 327 |
+
if isinstance(exception, (httpx.RequestError, httpx.TimeoutException, httpx.ConnectError)):
|
| 328 |
+
return True
|
| 329 |
+
|
| 330 |
+
# Don't retry on other exceptions (parsing errors, etc.)
|
| 331 |
+
return False
|
| 332 |
+
|
| 333 |
+
def _initialize_connection(self):
|
| 334 |
+
"""Initialize MCP client connection and fetch available tools"""
|
| 335 |
+
if not MCP_AVAILABLE:
|
| 336 |
+
return
|
| 337 |
+
|
| 338 |
+
try:
|
| 339 |
+
# Initialize session
|
| 340 |
+
init_result = self._make_request("initialize", {
|
| 341 |
+
"protocolVersion": "2025-06-18",
|
| 342 |
+
"capabilities": {},
|
| 343 |
+
"clientInfo": {
|
| 344 |
+
"name": "DeepDiver-MCP-Client",
|
| 345 |
+
"version": "1.0.0"
|
| 346 |
+
}
|
| 347 |
+
})
|
| 348 |
+
print(init_result)
|
| 349 |
+
if not init_result.success:
|
| 350 |
+
logger.error(f"MCP initialization failed: {init_result.error}")
|
| 351 |
+
return
|
| 352 |
+
|
| 353 |
+
logger.info("MCP client initialized successfully")
|
| 354 |
+
|
| 355 |
+
# Fetch available tools
|
| 356 |
+
tools_result = self._make_request("tools/list")
|
| 357 |
+
|
| 358 |
+
if tools_result.success and tools_result.data:
|
| 359 |
+
tools_data = tools_result.data.get("tools", [])
|
| 360 |
+
self._tools = {}
|
| 361 |
+
|
| 362 |
+
for tool_data in tools_data:
|
| 363 |
+
tool = MCPTool(
|
| 364 |
+
name=tool_data.get("name", ""),
|
| 365 |
+
description=tool_data.get("description", ""),
|
| 366 |
+
input_schema=tool_data.get("inputSchema", {})
|
| 367 |
+
)
|
| 368 |
+
self._tools[tool.name] = tool
|
| 369 |
+
|
| 370 |
+
logger.info(f"Discovered {len(self._tools)} tools from MCP server: {list(self._tools.keys())}")
|
| 371 |
+
|
| 372 |
+
self._connected = True
|
| 373 |
+
|
| 374 |
+
except Exception as e:
|
| 375 |
+
logger.error(f"Failed to initialize MCP client: {e}")
|
| 376 |
+
self._connected = False
|
| 377 |
+
|
| 378 |
+
def _ensure_connection(self):
|
| 379 |
+
"""Ensure MCP client is connected"""
|
| 380 |
+
if not MCP_AVAILABLE:
|
| 381 |
+
raise RuntimeError("HTTP client not available")
|
| 382 |
+
|
| 383 |
+
if not self._connected:
|
| 384 |
+
self._initialize_connection()
|
| 385 |
+
|
| 386 |
+
if not self._connected:
|
| 387 |
+
raise RuntimeError("MCP client not connected to server")
|
| 388 |
+
|
| 389 |
+
def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> MCPClientResult:
|
| 390 |
+
"""
|
| 391 |
+
Generic method to call any tool available on the MCP server.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
tool_name: Name of the tool to call
|
| 395 |
+
arguments: Dictionary of arguments to pass to the tool
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
MCPClientResult with the tool execution result
|
| 399 |
+
"""
|
| 400 |
+
try:
|
| 401 |
+
self._ensure_connection()
|
| 402 |
+
|
| 403 |
+
if tool_name not in self._tools:
|
| 404 |
+
return MCPClientResult(
|
| 405 |
+
success=False,
|
| 406 |
+
error=f"Tool '{tool_name}' not available on server. Available tools: {list(self._tools.keys())}"
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# Call the tool via JSON-RPC
|
| 410 |
+
result = self._make_request("tools/call", {
|
| 411 |
+
"name": tool_name,
|
| 412 |
+
"arguments": arguments
|
| 413 |
+
})
|
| 414 |
+
|
| 415 |
+
return result
|
| 416 |
+
|
| 417 |
+
except Exception as e:
|
| 418 |
+
logger.error(f"Error calling tool '{tool_name}': {e}")
|
| 419 |
+
return MCPClientResult(
|
| 420 |
+
success=False,
|
| 421 |
+
error=str(e)
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
def get_available_tools(self) -> Dict[str, MCPTool]:
|
| 425 |
+
"""Get dictionary of available tools from the server"""
|
| 426 |
+
return self._tools.copy()
|
| 427 |
+
|
| 428 |
+
def list_tools(self) -> List[str]:
|
| 429 |
+
"""Get list of available tool names"""
|
| 430 |
+
return list(self._tools.keys())
|
| 431 |
+
|
| 432 |
+
def get_tool_info(self, tool_name: str) -> Optional[MCPTool]:
|
| 433 |
+
"""Get detailed information about a specific tool"""
|
| 434 |
+
return self._tools.get(tool_name)
|
| 435 |
+
|
| 436 |
+
def is_connected(self) -> bool:
|
| 437 |
+
"""Check if client is connected to MCP server"""
|
| 438 |
+
return self._connected and MCP_AVAILABLE
|
| 439 |
+
|
| 440 |
+
def refresh_tools(self):
|
| 441 |
+
"""Refresh the list of available tools from the server"""
|
| 442 |
+
try:
|
| 443 |
+
# Fetch available tools
|
| 444 |
+
tools_result = self._make_request("tools/list")
|
| 445 |
+
|
| 446 |
+
if tools_result.success and tools_result.data:
|
| 447 |
+
tools_data = tools_result.data.get("tools", [])
|
| 448 |
+
self._tools = {}
|
| 449 |
+
print(self._tools)
|
| 450 |
+
|
| 451 |
+
for tool_data in tools_data:
|
| 452 |
+
tool = MCPTool(
|
| 453 |
+
name=tool_data.get("name", ""),
|
| 454 |
+
description=tool_data.get("description", ""),
|
| 455 |
+
input_schema=tool_data.get("inputSchema", {})
|
| 456 |
+
)
|
| 457 |
+
self._tools[tool.name] = tool
|
| 458 |
+
|
| 459 |
+
logger.info(f"Refreshed {len(self._tools)} tools from MCP server")
|
| 460 |
+
else:
|
| 461 |
+
logger.error(f"Failed to refresh tools: {tools_result.error}")
|
| 462 |
+
|
| 463 |
+
except Exception as e:
|
| 464 |
+
logger.error(f"Error refreshing tools: {e}")
|
| 465 |
+
|
| 466 |
+
def close(self):
|
| 467 |
+
"""Close MCP client connection"""
|
| 468 |
+
# Since we create connections per request, just mark as disconnected
|
| 469 |
+
self._connected = False
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class MCPToolsAdapter:
|
| 473 |
+
"""
|
| 474 |
+
Adapter class that provides the MCPTools interface while using the generic MCP client.
|
| 475 |
+
|
| 476 |
+
This adapter provides backward compatibility with existing agents by mapping
|
| 477 |
+
MCPTools method calls to generic MCP client tool calls.
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
def __init__(self, server_url: str = "http://localhost:6274/mcp", retry_config: Optional[RetryConfig] = None):
|
| 481 |
+
self.client = MCPClient(server_url, retry_config)
|
| 482 |
+
|
| 483 |
+
def _call_tool(self, tool_name: str, **kwargs) -> MCPClientResult:
|
| 484 |
+
"""Internal method to call tools through the MCP client"""
|
| 485 |
+
return self.client.call_tool(tool_name, kwargs)
|
| 486 |
+
|
| 487 |
+
def __getattr__(self, name: str):
|
| 488 |
+
"""
|
| 489 |
+
Dynamic method creation for any tool available on the server.
|
| 490 |
+
This allows calling tools like adapter.batch_web_search(...) or adapter.file_read(...)
|
| 491 |
+
"""
|
| 492 |
+
if name.startswith('_'):
|
| 493 |
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
| 494 |
+
|
| 495 |
+
# Create a dynamic method that calls the tool
|
| 496 |
+
def tool_method(**kwargs):
|
| 497 |
+
result = self._call_tool(name, **kwargs)
|
| 498 |
+
# For backward compatibility, return the data portion
|
| 499 |
+
return result.data if result.success else {"error": result.error}
|
| 500 |
+
|
| 501 |
+
return tool_method
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def is_connected(self) -> bool:
|
| 506 |
+
"""Check if the MCP client is connected to the server."""
|
| 507 |
+
return self.client.is_connected()
|
| 508 |
+
|
| 509 |
+
def get_available_tools(self) -> Dict[str, MCPTool]:
|
| 510 |
+
"""Get available tools from the MCP server."""
|
| 511 |
+
return self.client.get_available_tools()
|
| 512 |
+
|
| 513 |
+
def list_tools(self) -> List[str]:
|
| 514 |
+
"""Get list of available tool names."""
|
| 515 |
+
return self.client.list_tools()
|
| 516 |
+
|
| 517 |
+
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
| 518 |
+
"""
|
| 519 |
+
Get tool schemas for all available tools.
|
| 520 |
+
This is the proper MCP way - schemas come from server, not direct imports.
|
| 521 |
+
"""
|
| 522 |
+
schemas = []
|
| 523 |
+
available_tools = self.get_available_tools()
|
| 524 |
+
|
| 525 |
+
for tool_name, tool_info in available_tools.items():
|
| 526 |
+
schema = {
|
| 527 |
+
"type": "function",
|
| 528 |
+
"function": {
|
| 529 |
+
"name": tool_name,
|
| 530 |
+
"description": tool_info.description,
|
| 531 |
+
"parameters": tool_info.input_schema
|
| 532 |
+
}
|
| 533 |
+
}
|
| 534 |
+
schemas.append(schema)
|
| 535 |
+
|
| 536 |
+
return schemas
|
| 537 |
+
|
| 538 |
+
def refresh_tools(self):
|
| 539 |
+
"""Refresh the list of available tools from the server."""
|
| 540 |
+
self.client.refresh_tools()
|
| 541 |
+
|
| 542 |
+
def get_session_info(self) -> Optional[Dict[str, Any]]:
|
| 543 |
+
"""Get session information from the underlying MCP client."""
|
| 544 |
+
try:
|
| 545 |
+
if hasattr(self.client, '_session_id'):
|
| 546 |
+
return {
|
| 547 |
+
"session_id": self.client._session_id,
|
| 548 |
+
"connected": self.client.is_connected(),
|
| 549 |
+
"server_url": getattr(self.client, 'server_url', 'unknown')
|
| 550 |
+
}
|
| 551 |
+
return None
|
| 552 |
+
except Exception:
|
| 553 |
+
return None
|
| 554 |
+
|
| 555 |
+
def close(self):
|
| 556 |
+
"""Close the MCP client connection."""
|
| 557 |
+
self.client.close()
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class FilteredMCPToolsAdapter:
|
| 561 |
+
"""
|
| 562 |
+
Filtered adapter that shares MCP client connection but restricts tool access per agent type.
|
| 563 |
+
|
| 564 |
+
This allows agents to:
|
| 565 |
+
- Share the same session/workspace (via shared client)
|
| 566 |
+
- Have different tool sets appropriate for their role
|
| 567 |
+
- Maintain proper separation of concerns
|
| 568 |
+
"""
|
| 569 |
+
|
| 570 |
+
def __init__(self, shared_client: MCPClient, allowed_tools: List[str]):
|
| 571 |
+
"""
|
| 572 |
+
Initialize with shared client and allowed tools list
|
| 573 |
+
|
| 574 |
+
Args:
|
| 575 |
+
shared_client: Shared MCPClient instance (same session)
|
| 576 |
+
allowed_tools: List of tools this agent can access
|
| 577 |
+
"""
|
| 578 |
+
self.client = shared_client
|
| 579 |
+
self.allowed_tools = set(allowed_tools)
|
| 580 |
+
|
| 581 |
+
# Validate that allowed tools exist on server
|
| 582 |
+
available_tools = set(self.client.list_tools())
|
| 583 |
+
invalid_tools = self.allowed_tools - available_tools
|
| 584 |
+
if invalid_tools:
|
| 585 |
+
logger.warning(f"Requested tools not available on server: {invalid_tools}")
|
| 586 |
+
self.allowed_tools = self.allowed_tools & available_tools
|
| 587 |
+
|
| 588 |
+
def _call_tool(self, tool_name: str, **kwargs) -> MCPClientResult:
|
| 589 |
+
"""Call tool if allowed, otherwise return error"""
|
| 590 |
+
if tool_name not in self.allowed_tools:
|
| 591 |
+
return MCPClientResult(
|
| 592 |
+
success=False,
|
| 593 |
+
error=f"Tool '{tool_name}' not allowed for this agent. Allowed tools: {list(self.allowed_tools)}"
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
# Remove any workspace_path if accidentally passed - server handles workspace
|
| 597 |
+
kwargs.pop('workspace_path', None)
|
| 598 |
+
return self.client.call_tool(tool_name, kwargs)
|
| 599 |
+
|
| 600 |
+
def __getattr__(self, name: str):
|
| 601 |
+
"""
|
| 602 |
+
Dynamic method resolution with tool filtering.
|
| 603 |
+
|
| 604 |
+
Only allows access to tools in the allowed_tools list.
|
| 605 |
+
"""
|
| 606 |
+
if name in self.allowed_tools:
|
| 607 |
+
def tool_method(**kwargs):
|
| 608 |
+
return self._call_tool(name, **kwargs)
|
| 609 |
+
return tool_method
|
| 610 |
+
|
| 611 |
+
if name in self.client.list_tools():
|
| 612 |
+
# Tool exists but not allowed for this agent
|
| 613 |
+
raise AttributeError(f"Tool '{name}' not allowed for this agent. Allowed tools: {list(self.allowed_tools)}")
|
| 614 |
+
else:
|
| 615 |
+
# Tool doesn't exist on server
|
| 616 |
+
raise AttributeError(f"Tool '{name}' not available on server. Available tools: {self.client.list_tools()}")
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
# ================ CLIENT MANAGEMENT ================
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def is_connected(self) -> bool:
|
| 624 |
+
"""Check if client is connected to MCP server"""
|
| 625 |
+
return self.client.is_connected()
|
| 626 |
+
|
| 627 |
+
def get_available_tools(self) -> Dict[str, MCPTool]:
|
| 628 |
+
"""Get filtered list of available tools for this agent"""
|
| 629 |
+
all_tools = self.client.get_available_tools()
|
| 630 |
+
return {name: tool for name, tool in all_tools.items() if name in self.allowed_tools}
|
| 631 |
+
|
| 632 |
+
def list_tools(self) -> List[str]:
|
| 633 |
+
"""Get list of allowed tool names for this agent"""
|
| 634 |
+
return list(self.allowed_tools)
|
| 635 |
+
|
| 636 |
+
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
| 637 |
+
"""
|
| 638 |
+
Get tool schemas for tools allowed for this agent.
|
| 639 |
+
This is the proper MCP way - schemas come from server, not direct imports.
|
| 640 |
+
"""
|
| 641 |
+
schemas = []
|
| 642 |
+
available_tools = self.get_available_tools()
|
| 643 |
+
|
| 644 |
+
for tool_name, tool_info in available_tools.items():
|
| 645 |
+
schema = {
|
| 646 |
+
"type": "function",
|
| 647 |
+
"function": {
|
| 648 |
+
"name": tool_name,
|
| 649 |
+
"description": tool_info.description,
|
| 650 |
+
"parameters": tool_info.input_schema
|
| 651 |
+
}
|
| 652 |
+
}
|
| 653 |
+
schemas.append(schema)
|
| 654 |
+
|
| 655 |
+
return schemas
|
| 656 |
+
|
| 657 |
+
def refresh_tools(self):
|
| 658 |
+
"""Refresh the underlying client's tools"""
|
| 659 |
+
self.client.refresh_tools()
|
| 660 |
+
|
| 661 |
+
# Re-validate allowed tools after refresh
|
| 662 |
+
available_tools = set(self.client.list_tools())
|
| 663 |
+
invalid_tools = self.allowed_tools - available_tools
|
| 664 |
+
if invalid_tools:
|
| 665 |
+
logger.warning(f"Some allowed tools no longer available after refresh: {invalid_tools}")
|
| 666 |
+
self.allowed_tools = self.allowed_tools & available_tools
|
| 667 |
+
|
| 668 |
+
def close(self):
|
| 669 |
+
"""Close MCP client connection"""
|
| 670 |
+
self.client.close()
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
# ================ AGENT TOOL SETS ================
|
| 674 |
+
# Define what tools each agent type should have access to
|
| 675 |
+
|
| 676 |
+
PLANNER_AGENT_TOOLS = [
|
| 677 |
+
"download_files",
|
| 678 |
+
"document_qa",
|
| 679 |
+
|
| 680 |
+
"file_read",
|
| 681 |
+
"file_write",
|
| 682 |
+
"str_replace_based_edit_tool",
|
| 683 |
+
|
| 684 |
+
"list_workspace",
|
| 685 |
+
"file_find_by_name",
|
| 686 |
+
]
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
INFORMATION_SEEKER_TOOLS = [
|
| 690 |
+
"batch_web_search",
|
| 691 |
+
"url_crawler",
|
| 692 |
+
"document_extract",
|
| 693 |
+
"document_qa",
|
| 694 |
+
"download_files",
|
| 695 |
+
"file_read",
|
| 696 |
+
"file_write",
|
| 697 |
+
"str_replace_based_edit_tool",
|
| 698 |
+
"list_workspace",
|
| 699 |
+
"file_find_by_name",
|
| 700 |
+
]
|
| 701 |
+
|
| 702 |
+
WRITER_AGENT_TOOLS = [
|
| 703 |
+
"file_read",
|
| 704 |
+
"list_workspace",
|
| 705 |
+
"file_find_by_name",
|
| 706 |
+
|
| 707 |
+
"search_result_classifier",
|
| 708 |
+
"section_writer",
|
| 709 |
+
"concat_section_files",
|
| 710 |
+
]
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
def create_filtered_mcp_tools_adapter(
|
| 714 |
+
shared_client: MCPClient,
|
| 715 |
+
agent_type: str
|
| 716 |
+
) -> FilteredMCPToolsAdapter:
|
| 717 |
+
"""
|
| 718 |
+
Create a filtered MCP tools adapter for specific agent type
|
| 719 |
+
|
| 720 |
+
Args:
|
| 721 |
+
shared_client: Shared MCPClient instance
|
| 722 |
+
agent_type: Type of agent ("planner", "information_seeker", "writer")
|
| 723 |
+
|
| 724 |
+
Returns:
|
| 725 |
+
FilteredMCPToolsAdapter with appropriate tools for agent type
|
| 726 |
+
"""
|
| 727 |
+
tool_sets = {
|
| 728 |
+
"planner": PLANNER_AGENT_TOOLS,
|
| 729 |
+
"information_seeker": INFORMATION_SEEKER_TOOLS,
|
| 730 |
+
"writer": WRITER_AGENT_TOOLS
|
| 731 |
+
}
|
| 732 |
+
|
| 733 |
+
allowed_tools = tool_sets.get(agent_type, PLANNER_AGENT_TOOLS)
|
| 734 |
+
|
| 735 |
+
return FilteredMCPToolsAdapter(
|
| 736 |
+
shared_client=shared_client,
|
| 737 |
+
allowed_tools=allowed_tools
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def create_agent_mcp_tools(
|
| 742 |
+
agent_type: str,
|
| 743 |
+
server_url: str = "http://localhost:6274/mcp",
|
| 744 |
+
retry_config: Optional[RetryConfig] = None
|
| 745 |
+
) -> FilteredMCPToolsAdapter:
|
| 746 |
+
"""
|
| 747 |
+
Convenience factory to create a filtered MCP tools adapter with retry support.
|
| 748 |
+
This is the RECOMMENDED way to create MCP tools for agents.
|
| 749 |
+
|
| 750 |
+
Args:
|
| 751 |
+
agent_type: Type of agent ("planner", "information_seeker", "writer")
|
| 752 |
+
server_url: URL of the MCP server (default: http://localhost:6274/mcp)
|
| 753 |
+
retry_config: Optional retry configuration for handling rate limits
|
| 754 |
+
|
| 755 |
+
Returns:
|
| 756 |
+
FilteredMCPToolsAdapter with appropriate tools and retry support for the agent type
|
| 757 |
+
"""
|
| 758 |
+
# Create client with retry support
|
| 759 |
+
client = create_mcp_client(server_url=server_url, retry_config=retry_config)
|
| 760 |
+
|
| 761 |
+
# Create filtered adapter for the agent type
|
| 762 |
+
return create_filtered_mcp_tools_adapter(client, agent_type)
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
def create_mcp_client(
|
| 766 |
+
server_url: str = "http://localhost:6274/mcp",
|
| 767 |
+
retry_config: Optional[RetryConfig] = None
|
| 768 |
+
) -> MCPClient:
|
| 769 |
+
"""
|
| 770 |
+
Factory function to create a generic MCP Client with optional retry configuration
|
| 771 |
+
|
| 772 |
+
Args:
|
| 773 |
+
server_url: URL of the MCP server (default: http://localhost:6274/mcp)
|
| 774 |
+
retry_config: Optional retry configuration for handling rate limits
|
| 775 |
+
|
| 776 |
+
Returns:
|
| 777 |
+
MCPClient instance for direct tool calling with automatic retry on rate limits
|
| 778 |
+
"""
|
| 779 |
+
return MCPClient(server_url=server_url, retry_config=retry_config)
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
def create_mcp_tools_adapter(
|
| 783 |
+
server_url: str = "http://localhost:6274/mcp",
|
| 784 |
+
retry_config: Optional[RetryConfig] = None
|
| 785 |
+
) -> MCPToolsAdapter:
|
| 786 |
+
"""
|
| 787 |
+
Factory function to create an MCP Tools Adapter for backward compatibility with retry support.
|
| 788 |
+
|
| 789 |
+
Args:
|
| 790 |
+
server_url: URL of the MCP server (default: http://localhost:6274/mcp)
|
| 791 |
+
retry_config: Optional retry configuration for handling rate limits
|
| 792 |
+
|
| 793 |
+
Returns:
|
| 794 |
+
MCPToolsAdapter instance that behaves like MCPTools but uses MCP client with automatic retries
|
| 795 |
+
"""
|
| 796 |
+
return MCPToolsAdapter(server_url=server_url, retry_config=retry_config)
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
# Export for compatibility
|
| 800 |
+
__all__ = [
|
| 801 |
+
'MCPClientResult',
|
| 802 |
+
'MCPClient',
|
| 803 |
+
'MCPTool',
|
| 804 |
+
'RetryConfig',
|
| 805 |
+
'MCPToolsAdapter',
|
| 806 |
+
'FilteredMCPToolsAdapter',
|
| 807 |
+
'create_mcp_client',
|
| 808 |
+
'create_mcp_tools_adapter',
|
| 809 |
+
'create_filtered_mcp_tools_adapter',
|
| 810 |
+
'create_agent_mcp_tools', # RECOMMENDED for agents
|
| 811 |
+
'PLANNER_AGENT_TOOLS',
|
| 812 |
+
'INFORMATION_SEEKER_TOOLS',
|
| 813 |
+
'WRITER_AGENT_TOOLS'
|
| 814 |
+
]
|
deepdiver_v2/src/tools/mcp_server_standard.py
ADDED
|
@@ -0,0 +1,1751 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
#!/usr/bin/env python3
|
| 3 |
+
"""
|
| 4 |
+
Demo-Ready MCP Server - New Standard Implementation
|
| 5 |
+
Combines robust session management with comprehensive tool definitions.
|
| 6 |
+
Features: workspace isolation, tool call tracking, rate limiting, security, and full tool suite.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import asyncio
|
| 11 |
+
import json
|
| 12 |
+
import logging
|
| 13 |
+
import time
|
| 14 |
+
import uuid
|
| 15 |
+
import yaml
|
| 16 |
+
from collections import defaultdict, deque
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from datetime import datetime, timedelta
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from threading import Thread, Event
|
| 21 |
+
from typing import Any, Dict, List, Optional
|
| 22 |
+
|
| 23 |
+
# Third-party imports
|
| 24 |
+
from starlette.applications import Starlette
|
| 25 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 26 |
+
from starlette.requests import Request
|
| 27 |
+
from starlette.responses import JSONResponse, StreamingResponse
|
| 28 |
+
import uvicorn
|
| 29 |
+
|
| 30 |
+
# Add project root to Python path for imports
|
| 31 |
+
import sys
|
| 32 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 33 |
+
from src.utils.status_codes import JsonRpcErr
|
| 34 |
+
from http import HTTPStatus
|
| 35 |
+
|
| 36 |
+
# Handle both relative and absolute imports
|
| 37 |
+
try:
|
| 38 |
+
from .mcp_tools import MCPTools, get_tool_schemas
|
| 39 |
+
from .mcp_tools_async import AsyncMCPTools
|
| 40 |
+
except ImportError:
|
| 41 |
+
# Fallback for direct script execution
|
| 42 |
+
from src.tools.mcp_tools import MCPTools, get_tool_schemas
|
| 43 |
+
try:
|
| 44 |
+
from src.tools.mcp_tools_async import AsyncMCPTools
|
| 45 |
+
except ImportError:
|
| 46 |
+
AsyncMCPTools = None
|
| 47 |
+
|
| 48 |
+
# Workspace knowledge manager disabled
|
| 49 |
+
WORKSPACE_KNOWLEDGE_AVAILABLE = False
|
| 50 |
+
|
| 51 |
+
# Configure structured logging
|
| 52 |
+
logging.basicConfig(
|
| 53 |
+
level=logging.INFO,
|
| 54 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s',
|
| 55 |
+
handlers=[
|
| 56 |
+
logging.StreamHandler(sys.stdout),
|
| 57 |
+
logging.FileHandler('mcp_server.log')
|
| 58 |
+
]
|
| 59 |
+
)
|
| 60 |
+
logger = logging.getLogger(__name__)
|
| 61 |
+
|
| 62 |
+
# ================ CONFIGURATION ================
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class ServerConfig:
|
| 67 |
+
"""Server configuration with only actually implemented options"""
|
| 68 |
+
# Server Core Settings
|
| 69 |
+
host: str = "127.0.0.1"
|
| 70 |
+
port: int = 6274
|
| 71 |
+
debug_mode: bool = False
|
| 72 |
+
|
| 73 |
+
# Session Management
|
| 74 |
+
session_ttl_seconds: int = 3600 # 1 hour default
|
| 75 |
+
max_sessions: int = 1000
|
| 76 |
+
cleanup_interval_seconds: int = 300 # 5 minutes
|
| 77 |
+
enable_session_keepalive: bool = True
|
| 78 |
+
keepalive_touch_interval: int = 300
|
| 79 |
+
|
| 80 |
+
# Request Handling
|
| 81 |
+
request_timeout_seconds: int = 120
|
| 82 |
+
max_request_size_mb: int = 10
|
| 83 |
+
|
| 84 |
+
# Client Rate Limiting (per IP)
|
| 85 |
+
rate_limit_requests_per_minute: int = 300
|
| 86 |
+
|
| 87 |
+
# Workspace Management
|
| 88 |
+
base_workspace_dir: str = "workspaces"
|
| 89 |
+
|
| 90 |
+
# Tool Call Tracking & Logging
|
| 91 |
+
enable_tool_tracking: bool = True
|
| 92 |
+
max_tracked_calls_per_session: int = 1000
|
| 93 |
+
track_detailed_errors: bool = True
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Per-tool Rate Limiting Configuration
|
| 98 |
+
tool_rate_limits: Dict[str, Dict[str, int]] = field(default_factory=dict)
|
| 99 |
+
|
| 100 |
+
@classmethod
|
| 101 |
+
def from_yaml(cls, config_path: str) -> 'ServerConfig':
|
| 102 |
+
"""Load configuration from YAML file"""
|
| 103 |
+
try:
|
| 104 |
+
with open(config_path, 'r') as f:
|
| 105 |
+
config_data = yaml.safe_load(f)
|
| 106 |
+
|
| 107 |
+
# Extract configuration sections with defaults
|
| 108 |
+
server_config = config_data.get('server', {})
|
| 109 |
+
tracking_config = config_data.get('tracking', {})
|
| 110 |
+
tool_rate_limits = config_data.get('tool_rate_limits', {})
|
| 111 |
+
|
| 112 |
+
return cls(
|
| 113 |
+
# Server Core Settings
|
| 114 |
+
host=server_config.get('host', "127.0.0.1"),
|
| 115 |
+
port=server_config.get('port', 6274),
|
| 116 |
+
debug_mode=server_config.get('debug_mode', False),
|
| 117 |
+
|
| 118 |
+
# Session Management
|
| 119 |
+
session_ttl_seconds=server_config.get('session_ttl_seconds', 3600),
|
| 120 |
+
max_sessions=server_config.get('max_sessions', 1000),
|
| 121 |
+
cleanup_interval_seconds=server_config.get('cleanup_interval_seconds', 300),
|
| 122 |
+
enable_session_keepalive=server_config.get('enable_session_keepalive', True),
|
| 123 |
+
keepalive_touch_interval=server_config.get('keepalive_touch_interval', 300),
|
| 124 |
+
|
| 125 |
+
# Request Handling
|
| 126 |
+
request_timeout_seconds=server_config.get('request_timeout_seconds', 120),
|
| 127 |
+
max_request_size_mb=server_config.get('max_request_size_mb', 10),
|
| 128 |
+
|
| 129 |
+
# Client Rate Limiting
|
| 130 |
+
rate_limit_requests_per_minute=server_config.get('rate_limit_requests_per_minute', 300),
|
| 131 |
+
|
| 132 |
+
# Workspace Management
|
| 133 |
+
base_workspace_dir=server_config.get('base_workspace_dir', "workspaces"),
|
| 134 |
+
|
| 135 |
+
# Tool Call Tracking & Logging
|
| 136 |
+
enable_tool_tracking=tracking_config.get('enable_tool_tracking', True),
|
| 137 |
+
max_tracked_calls_per_session=tracking_config.get('max_tracked_calls_per_session', 1000),
|
| 138 |
+
track_detailed_errors=tracking_config.get('track_detailed_errors', True),
|
| 139 |
+
|
| 140 |
+
# Per-tool Rate Limiting
|
| 141 |
+
tool_rate_limits=tool_rate_limits
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"Failed to load configuration from {config_path}: {e}")
|
| 146 |
+
logger.info("Using default configuration")
|
| 147 |
+
return cls()
|
| 148 |
+
|
| 149 |
+
# Global configuration instance - will be set during startup
|
| 150 |
+
config: Optional[ServerConfig] = None
|
| 151 |
+
|
| 152 |
+
# ================ GLOBAL PER-TOOL RATE LIMITING ================
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@dataclass
|
| 156 |
+
class ToolRateLimit:
|
| 157 |
+
"""Rate limit configuration for a specific tool"""
|
| 158 |
+
requests_per_minute: float
|
| 159 |
+
requests_per_hour: float
|
| 160 |
+
burst_limit: int
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class GlobalToolRateLimiter:
|
| 164 |
+
"""
|
| 165 |
+
Global rate limiter that controls QPS to external APIs per tool.
|
| 166 |
+
This is shared across all sessions and clients to manage upstream service load.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
def __init__(self, tool_rate_limits: Dict[str, Dict[str, int]]):
|
| 170 |
+
self.tool_limits: Dict[str, ToolRateLimit] = {}
|
| 171 |
+
self.tool_requests: Dict[str, deque] = defaultdict(deque)
|
| 172 |
+
self.lock = asyncio.Lock()
|
| 173 |
+
|
| 174 |
+
# Initialize rate limits for each tool
|
| 175 |
+
for tool_name, limits_config in tool_rate_limits.items():
|
| 176 |
+
self.tool_limits[tool_name] = ToolRateLimit(
|
| 177 |
+
requests_per_minute=limits_config.get('requests_per_minute', float('inf')),
|
| 178 |
+
requests_per_hour=limits_config.get('requests_per_hour', float('inf')),
|
| 179 |
+
burst_limit=limits_config.get('burst_limit', 10)
|
| 180 |
+
)
|
| 181 |
+
self.tool_requests[tool_name] = deque()
|
| 182 |
+
|
| 183 |
+
logger.info(f"Initialized global tool rate limiter for {len(self.tool_limits)} tools")
|
| 184 |
+
|
| 185 |
+
async def is_allowed(self, tool_name: str) -> tuple[bool, Optional[str]]:
|
| 186 |
+
"""
|
| 187 |
+
Check if a request to the specified tool is allowed based on global rate limits.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
tuple[bool, Optional[str]]: (allowed, reason_if_denied)
|
| 191 |
+
"""
|
| 192 |
+
if tool_name not in self.tool_limits:
|
| 193 |
+
# Tool not configured for rate limiting - allow
|
| 194 |
+
return True, None
|
| 195 |
+
|
| 196 |
+
async with self.lock:
|
| 197 |
+
now = time.time()
|
| 198 |
+
limits = self.tool_limits[tool_name]
|
| 199 |
+
requests = self.tool_requests[tool_name]
|
| 200 |
+
|
| 201 |
+
# Clean old requests outside the time windows
|
| 202 |
+
self._cleanup_old_requests(requests, now)
|
| 203 |
+
|
| 204 |
+
# Check various time window limits
|
| 205 |
+
recent_requests = list(requests)
|
| 206 |
+
|
| 207 |
+
# Check burst limit (rapid requests in last second) - only if specified
|
| 208 |
+
if limits.burst_limit != float('inf'):
|
| 209 |
+
burst_count = sum(1 for req_time in recent_requests if now - req_time < 1.0)
|
| 210 |
+
if burst_count >= limits.burst_limit:
|
| 211 |
+
return False, f"Tool '{tool_name}' burst limit exceeded ({limits.burst_limit} requests/burst)"
|
| 212 |
+
|
| 213 |
+
# Check per-minute limit - only if specified
|
| 214 |
+
if limits.requests_per_minute != float('inf'):
|
| 215 |
+
minute_count = sum(1 for req_time in recent_requests if now - req_time < 60.0)
|
| 216 |
+
if minute_count >= limits.requests_per_minute:
|
| 217 |
+
return False, f"Tool '{tool_name}' per-minute limit exceeded ({limits.requests_per_minute} requests/minute)"
|
| 218 |
+
|
| 219 |
+
# Check per-hour limit - only if specified
|
| 220 |
+
if limits.requests_per_hour != float('inf'):
|
| 221 |
+
hour_count = sum(1 for req_time in recent_requests if now - req_time < 3600.0)
|
| 222 |
+
if hour_count >= limits.requests_per_hour:
|
| 223 |
+
return False, f"Tool '{tool_name}' per-hour limit exceeded ({limits.requests_per_hour} requests/hour)"
|
| 224 |
+
|
| 225 |
+
return True, None
|
| 226 |
+
|
| 227 |
+
async def record_request(self, tool_name: str):
|
| 228 |
+
"""Record a successful request for rate limiting tracking"""
|
| 229 |
+
if tool_name not in self.tool_limits:
|
| 230 |
+
return
|
| 231 |
+
|
| 232 |
+
async with self.lock:
|
| 233 |
+
now = time.time()
|
| 234 |
+
self.tool_requests[tool_name].append(now)
|
| 235 |
+
|
| 236 |
+
# Keep deque size manageable (only keep last hour of requests)
|
| 237 |
+
self._cleanup_old_requests(self.tool_requests[tool_name], now)
|
| 238 |
+
|
| 239 |
+
@staticmethod
|
| 240 |
+
def _cleanup_old_requests(requests: deque, now: float):
|
| 241 |
+
"""Remove requests older than 1 hour to keep memory usage bounded"""
|
| 242 |
+
while requests and now - requests[0] > 3600.0: # 1 hour
|
| 243 |
+
requests.popleft()
|
| 244 |
+
|
| 245 |
+
async def get_tool_stats(self, tool_name: str) -> Dict[str, Any]:
|
| 246 |
+
"""Get current usage statistics for a tool"""
|
| 247 |
+
if tool_name not in self.tool_limits:
|
| 248 |
+
return {"error": f"Tool '{tool_name}' not configured for rate limiting"}
|
| 249 |
+
|
| 250 |
+
async with self.lock:
|
| 251 |
+
now = time.time()
|
| 252 |
+
requests = self.tool_requests[tool_name]
|
| 253 |
+
limits = self.tool_limits[tool_name]
|
| 254 |
+
|
| 255 |
+
# Clean old requests first
|
| 256 |
+
self._cleanup_old_requests(requests, now)
|
| 257 |
+
|
| 258 |
+
recent_requests = list(requests)
|
| 259 |
+
|
| 260 |
+
return {
|
| 261 |
+
"tool_name": tool_name,
|
| 262 |
+
"current_usage": {
|
| 263 |
+
"last_second": sum(1 for req_time in recent_requests if now - req_time < 1.0),
|
| 264 |
+
"last_minute": sum(1 for req_time in recent_requests if now - req_time < 60.0),
|
| 265 |
+
"last_hour": sum(1 for req_time in recent_requests if now - req_time < 3600.0)
|
| 266 |
+
},
|
| 267 |
+
"limits": {
|
| 268 |
+
"requests_per_minute": limits.requests_per_minute if limits.requests_per_minute != float('inf') else None,
|
| 269 |
+
"requests_per_hour": limits.requests_per_hour if limits.requests_per_hour != float('inf') else None,
|
| 270 |
+
"burst_limit": limits.burst_limit if limits.burst_limit != float('inf') else None
|
| 271 |
+
},
|
| 272 |
+
"utilization": {
|
| 273 |
+
"minute_utilization": sum(1 for req_time in recent_requests if now - req_time < 60.0) / limits.requests_per_minute if limits.requests_per_minute != float('inf') else 0,
|
| 274 |
+
"hour_utilization": sum(1 for req_time in recent_requests if now - req_time < 3600.0) / limits.requests_per_hour if limits.requests_per_hour != float('inf') else 0
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
def get_all_stats(self) -> Dict[str, Any]:
|
| 279 |
+
"""Get usage statistics for all tools"""
|
| 280 |
+
return {
|
| 281 |
+
tool_name: self.get_tool_stats(tool_name)
|
| 282 |
+
for tool_name in self.tool_limits.keys()
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
# Global tool rate limiter instance - will be initialized during startup
|
| 286 |
+
global_tool_rate_limiter: Optional[GlobalToolRateLimiter] = None
|
| 287 |
+
|
| 288 |
+
# ================ TOOL DEFINITIONS ================
|
| 289 |
+
|
| 290 |
+
# Tool execution function mapping - maps tool names to their implementation functions
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def get_tool_function(tool_name: str):
|
| 294 |
+
"""Get the actual function for a tool"""
|
| 295 |
+
tool_map = {
|
| 296 |
+
"batch_web_search": lambda tools, **kwargs: tools.batch_web_search(**kwargs),
|
| 297 |
+
"url_crawler": lambda tools, **kwargs: tools.url_crawler(**kwargs),
|
| 298 |
+
"download_files": lambda tools, **kwargs: tools.download_files(**kwargs),
|
| 299 |
+
"list_workspace": lambda tools, **kwargs: tools.list_workspace(**kwargs),
|
| 300 |
+
"str_replace_based_edit_tool": lambda tools, **kwargs: tools.str_replace_based_edit_tool(**kwargs),
|
| 301 |
+
"file_stats": lambda tools, **kwargs: tools.file_stats(**kwargs),
|
| 302 |
+
"file_read": lambda tools, **kwargs: tools.file_read(**kwargs),
|
| 303 |
+
"file_read_lines": lambda tools, **kwargs: tools.file_read_lines(**kwargs),
|
| 304 |
+
"content_preview": lambda tools, **kwargs: tools.content_preview(**kwargs),
|
| 305 |
+
"file_write": lambda tools, **kwargs: tools.file_write(**kwargs),
|
| 306 |
+
"file_grep_search": lambda tools, **kwargs: tools.file_grep_search(**kwargs),
|
| 307 |
+
"file_grep_with_context": lambda tools, **kwargs: tools.file_grep_with_context(**kwargs),
|
| 308 |
+
"file_find_by_name": lambda tools, **kwargs: tools.file_find_by_name(**kwargs),
|
| 309 |
+
"bash": lambda tools, **kwargs: tools.bash(**kwargs),
|
| 310 |
+
"task_done": lambda tools, **kwargs: tools.task_done(**kwargs),
|
| 311 |
+
"think": lambda tools, **kwargs: tools.think(**kwargs),
|
| 312 |
+
"reflect": lambda tools, **kwargs: tools.reflect(**kwargs),
|
| 313 |
+
"document_qa": lambda tools, **kwargs: tools.document_qa(**kwargs),
|
| 314 |
+
"extract_markdown_toc": lambda tools, **kwargs: tools.extract_markdown_toc(**kwargs),
|
| 315 |
+
"extract_markdown_section": lambda tools, **kwargs: tools.extract_markdown_section(**kwargs),
|
| 316 |
+
|
| 317 |
+
"document_extract": lambda tools, **kwargs: tools.document_extract(**kwargs),
|
| 318 |
+
"search_result_classifier": lambda tools, **kwargs: tools.search_result_classifier(**kwargs),
|
| 319 |
+
"info_seeker_subjective_task_done": None,
|
| 320 |
+
"writer_subjective_task_done": None,
|
| 321 |
+
"section_writer": lambda tools, **kwargs: tools.section_writer(**kwargs),
|
| 322 |
+
"concat_section_files": lambda tools, **kwargs: tools.concat_section_files(**kwargs),
|
| 323 |
+
|
| 324 |
+
# Internal tools - available to server but NOT exposed to agents via tool schemas
|
| 325 |
+
"internal_file_read_unlimited": lambda tools, **kwargs: tools.internal_file_read_unlimited(**kwargs),
|
| 326 |
+
}
|
| 327 |
+
return tool_map.get(tool_name)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
# ================ TOOL CALL TRACKING ================
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
@dataclass
|
| 334 |
+
class ToolCallLog:
|
| 335 |
+
"""Individual tool call log entry"""
|
| 336 |
+
call_id: str
|
| 337 |
+
timestamp: datetime
|
| 338 |
+
tool_name: str
|
| 339 |
+
input_args: Dict[str, Any]
|
| 340 |
+
output_result: Dict[str, Any]
|
| 341 |
+
success: bool
|
| 342 |
+
duration_ms: float
|
| 343 |
+
error_details: Optional[str] = None
|
| 344 |
+
session_id: str = ""
|
| 345 |
+
agent_info: Optional[Dict[str, Any]] = None
|
| 346 |
+
|
| 347 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 348 |
+
"""Convert to dictionary for JSON serialization"""
|
| 349 |
+
return {
|
| 350 |
+
"call_id": self.call_id,
|
| 351 |
+
"timestamp": self.timestamp.isoformat(),
|
| 352 |
+
"tool_name": self.tool_name,
|
| 353 |
+
"input_args": self.input_args,
|
| 354 |
+
"output_result": self.output_result,
|
| 355 |
+
"success": self.success,
|
| 356 |
+
"duration_ms": self.duration_ms,
|
| 357 |
+
"error_details": self.error_details,
|
| 358 |
+
"session_id": self.session_id,
|
| 359 |
+
"agent_info": self.agent_info
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class ToolCallTracker:
|
| 364 |
+
"""Tracks and saves tool calls to workspace-specific files"""
|
| 365 |
+
|
| 366 |
+
def __init__(self, workspace_path: Path, session_id: str):
|
| 367 |
+
self.workspace_path = workspace_path
|
| 368 |
+
self.session_id = session_id
|
| 369 |
+
self.logs_dir = workspace_path / "tool_call_logs"
|
| 370 |
+
self.logs_dir.mkdir(exist_ok=True)
|
| 371 |
+
|
| 372 |
+
# Create daily log file
|
| 373 |
+
today = datetime.now().strftime("%Y-%m-%d")
|
| 374 |
+
self.current_log_file = self.logs_dir / f"tool_calls_{today}.jsonl"
|
| 375 |
+
self.summary_file = self.logs_dir / "session_summary.json"
|
| 376 |
+
|
| 377 |
+
# Track call counts
|
| 378 |
+
self.call_count = 0
|
| 379 |
+
self.tool_usage_stats = defaultdict(int)
|
| 380 |
+
|
| 381 |
+
# Initialize session summary
|
| 382 |
+
self._initialize_session_summary()
|
| 383 |
+
|
| 384 |
+
def _initialize_session_summary(self):
|
| 385 |
+
"""Initialize or update session summary file"""
|
| 386 |
+
summary = {
|
| 387 |
+
"session_id": self.session_id,
|
| 388 |
+
"session_start": datetime.now().isoformat(),
|
| 389 |
+
"last_updated": datetime.now().isoformat(),
|
| 390 |
+
"total_tool_calls": 0,
|
| 391 |
+
"tool_usage_stats": {},
|
| 392 |
+
"agent_activity": {},
|
| 393 |
+
"workspace_path": str(self.workspace_path)
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
# Load existing summary if it exists
|
| 397 |
+
if self.summary_file.exists():
|
| 398 |
+
try:
|
| 399 |
+
with open(self.summary_file, 'r') as f:
|
| 400 |
+
existing_summary = json.load(f)
|
| 401 |
+
summary.update(existing_summary)
|
| 402 |
+
# Don't overwrite session_start if it already exists
|
| 403 |
+
if "session_start" in existing_summary:
|
| 404 |
+
summary["session_start"] = existing_summary["session_start"]
|
| 405 |
+
except Exception as e:
|
| 406 |
+
logger.warning(f"Could not load existing session summary: {e}")
|
| 407 |
+
|
| 408 |
+
self._save_summary(summary)
|
| 409 |
+
|
| 410 |
+
def _save_summary(self, summary: Dict[str, Any]):
|
| 411 |
+
"""Save session summary to file"""
|
| 412 |
+
try:
|
| 413 |
+
with open(self.summary_file, 'w') as f:
|
| 414 |
+
json.dump(summary, f, indent=2, ensure_ascii=False)
|
| 415 |
+
except Exception as e:
|
| 416 |
+
logger.error(f"Failed to save session summary: {e}")
|
| 417 |
+
|
| 418 |
+
def log_tool_call(self,
|
| 419 |
+
tool_name: str,
|
| 420 |
+
input_args: Dict[str, Any],
|
| 421 |
+
output_result: Dict[str, Any],
|
| 422 |
+
success: bool,
|
| 423 |
+
duration_ms: float,
|
| 424 |
+
error_details: Optional[str] = None,
|
| 425 |
+
agent_info: Optional[Dict[str, Any]] = None) -> str:
|
| 426 |
+
"""Log a tool call and return the call ID"""
|
| 427 |
+
|
| 428 |
+
if not config.enable_tool_tracking:
|
| 429 |
+
return ""
|
| 430 |
+
|
| 431 |
+
# Respect max call limit per session
|
| 432 |
+
if self.call_count >= config.max_tracked_calls_per_session:
|
| 433 |
+
logger.warning(f"Max tracked calls reached for session {self.session_id}")
|
| 434 |
+
return ""
|
| 435 |
+
|
| 436 |
+
call_id = str(uuid.uuid4())
|
| 437 |
+
timestamp = datetime.now()
|
| 438 |
+
|
| 439 |
+
# Create log entry
|
| 440 |
+
log_entry = ToolCallLog(
|
| 441 |
+
call_id=call_id,
|
| 442 |
+
timestamp=timestamp,
|
| 443 |
+
tool_name=tool_name,
|
| 444 |
+
input_args=self._sanitize_args(input_args),
|
| 445 |
+
output_result=self._sanitize_result(output_result),
|
| 446 |
+
success=success,
|
| 447 |
+
duration_ms=duration_ms,
|
| 448 |
+
error_details=error_details if config.track_detailed_errors else None,
|
| 449 |
+
session_id=self.session_id,
|
| 450 |
+
agent_info=agent_info
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# Save to JSONL file (one JSON object per line)
|
| 454 |
+
try:
|
| 455 |
+
with open(self.current_log_file, 'a', encoding="utf-8") as f:
|
| 456 |
+
f.write(json.dumps(log_entry.to_dict(), ensure_ascii=False) + '\n')
|
| 457 |
+
except Exception as e:
|
| 458 |
+
logger.error(f"Failed to save tool call log: {e}")
|
| 459 |
+
|
| 460 |
+
# Update session summary
|
| 461 |
+
self._update_session_summary(log_entry)
|
| 462 |
+
|
| 463 |
+
self.call_count += 1
|
| 464 |
+
self.tool_usage_stats[tool_name] += 1
|
| 465 |
+
|
| 466 |
+
return call_id
|
| 467 |
+
|
| 468 |
+
@staticmethod
|
| 469 |
+
def _sanitize_args(args: Dict[str, Any]) -> Dict[str, Any]:
|
| 470 |
+
"""Sanitize arguments for logging (remove sensitive data)"""
|
| 471 |
+
sanitized = {}
|
| 472 |
+
for key, value in args.items():
|
| 473 |
+
if isinstance(value, str) and len(value) > 1000:
|
| 474 |
+
sanitized[key] = value[:1000] + "... [truncated]"
|
| 475 |
+
elif key.lower() in ['password', 'token', 'secret', 'key']:
|
| 476 |
+
sanitized[key] = "[REDACTED]"
|
| 477 |
+
else:
|
| 478 |
+
sanitized[key] = value
|
| 479 |
+
return sanitized
|
| 480 |
+
|
| 481 |
+
def _sanitize_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
| 482 |
+
"""Sanitize result for logging (remove large content)"""
|
| 483 |
+
if not isinstance(result, dict):
|
| 484 |
+
return result
|
| 485 |
+
|
| 486 |
+
sanitized = {}
|
| 487 |
+
for key, value in result.items():
|
| 488 |
+
if isinstance(value, str) and len(value) > 2000:
|
| 489 |
+
sanitized[key] = value[:2000] + "... [truncated]"
|
| 490 |
+
elif isinstance(value, dict):
|
| 491 |
+
sanitized[key] = self._sanitize_result(value)
|
| 492 |
+
else:
|
| 493 |
+
sanitized[key] = value
|
| 494 |
+
return sanitized
|
| 495 |
+
|
| 496 |
+
def _update_session_summary(self, log_entry: ToolCallLog):
|
| 497 |
+
"""Update session summary with new tool call"""
|
| 498 |
+
try:
|
| 499 |
+
summary = {
|
| 500 |
+
"session_id": self.session_id,
|
| 501 |
+
"last_updated": datetime.now().isoformat(),
|
| 502 |
+
"total_tool_calls": self.call_count + 1,
|
| 503 |
+
"tool_usage_stats": dict(self.tool_usage_stats),
|
| 504 |
+
"workspace_path": str(self.workspace_path)
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
# Load existing summary
|
| 508 |
+
if self.summary_file.exists():
|
| 509 |
+
with open(self.summary_file, 'r') as f:
|
| 510 |
+
existing_summary = json.load(f)
|
| 511 |
+
summary.update(existing_summary)
|
| 512 |
+
|
| 513 |
+
# Update with new data
|
| 514 |
+
summary["last_updated"] = datetime.now().isoformat()
|
| 515 |
+
summary["total_tool_calls"] = self.call_count + 1
|
| 516 |
+
summary["tool_usage_stats"] = dict(self.tool_usage_stats)
|
| 517 |
+
summary["tool_usage_stats"][log_entry.tool_name] = self.tool_usage_stats[log_entry.tool_name] + 1
|
| 518 |
+
|
| 519 |
+
# Track agent activity
|
| 520 |
+
if log_entry.agent_info:
|
| 521 |
+
agent_type = log_entry.agent_info.get('type', 'unknown')
|
| 522 |
+
if 'agent_activity' not in summary:
|
| 523 |
+
summary['agent_activity'] = {}
|
| 524 |
+
if agent_type not in summary['agent_activity']:
|
| 525 |
+
summary['agent_activity'][agent_type] = {
|
| 526 |
+
'tool_calls': 0,
|
| 527 |
+
'last_active': log_entry.timestamp.isoformat()
|
| 528 |
+
}
|
| 529 |
+
summary['agent_activity'][agent_type]['tool_calls'] += 1
|
| 530 |
+
summary['agent_activity'][agent_type]['last_active'] = log_entry.timestamp.isoformat()
|
| 531 |
+
|
| 532 |
+
self._save_summary(summary)
|
| 533 |
+
|
| 534 |
+
except Exception as e:
|
| 535 |
+
logger.error(f"Failed to update session summary: {e}")
|
| 536 |
+
|
| 537 |
+
# ================ SESSION KEEP-ALIVE FOR LONG OPERATIONS ================
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
class KeepAliveSessionWrapper:
|
| 541 |
+
"""Wrapper that keeps a session alive during long-running operations"""
|
| 542 |
+
|
| 543 |
+
def __init__(self, session: 'Session', touch_interval: int = 300): # Touch every 5 minutes
|
| 544 |
+
self.session = session
|
| 545 |
+
self.touch_interval = touch_interval
|
| 546 |
+
self.keep_alive_thread = None
|
| 547 |
+
self.stop_event = Event()
|
| 548 |
+
self.active = False
|
| 549 |
+
|
| 550 |
+
def start_keep_alive(self):
|
| 551 |
+
"""Start the keep-alive mechanism"""
|
| 552 |
+
if self.active:
|
| 553 |
+
return
|
| 554 |
+
|
| 555 |
+
self.active = True
|
| 556 |
+
self.stop_event.clear()
|
| 557 |
+
|
| 558 |
+
def keep_alive_worker():
|
| 559 |
+
while not self.stop_event.wait(self.touch_interval):
|
| 560 |
+
try:
|
| 561 |
+
self.session.touch()
|
| 562 |
+
logger.debug("Keep-alive: Touched session {%s}", self.session.id)
|
| 563 |
+
except Exception as e:
|
| 564 |
+
logger.error(f"Keep-alive error for session {self.session.id}: {e}")
|
| 565 |
+
break
|
| 566 |
+
|
| 567 |
+
self.keep_alive_thread = Thread(target=keep_alive_worker, daemon=True)
|
| 568 |
+
self.keep_alive_thread.start()
|
| 569 |
+
logger.info(f"Started keep-alive for session {self.session.id}")
|
| 570 |
+
|
| 571 |
+
def stop_keep_alive(self):
|
| 572 |
+
"""Stop the keep-alive mechanism"""
|
| 573 |
+
if not self.active:
|
| 574 |
+
return
|
| 575 |
+
|
| 576 |
+
self.active = False
|
| 577 |
+
self.stop_event.set()
|
| 578 |
+
|
| 579 |
+
if self.keep_alive_thread and self.keep_alive_thread.is_alive():
|
| 580 |
+
self.keep_alive_thread.join(timeout=1.0)
|
| 581 |
+
|
| 582 |
+
# Final touch
|
| 583 |
+
try:
|
| 584 |
+
self.session.touch()
|
| 585 |
+
except Exception as e:
|
| 586 |
+
logger.error(f"Final keep-alive touch error for session {self.session.id}: {e}")
|
| 587 |
+
|
| 588 |
+
logger.info(f"Stopped keep-alive for session {self.session.id}")
|
| 589 |
+
|
| 590 |
+
def __enter__(self):
|
| 591 |
+
self.start_keep_alive()
|
| 592 |
+
return self
|
| 593 |
+
|
| 594 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 595 |
+
self.stop_keep_alive()
|
| 596 |
+
|
| 597 |
+
# ================ SESSION MANAGEMENT ================
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
@dataclass
|
| 601 |
+
class Session:
|
| 602 |
+
"""Thread-safe session data structure with workspace management"""
|
| 603 |
+
id: str
|
| 604 |
+
created_at: datetime
|
| 605 |
+
last_accessed: datetime
|
| 606 |
+
initialized: bool = False
|
| 607 |
+
request_count: int = 0
|
| 608 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 609 |
+
workspace_path: Optional[Path] = None
|
| 610 |
+
mcp_tools: Optional[MCPTools] = None
|
| 611 |
+
tool_tracker: Optional[ToolCallTracker] = None
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def is_expired(self, ttl_seconds: int) -> bool:
|
| 615 |
+
"""Check if session has expired"""
|
| 616 |
+
return datetime.now() - self.last_accessed > timedelta(seconds=ttl_seconds)
|
| 617 |
+
|
| 618 |
+
def touch(self):
|
| 619 |
+
"""Update last accessed time"""
|
| 620 |
+
self.last_accessed = datetime.now()
|
| 621 |
+
self.request_count += 1
|
| 622 |
+
|
| 623 |
+
def get_mcp_tools(self, prefer_async: bool = True) -> MCPTools:
|
| 624 |
+
"""Get or create MCP tools instance for this session"""
|
| 625 |
+
if self.mcp_tools is None:
|
| 626 |
+
# Use async tools if available and preferred
|
| 627 |
+
if prefer_async and AsyncMCPTools is not None:
|
| 628 |
+
self.mcp_tools = AsyncMCPTools(workspace_path=str(self.workspace_path) if self.workspace_path else None)
|
| 629 |
+
else:
|
| 630 |
+
self.mcp_tools = MCPTools(workspace_path=str(self.workspace_path) if self.workspace_path else None)
|
| 631 |
+
return self.mcp_tools
|
| 632 |
+
|
| 633 |
+
def get_tool_tracker(self) -> Optional[ToolCallTracker]:
|
| 634 |
+
"""Get or create tool call tracker for this session"""
|
| 635 |
+
if config.enable_tool_tracking and self.workspace_path:
|
| 636 |
+
if self.tool_tracker is None:
|
| 637 |
+
self.tool_tracker = ToolCallTracker(self.workspace_path, self.id)
|
| 638 |
+
return self.tool_tracker
|
| 639 |
+
return None
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
class AsyncRLock:
|
| 644 |
+
"""异步可重入锁,模拟 threading.RLock 的异步版本"""
|
| 645 |
+
def __init__(self):
|
| 646 |
+
self._lock = asyncio.Lock()
|
| 647 |
+
self._owner: Optional[asyncio.Task] = None # 记录持有锁的协程任务
|
| 648 |
+
self._count = 0 # 重入次数
|
| 649 |
+
|
| 650 |
+
async def acquire(self):
|
| 651 |
+
current_task = asyncio.current_task()
|
| 652 |
+
# 如果当前协程已持有锁,直接增加重入次数
|
| 653 |
+
if self._owner == current_task:
|
| 654 |
+
self._count += 1
|
| 655 |
+
return
|
| 656 |
+
# 否则等待获取锁
|
| 657 |
+
await self._lock.acquire()
|
| 658 |
+
self._owner = current_task
|
| 659 |
+
self._count = 1
|
| 660 |
+
|
| 661 |
+
async def release(self):
|
| 662 |
+
if self._owner != asyncio.current_task():
|
| 663 |
+
raise RuntimeError("不能释放非当前协程持有的锁")
|
| 664 |
+
self._count -= 1
|
| 665 |
+
if self._count == 0: # 重入次数归零时,真正释放锁
|
| 666 |
+
self._owner = None
|
| 667 |
+
self._lock.release()
|
| 668 |
+
|
| 669 |
+
# 支持 async with 语法
|
| 670 |
+
async def __aenter__(self):
|
| 671 |
+
await self.acquire()
|
| 672 |
+
return self
|
| 673 |
+
|
| 674 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 675 |
+
await self.release()
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
class ThreadSafeSessionManager:
|
| 679 |
+
"""Thread-safe session manager with workspace management"""
|
| 680 |
+
|
| 681 |
+
def __init__(self, ttl_seconds: int = 3600, max_sessions: int = 1000, base_workspace_dir: str = "workspaces"):
|
| 682 |
+
self.ttl_seconds = ttl_seconds
|
| 683 |
+
self.max_sessions = max_sessions
|
| 684 |
+
self.base_workspace_dir = Path(base_workspace_dir)
|
| 685 |
+
self.base_workspace_dir.mkdir(exist_ok=True)
|
| 686 |
+
|
| 687 |
+
# Thread-safe session storage
|
| 688 |
+
self.sessions: Dict[str, Session] = {}
|
| 689 |
+
self.lock = AsyncRLock()
|
| 690 |
+
|
| 691 |
+
# Start cleanup thread
|
| 692 |
+
self._start_cleanup_thread()
|
| 693 |
+
|
| 694 |
+
async def create_session(self) -> str:
|
| 695 |
+
"""Create a new session and return session ID"""
|
| 696 |
+
session_id = str(uuid.uuid4())
|
| 697 |
+
|
| 698 |
+
async with self.lock:
|
| 699 |
+
# Check session limits
|
| 700 |
+
if len(self.sessions) >= self.max_sessions:
|
| 701 |
+
await self._cleanup_oldest_sessions()
|
| 702 |
+
|
| 703 |
+
# Create workspace directory
|
| 704 |
+
workspace_path = self.base_workspace_dir / session_id
|
| 705 |
+
workspace_path.mkdir(exist_ok=True, parents=True)
|
| 706 |
+
|
| 707 |
+
# Create session
|
| 708 |
+
session = Session(
|
| 709 |
+
id=session_id,
|
| 710 |
+
created_at=datetime.now(),
|
| 711 |
+
last_accessed=datetime.now(),
|
| 712 |
+
workspace_path=workspace_path
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
self.sessions[session_id] = session
|
| 716 |
+
|
| 717 |
+
logger.info(f"Created session {session_id} with workspace {workspace_path}")
|
| 718 |
+
return session_id
|
| 719 |
+
|
| 720 |
+
async def get_session(self, session_id: str) -> Optional[Session]:
|
| 721 |
+
"""Get session by ID if it exists and is not expired"""
|
| 722 |
+
async with self.lock:
|
| 723 |
+
session = self.sessions.get(session_id)
|
| 724 |
+
if session and not session.is_expired(self.ttl_seconds):
|
| 725 |
+
session.touch()
|
| 726 |
+
return session
|
| 727 |
+
elif session:
|
| 728 |
+
# Remove expired session
|
| 729 |
+
del self.sessions[session_id]
|
| 730 |
+
logger.info(f"Removed expired session {session_id}")
|
| 731 |
+
return None
|
| 732 |
+
|
| 733 |
+
async def get_or_create_session(self, session_id: Optional[str] = None) -> Session:
|
| 734 |
+
"""Get existing session or create new one"""
|
| 735 |
+
if session_id:
|
| 736 |
+
session = await self.get_session(session_id)
|
| 737 |
+
if session:
|
| 738 |
+
return session
|
| 739 |
+
|
| 740 |
+
# Create new session
|
| 741 |
+
new_session_id = await self.create_session()
|
| 742 |
+
return self.sessions[new_session_id]
|
| 743 |
+
|
| 744 |
+
async def _cleanup_expired_sessions(self):
|
| 745 |
+
"""Remove expired sessions"""
|
| 746 |
+
async with self.lock:
|
| 747 |
+
expired_sessions = []
|
| 748 |
+
for session_id, session in self.sessions.items():
|
| 749 |
+
if session.is_expired(self.ttl_seconds):
|
| 750 |
+
expired_sessions.append(session_id)
|
| 751 |
+
|
| 752 |
+
for session_id in expired_sessions:
|
| 753 |
+
del self.sessions[session_id]
|
| 754 |
+
logger.info(f"Cleaned up expired session {session_id}")
|
| 755 |
+
|
| 756 |
+
async def _cleanup_oldest_sessions(self):
|
| 757 |
+
"""Remove oldest sessions when limit is reached"""
|
| 758 |
+
async with self.lock:
|
| 759 |
+
if len(self.sessions) < self.max_sessions:
|
| 760 |
+
return
|
| 761 |
+
|
| 762 |
+
# Sort by last accessed time and remove oldest
|
| 763 |
+
sorted_sessions = sorted(
|
| 764 |
+
self.sessions.items(),
|
| 765 |
+
key=lambda x: x[1].last_accessed
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
sessions_to_remove = len(self.sessions) - self.max_sessions + 10 # Remove extra
|
| 769 |
+
for i in range(sessions_to_remove):
|
| 770 |
+
if i < len(sorted_sessions):
|
| 771 |
+
session_id = sorted_sessions[i][0]
|
| 772 |
+
del self.sessions[session_id]
|
| 773 |
+
logger.info(f"Removed old session {session_id} due to session limit")
|
| 774 |
+
|
| 775 |
+
def _start_cleanup_thread(self):
|
| 776 |
+
"""Start background cleanup thread"""
|
| 777 |
+
def cleanup_worker():
|
| 778 |
+
while True:
|
| 779 |
+
try:
|
| 780 |
+
time.sleep(config.cleanup_interval_seconds)
|
| 781 |
+
# Run async method in sync context
|
| 782 |
+
loop = asyncio.new_event_loop()
|
| 783 |
+
loop.run_until_complete(self._cleanup_expired_sessions())
|
| 784 |
+
loop.close()
|
| 785 |
+
except Exception as e:
|
| 786 |
+
logger.error(f"Error in cleanup thread: {e}")
|
| 787 |
+
|
| 788 |
+
import threading
|
| 789 |
+
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
|
| 790 |
+
cleanup_thread.start()
|
| 791 |
+
logger.info("Started session cleanup thread")
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
async def get_stats(self) -> Dict[str, Any]:
|
| 796 |
+
"""Get session manager statistics"""
|
| 797 |
+
async with self.lock:
|
| 798 |
+
return {
|
| 799 |
+
"total_sessions": len(self.sessions),
|
| 800 |
+
"max_sessions": self.max_sessions,
|
| 801 |
+
"ttl_seconds": self.ttl_seconds,
|
| 802 |
+
"session_ids": list(self.sessions.keys())
|
| 803 |
+
}
|
| 804 |
+
|
| 805 |
+
# ================ MIDDLEWARE AND SECURITY ================
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
class RateLimiter:
|
| 809 |
+
"""Simple rate limiter with time-window tracking"""
|
| 810 |
+
|
| 811 |
+
def __init__(self, requests_per_minute: int = 60):
|
| 812 |
+
self.requests_per_minute = requests_per_minute
|
| 813 |
+
self.requests: Dict[str, List[float]] = defaultdict(list)
|
| 814 |
+
self.lock = asyncio.Lock()
|
| 815 |
+
|
| 816 |
+
async def is_allowed(self, client_id: str) -> bool:
|
| 817 |
+
"""Check if request is allowed for client"""
|
| 818 |
+
async with self.lock:
|
| 819 |
+
now = time.time()
|
| 820 |
+
minute_ago = now - 60
|
| 821 |
+
|
| 822 |
+
# Clean old requests
|
| 823 |
+
self.requests[client_id] = [
|
| 824 |
+
req_time for req_time in self.requests[client_id]
|
| 825 |
+
if req_time > minute_ago
|
| 826 |
+
]
|
| 827 |
+
|
| 828 |
+
# Check rate limit
|
| 829 |
+
if len(self.requests[client_id]) >= self.requests_per_minute:
|
| 830 |
+
return False
|
| 831 |
+
|
| 832 |
+
# Add current request
|
| 833 |
+
self.requests[client_id].append(now)
|
| 834 |
+
return True
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
class RequestValidator:
|
| 838 |
+
"""Validates incoming MCP requests"""
|
| 839 |
+
|
| 840 |
+
@staticmethod
|
| 841 |
+
def validate_mcp_request(data: Dict[str, Any]) -> tuple[bool, Optional[str]]:
|
| 842 |
+
"""Validate basic MCP request structure"""
|
| 843 |
+
if not isinstance(data, dict):
|
| 844 |
+
return False, "Request must be a JSON object"
|
| 845 |
+
|
| 846 |
+
if "method" not in data:
|
| 847 |
+
return False, "Missing 'method' field"
|
| 848 |
+
|
| 849 |
+
if "id" not in data:
|
| 850 |
+
return False, "Missing 'id' field"
|
| 851 |
+
|
| 852 |
+
return True, None
|
| 853 |
+
|
| 854 |
+
@staticmethod
|
| 855 |
+
def validate_tool_call(params: Dict[str, Any]) -> tuple[bool, Optional[str]]:
|
| 856 |
+
"""Validate tool call parameters"""
|
| 857 |
+
if not isinstance(params, dict):
|
| 858 |
+
return False, "Tool parameters must be a JSON object"
|
| 859 |
+
|
| 860 |
+
if "name" not in params:
|
| 861 |
+
return False, "Missing tool 'name'"
|
| 862 |
+
|
| 863 |
+
if "arguments" not in params:
|
| 864 |
+
return False, "Missing tool 'arguments'"
|
| 865 |
+
|
| 866 |
+
tool_name = params["name"]
|
| 867 |
+
|
| 868 |
+
# Get detailed schemas
|
| 869 |
+
detailed_schemas = get_tool_schemas()
|
| 870 |
+
|
| 871 |
+
if tool_name not in detailed_schemas:
|
| 872 |
+
return False, f"Unknown tool: {tool_name}. Available tools: {sorted(list(detailed_schemas.keys()))}"
|
| 873 |
+
|
| 874 |
+
return True, None
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
class SecurityMiddleware(BaseHTTPMiddleware):
|
| 878 |
+
"""Security middleware for basic protection"""
|
| 879 |
+
|
| 880 |
+
async def dispatch(self, request: Request, call_next):
|
| 881 |
+
# Check content length
|
| 882 |
+
content_length = request.headers.get("content-length")
|
| 883 |
+
if content_length and int(content_length) > config.max_request_size_mb * 1024 * 1024:
|
| 884 |
+
return JSONResponse(
|
| 885 |
+
status_code=HTTPStatus.REQUEST_ENTITY_TOO_LARGE,
|
| 886 |
+
content={"error": "Request too large"}
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
# Add security headers
|
| 890 |
+
response = await call_next(request)
|
| 891 |
+
response.headers["X-Content-Type-Options"] = "nosniff"
|
| 892 |
+
response.headers["X-Frame-Options"] = "DENY"
|
| 893 |
+
response.headers["X-XSS-Protection"] = "1; mode=block"
|
| 894 |
+
|
| 895 |
+
return response
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
class RateLimitMiddleware(BaseHTTPMiddleware):
|
| 899 |
+
"""Rate limiting middleware"""
|
| 900 |
+
|
| 901 |
+
def __init__(self, app, input_rate_limiter: RateLimiter):
|
| 902 |
+
super().__init__(app)
|
| 903 |
+
self.rate_limiter = input_rate_limiter
|
| 904 |
+
|
| 905 |
+
async def dispatch(self, request: Request, call_next):
|
| 906 |
+
# Get client identifier (IP address)
|
| 907 |
+
client_ip = request.client.host if request.client else "unknown"
|
| 908 |
+
|
| 909 |
+
if not await self.rate_limiter.is_allowed(client_ip):
|
| 910 |
+
return JSONResponse(
|
| 911 |
+
status_code=HTTPStatus.TOO_MANY_REQUESTS,
|
| 912 |
+
content={"error": "Rate limit exceeded"}
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
return await call_next(request)
|
| 916 |
+
|
| 917 |
+
# Global session manager
|
| 918 |
+
session_manager = None
|
| 919 |
+
rate_limiter = None
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
@dataclass
|
| 923 |
+
class RateLimitViolation:
|
| 924 |
+
"""Represents a rate limit violation with standardized error information"""
|
| 925 |
+
tool_name: str
|
| 926 |
+
limit_type: str # "burst", "second", "minute", "hour"
|
| 927 |
+
current_usage: int
|
| 928 |
+
limit_value: float
|
| 929 |
+
retry_after_seconds: float
|
| 930 |
+
|
| 931 |
+
def to_user_friendly_message(self) -> str:
|
| 932 |
+
"""Generate user-friendly error message"""
|
| 933 |
+
if self.limit_type == "burst":
|
| 934 |
+
return f"Service temporarily unavailable: Too many rapid requests to {self.tool_name}. Please wait {self.retry_after_seconds:.0f} seconds before trying again."
|
| 935 |
+
elif self.limit_type == "second":
|
| 936 |
+
return f"Service temporarily unavailable: {self.tool_name} request rate exceeded ({self.limit_value}/second). Please wait {self.retry_after_seconds:.0f} seconds before trying again."
|
| 937 |
+
elif self.limit_type == "minute":
|
| 938 |
+
return f"Service temporarily unavailable: {self.tool_name} quota exceeded ({self.limit_value}/minute). Please try again in {self.retry_after_seconds:.0f} seconds."
|
| 939 |
+
elif self.limit_type == "hour":
|
| 940 |
+
return f"Service temporarily unavailable: {self.tool_name} hourly quota exceeded ({self.limit_value}/hour). Please try again in {self.retry_after_seconds:.0f} minutes."
|
| 941 |
+
else:
|
| 942 |
+
return f"Service temporarily unavailable: {self.tool_name} rate limit exceeded. Please try again later."
|
| 943 |
+
|
| 944 |
+
def to_technical_message(self) -> str:
|
| 945 |
+
"""Generate technical error message for debugging"""
|
| 946 |
+
return f"Tool '{self.tool_name}' {self.limit_type} limit exceeded ({self.current_usage}/{self.limit_value} {self.limit_type})"
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
def _parse_rate_limit_denial(tool_name: str, denial_reason: str) -> RateLimitViolation:
|
| 950 |
+
"""Parse rate limit denial reason into structured violation information"""
|
| 951 |
+
import re
|
| 952 |
+
|
| 953 |
+
# Default values
|
| 954 |
+
limit_type = "unknown"
|
| 955 |
+
current_usage = 0
|
| 956 |
+
limit_value = 0.0
|
| 957 |
+
retry_after_seconds = 60.0 # Default retry after 1 minute
|
| 958 |
+
|
| 959 |
+
# Parse different types of rate limit violations
|
| 960 |
+
if "burst limit exceeded" in denial_reason:
|
| 961 |
+
limit_type = "burst"
|
| 962 |
+
retry_after_seconds = 1.0 # Burst limits reset quickly
|
| 963 |
+
match = re.search(r'\((\d+) requests/burst\)', denial_reason)
|
| 964 |
+
if match:
|
| 965 |
+
limit_value = float(match.group(1))
|
| 966 |
+
current_usage = int(limit_value) # Approximation
|
| 967 |
+
|
| 968 |
+
elif "per-second limit exceeded" in denial_reason:
|
| 969 |
+
limit_type = "second"
|
| 970 |
+
retry_after_seconds = 1.0 # Wait 1 second
|
| 971 |
+
match = re.search(r'\(([0-9.]+) requests/second\)', denial_reason)
|
| 972 |
+
if match:
|
| 973 |
+
limit_value = float(match.group(1))
|
| 974 |
+
current_usage = int(limit_value) # Approximation
|
| 975 |
+
|
| 976 |
+
elif "per-minute limit exceeded" in denial_reason:
|
| 977 |
+
limit_type = "minute"
|
| 978 |
+
retry_after_seconds = 10.0 # Wait 10 seconds for minute limits
|
| 979 |
+
match = re.search(r'\(([0-9.]+) requests/minute\)', denial_reason)
|
| 980 |
+
if match:
|
| 981 |
+
limit_value = float(match.group(1))
|
| 982 |
+
current_usage = int(limit_value) # Approximation
|
| 983 |
+
|
| 984 |
+
elif "per-hour limit exceeded" in denial_reason:
|
| 985 |
+
limit_type = "hour"
|
| 986 |
+
retry_after_seconds = 300.0 # Wait 5 minutes for hour limits
|
| 987 |
+
match = re.search(r'\(([0-9.]+) requests/hour\)', denial_reason)
|
| 988 |
+
if match:
|
| 989 |
+
limit_value = float(match.group(1))
|
| 990 |
+
current_usage = int(limit_value) # Approximation
|
| 991 |
+
|
| 992 |
+
return RateLimitViolation(
|
| 993 |
+
tool_name=tool_name,
|
| 994 |
+
limit_type=limit_type,
|
| 995 |
+
current_usage=current_usage,
|
| 996 |
+
limit_value=limit_value,
|
| 997 |
+
retry_after_seconds=retry_after_seconds
|
| 998 |
+
)
|
| 999 |
+
|
| 1000 |
+
|
| 1001 |
+
async def _call_session_tool_async(session: Session, tool_name: str, tool_args: Dict[str, Any],
|
| 1002 |
+
client_ip: str = "unknown") -> Dict[str, Any]:
|
| 1003 |
+
"""Execute a tool within a session context with full tracking, workspace management, and global rate limiting"""
|
| 1004 |
+
|
| 1005 |
+
start_time = time.time()
|
| 1006 |
+
success = False
|
| 1007 |
+
error_details = None
|
| 1008 |
+
result_data = None
|
| 1009 |
+
|
| 1010 |
+
# Touch session at start of tool execution to prevent expiry during long operations
|
| 1011 |
+
session.touch()
|
| 1012 |
+
|
| 1013 |
+
try:
|
| 1014 |
+
# CHECK GLOBAL TOOL RATE LIMITS FIRST
|
| 1015 |
+
if global_tool_rate_limiter:
|
| 1016 |
+
allowed, deny_reason = await global_tool_rate_limiter.is_allowed(tool_name)
|
| 1017 |
+
if not allowed:
|
| 1018 |
+
# Parse the denial reason to create structured rate limit violation
|
| 1019 |
+
rate_limit_violation = _parse_rate_limit_denial(tool_name, deny_reason)
|
| 1020 |
+
|
| 1021 |
+
# Create user-friendly error message
|
| 1022 |
+
user_message = rate_limit_violation.to_user_friendly_message()
|
| 1023 |
+
technical_message = rate_limit_violation.to_technical_message()
|
| 1024 |
+
|
| 1025 |
+
logger.warning(f"Session {session.id}: {technical_message}")
|
| 1026 |
+
|
| 1027 |
+
result_data = {
|
| 1028 |
+
"success": False,
|
| 1029 |
+
"error": user_message,
|
| 1030 |
+
"error_code": "RATE_LIMIT_EXCEEDED",
|
| 1031 |
+
"error_type": "rate_limit",
|
| 1032 |
+
"tool_name": tool_name,
|
| 1033 |
+
"limit_type": rate_limit_violation.limit_type,
|
| 1034 |
+
"retry_after_seconds": rate_limit_violation.retry_after_seconds,
|
| 1035 |
+
"data": None,
|
| 1036 |
+
"rate_limited": True, # Keep for backward compatibility
|
| 1037 |
+
"technical_details": technical_message # For debugging
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
# Still log this for tracking purposes
|
| 1041 |
+
duration_ms = (time.time() - start_time) * 1000
|
| 1042 |
+
tracker = session.get_tool_tracker()
|
| 1043 |
+
if tracker:
|
| 1044 |
+
try:
|
| 1045 |
+
agent_info = {
|
| 1046 |
+
"client_ip": client_ip,
|
| 1047 |
+
"type": "unknown",
|
| 1048 |
+
"session_request_count": session.request_count
|
| 1049 |
+
}
|
| 1050 |
+
|
| 1051 |
+
tracker.log_tool_call(
|
| 1052 |
+
tool_name=tool_name,
|
| 1053 |
+
input_args=tool_args,
|
| 1054 |
+
output_result=result_data,
|
| 1055 |
+
success=False,
|
| 1056 |
+
duration_ms=duration_ms,
|
| 1057 |
+
error_details=user_message,
|
| 1058 |
+
agent_info=agent_info
|
| 1059 |
+
)
|
| 1060 |
+
except Exception as e:
|
| 1061 |
+
logger.error(f"Failed to log rate-limited tool call: {e}")
|
| 1062 |
+
|
| 1063 |
+
return result_data
|
| 1064 |
+
|
| 1065 |
+
# Get MCP tools instance for this session (handles workspace isolation)
|
| 1066 |
+
mcp_tools = session.get_mcp_tools(prefer_async=True)
|
| 1067 |
+
|
| 1068 |
+
# Get tool method directly from the mcp_tools instance
|
| 1069 |
+
if not hasattr(mcp_tools, tool_name):
|
| 1070 |
+
raise ValueError(f"Tool '{tool_name}' not implemented")
|
| 1071 |
+
|
| 1072 |
+
tool_method = getattr(mcp_tools, tool_name)
|
| 1073 |
+
|
| 1074 |
+
# Add session context to tool arguments for workspace-aware tools
|
| 1075 |
+
if hasattr(mcp_tools, 'set_session_context'):
|
| 1076 |
+
mcp_tools.set_session_context(session.id, str(session.workspace_path))
|
| 1077 |
+
|
| 1078 |
+
# Execute tool with keep-alive for potentially long operations
|
| 1079 |
+
logger.info(f"Session {session.id}: Executing tool '{tool_name}' with args: {list(tool_args.keys())}")
|
| 1080 |
+
|
| 1081 |
+
# Use keep-alive wrapper for tools that might take a long time
|
| 1082 |
+
long_running_tools = {'batch_web_search', 'url_crawler', 'document_qa', 'document_extract', 'bash'}
|
| 1083 |
+
|
| 1084 |
+
# Check if the tool method is async
|
| 1085 |
+
import inspect
|
| 1086 |
+
is_async_tool = inspect.iscoroutinefunction(tool_method)
|
| 1087 |
+
|
| 1088 |
+
# Execute tool based on whether it's async or sync
|
| 1089 |
+
if is_async_tool:
|
| 1090 |
+
# Tool is async - execute directly
|
| 1091 |
+
logger.debug("Executing async tool '{%s}'", tool_name)
|
| 1092 |
+
|
| 1093 |
+
if config.enable_session_keepalive and tool_name in long_running_tools:
|
| 1094 |
+
# For long-running async tools, use keep-alive
|
| 1095 |
+
with KeepAliveSessionWrapper(session, touch_interval=config.keepalive_touch_interval):
|
| 1096 |
+
result = await tool_method(**tool_args)
|
| 1097 |
+
else:
|
| 1098 |
+
# For regular async tools, execute directly
|
| 1099 |
+
result = await tool_method(**tool_args)
|
| 1100 |
+
else:
|
| 1101 |
+
# Tool is sync - execute in thread pool
|
| 1102 |
+
logger.debug("Executing sync tool '{%s}' in thread pool", tool_name)
|
| 1103 |
+
|
| 1104 |
+
# Define the synchronous tool execution function
|
| 1105 |
+
def execute_tool_sync():
|
| 1106 |
+
"""Synchronous tool execution to be run in thread pool"""
|
| 1107 |
+
return tool_method(**tool_args)
|
| 1108 |
+
|
| 1109 |
+
# Execute tool asynchronously in thread pool for true non-blocking execution
|
| 1110 |
+
import asyncio
|
| 1111 |
+
import concurrent.futures
|
| 1112 |
+
|
| 1113 |
+
# Create a thread pool executor for CPU-bound/blocking operations
|
| 1114 |
+
loop = asyncio.get_event_loop()
|
| 1115 |
+
|
| 1116 |
+
if config.enable_session_keepalive and tool_name in long_running_tools:
|
| 1117 |
+
# For long-running tools, use keep-alive with async execution
|
| 1118 |
+
with KeepAliveSessionWrapper(session, touch_interval=config.keepalive_touch_interval):
|
| 1119 |
+
# Run in thread pool to avoid blocking the event loop
|
| 1120 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
| 1121 |
+
result = await loop.run_in_executor(executor, execute_tool_sync)
|
| 1122 |
+
else:
|
| 1123 |
+
# For regular tools, use async execution without keep-alive
|
| 1124 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
| 1125 |
+
result = await loop.run_in_executor(executor, execute_tool_sync)
|
| 1126 |
+
|
| 1127 |
+
# Touch session after tool execution to update activity
|
| 1128 |
+
session.touch()
|
| 1129 |
+
|
| 1130 |
+
# Handle different result formats
|
| 1131 |
+
if hasattr(result, 'to_dict'):
|
| 1132 |
+
result_data = result.to_dict()
|
| 1133 |
+
elif isinstance(result, dict):
|
| 1134 |
+
result_data = result
|
| 1135 |
+
else:
|
| 1136 |
+
result_data = {"result": result}
|
| 1137 |
+
|
| 1138 |
+
success = result_data.get('success', True)
|
| 1139 |
+
|
| 1140 |
+
if success:
|
| 1141 |
+
logger.info(f"Session {session.id}: Tool '{tool_name}' completed successfully")
|
| 1142 |
+
|
| 1143 |
+
# RECORD SUCCESSFUL REQUEST FOR RATE LIMITING
|
| 1144 |
+
if global_tool_rate_limiter:
|
| 1145 |
+
await global_tool_rate_limiter.record_request(tool_name)
|
| 1146 |
+
|
| 1147 |
+
|
| 1148 |
+
|
| 1149 |
+
else:
|
| 1150 |
+
error_details = result_data.get('error', 'Unknown error')
|
| 1151 |
+
logger.warning(f"Session {session.id}: Tool '{tool_name}' failed: {error_details}")
|
| 1152 |
+
|
| 1153 |
+
except Exception as e:
|
| 1154 |
+
success = False
|
| 1155 |
+
error_details = str(e)
|
| 1156 |
+
result_data = {
|
| 1157 |
+
"success": False,
|
| 1158 |
+
"error": error_details,
|
| 1159 |
+
"data": None
|
| 1160 |
+
}
|
| 1161 |
+
logger.error(f"Session {session.id}: Tool '{tool_name}' exception: {e}")
|
| 1162 |
+
|
| 1163 |
+
# Calculate execution time
|
| 1164 |
+
duration_ms = (time.time() - start_time) * 1000
|
| 1165 |
+
|
| 1166 |
+
# Log tool call if tracking is enabled
|
| 1167 |
+
tracker = session.get_tool_tracker()
|
| 1168 |
+
if tracker:
|
| 1169 |
+
try:
|
| 1170 |
+
agent_info = {
|
| 1171 |
+
"client_ip": client_ip,
|
| 1172 |
+
"type": "unknown", # Could be enhanced to detect agent type
|
| 1173 |
+
"session_request_count": session.request_count
|
| 1174 |
+
}
|
| 1175 |
+
|
| 1176 |
+
tracker.log_tool_call(
|
| 1177 |
+
tool_name=tool_name,
|
| 1178 |
+
input_args=tool_args,
|
| 1179 |
+
output_result=result_data,
|
| 1180 |
+
success=success,
|
| 1181 |
+
duration_ms=duration_ms,
|
| 1182 |
+
error_details=error_details,
|
| 1183 |
+
agent_info=agent_info
|
| 1184 |
+
)
|
| 1185 |
+
except Exception as e:
|
| 1186 |
+
logger.error(f"Failed to log tool call: {e}")
|
| 1187 |
+
|
| 1188 |
+
return result_data
|
| 1189 |
+
|
| 1190 |
+
|
| 1191 |
+
|
| 1192 |
+
def create_sse_response(response_data: dict, session_id: str = None) -> StreamingResponse:
|
| 1193 |
+
"""Create Server-Sent Events response with proper formatting"""
|
| 1194 |
+
def generate_sse():
|
| 1195 |
+
try:
|
| 1196 |
+
# Add session info to response if available
|
| 1197 |
+
if session_id:
|
| 1198 |
+
response_data["session_id"] = session_id
|
| 1199 |
+
|
| 1200 |
+
json_data = json.dumps(response_data, ensure_ascii=False)
|
| 1201 |
+
yield f"event: message\n"
|
| 1202 |
+
yield f"data: {json_data}\n"
|
| 1203 |
+
yield f"\n"
|
| 1204 |
+
except Exception as e:
|
| 1205 |
+
error_data = {
|
| 1206 |
+
"jsonrpc": "2.0",
|
| 1207 |
+
"error": {"code": JsonRpcErr.INTERNAL_ERROR, "message": f"Internal error: {str(e)}"},
|
| 1208 |
+
"id": response_data.get("id")
|
| 1209 |
+
}
|
| 1210 |
+
json_data = json.dumps(error_data, ensure_ascii=False)
|
| 1211 |
+
yield f"event: error\n"
|
| 1212 |
+
yield f"data: {json_data}\n"
|
| 1213 |
+
yield f"\n"
|
| 1214 |
+
|
| 1215 |
+
return StreamingResponse(
|
| 1216 |
+
generate_sse(),
|
| 1217 |
+
media_type="text/event-stream",
|
| 1218 |
+
headers={
|
| 1219 |
+
"Cache-Control": "no-cache",
|
| 1220 |
+
"Connection": "keep-alive",
|
| 1221 |
+
"Access-Control-Allow-Origin": "*",
|
| 1222 |
+
}
|
| 1223 |
+
)
|
| 1224 |
+
|
| 1225 |
+
|
| 1226 |
+
def create_error_response(request_id: Any, code: int, message: str, session_id: str = None) -> StreamingResponse:
|
| 1227 |
+
"""Create error response in SSE format"""
|
| 1228 |
+
error_data = {
|
| 1229 |
+
"jsonrpc": "2.0",
|
| 1230 |
+
"error": {"code": code, "message": message},
|
| 1231 |
+
"id": request_id
|
| 1232 |
+
}
|
| 1233 |
+
return create_sse_response(error_data, session_id)
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
def create_rate_limit_response(
|
| 1237 |
+
request_id: Any,
|
| 1238 |
+
tool_name: str,
|
| 1239 |
+
error_message: str,
|
| 1240 |
+
retry_after_seconds: float,
|
| 1241 |
+
limit_type: str,
|
| 1242 |
+
technical_details: str = "",
|
| 1243 |
+
session_id: str = None
|
| 1244 |
+
) -> JSONResponse:
|
| 1245 |
+
"""
|
| 1246 |
+
Create HTTP 429 Rate Limit Exceeded response with proper headers and error format.
|
| 1247 |
+
|
| 1248 |
+
Returns proper HTTP status code instead of SSE for rate limiting errors.
|
| 1249 |
+
"""
|
| 1250 |
+
|
| 1251 |
+
# Calculate retry-after header value
|
| 1252 |
+
retry_after_header = int(max(1.0, retry_after_seconds))
|
| 1253 |
+
|
| 1254 |
+
# Create standardized error response
|
| 1255 |
+
error_data = {
|
| 1256 |
+
"error": {
|
| 1257 |
+
"type": "rate_limit_exceeded",
|
| 1258 |
+
"code": "RATE_LIMIT_EXCEEDED",
|
| 1259 |
+
"message": error_message,
|
| 1260 |
+
"details": {
|
| 1261 |
+
"tool_name": tool_name,
|
| 1262 |
+
"limit_type": limit_type,
|
| 1263 |
+
"retry_after_seconds": retry_after_seconds,
|
| 1264 |
+
"technical_details": technical_details
|
| 1265 |
+
}
|
| 1266 |
+
},
|
| 1267 |
+
"request_id": request_id,
|
| 1268 |
+
"timestamp": datetime.now().isoformat(),
|
| 1269 |
+
"session_id": session_id
|
| 1270 |
+
}
|
| 1271 |
+
|
| 1272 |
+
# Set appropriate headers
|
| 1273 |
+
headers = {
|
| 1274 |
+
"Retry-After": str(retry_after_header), # HTTP standard header
|
| 1275 |
+
"X-RateLimit-Limit-Type": limit_type,
|
| 1276 |
+
"X-RateLimit-Tool": tool_name,
|
| 1277 |
+
"X-RateLimit-Retry-After": str(retry_after_seconds),
|
| 1278 |
+
"Content-Type": "application/json"
|
| 1279 |
+
}
|
| 1280 |
+
|
| 1281 |
+
return JSONResponse(
|
| 1282 |
+
status_code=HTTPStatus.TOO_MANY_REQUESTS, # Too Many Requests
|
| 1283 |
+
content=error_data,
|
| 1284 |
+
headers=headers
|
| 1285 |
+
)
|
| 1286 |
+
|
| 1287 |
+
|
| 1288 |
+
async def handle_mcp_request(request: Request) -> StreamingResponse:
|
| 1289 |
+
"""Main MCP request handler with session management and tool execution"""
|
| 1290 |
+
|
| 1291 |
+
try:
|
| 1292 |
+
# Check content length before reading body
|
| 1293 |
+
content_length = request.headers.get("content-length")
|
| 1294 |
+
if content_length:
|
| 1295 |
+
content_size_mb = int(content_length) / (1024 * 1024)
|
| 1296 |
+
if content_size_mb > config.max_request_size_mb:
|
| 1297 |
+
logger.warning(f"Request too large: {content_size_mb:.2f}MB > {config.max_request_size_mb}MB")
|
| 1298 |
+
return create_error_response(None, JsonRpcErr.PARSE_ERROR, f"Request too large: {content_size_mb:.2f}MB")
|
| 1299 |
+
|
| 1300 |
+
# Parse request with timeout protection
|
| 1301 |
+
try:
|
| 1302 |
+
body = await asyncio.wait_for(request.body(), timeout=config.request_timeout_seconds)
|
| 1303 |
+
except asyncio.TimeoutError:
|
| 1304 |
+
logger.error("Timeout while reading request body")
|
| 1305 |
+
return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Request body read timeout")
|
| 1306 |
+
|
| 1307 |
+
if not body:
|
| 1308 |
+
return create_error_response(None, JsonRpcErr.PARSE_ERROR, "Empty request body")
|
| 1309 |
+
|
| 1310 |
+
try:
|
| 1311 |
+
data = json.loads(body.decode('utf-8'))
|
| 1312 |
+
except json.JSONDecodeError as e:
|
| 1313 |
+
return create_error_response(None, JsonRpcErr.PARSE_ERROR, f"Invalid JSON: {str(e)}")
|
| 1314 |
+
|
| 1315 |
+
# Validate MCP request structure
|
| 1316 |
+
is_valid, error_msg = RequestValidator.validate_mcp_request(data)
|
| 1317 |
+
if not is_valid:
|
| 1318 |
+
return create_error_response(data.get("id"), JsonRpcErr.INVALID_REQUEST, error_msg)
|
| 1319 |
+
|
| 1320 |
+
request_id = data["id"]
|
| 1321 |
+
method = data["method"]
|
| 1322 |
+
params = data.get("params", {})
|
| 1323 |
+
|
| 1324 |
+
# Get or create session
|
| 1325 |
+
session_id = request.headers.get("X-Session-ID")
|
| 1326 |
+
client_ip = request.client.host if request.client else "unknown"
|
| 1327 |
+
|
| 1328 |
+
session = await session_manager.get_or_create_session(session_id)
|
| 1329 |
+
logger.info(f"Processing {method} request for session {session.id} from {client_ip}")
|
| 1330 |
+
|
| 1331 |
+
# Handle different MCP methods
|
| 1332 |
+
if method == "initialize":
|
| 1333 |
+
# MCP initialization
|
| 1334 |
+
response_data = {
|
| 1335 |
+
"jsonrpc": "2.0",
|
| 1336 |
+
"result": {
|
| 1337 |
+
"protocolVersion": "2025-06-18",
|
| 1338 |
+
"capabilities": {
|
| 1339 |
+
"tools": {"supportsProgress": True},
|
| 1340 |
+
"resources": {},
|
| 1341 |
+
"prompts": {}
|
| 1342 |
+
},
|
| 1343 |
+
"serverInfo": {
|
| 1344 |
+
"name": "DeepDiver-Demo-MCP",
|
| 1345 |
+
"version": "1.0.0"
|
| 1346 |
+
}
|
| 1347 |
+
},
|
| 1348 |
+
"id": request_id
|
| 1349 |
+
}
|
| 1350 |
+
|
| 1351 |
+
elif method == "tools/list":
|
| 1352 |
+
# List available tools using detailed schemas from get_tool_schemas()
|
| 1353 |
+
tools_list = []
|
| 1354 |
+
detailed_schemas = get_tool_schemas()
|
| 1355 |
+
|
| 1356 |
+
# Build tools list from schemas
|
| 1357 |
+
for _, detailed_schema in detailed_schemas.items():
|
| 1358 |
+
tools_list.append({
|
| 1359 |
+
"name": detailed_schema["name"],
|
| 1360 |
+
"description": detailed_schema["description"],
|
| 1361 |
+
"inputSchema": detailed_schema["inputSchema"]
|
| 1362 |
+
})
|
| 1363 |
+
|
| 1364 |
+
logger.info(f"Serving {len(tools_list)} tools with detailed schemas to client")
|
| 1365 |
+
|
| 1366 |
+
response_data = {
|
| 1367 |
+
"jsonrpc": "2.0",
|
| 1368 |
+
"result": {"tools": tools_list},
|
| 1369 |
+
"id": request_id
|
| 1370 |
+
}
|
| 1371 |
+
|
| 1372 |
+
elif method == "tools/call":
|
| 1373 |
+
# Execute tool call
|
| 1374 |
+
is_valid, error_msg = RequestValidator.validate_tool_call(params)
|
| 1375 |
+
if not is_valid:
|
| 1376 |
+
return create_error_response(request_id, JsonRpcErr.INVALID_PARAMS, error_msg, session.id)
|
| 1377 |
+
|
| 1378 |
+
tool_name = params["name"]
|
| 1379 |
+
tool_arguments = params["arguments"]
|
| 1380 |
+
|
| 1381 |
+
# Execute tool in session context asynchronously
|
| 1382 |
+
result = await _call_session_tool_async(session, tool_name, tool_arguments, client_ip)
|
| 1383 |
+
|
| 1384 |
+
# CHECK FOR RATE LIMITING AND RETURN PROPER HTTP STATUS
|
| 1385 |
+
if result.get("rate_limited", False):
|
| 1386 |
+
return create_rate_limit_response(
|
| 1387 |
+
request_id=request_id,
|
| 1388 |
+
tool_name=tool_name,
|
| 1389 |
+
error_message=result.get("error", "Rate limit exceeded"),
|
| 1390 |
+
retry_after_seconds=result.get("retry_after_seconds", 60),
|
| 1391 |
+
limit_type=result.get("limit_type", "unknown"),
|
| 1392 |
+
technical_details=result.get("technical_details", ""),
|
| 1393 |
+
session_id=session.id
|
| 1394 |
+
)
|
| 1395 |
+
|
| 1396 |
+
# Format normal response
|
| 1397 |
+
response_data = {
|
| 1398 |
+
"jsonrpc": "2.0",
|
| 1399 |
+
"result": {
|
| 1400 |
+
"content": [
|
| 1401 |
+
{
|
| 1402 |
+
"type": "text",
|
| 1403 |
+
"text": json.dumps(result, indent=2, ensure_ascii=False)
|
| 1404 |
+
}
|
| 1405 |
+
]
|
| 1406 |
+
},
|
| 1407 |
+
"id": request_id
|
| 1408 |
+
}
|
| 1409 |
+
|
| 1410 |
+
else:
|
| 1411 |
+
return create_error_response(request_id, JsonRpcErr.METHOD_NOT_FOUND, f"Method not found: {method}", session.id)
|
| 1412 |
+
|
| 1413 |
+
return create_sse_response(response_data, session.id)
|
| 1414 |
+
|
| 1415 |
+
except asyncio.TimeoutError:
|
| 1416 |
+
logger.warning("Request timeout - client may have disconnected")
|
| 1417 |
+
return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Request timeout")
|
| 1418 |
+
except Exception as e:
|
| 1419 |
+
# Handle client disconnects gracefully
|
| 1420 |
+
if "ClientDisconnect" in str(e) or "ConnectionClosedError" in str(e):
|
| 1421 |
+
logger.warning(f"Client disconnected during request processing: {e}")
|
| 1422 |
+
return create_error_response(None, JsonRpcErr.REQUEST_TIMEOUT, "Client disconnected")
|
| 1423 |
+
|
| 1424 |
+
logger.error(f"Unexpected error in MCP request handler: {e}")
|
| 1425 |
+
import traceback
|
| 1426 |
+
logger.error(traceback.format_exc())
|
| 1427 |
+
return create_error_response(None, JsonRpcErr.INTERNAL_ERROR, f"Internal server error: {str(e)}")
|
| 1428 |
+
|
| 1429 |
+
|
| 1430 |
+
async def handle_health_check(request: Request) -> JSONResponse:
|
| 1431 |
+
"""Health check endpoint"""
|
| 1432 |
+
try:
|
| 1433 |
+
stats = await session_manager.get_stats() if session_manager else {}
|
| 1434 |
+
|
| 1435 |
+
# Get rate limiting summary
|
| 1436 |
+
rate_limit_summary = {}
|
| 1437 |
+
if global_tool_rate_limiter:
|
| 1438 |
+
all_stats = global_tool_rate_limiter.get_all_stats()
|
| 1439 |
+
rate_limit_summary = {
|
| 1440 |
+
"enabled": True,
|
| 1441 |
+
"tools_with_limits": len(all_stats),
|
| 1442 |
+
"total_configured_tools": list(all_stats.keys())
|
| 1443 |
+
}
|
| 1444 |
+
else:
|
| 1445 |
+
rate_limit_summary = {"enabled": False}
|
| 1446 |
+
|
| 1447 |
+
health_data = {
|
| 1448 |
+
"status": "healthy",
|
| 1449 |
+
"timestamp": datetime.now().isoformat(),
|
| 1450 |
+
"version": "1.0.0",
|
| 1451 |
+
"session_stats": stats,
|
| 1452 |
+
"features": {
|
| 1453 |
+
"workspace_isolation": True,
|
| 1454 |
+
"tool_call_tracking": config.enable_tool_tracking if config else False,
|
| 1455 |
+
"client_rate_limiting": True,
|
| 1456 |
+
"global_tool_rate_limiting": rate_limit_summary["enabled"],
|
| 1457 |
+
"security_middleware": True,
|
| 1458 |
+
"standardized_rate_limit_responses": True
|
| 1459 |
+
},
|
| 1460 |
+
"rate_limiting": rate_limit_summary,
|
| 1461 |
+
"error_formats": {
|
| 1462 |
+
"rate_limit_exceeded": {
|
| 1463 |
+
"http_status": HTTPStatus.TOO_MANY_REQUESTS,
|
| 1464 |
+
"headers": ["Retry-After", "X-RateLimit-*"],
|
| 1465 |
+
"error_code": "RATE_LIMIT_EXCEEDED",
|
| 1466 |
+
"response_format": "application/json"
|
| 1467 |
+
}
|
| 1468 |
+
}
|
| 1469 |
+
}
|
| 1470 |
+
|
| 1471 |
+
return JSONResponse(content=health_data)
|
| 1472 |
+
|
| 1473 |
+
except Exception as e:
|
| 1474 |
+
return JSONResponse(
|
| 1475 |
+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
| 1476 |
+
content={"status": "unhealthy", "error": str(e)}
|
| 1477 |
+
)
|
| 1478 |
+
|
| 1479 |
+
|
| 1480 |
+
async def handle_tracking_info(request: Request) -> JSONResponse:
|
| 1481 |
+
"""Get tool call tracking information for a session"""
|
| 1482 |
+
try:
|
| 1483 |
+
session_id = request.query_params.get("session_id")
|
| 1484 |
+
if not session_id:
|
| 1485 |
+
return JSONResponse(
|
| 1486 |
+
status_code=HTTPStatus.BAD_REQUEST,
|
| 1487 |
+
content={"error": "session_id parameter required"}
|
| 1488 |
+
)
|
| 1489 |
+
|
| 1490 |
+
session = await session_manager.get_session(session_id)
|
| 1491 |
+
if not session:
|
| 1492 |
+
return JSONResponse(
|
| 1493 |
+
status_code=HTTPStatus.NOT_FOUND,
|
| 1494 |
+
content={"error": f"Session {session_id} not found"}
|
| 1495 |
+
)
|
| 1496 |
+
|
| 1497 |
+
tracker = session.get_tool_tracker()
|
| 1498 |
+
if not tracker:
|
| 1499 |
+
return JSONResponse(
|
| 1500 |
+
content={
|
| 1501 |
+
"session_id": session_id,
|
| 1502 |
+
"tracking_enabled": False,
|
| 1503 |
+
"message": "Tool call tracking not enabled or no workspace"
|
| 1504 |
+
}
|
| 1505 |
+
)
|
| 1506 |
+
|
| 1507 |
+
# Read session summary
|
| 1508 |
+
summary_data = {}
|
| 1509 |
+
if tracker.summary_file.exists():
|
| 1510 |
+
try:
|
| 1511 |
+
with open(tracker.summary_file, 'r') as f:
|
| 1512 |
+
summary_data = json.load(f)
|
| 1513 |
+
except Exception as e:
|
| 1514 |
+
logger.error(f"Failed to read session summary: {e}")
|
| 1515 |
+
|
| 1516 |
+
return JSONResponse(content={
|
| 1517 |
+
"session_id": session_id,
|
| 1518 |
+
"tracking_enabled": True,
|
| 1519 |
+
"summary": summary_data,
|
| 1520 |
+
"logs_directory": str(tracker.logs_dir),
|
| 1521 |
+
"current_log_file": str(tracker.current_log_file)
|
| 1522 |
+
})
|
| 1523 |
+
|
| 1524 |
+
except Exception as e:
|
| 1525 |
+
return JSONResponse(
|
| 1526 |
+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
| 1527 |
+
content={"error": str(e)}
|
| 1528 |
+
)
|
| 1529 |
+
|
| 1530 |
+
|
| 1531 |
+
|
| 1532 |
+
async def handle_rate_limit_stats(request: Request) -> JSONResponse:
|
| 1533 |
+
"""Get global tool rate limiting statistics"""
|
| 1534 |
+
try:
|
| 1535 |
+
if not global_tool_rate_limiter:
|
| 1536 |
+
return JSONResponse(
|
| 1537 |
+
status_code=HTTPStatus.NOT_FOUND,
|
| 1538 |
+
content={"error": "Global tool rate limiter not initialized"}
|
| 1539 |
+
)
|
| 1540 |
+
|
| 1541 |
+
# Check if specific tool requested
|
| 1542 |
+
tool_name = request.query_params.get("tool")
|
| 1543 |
+
|
| 1544 |
+
if tool_name:
|
| 1545 |
+
# Get stats for specific tool
|
| 1546 |
+
stats = await global_tool_rate_limiter.get_tool_stats(tool_name)
|
| 1547 |
+
return JSONResponse(content=stats)
|
| 1548 |
+
else:
|
| 1549 |
+
# Get stats for all tools
|
| 1550 |
+
all_stats = global_tool_rate_limiter.get_all_stats()
|
| 1551 |
+
return JSONResponse(content={
|
| 1552 |
+
"timestamp": datetime.now().isoformat(),
|
| 1553 |
+
"global_tool_rate_limiting": True,
|
| 1554 |
+
"tools": all_stats,
|
| 1555 |
+
"summary": {
|
| 1556 |
+
"total_tools_with_limits": len(all_stats),
|
| 1557 |
+
"tools_configured": list(all_stats.keys())
|
| 1558 |
+
}
|
| 1559 |
+
})
|
| 1560 |
+
|
| 1561 |
+
except Exception as e:
|
| 1562 |
+
logger.error(f"Failed to get rate limit stats: {e}")
|
| 1563 |
+
return JSONResponse(
|
| 1564 |
+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
| 1565 |
+
content={"error": str(e)}
|
| 1566 |
+
)
|
| 1567 |
+
|
| 1568 |
+
|
| 1569 |
+
def create_app() -> Starlette:
|
| 1570 |
+
"""Create and configure the Starlette application"""
|
| 1571 |
+
global session_manager, rate_limiter, global_tool_rate_limiter
|
| 1572 |
+
|
| 1573 |
+
if not config:
|
| 1574 |
+
raise RuntimeError("Server configuration not initialized")
|
| 1575 |
+
|
| 1576 |
+
# Initialize global components
|
| 1577 |
+
session_manager = ThreadSafeSessionManager(
|
| 1578 |
+
ttl_seconds=config.session_ttl_seconds,
|
| 1579 |
+
max_sessions=config.max_sessions,
|
| 1580 |
+
base_workspace_dir=config.base_workspace_dir
|
| 1581 |
+
)
|
| 1582 |
+
rate_limiter = RateLimiter(config.rate_limit_requests_per_minute)
|
| 1583 |
+
|
| 1584 |
+
# Initialize global tool rate limiter
|
| 1585 |
+
if config.tool_rate_limits:
|
| 1586 |
+
global_tool_rate_limiter = GlobalToolRateLimiter(config.tool_rate_limits)
|
| 1587 |
+
logger.info(f"Initialized global tool rate limiter with {len(config.tool_rate_limits)} tool limits")
|
| 1588 |
+
else:
|
| 1589 |
+
logger.info("No tool rate limits configured - tools will run without global rate limiting")
|
| 1590 |
+
|
| 1591 |
+
# Create app
|
| 1592 |
+
app = Starlette(debug=config.debug_mode)
|
| 1593 |
+
|
| 1594 |
+
app.add_middleware(SecurityMiddleware)
|
| 1595 |
+
app.add_middleware(RateLimitMiddleware, input_rate_limiter=rate_limiter)
|
| 1596 |
+
|
| 1597 |
+
# Add routes
|
| 1598 |
+
app.add_route("/mcp", handle_mcp_request, methods=["POST"])
|
| 1599 |
+
app.add_route("/health", handle_health_check, methods=["GET"])
|
| 1600 |
+
app.add_route("/tracking", handle_tracking_info, methods=["GET"])
|
| 1601 |
+
app.add_route("/rate-limits", handle_rate_limit_stats, methods=["GET"])
|
| 1602 |
+
|
| 1603 |
+
return app
|
| 1604 |
+
|
| 1605 |
+
|
| 1606 |
+
def parse_arguments():
|
| 1607 |
+
"""Parse command line arguments"""
|
| 1608 |
+
parser = argparse.ArgumentParser(
|
| 1609 |
+
description="Demo-Ready MCP Server with Per-Tool Rate Limiting",
|
| 1610 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 1611 |
+
epilog="""
|
| 1612 |
+
Examples:
|
| 1613 |
+
python src/tools/mcp_server_standard.py --config src/tools/server_config.yaml
|
| 1614 |
+
python src/tools/mcp_server_standard.py --host 127.0.0.1 --port 8080
|
| 1615 |
+
python src/tools/mcp_server_standard.py --config custom_config.yaml --debug
|
| 1616 |
+
"""
|
| 1617 |
+
)
|
| 1618 |
+
|
| 1619 |
+
parser.add_argument(
|
| 1620 |
+
'--config', '-c',
|
| 1621 |
+
type=str,
|
| 1622 |
+
help='Path to YAML configuration file'
|
| 1623 |
+
)
|
| 1624 |
+
|
| 1625 |
+
parser.add_argument(
|
| 1626 |
+
'--host',
|
| 1627 |
+
type=str,
|
| 1628 |
+
help='Server host (overrides config file)'
|
| 1629 |
+
)
|
| 1630 |
+
|
| 1631 |
+
parser.add_argument(
|
| 1632 |
+
'--port', '-p',
|
| 1633 |
+
type=int,
|
| 1634 |
+
help='Server port (overrides config file)'
|
| 1635 |
+
)
|
| 1636 |
+
|
| 1637 |
+
parser.add_argument(
|
| 1638 |
+
'--debug',
|
| 1639 |
+
action='store_true',
|
| 1640 |
+
help='Enable debug mode (overrides config file)'
|
| 1641 |
+
)
|
| 1642 |
+
|
| 1643 |
+
parser.add_argument(
|
| 1644 |
+
'--workspace-dir',
|
| 1645 |
+
type=str,
|
| 1646 |
+
help='Base workspace directory (overrides config file)'
|
| 1647 |
+
)
|
| 1648 |
+
|
| 1649 |
+
return parser.parse_args()
|
| 1650 |
+
|
| 1651 |
+
|
| 1652 |
+
def print_startup_info():
|
| 1653 |
+
"""Print server startup information"""
|
| 1654 |
+
logger.info("🚀 DeepDiver Demo MCP Server")
|
| 1655 |
+
logger.info("=" * 50)
|
| 1656 |
+
logger.info(f"📊 Features:")
|
| 1657 |
+
logger.info(f" • Session Management: ✅ (TTL: {config.session_ttl_seconds}s)")
|
| 1658 |
+
logger.info(f" • Workspace Isolation: ✅ (Base: {config.base_workspace_dir})")
|
| 1659 |
+
logger.info(f" • Tool Call Tracking: {'✅' if config.enable_tool_tracking else '❌'}")
|
| 1660 |
+
logger.info(f" • Client Rate Limiting: ✅ ({config.rate_limit_requests_per_minute}/min)")
|
| 1661 |
+
logger.info(f" • Global Tool Rate Limiting: {'✅' if config.tool_rate_limits else '❌'}")
|
| 1662 |
+
logger.info(f" • Security Middleware: ✅")
|
| 1663 |
+
|
| 1664 |
+
# Tool rate limiting information
|
| 1665 |
+
if config.tool_rate_limits:
|
| 1666 |
+
logger.info(f"🚦 Tool Rate Limits: {len(config.tool_rate_limits)} tools configured")
|
| 1667 |
+
for tool_name, limits in list(config.tool_rate_limits.items())[:3]:
|
| 1668 |
+
burst = limits.get('burst_limit', '∞')
|
| 1669 |
+
rpm = limits.get('requests_per_minute', '∞')
|
| 1670 |
+
logger.info(f" • {tool_name}: {rpm}/min, burst: {burst}")
|
| 1671 |
+
if len(config.tool_rate_limits) > 3:
|
| 1672 |
+
logger.info(f" • ... and {len(config.tool_rate_limits) - 3} more tools")
|
| 1673 |
+
|
| 1674 |
+
# Tool information from schemas
|
| 1675 |
+
tool_schemas = get_tool_schemas()
|
| 1676 |
+
available_tools = list(tool_schemas.keys())
|
| 1677 |
+
|
| 1678 |
+
logger.info(f"🔧 Tools Available: {len(available_tools)}")
|
| 1679 |
+
logger.info(f" • All tools defined in schemas: {len(available_tools)} tools")
|
| 1680 |
+
logger.info(f" • Sample tools: {', '.join(sorted(available_tools)[:5])}...")
|
| 1681 |
+
logger.info("=" * 50)
|
| 1682 |
+
|
| 1683 |
+
|
| 1684 |
+
def main():
|
| 1685 |
+
"""Main function to run the production MCP server"""
|
| 1686 |
+
global config
|
| 1687 |
+
|
| 1688 |
+
# Parse command line arguments
|
| 1689 |
+
args = parse_arguments()
|
| 1690 |
+
|
| 1691 |
+
config = ServerConfig.from_yaml("./src/tools/server_config.yaml")
|
| 1692 |
+
|
| 1693 |
+
# Apply CLI overrides
|
| 1694 |
+
if args.host:
|
| 1695 |
+
config.host = args.host
|
| 1696 |
+
logger.info(f"🔧 Override: Host = {config.host}")
|
| 1697 |
+
|
| 1698 |
+
if args.port:
|
| 1699 |
+
config.port = args.port
|
| 1700 |
+
logger.info(f"🔧 Override: Port = {config.port}")
|
| 1701 |
+
|
| 1702 |
+
if args.debug:
|
| 1703 |
+
config.debug_mode = True
|
| 1704 |
+
logger.info(f"🔧 Override: Debug mode enabled")
|
| 1705 |
+
|
| 1706 |
+
if args.workspace_dir:
|
| 1707 |
+
config.base_workspace_dir = args.workspace_dir
|
| 1708 |
+
logger.info(f"🔧 Override: Workspace directory = {config.base_workspace_dir}")
|
| 1709 |
+
|
| 1710 |
+
print_startup_info()
|
| 1711 |
+
|
| 1712 |
+
try:
|
| 1713 |
+
import os
|
| 1714 |
+
|
| 1715 |
+
# Calculate optimal worker count for high-concurrency FIRST
|
| 1716 |
+
# Use CPU core count indirectly via uvicorn's defaults; no local variable needed
|
| 1717 |
+
|
| 1718 |
+
# Override for high-concurrency scenarios
|
| 1719 |
+
if os.getenv('FORCE_HIGH_CONCURRENCY', '').lower() == 'true':
|
| 1720 |
+
pass # Configuration handled elsewhere if needed
|
| 1721 |
+
|
| 1722 |
+
app = create_app()
|
| 1723 |
+
|
| 1724 |
+
logger.info(f"🌐 Starting server at http://{config.host}:{config.port}")
|
| 1725 |
+
logger.info(f"📡 MCP endpoint: http://{config.host}:{config.port}/mcp")
|
| 1726 |
+
logger.info(f"🏥 Health check: http://{config.host}:{config.port}/health")
|
| 1727 |
+
logger.info(f"📊 Tracking info: http://{config.host}:{config.port}/tracking?session_id=<id>")
|
| 1728 |
+
logger.info(f"🚦 Rate limit stats: http://{config.host}:{config.port}/rate-limits")
|
| 1729 |
+
|
| 1730 |
+
uvicorn.run(
|
| 1731 |
+
app, # Use app instance directly for single worker with async optimizations
|
| 1732 |
+
host=config.host,
|
| 1733 |
+
port=config.port,
|
| 1734 |
+
log_level="info",
|
| 1735 |
+
timeout_keep_alive=config.request_timeout_seconds,
|
| 1736 |
+
workers=1, # Single worker with async optimizations
|
| 1737 |
+
backlog=1024, # Larger backlog for high-concurrency
|
| 1738 |
+
access_log=False, # Disable access logs for better performance
|
| 1739 |
+
limit_concurrency=None, # No artificial concurrency limit
|
| 1740 |
+
)
|
| 1741 |
+
|
| 1742 |
+
except KeyboardInterrupt:
|
| 1743 |
+
print("\n⏹️ Server stopped by user")
|
| 1744 |
+
except Exception as e:
|
| 1745 |
+
print(f"❌ Server startup failed: {e}")
|
| 1746 |
+
import traceback
|
| 1747 |
+
traceback.print_exc()
|
| 1748 |
+
raise
|
| 1749 |
+
|
| 1750 |
+
if __name__ == "__main__":
|
| 1751 |
+
main()
|
deepdiver_v2/src/tools/mcp_tools.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
deepdiver_v2/src/tools/server_config.yaml
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =================================================================
|
| 2 |
+
# DeepDiver MCP Server Configuration
|
| 3 |
+
# =================================================================
|
| 4 |
+
# This file contains ONLY the configuration options that are actually
|
| 5 |
+
# implemented and used by the server. No unused options!
|
| 6 |
+
|
| 7 |
+
# =================================================================
|
| 8 |
+
# SERVER CORE SETTINGS
|
| 9 |
+
# =================================================================
|
| 10 |
+
server:
|
| 11 |
+
# Network Configuration
|
| 12 |
+
host: "127.0.0.1" # Server bind address
|
| 13 |
+
port: 6274 # Server port
|
| 14 |
+
debug_mode: false # Enable debug logging and error details
|
| 15 |
+
|
| 16 |
+
# Session Management
|
| 17 |
+
session_ttl_seconds: 21600 # Session timeout (6 hours)
|
| 18 |
+
max_sessions: 1000 # Maximum concurrent sessions
|
| 19 |
+
cleanup_interval_seconds: 600 # How often to clean expired sessions (5 min)
|
| 20 |
+
enable_session_keepalive: true # Keep sessions alive during long operations
|
| 21 |
+
keepalive_touch_interval: 300 # Touch session every N seconds during long ops
|
| 22 |
+
|
| 23 |
+
# Request Handling
|
| 24 |
+
request_timeout_seconds: 1800 # Request timeout
|
| 25 |
+
max_request_size_mb: 1000 # Maximum request size
|
| 26 |
+
|
| 27 |
+
# Client Rate Limiting (per IP address)
|
| 28 |
+
rate_limit_requests_per_minute: 300000 # Requests per minute per client IP
|
| 29 |
+
|
| 30 |
+
# Workspace Management
|
| 31 |
+
base_workspace_dir: "workspaces" # Base directory for session workspaces
|
| 32 |
+
|
| 33 |
+
# =================================================================
|
| 34 |
+
# TOOL CALL TRACKING & LOGGING
|
| 35 |
+
# =================================================================
|
| 36 |
+
tracking:
|
| 37 |
+
enable_tool_tracking: true # Enable detailed tool call logging
|
| 38 |
+
max_tracked_calls_per_session: 10000 # Limit tool calls logged per session
|
| 39 |
+
track_detailed_errors: true # Include full error details in logs
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# =================================================================
|
| 43 |
+
# GLOBAL PER-TOOL RATE LIMITING
|
| 44 |
+
# =================================================================
|
| 45 |
+
# These limits control requests to external APIs to avoid hitting provider limits.
|
| 46 |
+
# They are shared across ALL sessions and clients.
|
| 47 |
+
#
|
| 48 |
+
# Each tool can have these limits:
|
| 49 |
+
# - requests_per_second: QPS limit
|
| 50 |
+
# - requests_per_minute: Per-minute limit
|
| 51 |
+
# - requests_per_hour: Hourly limit
|
| 52 |
+
# - burst_limit: Short-term burst allowance
|
| 53 |
+
#
|
| 54 |
+
# Omit a limit to disable it (infinite). All limits are optional.
|
| 55 |
+
|
| 56 |
+
tool_rate_limits:
|
| 57 |
+
# API-based tools with external service limits
|
| 58 |
+
batch_web_search:
|
| 59 |
+
requests_per_minute: 9000
|
| 60 |
+
burst_limit: 35
|
| 61 |
+
|
| 62 |
+
url_crawler:
|
| 63 |
+
requests_per_minute: 9000
|
| 64 |
+
burst_limit: 60
|
| 65 |
+
|
| 66 |
+
document_qa:
|
| 67 |
+
requests_per_minute: 15000
|
| 68 |
+
burst_limit: 150
|
| 69 |
+
|
| 70 |
+
document_extract:
|
| 71 |
+
requests_per_minute: 15000
|
| 72 |
+
burst_limit: 150
|
| 73 |
+
|
deepdiver_v2/src/utils/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
#!/usr/bin/env python3
|
| 3 |
+
|
| 4 |
+
from .status_codes import JsonRpcErr
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
'JsonRpcErr',
|
| 8 |
+
]
|
deepdiver_v2/src/utils/status_codes.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
#!/usr/bin/env python3
|
| 3 |
+
from enum import IntEnum
|
| 4 |
+
|
| 5 |
+
class JsonRpcErr(IntEnum):
|
| 6 |
+
PARSE_ERROR = -32700
|
| 7 |
+
INVALID_REQUEST = -32600
|
| 8 |
+
METHOD_NOT_FOUND = -32601
|
| 9 |
+
INVALID_PARAMS = -32602
|
| 10 |
+
INTERNAL_ERROR = -32603
|
| 11 |
+
REQUEST_TIMEOUT = -32000
|
| 12 |
+
|
deepdiver_v2/src/workspace/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
"""
|
| 3 |
+
Workspace management module for the DeepDiver Multi-Agent System.
|
| 4 |
+
|
| 5 |
+
This module provides local workspace management capabilities that don't require
|
| 6 |
+
external dependencies like E2B. Each chat session gets its own isolated workspace
|
| 7 |
+
directory for file operations and data persistence.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from .local_workspace_manager import (
|
| 11 |
+
LocalWorkspaceManager,
|
| 12 |
+
WorkspaceInfo,
|
| 13 |
+
WorkspaceStatus,
|
| 14 |
+
get_workspace_manager,
|
| 15 |
+
initialize_workspace_manager,
|
| 16 |
+
shutdown_workspace_manager
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
'LocalWorkspaceManager',
|
| 21 |
+
'WorkspaceInfo',
|
| 22 |
+
'WorkspaceStatus',
|
| 23 |
+
'get_workspace_manager',
|
| 24 |
+
'initialize_workspace_manager',
|
| 25 |
+
'shutdown_workspace_manager'
|
| 26 |
+
]
|
deepdiver_v2/src/workspace/local_workspace_manager.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 2 |
+
"""
|
| 3 |
+
Local Workspace Manager for Multi-Agent System
|
| 4 |
+
|
| 5 |
+
This module provides session-based workspace management using local directories.
|
| 6 |
+
Each chat session gets its own isolated workspace directory that persists
|
| 7 |
+
throughout the conversation and can be cleaned up when the session ends.
|
| 8 |
+
|
| 9 |
+
Features:
|
| 10 |
+
- Session-based workspace lifecycle management
|
| 11 |
+
- Local directory isolation per session
|
| 12 |
+
- File operations within session workspaces
|
| 13 |
+
- Integration with existing MCP tools
|
| 14 |
+
- Comprehensive error handling and logging
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import shutil
|
| 18 |
+
import logging
|
| 19 |
+
from typing import Dict, Optional, Any, List, Union
|
| 20 |
+
from datetime import datetime, timedelta
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from dataclasses import dataclass, field
|
| 23 |
+
from enum import Enum
|
| 24 |
+
import json
|
| 25 |
+
|
| 26 |
+
# Configure logging
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class WorkspaceStatus(Enum):
|
| 31 |
+
"""Workspace lifecycle status"""
|
| 32 |
+
CREATING = "creating"
|
| 33 |
+
ACTIVE = "active"
|
| 34 |
+
DESTROYING = "destroying"
|
| 35 |
+
DESTROYED = "destroyed"
|
| 36 |
+
ERROR = "error"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class WorkspaceInfo:
|
| 41 |
+
"""Information about a workspace instance"""
|
| 42 |
+
workspace_id: str
|
| 43 |
+
session_id: str
|
| 44 |
+
workspace_path: Path
|
| 45 |
+
created_at: datetime
|
| 46 |
+
last_activity: datetime
|
| 47 |
+
status: WorkspaceStatus
|
| 48 |
+
workspace_files: List[str] = field(default_factory=list)
|
| 49 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 50 |
+
error_message: Optional[str] = None
|
| 51 |
+
|
| 52 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 53 |
+
"""Convert to dictionary for serialization"""
|
| 54 |
+
return {
|
| 55 |
+
"workspace_id": self.workspace_id,
|
| 56 |
+
"session_id": self.session_id,
|
| 57 |
+
"workspace_path": str(self.workspace_path),
|
| 58 |
+
"created_at": self.created_at.isoformat(),
|
| 59 |
+
"last_activity": self.last_activity.isoformat(),
|
| 60 |
+
"status": self.status.value,
|
| 61 |
+
"workspace_files": self.workspace_files,
|
| 62 |
+
"metadata": self.metadata,
|
| 63 |
+
"error_message": self.error_message
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class LocalWorkspaceManager:
|
| 68 |
+
"""
|
| 69 |
+
Manages local workspaces for multi-agent chat sessions.
|
| 70 |
+
|
| 71 |
+
Each chat session gets its own isolated workspace directory that persists
|
| 72 |
+
throughout the conversation. Workspaces are automatically managed
|
| 73 |
+
with cleanup capabilities.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
base_workspace_dir: str = "workspaces",
|
| 79 |
+
default_timeout: int = 86400, # 24 hours default
|
| 80 |
+
cleanup_on_exit: bool = False # Don't auto-cleanup by default
|
| 81 |
+
):
|
| 82 |
+
"""
|
| 83 |
+
Initialize the workspace manager.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
base_workspace_dir: Base directory for all workspaces
|
| 87 |
+
default_timeout: Default workspace timeout in seconds
|
| 88 |
+
cleanup_on_exit: Whether to cleanup workspaces on manager shutdown
|
| 89 |
+
"""
|
| 90 |
+
self.base_workspace_dir = Path(base_workspace_dir)
|
| 91 |
+
self.base_workspace_dir.mkdir(exist_ok=True)
|
| 92 |
+
self.default_timeout = default_timeout
|
| 93 |
+
self.cleanup_on_exit = cleanup_on_exit
|
| 94 |
+
|
| 95 |
+
# Active workspaces by session ID
|
| 96 |
+
self.workspaces: Dict[str, WorkspaceInfo] = {}
|
| 97 |
+
|
| 98 |
+
# Load existing workspaces from metadata
|
| 99 |
+
self._load_existing_workspaces()
|
| 100 |
+
|
| 101 |
+
logger.info(f"LocalWorkspaceManager initialized with base_dir={base_workspace_dir}")
|
| 102 |
+
|
| 103 |
+
def _load_existing_workspaces(self):
|
| 104 |
+
"""Load existing workspaces from metadata files"""
|
| 105 |
+
try:
|
| 106 |
+
for workspace_dir in self.base_workspace_dir.iterdir():
|
| 107 |
+
if workspace_dir.is_dir():
|
| 108 |
+
metadata_file = workspace_dir / ".workspace_metadata.json"
|
| 109 |
+
if metadata_file.exists():
|
| 110 |
+
try:
|
| 111 |
+
with open(metadata_file, 'r') as f:
|
| 112 |
+
data = json.load(f)
|
| 113 |
+
|
| 114 |
+
workspace_info = WorkspaceInfo(
|
| 115 |
+
workspace_id=data["workspace_id"],
|
| 116 |
+
session_id=data["session_id"],
|
| 117 |
+
workspace_path=Path(data["workspace_path"]),
|
| 118 |
+
created_at=datetime.fromisoformat(data["created_at"]),
|
| 119 |
+
last_activity=datetime.fromisoformat(data["last_activity"]),
|
| 120 |
+
status=WorkspaceStatus(data["status"]),
|
| 121 |
+
workspace_files=data.get("workspace_files", []),
|
| 122 |
+
metadata=data.get("metadata", {}),
|
| 123 |
+
error_message=data.get("error_message")
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self.workspaces[workspace_info.session_id] = workspace_info
|
| 127 |
+
logger.info(f"Loaded existing workspace for session {workspace_info.session_id}")
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.warning(f"Failed to load workspace metadata from {metadata_file}: {e}")
|
| 131 |
+
|
| 132 |
+
except Exception as e:
|
| 133 |
+
logger.warning(f"Failed to load existing workspaces: {e}")
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def _save_workspace_metadata(workspace_info: WorkspaceInfo):
|
| 137 |
+
"""Save workspace metadata to disk"""
|
| 138 |
+
try:
|
| 139 |
+
metadata_file = workspace_info.workspace_path / ".workspace_metadata.json"
|
| 140 |
+
with open(metadata_file, 'w') as f:
|
| 141 |
+
json.dump(workspace_info.to_dict(), f, indent=2)
|
| 142 |
+
except Exception as e:
|
| 143 |
+
logger.error(f"Failed to save workspace metadata: {e}")
|
| 144 |
+
|
| 145 |
+
def create_workspace(
|
| 146 |
+
self,
|
| 147 |
+
session_id: str,
|
| 148 |
+
workspace_id: Optional[str] = None,
|
| 149 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 150 |
+
) -> WorkspaceInfo:
|
| 151 |
+
"""
|
| 152 |
+
Create a new workspace for a chat session.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
session_id: Unique session identifier
|
| 156 |
+
workspace_id: Optional custom workspace ID (defaults to session_id)
|
| 157 |
+
metadata: Additional metadata to store with the workspace
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
WorkspaceInfo: Information about the created workspace
|
| 161 |
+
|
| 162 |
+
Raises:
|
| 163 |
+
ValueError: If session already has an active workspace
|
| 164 |
+
Exception: If workspace creation fails
|
| 165 |
+
"""
|
| 166 |
+
if session_id in self.workspaces:
|
| 167 |
+
raise ValueError(f"Session {session_id} already has an active workspace")
|
| 168 |
+
|
| 169 |
+
workspace_id = workspace_id or session_id
|
| 170 |
+
workspace_path = self.base_workspace_dir / workspace_id
|
| 171 |
+
|
| 172 |
+
logger.info(f"Creating workspace for session {session_id} at {workspace_path}")
|
| 173 |
+
|
| 174 |
+
# Create workspace info with creating status
|
| 175 |
+
workspace_info = WorkspaceInfo(
|
| 176 |
+
workspace_id=workspace_id,
|
| 177 |
+
session_id=session_id,
|
| 178 |
+
workspace_path=workspace_path,
|
| 179 |
+
created_at=datetime.now(),
|
| 180 |
+
last_activity=datetime.now(),
|
| 181 |
+
status=WorkspaceStatus.CREATING,
|
| 182 |
+
metadata=metadata or {}
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
# Create workspace directory
|
| 187 |
+
workspace_path.mkdir(parents=True, exist_ok=True)
|
| 188 |
+
|
| 189 |
+
# Create subdirectories
|
| 190 |
+
(workspace_path / "downloads").mkdir(exist_ok=True)
|
| 191 |
+
(workspace_path / "outputs").mkdir(exist_ok=True)
|
| 192 |
+
(workspace_path / "temp").mkdir(exist_ok=True)
|
| 193 |
+
|
| 194 |
+
# Update status
|
| 195 |
+
workspace_info.status = WorkspaceStatus.ACTIVE
|
| 196 |
+
self.workspaces[session_id] = workspace_info
|
| 197 |
+
|
| 198 |
+
# Save metadata
|
| 199 |
+
self._save_workspace_metadata(workspace_info)
|
| 200 |
+
|
| 201 |
+
# Update workspace files list
|
| 202 |
+
self._update_workspace_files(session_id)
|
| 203 |
+
|
| 204 |
+
logger.info(f"Workspace created successfully: {workspace_path} for session {session_id}")
|
| 205 |
+
return workspace_info
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
workspace_info.status = WorkspaceStatus.ERROR
|
| 209 |
+
workspace_info.error_message = str(e)
|
| 210 |
+
logger.error(f"Failed to create workspace for session {session_id}: {e}")
|
| 211 |
+
raise
|
| 212 |
+
|
| 213 |
+
def get_workspace(self, session_id: str) -> Optional[WorkspaceInfo]:
|
| 214 |
+
"""Get workspace info for a session"""
|
| 215 |
+
workspace_info = self.workspaces.get(session_id)
|
| 216 |
+
if workspace_info:
|
| 217 |
+
# Update last activity
|
| 218 |
+
workspace_info.last_activity = datetime.now()
|
| 219 |
+
self._save_workspace_metadata(workspace_info)
|
| 220 |
+
return workspace_info
|
| 221 |
+
|
| 222 |
+
def get_workspace_path(self, session_id: str) -> Optional[Path]:
|
| 223 |
+
"""Get workspace path for a session"""
|
| 224 |
+
workspace_info = self.get_workspace(session_id)
|
| 225 |
+
return workspace_info.workspace_path if workspace_info else None
|
| 226 |
+
|
| 227 |
+
def list_sessions(self) -> List[str]:
|
| 228 |
+
"""List all active session IDs"""
|
| 229 |
+
return list(self.workspaces.keys())
|
| 230 |
+
|
| 231 |
+
def destroy_workspace(self, session_id: str, force: bool = False) -> bool:
|
| 232 |
+
"""
|
| 233 |
+
Destroy a workspace for a session.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
session_id: Session identifier
|
| 237 |
+
force: Force removal even if files exist
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
bool: True if destroyed successfully
|
| 241 |
+
"""
|
| 242 |
+
if session_id not in self.workspaces:
|
| 243 |
+
logger.warning(f"No workspace found for session {session_id}")
|
| 244 |
+
return False
|
| 245 |
+
|
| 246 |
+
workspace_info = self.workspaces[session_id]
|
| 247 |
+
|
| 248 |
+
try:
|
| 249 |
+
logger.info(f"Destroying workspace for session {session_id}")
|
| 250 |
+
workspace_info.status = WorkspaceStatus.DESTROYING
|
| 251 |
+
|
| 252 |
+
# Remove workspace directory
|
| 253 |
+
if workspace_info.workspace_path.exists():
|
| 254 |
+
if force or not any(workspace_info.workspace_path.iterdir()):
|
| 255 |
+
shutil.rmtree(workspace_info.workspace_path)
|
| 256 |
+
logger.info(f"Workspace directory removed: {workspace_info.workspace_path}")
|
| 257 |
+
else:
|
| 258 |
+
logger.warning(f"Workspace contains files, use force=True to remove: {workspace_info.workspace_path}")
|
| 259 |
+
return False
|
| 260 |
+
|
| 261 |
+
# Update status and remove from active workspaces
|
| 262 |
+
workspace_info.status = WorkspaceStatus.DESTROYED
|
| 263 |
+
del self.workspaces[session_id]
|
| 264 |
+
|
| 265 |
+
logger.info(f"Workspace destroyed for session {session_id}")
|
| 266 |
+
return True
|
| 267 |
+
|
| 268 |
+
except Exception as e:
|
| 269 |
+
workspace_info.status = WorkspaceStatus.ERROR
|
| 270 |
+
workspace_info.error_message = str(e)
|
| 271 |
+
logger.error(f"Failed to destroy workspace for session {session_id}: {e}")
|
| 272 |
+
return False
|
| 273 |
+
|
| 274 |
+
def write_file(self, session_id: str, file_path: str, content: Union[str, bytes]) -> bool:
|
| 275 |
+
"""Write content to a file in the workspace"""
|
| 276 |
+
workspace_info = self.get_workspace(session_id)
|
| 277 |
+
if not workspace_info:
|
| 278 |
+
logger.error(f"No workspace found for session {session_id}")
|
| 279 |
+
return False
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
full_path = workspace_info.workspace_path / file_path
|
| 283 |
+
full_path.parent.mkdir(parents=True, exist_ok=True)
|
| 284 |
+
|
| 285 |
+
if isinstance(content, str):
|
| 286 |
+
with open(full_path, 'w', encoding='utf-8') as f:
|
| 287 |
+
f.write(content)
|
| 288 |
+
else:
|
| 289 |
+
with open(full_path, 'wb') as f:
|
| 290 |
+
f.write(content)
|
| 291 |
+
|
| 292 |
+
# Update workspace files list
|
| 293 |
+
self._update_workspace_files(session_id)
|
| 294 |
+
|
| 295 |
+
logger.info(f"File written to workspace: {file_path}")
|
| 296 |
+
return True
|
| 297 |
+
|
| 298 |
+
except Exception as e:
|
| 299 |
+
logger.error(f"Failed to write file {file_path} in workspace {session_id}: {e}")
|
| 300 |
+
return False
|
| 301 |
+
|
| 302 |
+
def read_file(self, session_id: str, file_path: str) -> Optional[Union[str, bytes]]:
|
| 303 |
+
"""Read content from a file in the workspace"""
|
| 304 |
+
workspace_info = self.get_workspace(session_id)
|
| 305 |
+
if not workspace_info:
|
| 306 |
+
logger.error(f"No workspace found for session {session_id}")
|
| 307 |
+
return None
|
| 308 |
+
|
| 309 |
+
try:
|
| 310 |
+
full_path = workspace_info.workspace_path / file_path
|
| 311 |
+
|
| 312 |
+
if not full_path.exists():
|
| 313 |
+
logger.error(f"File not found: {file_path}")
|
| 314 |
+
return None
|
| 315 |
+
|
| 316 |
+
# Try to read as text first
|
| 317 |
+
try:
|
| 318 |
+
with open(full_path, 'r', encoding='utf-8') as f:
|
| 319 |
+
return f.read()
|
| 320 |
+
except UnicodeDecodeError:
|
| 321 |
+
# If text reading fails, read as bytes
|
| 322 |
+
with open(full_path, 'rb') as f:
|
| 323 |
+
return f.read()
|
| 324 |
+
|
| 325 |
+
except Exception as e:
|
| 326 |
+
logger.error(f"Failed to read file {file_path} from workspace {session_id}: {e}")
|
| 327 |
+
return None
|
| 328 |
+
|
| 329 |
+
def list_files(self, session_id: str, directory: str = "") -> List[str]:
|
| 330 |
+
"""List files in the workspace directory"""
|
| 331 |
+
workspace_info = self.get_workspace(session_id)
|
| 332 |
+
if not workspace_info:
|
| 333 |
+
logger.error(f"No workspace found for session {session_id}")
|
| 334 |
+
return []
|
| 335 |
+
|
| 336 |
+
try:
|
| 337 |
+
target_path = workspace_info.workspace_path / directory if directory else workspace_info.workspace_path
|
| 338 |
+
|
| 339 |
+
if not target_path.exists():
|
| 340 |
+
return []
|
| 341 |
+
|
| 342 |
+
files = []
|
| 343 |
+
for item in target_path.rglob('*'):
|
| 344 |
+
if item.is_file() and not item.name.startswith('.'):
|
| 345 |
+
rel_path = item.relative_to(workspace_info.workspace_path)
|
| 346 |
+
files.append(str(rel_path))
|
| 347 |
+
|
| 348 |
+
return sorted(files)
|
| 349 |
+
|
| 350 |
+
except Exception as e:
|
| 351 |
+
logger.error(f"Failed to list files in workspace {session_id}: {e}")
|
| 352 |
+
return []
|
| 353 |
+
|
| 354 |
+
def _update_workspace_files(self, session_id: str):
|
| 355 |
+
"""Update the list of workspace files for a session."""
|
| 356 |
+
try:
|
| 357 |
+
workspace_info = self.workspaces.get(session_id)
|
| 358 |
+
if workspace_info:
|
| 359 |
+
files = self.list_files(session_id)
|
| 360 |
+
workspace_info.workspace_files = files
|
| 361 |
+
self._save_workspace_metadata(workspace_info)
|
| 362 |
+
except Exception as e:
|
| 363 |
+
logger.debug("Failed to update workspace files for session {%s}: {%s}", session_id, e)
|
| 364 |
+
|
| 365 |
+
def cleanup_expired_workspaces(self, max_age_hours: int = 24):
|
| 366 |
+
"""Clean up workspaces older than max_age_hours"""
|
| 367 |
+
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
| 368 |
+
expired_sessions = []
|
| 369 |
+
|
| 370 |
+
for session_id, workspace_info in self.workspaces.items():
|
| 371 |
+
if workspace_info.last_activity < cutoff_time:
|
| 372 |
+
expired_sessions.append(session_id)
|
| 373 |
+
|
| 374 |
+
for session_id in expired_sessions:
|
| 375 |
+
logger.info(f"Cleaning up expired workspace for session {session_id}")
|
| 376 |
+
self.destroy_workspace(session_id, force=True)
|
| 377 |
+
|
| 378 |
+
def shutdown(self):
|
| 379 |
+
"""Shutdown the workspace manager"""
|
| 380 |
+
logger.info("Shutting down LocalWorkspaceManager...")
|
| 381 |
+
|
| 382 |
+
if self.cleanup_on_exit:
|
| 383 |
+
# Clean up all workspaces
|
| 384 |
+
session_ids = list(self.workspaces.keys())
|
| 385 |
+
for session_id in session_ids:
|
| 386 |
+
self.destroy_workspace(session_id, force=True)
|
| 387 |
+
else:
|
| 388 |
+
# Just save metadata for all workspaces
|
| 389 |
+
for workspace_info in self.workspaces.values():
|
| 390 |
+
self._save_workspace_metadata(workspace_info)
|
| 391 |
+
|
| 392 |
+
logger.info("LocalWorkspaceManager shutdown complete")
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
# Global instance
|
| 396 |
+
_workspace_manager: Optional[LocalWorkspaceManager] = None
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def get_workspace_manager(base_workspace_dir: str = "workspaces") -> LocalWorkspaceManager:
|
| 400 |
+
"""Get or create the global workspace manager instance"""
|
| 401 |
+
global _workspace_manager
|
| 402 |
+
if _workspace_manager is None:
|
| 403 |
+
_workspace_manager = LocalWorkspaceManager(base_workspace_dir)
|
| 404 |
+
return _workspace_manager
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def initialize_workspace_manager(base_workspace_dir: str = "workspaces", **kwargs) -> LocalWorkspaceManager:
|
| 408 |
+
"""Initialize the workspace manager with custom settings"""
|
| 409 |
+
global _workspace_manager
|
| 410 |
+
_workspace_manager = LocalWorkspaceManager(base_workspace_dir, **kwargs)
|
| 411 |
+
logger.info(f"Workspace manager initialized with base directory: {base_workspace_dir}")
|
| 412 |
+
return _workspace_manager
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def shutdown_workspace_manager():
|
| 416 |
+
"""Shutdown the global workspace manager"""
|
| 417 |
+
global _workspace_manager
|
| 418 |
+
if _workspace_manager:
|
| 419 |
+
_workspace_manager.shutdown()
|
| 420 |
+
_workspace_manager = None
|
docs/openpangu-deepdiver-v2-tech-report.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d9ff32d4bd7190ea26049ef1f5d009b9861c652a980623a5f5cd043e7dcec2a4
|
| 3 |
+
size 39847395
|
generation_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"do_sample": true,
|
| 4 |
+
"bos_token_id": 1,
|
| 5 |
+
"pad_token_id": 0,
|
| 6 |
+
"eos_token_id": 45892,
|
| 7 |
+
"temperature": 1.0,
|
| 8 |
+
"top_k": 0,
|
| 9 |
+
"top_p": 0.8,
|
| 10 |
+
"transformers_version": "4.53.2"
|
| 11 |
+
}
|
model-00001-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7b8ec6cd94b1921560d37755c7c0c08280c1f9123195d14d352ad0607788f7f6
|
| 3 |
+
size 4926842416
|
model-00002-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fc05d80f52ce44d1433a942e867bf61ea49eb1eebb0700312f76d6b3a3dee917
|
| 3 |
+
size 4991686576
|
model-00003-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1ed37f38214c755b51bea06a71e154c9ea27670eb3b8506c06addcfbea2066f2
|
| 3 |
+
size 4886853760
|
model-00004-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0145e255ba965ed0e75164a037b9a0137c5e5c12ffc42463ff82568054fe0186
|
| 3 |
+
size 1256456320
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 16061784576
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"lm_head.weight": "model-00004-of-00004.safetensors",
|
| 7 |
+
"model.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
| 8 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 9 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 10 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 11 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 12 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 13 |
+
"model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 14 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 15 |
+
"model.layers.0.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 16 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 17 |
+
"model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 18 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 19 |
+
"model.layers.0.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 20 |
+
"model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 21 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 22 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 23 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 24 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 25 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 26 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 27 |
+
"model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 28 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 29 |
+
"model.layers.1.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 30 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 31 |
+
"model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 32 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 33 |
+
"model.layers.1.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 34 |
+
"model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 35 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 36 |
+
"model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 37 |
+
"model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 38 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 39 |
+
"model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 40 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 41 |
+
"model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 42 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 43 |
+
"model.layers.10.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 44 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 45 |
+
"model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 46 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 47 |
+
"model.layers.10.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 48 |
+
"model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 49 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 50 |
+
"model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 51 |
+
"model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 52 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 53 |
+
"model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 54 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 55 |
+
"model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 56 |
+
"model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 57 |
+
"model.layers.11.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 58 |
+
"model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 59 |
+
"model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 60 |
+
"model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 61 |
+
"model.layers.11.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 62 |
+
"model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 63 |
+
"model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 64 |
+
"model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 65 |
+
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 66 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 67 |
+
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 68 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 69 |
+
"model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 70 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 71 |
+
"model.layers.12.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 72 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 73 |
+
"model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 74 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 75 |
+
"model.layers.12.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 76 |
+
"model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 77 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 78 |
+
"model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 79 |
+
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 80 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 81 |
+
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 82 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 83 |
+
"model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 84 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 85 |
+
"model.layers.13.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 86 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 87 |
+
"model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 88 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 89 |
+
"model.layers.13.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 90 |
+
"model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 91 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 92 |
+
"model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 93 |
+
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 94 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 95 |
+
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 96 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 97 |
+
"model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 98 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 99 |
+
"model.layers.14.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 100 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 101 |
+
"model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 102 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 103 |
+
"model.layers.14.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 104 |
+
"model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 105 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 106 |
+
"model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 107 |
+
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 108 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 109 |
+
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 110 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 111 |
+
"model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 112 |
+
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 113 |
+
"model.layers.15.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 114 |
+
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 115 |
+
"model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 116 |
+
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 117 |
+
"model.layers.15.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 118 |
+
"model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 119 |
+
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 120 |
+
"model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 121 |
+
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 122 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 123 |
+
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 124 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 125 |
+
"model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 126 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 127 |
+
"model.layers.16.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 128 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 129 |
+
"model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 130 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 131 |
+
"model.layers.16.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 132 |
+
"model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 133 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 134 |
+
"model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 135 |
+
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 136 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 137 |
+
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 138 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 139 |
+
"model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 140 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 141 |
+
"model.layers.17.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 142 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 143 |
+
"model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 144 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 145 |
+
"model.layers.17.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 146 |
+
"model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 147 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 148 |
+
"model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 149 |
+
"model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 150 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 151 |
+
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 152 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 153 |
+
"model.layers.18.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 154 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 155 |
+
"model.layers.18.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 156 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 157 |
+
"model.layers.18.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 158 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 159 |
+
"model.layers.18.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 160 |
+
"model.layers.18.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 161 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 162 |
+
"model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 163 |
+
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 164 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 165 |
+
"model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 166 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 167 |
+
"model.layers.19.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 168 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 169 |
+
"model.layers.19.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 170 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 171 |
+
"model.layers.19.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 172 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 173 |
+
"model.layers.19.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 174 |
+
"model.layers.19.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 175 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 176 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 177 |
+
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 178 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 179 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 180 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 181 |
+
"model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 182 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 183 |
+
"model.layers.2.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 184 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 185 |
+
"model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 186 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 187 |
+
"model.layers.2.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 188 |
+
"model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 189 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 190 |
+
"model.layers.20.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 191 |
+
"model.layers.20.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 192 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 193 |
+
"model.layers.20.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 194 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 195 |
+
"model.layers.20.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 196 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 197 |
+
"model.layers.20.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 198 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 199 |
+
"model.layers.20.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 200 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 201 |
+
"model.layers.20.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 202 |
+
"model.layers.20.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 203 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 204 |
+
"model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 205 |
+
"model.layers.21.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 206 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 207 |
+
"model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 208 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 209 |
+
"model.layers.21.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 210 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 211 |
+
"model.layers.21.self_attn.o_proj.bias": "model-00002-of-00004.safetensors",
|
| 212 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 213 |
+
"model.layers.21.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 214 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 215 |
+
"model.layers.21.self_attn.rotary_emb.inv_freq": "model-00002-of-00004.safetensors",
|
| 216 |
+
"model.layers.21.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 217 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 218 |
+
"model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 219 |
+
"model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 220 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 221 |
+
"model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 222 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 223 |
+
"model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 224 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 225 |
+
"model.layers.22.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 226 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 227 |
+
"model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 228 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 229 |
+
"model.layers.22.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 230 |
+
"model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 231 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 232 |
+
"model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 233 |
+
"model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 234 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 235 |
+
"model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 236 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 237 |
+
"model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 238 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 239 |
+
"model.layers.23.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 240 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 241 |
+
"model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 242 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 243 |
+
"model.layers.23.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 244 |
+
"model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 245 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 246 |
+
"model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 247 |
+
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 248 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 249 |
+
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 250 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 251 |
+
"model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 252 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 253 |
+
"model.layers.24.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 254 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 255 |
+
"model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 256 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 257 |
+
"model.layers.24.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 258 |
+
"model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 259 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 260 |
+
"model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 261 |
+
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 262 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 263 |
+
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 264 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 265 |
+
"model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 266 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 267 |
+
"model.layers.25.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 268 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 269 |
+
"model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 270 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 271 |
+
"model.layers.25.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 272 |
+
"model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 273 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 274 |
+
"model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 275 |
+
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 276 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 277 |
+
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 278 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 279 |
+
"model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 280 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 281 |
+
"model.layers.26.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 282 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 283 |
+
"model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 284 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 285 |
+
"model.layers.26.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 286 |
+
"model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 287 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 288 |
+
"model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 289 |
+
"model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 290 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 291 |
+
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 292 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 293 |
+
"model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 294 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 295 |
+
"model.layers.27.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 296 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 297 |
+
"model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 298 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 299 |
+
"model.layers.27.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 300 |
+
"model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 301 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 302 |
+
"model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 303 |
+
"model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 304 |
+
"model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 305 |
+
"model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 306 |
+
"model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 307 |
+
"model.layers.28.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 308 |
+
"model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 309 |
+
"model.layers.28.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 310 |
+
"model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 311 |
+
"model.layers.28.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 312 |
+
"model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 313 |
+
"model.layers.28.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 314 |
+
"model.layers.28.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 315 |
+
"model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 316 |
+
"model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 317 |
+
"model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 318 |
+
"model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 319 |
+
"model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 320 |
+
"model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 321 |
+
"model.layers.29.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 322 |
+
"model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 323 |
+
"model.layers.29.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 324 |
+
"model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 325 |
+
"model.layers.29.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 326 |
+
"model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 327 |
+
"model.layers.29.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 328 |
+
"model.layers.29.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 329 |
+
"model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 330 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 331 |
+
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 332 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 333 |
+
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 334 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 335 |
+
"model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 336 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 337 |
+
"model.layers.3.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 338 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 339 |
+
"model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 340 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 341 |
+
"model.layers.3.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 342 |
+
"model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 343 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 344 |
+
"model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 345 |
+
"model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 346 |
+
"model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 347 |
+
"model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 348 |
+
"model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 349 |
+
"model.layers.30.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 350 |
+
"model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 351 |
+
"model.layers.30.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 352 |
+
"model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 353 |
+
"model.layers.30.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 354 |
+
"model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 355 |
+
"model.layers.30.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 356 |
+
"model.layers.30.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 357 |
+
"model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 358 |
+
"model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 359 |
+
"model.layers.31.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 360 |
+
"model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 361 |
+
"model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 362 |
+
"model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 363 |
+
"model.layers.31.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 364 |
+
"model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 365 |
+
"model.layers.31.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 366 |
+
"model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 367 |
+
"model.layers.31.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 368 |
+
"model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 369 |
+
"model.layers.31.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 370 |
+
"model.layers.31.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 371 |
+
"model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 372 |
+
"model.layers.32.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 373 |
+
"model.layers.32.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 374 |
+
"model.layers.32.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 375 |
+
"model.layers.32.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 376 |
+
"model.layers.32.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 377 |
+
"model.layers.32.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 378 |
+
"model.layers.32.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 379 |
+
"model.layers.32.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 380 |
+
"model.layers.32.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 381 |
+
"model.layers.32.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 382 |
+
"model.layers.32.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 383 |
+
"model.layers.32.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 384 |
+
"model.layers.32.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 385 |
+
"model.layers.32.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 386 |
+
"model.layers.33.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 387 |
+
"model.layers.33.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 388 |
+
"model.layers.33.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 389 |
+
"model.layers.33.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 390 |
+
"model.layers.33.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 391 |
+
"model.layers.33.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 392 |
+
"model.layers.33.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 393 |
+
"model.layers.33.self_attn.o_proj.bias": "model-00003-of-00004.safetensors",
|
| 394 |
+
"model.layers.33.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 395 |
+
"model.layers.33.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 396 |
+
"model.layers.33.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 397 |
+
"model.layers.33.self_attn.rotary_emb.inv_freq": "model-00003-of-00004.safetensors",
|
| 398 |
+
"model.layers.33.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 399 |
+
"model.layers.33.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 400 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 401 |
+
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 402 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 403 |
+
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 404 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 405 |
+
"model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 406 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 407 |
+
"model.layers.4.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 408 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 409 |
+
"model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 410 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 411 |
+
"model.layers.4.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 412 |
+
"model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 413 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 414 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 415 |
+
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 416 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 417 |
+
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 418 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 419 |
+
"model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 420 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 421 |
+
"model.layers.5.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 422 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 423 |
+
"model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 424 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 425 |
+
"model.layers.5.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 426 |
+
"model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 427 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 428 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 429 |
+
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 430 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 431 |
+
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 432 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 433 |
+
"model.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 434 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 435 |
+
"model.layers.6.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 436 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 437 |
+
"model.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 438 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 439 |
+
"model.layers.6.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 440 |
+
"model.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 441 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 442 |
+
"model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 443 |
+
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 444 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 445 |
+
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 446 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 447 |
+
"model.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 448 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 449 |
+
"model.layers.7.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 450 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 451 |
+
"model.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 452 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 453 |
+
"model.layers.7.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 454 |
+
"model.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 455 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 456 |
+
"model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 457 |
+
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 458 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 459 |
+
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 460 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 461 |
+
"model.layers.8.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 462 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 463 |
+
"model.layers.8.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 464 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 465 |
+
"model.layers.8.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 466 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 467 |
+
"model.layers.8.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 468 |
+
"model.layers.8.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 469 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 470 |
+
"model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 471 |
+
"model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 472 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 473 |
+
"model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 474 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 475 |
+
"model.layers.9.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 476 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 477 |
+
"model.layers.9.self_attn.o_proj.bias": "model-00001-of-00004.safetensors",
|
| 478 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 479 |
+
"model.layers.9.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 480 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 481 |
+
"model.layers.9.self_attn.rotary_emb.inv_freq": "model-00001-of-00004.safetensors",
|
| 482 |
+
"model.layers.9.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 483 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 484 |
+
"model.norm.weight": "model-00003-of-00004.safetensors"
|
| 485 |
+
}
|
| 486 |
+
}
|
modeling_openpangu_dense.py
ADDED
|
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 2 |
+
# This file was automatically generated from modular_openpangu_dense.py.
|
| 3 |
+
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
| 4 |
+
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
+
# modular_openpangu_dense.py file directly. One of our CI enforces this.
|
| 6 |
+
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
| 7 |
+
|
| 8 |
+
# coding=utf-8
|
| 9 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 10 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 13 |
+
# and OPT implementations in this library. It has been modified from its
|
| 14 |
+
# original forms to accommodate minor architectural differences compared
|
| 15 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 16 |
+
#
|
| 17 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 18 |
+
# you may not use this file except in compliance with the License.
|
| 19 |
+
# You may obtain a copy of the License at
|
| 20 |
+
#
|
| 21 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 22 |
+
#
|
| 23 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 24 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 25 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 26 |
+
# See the License for the specific language governing permissions and
|
| 27 |
+
# limitations under the License.
|
| 28 |
+
|
| 29 |
+
from typing import Callable, Optional, Union
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
from torch import nn
|
| 33 |
+
|
| 34 |
+
import torch_npu
|
| 35 |
+
from torch_npu.contrib import transfer_to_npu
|
| 36 |
+
if "910" in torch.npu.get_device_name():
|
| 37 |
+
NPU_ATTN_INFR = True
|
| 38 |
+
print("[INFO] torch_npu detected. Using NPU fused infer attention.")
|
| 39 |
+
else:
|
| 40 |
+
NPU_ATTN_INFR = False
|
| 41 |
+
|
| 42 |
+
from transformers.activations import ACT2FN
|
| 43 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 44 |
+
from transformers.generation import GenerationMixin
|
| 45 |
+
from transformers.masking_utils import create_causal_mask
|
| 46 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 47 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 48 |
+
from transformers.modeling_outputs import (
|
| 49 |
+
BaseModelOutputWithPast,
|
| 50 |
+
CausalLMOutputWithPast,
|
| 51 |
+
SequenceClassifierOutputWithPast,
|
| 52 |
+
)
|
| 53 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 54 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 55 |
+
from transformers.processing_utils import Unpack
|
| 56 |
+
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
| 57 |
+
from .configuration_openpangu_dense import PanguEmbeddedConfig
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
logger = logging.get_logger(__name__)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class PanguEmbeddedRMSNorm(nn.Module):
|
| 64 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 65 |
+
"""
|
| 66 |
+
PanguEmbeddedRMSNorm is equivalent to T5LayerNorm
|
| 67 |
+
"""
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 70 |
+
self.variance_epsilon = eps
|
| 71 |
+
|
| 72 |
+
def forward(self, hidden_states):
|
| 73 |
+
input_dtype = hidden_states.dtype
|
| 74 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 75 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 76 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 77 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 78 |
+
|
| 79 |
+
def extra_repr(self):
|
| 80 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class PanguEmbeddedRotaryEmbedding(nn.Module):
|
| 84 |
+
def __init__(self, config: PanguEmbeddedConfig, device=None):
|
| 85 |
+
super().__init__()
|
| 86 |
+
# BC: "rope_type" was originally "type"
|
| 87 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
| 88 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 89 |
+
else:
|
| 90 |
+
self.rope_type = "default"
|
| 91 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 92 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 93 |
+
|
| 94 |
+
self.config = config
|
| 95 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 96 |
+
|
| 97 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 98 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 99 |
+
self.original_inv_freq = self.inv_freq
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 103 |
+
def forward(self, x, position_ids):
|
| 104 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 105 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 106 |
+
|
| 107 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 108 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 109 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 110 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 111 |
+
cos = emb.cos() * self.attention_scaling
|
| 112 |
+
sin = emb.sin() * self.attention_scaling
|
| 113 |
+
|
| 114 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def rotate_half(x):
|
| 118 |
+
"""Rotates half the hidden dims of the input."""
|
| 119 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 120 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 121 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 125 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
q (`torch.Tensor`): The query tensor.
|
| 129 |
+
k (`torch.Tensor`): The key tensor.
|
| 130 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 131 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 132 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 133 |
+
Deprecated and unused.
|
| 134 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 135 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 136 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 137 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 138 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 139 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 140 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 141 |
+
Returns:
|
| 142 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 143 |
+
"""
|
| 144 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 145 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 146 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 147 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 148 |
+
return q_embed, k_embed
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class PanguEmbeddedMLP(nn.Module):
|
| 152 |
+
def __init__(self, config):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.config = config
|
| 155 |
+
self.hidden_size = config.hidden_size
|
| 156 |
+
self.intermediate_size = config.intermediate_size
|
| 157 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 158 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 159 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 160 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 161 |
+
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 164 |
+
return down_proj
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 168 |
+
"""
|
| 169 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 170 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 171 |
+
"""
|
| 172 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 173 |
+
if n_rep == 1:
|
| 174 |
+
return hidden_states
|
| 175 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 176 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def eager_attention_forward(
|
| 180 |
+
module: nn.Module,
|
| 181 |
+
query: torch.Tensor,
|
| 182 |
+
key: torch.Tensor,
|
| 183 |
+
value: torch.Tensor,
|
| 184 |
+
attention_mask: Optional[torch.Tensor],
|
| 185 |
+
scaling: float,
|
| 186 |
+
dropout: float = 0.0,
|
| 187 |
+
**kwargs,
|
| 188 |
+
):
|
| 189 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 190 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 191 |
+
|
| 192 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 193 |
+
if attention_mask is not None:
|
| 194 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 195 |
+
attn_weights = attn_weights + causal_mask
|
| 196 |
+
|
| 197 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 198 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 199 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 200 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 201 |
+
|
| 202 |
+
return attn_output, attn_weights
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class PanguEmbeddedAttention(nn.Module):
|
| 206 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 207 |
+
|
| 208 |
+
def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.config = config
|
| 211 |
+
self.layer_idx = layer_idx
|
| 212 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 213 |
+
self.num_heads = config.num_attention_heads
|
| 214 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 215 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 216 |
+
self.scaling = self.head_dim**-0.5
|
| 217 |
+
self.attention_dropout = config.attention_dropout
|
| 218 |
+
self.is_causal = True
|
| 219 |
+
|
| 220 |
+
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.bias)
|
| 221 |
+
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
|
| 222 |
+
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
|
| 223 |
+
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)
|
| 224 |
+
|
| 225 |
+
def forward(
|
| 226 |
+
self,
|
| 227 |
+
hidden_states: torch.Tensor,
|
| 228 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 229 |
+
attention_mask: Optional[torch.Tensor],
|
| 230 |
+
past_key_value: Optional[Cache] = None,
|
| 231 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 232 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 233 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 234 |
+
input_shape = hidden_states.shape[:-1]
|
| 235 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 236 |
+
|
| 237 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 238 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 239 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 240 |
+
|
| 241 |
+
cos, sin = position_embeddings
|
| 242 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 243 |
+
|
| 244 |
+
if past_key_value is not None:
|
| 245 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 246 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 247 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 248 |
+
|
| 249 |
+
attention_interface: Callable = eager_attention_forward
|
| 250 |
+
if self.config._attn_implementation != "eager":
|
| 251 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 252 |
+
|
| 253 |
+
if not self.training and NPU_ATTN_INFR:
|
| 254 |
+
q_len = input_shape[1]
|
| 255 |
+
if attention_mask is not None:
|
| 256 |
+
attention_mask = ~attention_mask.bool()
|
| 257 |
+
elif q_len > 1:
|
| 258 |
+
attention_mask = torch.triu(torch.ones([q_len, q_len]), diagonal=1).bool().unsqueeze(0).unsqueeze(0).to(query_states.device)
|
| 259 |
+
|
| 260 |
+
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
| 261 |
+
query_states, key_states, value_states,
|
| 262 |
+
num_heads=self.num_heads, num_key_value_heads=self.num_key_value_heads,
|
| 263 |
+
input_layout="BNSD", atten_mask=attention_mask, scale=self.scaling)
|
| 264 |
+
attn_output = attn_output.transpose(1, 2)
|
| 265 |
+
attn_weights = None
|
| 266 |
+
else:
|
| 267 |
+
attn_output, attn_weights = attention_interface(
|
| 268 |
+
self,
|
| 269 |
+
query_states,
|
| 270 |
+
key_states,
|
| 271 |
+
value_states,
|
| 272 |
+
attention_mask,
|
| 273 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 274 |
+
scaling=self.scaling,
|
| 275 |
+
**kwargs,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 279 |
+
attn_output = self.o_proj(attn_output)
|
| 280 |
+
return attn_output, attn_weights
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class PanguEmbeddedDecoderLayer(GradientCheckpointingLayer):
|
| 284 |
+
def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.hidden_size = config.hidden_size
|
| 287 |
+
self.self_attn = PanguEmbeddedAttention(config=config, layer_idx=layer_idx)
|
| 288 |
+
self.mlp = PanguEmbeddedMLP(config)
|
| 289 |
+
self.input_layernorm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 290 |
+
self.post_attention_layernorm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 291 |
+
|
| 292 |
+
def forward(
|
| 293 |
+
self,
|
| 294 |
+
hidden_states: torch.Tensor,
|
| 295 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 296 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 297 |
+
past_key_value: Optional[Cache] = None,
|
| 298 |
+
output_attentions: Optional[bool] = False,
|
| 299 |
+
use_cache: Optional[bool] = False,
|
| 300 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 301 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 302 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 303 |
+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 304 |
+
residual = hidden_states
|
| 305 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 306 |
+
|
| 307 |
+
# Self Attention
|
| 308 |
+
hidden_states, self_attn_weights = self.self_attn(
|
| 309 |
+
hidden_states=hidden_states,
|
| 310 |
+
attention_mask=attention_mask,
|
| 311 |
+
position_ids=position_ids,
|
| 312 |
+
past_key_value=past_key_value,
|
| 313 |
+
output_attentions=output_attentions,
|
| 314 |
+
use_cache=use_cache,
|
| 315 |
+
cache_position=cache_position,
|
| 316 |
+
position_embeddings=position_embeddings,
|
| 317 |
+
**kwargs,
|
| 318 |
+
)
|
| 319 |
+
hidden_states = residual + hidden_states
|
| 320 |
+
|
| 321 |
+
# Fully Connected
|
| 322 |
+
residual = hidden_states
|
| 323 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 324 |
+
hidden_states = self.mlp(hidden_states)
|
| 325 |
+
hidden_states = residual + hidden_states
|
| 326 |
+
|
| 327 |
+
outputs = (hidden_states,)
|
| 328 |
+
if output_attentions:
|
| 329 |
+
outputs += (self_attn_weights,)
|
| 330 |
+
|
| 331 |
+
return outputs
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
@auto_docstring
|
| 335 |
+
class PanguEmbeddedPreTrainedModel(PreTrainedModel):
|
| 336 |
+
config_class = PanguEmbeddedConfig
|
| 337 |
+
base_model_prefix = "model"
|
| 338 |
+
supports_gradient_checkpointing = True
|
| 339 |
+
_no_split_modules = ["PanguEmbeddedDecoderLayer"]
|
| 340 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 341 |
+
_supports_flash_attn_3 = True
|
| 342 |
+
_supports_flash_attn_2 = True
|
| 343 |
+
_supports_sdpa = True
|
| 344 |
+
_supports_flex_attn = True
|
| 345 |
+
_supports_cache_class = True
|
| 346 |
+
_supports_quantized_cache = True
|
| 347 |
+
_supports_static_cache = True
|
| 348 |
+
_supports_attention_backend = True
|
| 349 |
+
|
| 350 |
+
def _init_weights(self, module):
|
| 351 |
+
std = self.config.initializer_range
|
| 352 |
+
if isinstance(module, nn.Linear):
|
| 353 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 354 |
+
if module.bias is not None:
|
| 355 |
+
module.bias.data.zero_()
|
| 356 |
+
elif isinstance(module, nn.Embedding):
|
| 357 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 358 |
+
if module.padding_idx is not None:
|
| 359 |
+
module.weight.data[module.padding_idx].zero_()
|
| 360 |
+
elif isinstance(module, PanguEmbeddedRMSNorm):
|
| 361 |
+
module.weight.data.fill_(1.0)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
@auto_docstring
|
| 365 |
+
class PanguEmbeddedModel(PanguEmbeddedPreTrainedModel):
|
| 366 |
+
def __init__(self, config: PanguEmbeddedConfig):
|
| 367 |
+
super().__init__(config)
|
| 368 |
+
self.padding_idx = config.pad_token_id
|
| 369 |
+
self.vocab_size = config.vocab_size
|
| 370 |
+
|
| 371 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 372 |
+
self.layers = nn.ModuleList(
|
| 373 |
+
[PanguEmbeddedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 374 |
+
)
|
| 375 |
+
self.norm = PanguEmbeddedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 376 |
+
self.rotary_emb = PanguEmbeddedRotaryEmbedding(config=config)
|
| 377 |
+
self.gradient_checkpointing = False
|
| 378 |
+
|
| 379 |
+
# Initialize weights and apply final processing
|
| 380 |
+
self.post_init()
|
| 381 |
+
|
| 382 |
+
def get_input_embeddings(self):
|
| 383 |
+
return self.embed_tokens
|
| 384 |
+
|
| 385 |
+
def set_input_embeddings(self, value):
|
| 386 |
+
self.embed_tokens = value
|
| 387 |
+
|
| 388 |
+
@can_return_tuple
|
| 389 |
+
@auto_docstring
|
| 390 |
+
def forward(
|
| 391 |
+
self,
|
| 392 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 393 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 394 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 395 |
+
past_key_values: Optional[Cache] = None,
|
| 396 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 397 |
+
use_cache: Optional[bool] = None,
|
| 398 |
+
output_attentions: Optional[bool] = None,
|
| 399 |
+
output_hidden_states: Optional[bool] = None,
|
| 400 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 401 |
+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 402 |
+
) -> BaseModelOutputWithPast:
|
| 403 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 404 |
+
output_hidden_states = (
|
| 405 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 406 |
+
)
|
| 407 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 408 |
+
|
| 409 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 410 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 411 |
+
|
| 412 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 413 |
+
logger.warning_once(
|
| 414 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
| 415 |
+
)
|
| 416 |
+
use_cache = False
|
| 417 |
+
|
| 418 |
+
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
| 419 |
+
if not isinstance(past_key_values, (type(None), Cache)):
|
| 420 |
+
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
| 421 |
+
|
| 422 |
+
if inputs_embeds is None:
|
| 423 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 424 |
+
|
| 425 |
+
if use_cache and past_key_values is None:
|
| 426 |
+
past_key_values = DynamicCache()
|
| 427 |
+
|
| 428 |
+
if cache_position is None:
|
| 429 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 430 |
+
cache_position = torch.arange(
|
| 431 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if position_ids is None:
|
| 435 |
+
position_ids = cache_position.unsqueeze(0)
|
| 436 |
+
|
| 437 |
+
causal_mask = create_causal_mask(
|
| 438 |
+
config=self.config,
|
| 439 |
+
input_embeds=inputs_embeds,
|
| 440 |
+
attention_mask=attention_mask,
|
| 441 |
+
cache_position=cache_position,
|
| 442 |
+
past_key_values=past_key_values,
|
| 443 |
+
position_ids=position_ids,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
hidden_states = inputs_embeds
|
| 447 |
+
|
| 448 |
+
# create position embeddings to be shared across the decoder layers
|
| 449 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 450 |
+
|
| 451 |
+
# decoder layers
|
| 452 |
+
all_hidden_states = () if output_hidden_states else None
|
| 453 |
+
all_self_attns = () if output_attentions else None
|
| 454 |
+
|
| 455 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 456 |
+
if output_hidden_states:
|
| 457 |
+
all_hidden_states += (hidden_states,)
|
| 458 |
+
|
| 459 |
+
layer_outputs = decoder_layer(
|
| 460 |
+
hidden_states,
|
| 461 |
+
attention_mask=causal_mask,
|
| 462 |
+
position_ids=position_ids,
|
| 463 |
+
past_key_value=past_key_values,
|
| 464 |
+
output_attentions=output_attentions,
|
| 465 |
+
use_cache=use_cache,
|
| 466 |
+
cache_position=cache_position,
|
| 467 |
+
position_embeddings=position_embeddings,
|
| 468 |
+
**flash_attn_kwargs,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
hidden_states = layer_outputs[0]
|
| 472 |
+
|
| 473 |
+
if output_attentions:
|
| 474 |
+
all_self_attns += (layer_outputs[1],)
|
| 475 |
+
|
| 476 |
+
hidden_states = self.norm(hidden_states)
|
| 477 |
+
|
| 478 |
+
# add hidden states from the last decoder layer
|
| 479 |
+
if output_hidden_states:
|
| 480 |
+
all_hidden_states += (hidden_states,)
|
| 481 |
+
|
| 482 |
+
return BaseModelOutputWithPast(
|
| 483 |
+
last_hidden_state=hidden_states,
|
| 484 |
+
past_key_values=past_key_values if use_cache else None,
|
| 485 |
+
hidden_states=all_hidden_states,
|
| 486 |
+
attentions=all_self_attns,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
@auto_docstring
|
| 494 |
+
class PanguEmbeddedForCausalLM(PanguEmbeddedPreTrainedModel, GenerationMixin):
|
| 495 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 496 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 497 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 498 |
+
|
| 499 |
+
def __init__(self, config):
|
| 500 |
+
super().__init__(config)
|
| 501 |
+
self.model = PanguEmbeddedModel(config)
|
| 502 |
+
self.vocab_size = config.vocab_size
|
| 503 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 504 |
+
|
| 505 |
+
# Initialize weights and apply final processing
|
| 506 |
+
self.post_init()
|
| 507 |
+
|
| 508 |
+
def get_input_embeddings(self):
|
| 509 |
+
return self.model.embed_tokens
|
| 510 |
+
|
| 511 |
+
def set_input_embeddings(self, value):
|
| 512 |
+
self.model.embed_tokens = value
|
| 513 |
+
|
| 514 |
+
def get_output_embeddings(self):
|
| 515 |
+
return self.lm_head
|
| 516 |
+
|
| 517 |
+
def set_output_embeddings(self, new_embeddings):
|
| 518 |
+
self.lm_head = new_embeddings
|
| 519 |
+
|
| 520 |
+
def set_decoder(self, decoder):
|
| 521 |
+
self.model = decoder
|
| 522 |
+
|
| 523 |
+
def get_decoder(self):
|
| 524 |
+
return self.model
|
| 525 |
+
|
| 526 |
+
@can_return_tuple
|
| 527 |
+
@auto_docstring
|
| 528 |
+
def forward(
|
| 529 |
+
self,
|
| 530 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 531 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 532 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 533 |
+
past_key_values: Optional[Cache] = None,
|
| 534 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 535 |
+
labels: Optional[torch.LongTensor] = None,
|
| 536 |
+
use_cache: Optional[bool] = None,
|
| 537 |
+
output_attentions: Optional[bool] = None,
|
| 538 |
+
output_hidden_states: Optional[bool] = None,
|
| 539 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 540 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 541 |
+
**kwargs: Unpack[KwargsForCausalLM],
|
| 542 |
+
) -> CausalLMOutputWithPast:
|
| 543 |
+
|
| 544 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 545 |
+
output_hidden_states = (
|
| 546 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 550 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 551 |
+
input_ids=input_ids,
|
| 552 |
+
attention_mask=attention_mask,
|
| 553 |
+
position_ids=position_ids,
|
| 554 |
+
past_key_values=past_key_values,
|
| 555 |
+
inputs_embeds=inputs_embeds,
|
| 556 |
+
use_cache=use_cache,
|
| 557 |
+
output_attentions=output_attentions,
|
| 558 |
+
output_hidden_states=output_hidden_states,
|
| 559 |
+
cache_position=cache_position,
|
| 560 |
+
**kwargs,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
hidden_states = outputs.last_hidden_state
|
| 564 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 565 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 566 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 567 |
+
|
| 568 |
+
loss = None
|
| 569 |
+
if labels is not None:
|
| 570 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 571 |
+
|
| 572 |
+
return CausalLMOutputWithPast(
|
| 573 |
+
loss=loss,
|
| 574 |
+
logits=logits,
|
| 575 |
+
past_key_values=outputs.past_key_values,
|
| 576 |
+
hidden_states=outputs.hidden_states,
|
| 577 |
+
attentions=outputs.attentions,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
__all__ = [
|
| 582 |
+
"PanguEmbeddedForCausalLM",
|
| 583 |
+
"PanguEmbeddedModel",
|
| 584 |
+
"PanguEmbeddedPreTrainedModel",
|
| 585 |
+
]
|
modular_openpangu_dense.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 6 |
+
# and OPT implementations in this library. It has been modified from its
|
| 7 |
+
# original forms to accommodate minor architectural differences compared
|
| 8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
from typing import Callable, Optional, Tuple
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
|
| 27 |
+
import torch_npu
|
| 28 |
+
from torch_npu.contrib import transfer_to_npu
|
| 29 |
+
if "910" in torch.npu.get_device_name():
|
| 30 |
+
NPU_ATTN_INFR = True
|
| 31 |
+
print("[INFO] torch_npu detected. Using NPU fused infer attention.")
|
| 32 |
+
else:
|
| 33 |
+
NPU_ATTN_INFR = False
|
| 34 |
+
|
| 35 |
+
from transformers.cache_utils import Cache
|
| 36 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 37 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
| 38 |
+
from transformers.processing_utils import Unpack
|
| 39 |
+
from transformers.utils import logging
|
| 40 |
+
from transformers.models.llama.modeling_llama import (
|
| 41 |
+
LlamaAttention,
|
| 42 |
+
LlamaDecoderLayer,
|
| 43 |
+
LlamaForCausalLM,
|
| 44 |
+
LlamaForSequenceClassification,
|
| 45 |
+
LlamaMLP,
|
| 46 |
+
LlamaModel,
|
| 47 |
+
apply_rotary_pos_emb,
|
| 48 |
+
eager_attention_forward,
|
| 49 |
+
)
|
| 50 |
+
from .configuration_openpangu_dense import PanguEmbeddedConfig
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
logger = logging.get_logger(__name__)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class PanguEmbeddedMLP(LlamaMLP):
|
| 57 |
+
def __init__(self, config):
|
| 58 |
+
super().__init__(config)
|
| 59 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 60 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 61 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class PanguEmbeddedAttention(LlamaAttention):
|
| 65 |
+
def __init__(self, config: PanguEmbeddedConfig, layer_idx: int):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.config = config
|
| 68 |
+
self.layer_idx = layer_idx
|
| 69 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 70 |
+
self.num_heads = config.num_attention_heads
|
| 71 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 72 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 73 |
+
self.scaling = self.head_dim**-0.5
|
| 74 |
+
self.attention_dropout = config.attention_dropout
|
| 75 |
+
self.is_causal = True
|
| 76 |
+
|
| 77 |
+
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.bias)
|
| 78 |
+
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
|
| 79 |
+
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.bias)
|
| 80 |
+
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.bias)
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
hidden_states: torch.Tensor,
|
| 85 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 86 |
+
attention_mask: Optional[torch.Tensor],
|
| 87 |
+
past_key_value: Optional[Cache] = None,
|
| 88 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 89 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 90 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 91 |
+
input_shape = hidden_states.shape[:-1]
|
| 92 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 93 |
+
|
| 94 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 95 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 96 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 97 |
+
|
| 98 |
+
cos, sin = position_embeddings
|
| 99 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 100 |
+
|
| 101 |
+
if past_key_value is not None:
|
| 102 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 103 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 104 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 105 |
+
|
| 106 |
+
attention_interface: Callable = eager_attention_forward
|
| 107 |
+
if self.config._attn_implementation != "eager":
|
| 108 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 109 |
+
|
| 110 |
+
if not self.training and NPU_ATTN_INFR:
|
| 111 |
+
q_len = input_shape[1]
|
| 112 |
+
if attention_mask is not None:
|
| 113 |
+
attention_mask = ~attention_mask.bool()
|
| 114 |
+
elif q_len > 1:
|
| 115 |
+
attention_mask = torch.triu(torch.ones([q_len, q_len]), diagonal=1).bool().unsqueeze(0).unsqueeze(0).to(query_states.device)
|
| 116 |
+
|
| 117 |
+
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
| 118 |
+
query_states, key_states, value_states,
|
| 119 |
+
num_heads=self.num_heads, num_key_value_heads=self.num_key_value_heads,
|
| 120 |
+
input_layout="BNSD", atten_mask=attention_mask, scale=self.scaling)
|
| 121 |
+
attn_output = attn_output.transpose(1, 2)
|
| 122 |
+
attn_weights = None
|
| 123 |
+
else:
|
| 124 |
+
attn_output, attn_weights = attention_interface(
|
| 125 |
+
self,
|
| 126 |
+
query_states,
|
| 127 |
+
key_states,
|
| 128 |
+
value_states,
|
| 129 |
+
attention_mask,
|
| 130 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 131 |
+
scaling=self.scaling,
|
| 132 |
+
**kwargs,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 136 |
+
attn_output = self.o_proj(attn_output)
|
| 137 |
+
return attn_output, attn_weights
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class PanguEmbeddedDecoderLayer(LlamaDecoderLayer):
|
| 141 |
+
pass
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class PanguEmbeddedModel(LlamaModel):
|
| 145 |
+
pass
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class PanguEmbeddedForCausalLM(LlamaForCausalLM):
|
| 149 |
+
pass
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "[unused10]",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "<unk>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"unk_token": {
|
| 24 |
+
"content": "<unk>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
}
|
| 30 |
+
}
|
tokenization_openpangu.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 6 |
+
# and OPT implementations in this library. It has been modified from its
|
| 7 |
+
# original forms to accommodate minor architectural differences compared
|
| 8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
from shutil import copyfile
|
| 24 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 25 |
+
|
| 26 |
+
import sentencepiece as spm
|
| 27 |
+
|
| 28 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 29 |
+
from transformers.utils import logging
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
|
| 35 |
+
|
| 36 |
+
PRETRAINED_VOCAB_FILES_MAP = {}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def convert_bool(string):
|
| 40 |
+
if isinstance(string, str):
|
| 41 |
+
if string.lower() == "true":
|
| 42 |
+
return True
|
| 43 |
+
elif string.lower() == "false":
|
| 44 |
+
return False
|
| 45 |
+
else:
|
| 46 |
+
return string
|
| 47 |
+
else:
|
| 48 |
+
return string
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class PanguTokenizer(PreTrainedTokenizer):
|
| 52 |
+
"""
|
| 53 |
+
Construct a tokenizer. Based on byte-level Byte-Pair-Encoding.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
vocab_file (`str`):
|
| 57 |
+
Path to the vocabulary file.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 61 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
| 62 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 63 |
+
_auto_class = "AutoTokenizer"
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
vocab_file,
|
| 68 |
+
unk_token="<unk>",
|
| 69 |
+
bos_token="<s>",
|
| 70 |
+
eos_token="</s>",
|
| 71 |
+
pad_token="</s>",
|
| 72 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 73 |
+
add_bos_token=True,
|
| 74 |
+
add_eos_token=False,
|
| 75 |
+
decode_with_prefix_space=False,
|
| 76 |
+
clean_up_tokenization_spaces=False,
|
| 77 |
+
**kwargs,
|
| 78 |
+
):
|
| 79 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 80 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 81 |
+
self.sp_model.Load(vocab_file)
|
| 82 |
+
super().__init__(
|
| 83 |
+
bos_token=bos_token,
|
| 84 |
+
eos_token=eos_token,
|
| 85 |
+
unk_token=unk_token,
|
| 86 |
+
pad_token=pad_token,
|
| 87 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 88 |
+
**kwargs,
|
| 89 |
+
)
|
| 90 |
+
self.vocab_file = vocab_file
|
| 91 |
+
self.add_bos_token = convert_bool(add_bos_token)
|
| 92 |
+
self.add_eos_token = add_eos_token
|
| 93 |
+
self.decode_with_prefix_space = decode_with_prefix_space
|
| 94 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 95 |
+
self.sp_model.Load(vocab_file)
|
| 96 |
+
self._no_prefix_space_tokens = None
|
| 97 |
+
|
| 98 |
+
""" Initialisation"""
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def no_prefix_space_tokens(self):
|
| 102 |
+
if self._no_prefix_space_tokens is None:
|
| 103 |
+
vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
|
| 104 |
+
self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
|
| 105 |
+
return self._no_prefix_space_tokens
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def vocab_size(self):
|
| 109 |
+
"""Returns vocab size"""
|
| 110 |
+
return self.sp_model.get_piece_size()
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def bos_token_id(self) -> Optional[int]:
|
| 114 |
+
return self.sp_model.bos_id()
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def eos_token_id(self) -> Optional[int]:
|
| 118 |
+
return super().eos_token_id
|
| 119 |
+
|
| 120 |
+
def get_vocab(self):
|
| 121 |
+
"""Returns vocab as a dict"""
|
| 122 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 123 |
+
vocab.update(self.added_tokens_encoder)
|
| 124 |
+
return vocab
|
| 125 |
+
|
| 126 |
+
def _tokenize(self, text):
|
| 127 |
+
"""Returns a tokenized string."""
|
| 128 |
+
return self.sp_model.encode(text, out_type=str)
|
| 129 |
+
|
| 130 |
+
def _convert_token_to_id(self, token):
|
| 131 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 132 |
+
return self.sp_model.piece_to_id(token)
|
| 133 |
+
|
| 134 |
+
def _convert_id_to_token(self, index):
|
| 135 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 136 |
+
token = self.sp_model.IdToPiece(index)
|
| 137 |
+
return token
|
| 138 |
+
|
| 139 |
+
def _maybe_add_prefix_space(self, tokens, decoded):
|
| 140 |
+
if tokens and tokens[0] not in self.no_prefix_space_tokens:
|
| 141 |
+
return " " + decoded
|
| 142 |
+
else:
|
| 143 |
+
return decoded
|
| 144 |
+
|
| 145 |
+
def convert_tokens_to_string(self, tokens):
|
| 146 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 147 |
+
current_sub_tokens = []
|
| 148 |
+
out_string = ""
|
| 149 |
+
prev_is_special = False
|
| 150 |
+
for token in tokens:
|
| 151 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 152 |
+
if token in self.all_special_tokens:
|
| 153 |
+
# Decode the current sub-tokens first
|
| 154 |
+
if current_sub_tokens:
|
| 155 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 156 |
+
current_sub_tokens = []
|
| 157 |
+
# Append the special token without adding extra spaces
|
| 158 |
+
out_string += token
|
| 159 |
+
prev_is_special = True
|
| 160 |
+
else:
|
| 161 |
+
current_sub_tokens.append(token)
|
| 162 |
+
prev_is_special = False
|
| 163 |
+
# Decode any remaining sub-tokens
|
| 164 |
+
if current_sub_tokens:
|
| 165 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 166 |
+
# Clean up leading and trailing spaces
|
| 167 |
+
if self.clean_up_tokenization_spaces:
|
| 168 |
+
out_string = self.clean_up_tokenization(out_string)
|
| 169 |
+
out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
|
| 170 |
+
return out_string[1:]
|
| 171 |
+
|
| 172 |
+
# Override decode to set spaces_between_special_tokens to True as default
|
| 173 |
+
def decode(self,
|
| 174 |
+
token_ids,
|
| 175 |
+
spaces_between_special_tokens: bool = False,
|
| 176 |
+
**kwargs):
|
| 177 |
+
return super().decode(
|
| 178 |
+
token_ids=token_ids,
|
| 179 |
+
spaces_between_special_tokens=spaces_between_special_tokens,
|
| 180 |
+
**kwargs,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 184 |
+
"""
|
| 185 |
+
Save the vocabulary and special tokens file to a directory.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
save_directory (`str`):
|
| 189 |
+
The directory in which to save the vocabulary.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
`Tuple(str)`: Paths to the files saved.
|
| 193 |
+
"""
|
| 194 |
+
if not os.path.isdir(save_directory):
|
| 195 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 196 |
+
return ("",)
|
| 197 |
+
out_vocab_file = os.path.join(
|
| 198 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 202 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 203 |
+
elif not os.path.isfile(self.vocab_file):
|
| 204 |
+
with open(out_vocab_file, "wb") as fi:
|
| 205 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 206 |
+
fi.write(content_spiece_model)
|
| 207 |
+
|
| 208 |
+
return (out_vocab_file,)
|
| 209 |
+
|
| 210 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 211 |
+
if self.add_bos_token:
|
| 212 |
+
bos_token_ids = [self.bos_token_id]
|
| 213 |
+
else:
|
| 214 |
+
bos_token_ids = []
|
| 215 |
+
|
| 216 |
+
output = bos_token_ids + token_ids_0
|
| 217 |
+
|
| 218 |
+
if token_ids_1 is not None:
|
| 219 |
+
output = output + token_ids_1
|
| 220 |
+
|
| 221 |
+
if self.add_eos_token:
|
| 222 |
+
output = output + [self.eos_token_id]
|
| 223 |
+
|
| 224 |
+
return output
|
| 225 |
+
|
| 226 |
+
def get_special_tokens_mask(
|
| 227 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 228 |
+
) -> List[int]:
|
| 229 |
+
"""
|
| 230 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 231 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
token_ids_0 (`List[int]`):
|
| 235 |
+
List of IDs.
|
| 236 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 237 |
+
Optional second list of IDs for sequence pairs.
|
| 238 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 239 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 243 |
+
"""
|
| 244 |
+
if already_has_special_tokens:
|
| 245 |
+
return super().get_special_tokens_mask(
|
| 246 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if token_ids_1 is None:
|
| 250 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 251 |
+
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
| 252 |
+
|
| 253 |
+
def create_token_type_ids_from_sequences(
|
| 254 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 255 |
+
) -> List[int]:
|
| 256 |
+
"""
|
| 257 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
| 258 |
+
use of token type ids, therefore a list of zeros is returned.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
token_ids_0 (`List[int]`):
|
| 262 |
+
List of IDs.
|
| 263 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 264 |
+
Optional second list of IDs for sequence pairs.
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
`List[int]`: List of zeros.
|
| 268 |
+
"""
|
| 269 |
+
eos = [self.eos_token_id]
|
| 270 |
+
|
| 271 |
+
if token_ids_1 is None:
|
| 272 |
+
return len(token_ids_0 + eos) * [0]
|
| 273 |
+
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b16f1558c0cd4ae6ef1a2c605713be0a514f50e1ce2d2c878979ce988c148ec
|
| 3 |
+
size 2477809
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"add_bos_token": true, "add_eos_token": false, "add_prefix_space": true, "added_tokens_decoder": {"0": {"content": "<unk>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "1": {"content": "<s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "2": {"content": "</s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45806": {"content": "<|User|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45813": {"content": "<|Bot|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45830": {"content": "[unused0]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45840": {"content": "[unused1]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45846": {"content": "[unused2]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45849": {"content": "[unused3]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45861": {"content": "[unused4]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45866": {"content": "[unused5]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45874": {"content": "[unused6]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45883": {"content": "[unused7]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45884": {"content": "[unused8]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45887": {"content": "[unused9]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45892": {"content": "[unused10]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45920": {"content": "[unused11]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45932": {"content": "[unused12]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45938": {"content": "[unused13]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45953": {"content": "[unused14]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45968": {"content": "[unused15]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45974": {"content": "[unused16]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45982": {"content": "[unused17]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45986": {"content": "[unused18]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46005": {"content": "[unused19]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46007": {"content": "[unused20]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46014": {"content": "[unused21]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46017": {"content": "[unused22]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46028": {"content": "[unused23]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46032": {"content": "[unused24]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46081": {"content": "[unused25]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46086": {"content": "[unused26]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46101": {"content": "[unused27]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46183": {"content": "[unused28]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46230": {"content": "[unused29]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46245": {"content": "[unused30]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46257": {"content": "[unused31]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144208": {"content": "[unused32]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144209": {"content": "[unused33]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}}, "auto_map": {"AutoTokenizer": ["tokenization_openpangu.PanguTokenizer", null]}, "bos_token": "<s>", "clean_up_tokenization_spaces": false, "eos_token": "[unused10]", "legacy": true, "model_max_length": 1000000000000000019884624838656, "pad_token": "<unk>", "sp_model_kwargs": {}, "spaces_between_special_tokens": false, "tokenizer_class": "PanguTokenizer", "unk_token": "<unk>", "use_default_system_prompt": false, "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{ '[unused9]系统:' + message['content'] + '[unused10]' }}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}"}
|