Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .DS_Store +0 -0
- .gitattributes +3 -0
- .pytest_cache/.gitignore +2 -0
- .pytest_cache/CACHEDIR.TAG +4 -0
- .pytest_cache/README.md +8 -0
- .pytest_cache/v/cache/nodeids +3 -0
- .ruff_cache/.gitignore +2 -0
- .ruff_cache/0.14.10/10241894308290549172 +0 -0
- .ruff_cache/0.14.10/1073426088278906643 +0 -0
- .ruff_cache/0.14.10/13957033273656742151 +0 -0
- .ruff_cache/0.14.10/1442719585850318975 +0 -0
- .ruff_cache/0.14.10/14754177912317367819 +0 -0
- .ruff_cache/0.14.10/14978186029505022734 +0 -0
- .ruff_cache/0.14.10/15569745458013874055 +0 -0
- .ruff_cache/0.14.10/17608220473508725558 +0 -0
- .ruff_cache/0.14.10/18191902847846296179 +0 -0
- .ruff_cache/0.14.10/2046185769257499142 +0 -0
- .ruff_cache/0.14.10/3165187837348788939 +0 -0
- .ruff_cache/0.14.10/4171122735627067383 +0 -0
- .ruff_cache/0.14.10/8273464926453838394 +0 -0
- .ruff_cache/0.14.10/9088412491868955099 +0 -0
- .ruff_cache/0.14.10/9103521535542433765 +0 -0
- .ruff_cache/0.14.10/9189204400079810969 +0 -0
- .ruff_cache/0.14.10/9226417474992298237 +0 -0
- .ruff_cache/0.14.10/9918913907578606062 +0 -0
- .ruff_cache/CACHEDIR.TAG +1 -0
- ACKNOWLEDGEMENTS +34 -0
- CODE_OF_CONDUCT.md +71 -0
- CONTRIBUTING.md +10 -0
- LICENSE +46 -0
- README.md +407 -0
- docs/Gemfile +23 -0
- docs/_config.yml +61 -0
- docs/getting_started.md +80 -0
- docs/index.md +53 -0
- docs/inference.md +134 -0
- docs/training.md +129 -0
- evaluation/evaluate.py +936 -0
- evaluation/evaluate.py.bak +910 -0
- evaluation/evaluation_data/end_to_end_evaluation/2wiki.zip +3 -0
- evaluation/evaluation_data/end_to_end_evaluation/hotpotqa.zip +3 -0
- evaluation/evaluation_data/end_to_end_evaluation/musique.zip +3 -0
- evaluation/evaluation_data/end_to_end_evaluation/nq.zip +3 -0
- evaluation/evaluation_data/instruction_tuning_evaluation/2wiki.zip +3 -0
- evaluation/evaluation_data/instruction_tuning_evaluation/hotpotqa.zip +3 -0
- evaluation/evaluation_data/instruction_tuning_evaluation/musique.zip +3 -0
- evaluation/evaluation_data/instruction_tuning_evaluation/nq.zip +3 -0
- example/end_to_end_data.jsonl +3 -0
- example/instruction_tuning_data.jsonl +0 -0
- example/pretrain_data.jsonl +0 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
example/end_to_end_data.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
figs/intro.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
figs/sample_main.png filter=lfs diff=lfs merge=lfs -text
|
.pytest_cache/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Created by pytest automatically.
|
| 2 |
+
*
|
.pytest_cache/CACHEDIR.TAG
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
| 2 |
+
# This file is a cache directory tag created by pytest.
|
| 3 |
+
# For information about cache directory tags, see:
|
| 4 |
+
# https://bford.info/cachedir/spec.html
|
.pytest_cache/README.md
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pytest cache directory #
|
| 2 |
+
|
| 3 |
+
This directory contains data from the pytest's cache plugin,
|
| 4 |
+
which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
|
| 5 |
+
|
| 6 |
+
**Do not** commit this to version control.
|
| 7 |
+
|
| 8 |
+
See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
|
.pytest_cache/v/cache/nodeids
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
"tests/test_placeholder.py::test_placeholder"
|
| 3 |
+
]
|
.ruff_cache/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Automatically created by ruff.
|
| 2 |
+
*
|
.ruff_cache/0.14.10/10241894308290549172
ADDED
|
Binary file (136 Bytes). View file
|
|
|
.ruff_cache/0.14.10/1073426088278906643
ADDED
|
Binary file (136 Bytes). View file
|
|
|
.ruff_cache/0.14.10/13957033273656742151
ADDED
|
Binary file (1.05 kB). View file
|
|
|
.ruff_cache/0.14.10/1442719585850318975
ADDED
|
Binary file (171 Bytes). View file
|
|
|
.ruff_cache/0.14.10/14754177912317367819
ADDED
|
Binary file (132 Bytes). View file
|
|
|
.ruff_cache/0.14.10/14978186029505022734
ADDED
|
Binary file (129 Bytes). View file
|
|
|
.ruff_cache/0.14.10/15569745458013874055
ADDED
|
Binary file (136 Bytes). View file
|
|
|
.ruff_cache/0.14.10/17608220473508725558
ADDED
|
Binary file (132 Bytes). View file
|
|
|
.ruff_cache/0.14.10/18191902847846296179
ADDED
|
Binary file (129 Bytes). View file
|
|
|
.ruff_cache/0.14.10/2046185769257499142
ADDED
|
Binary file (129 Bytes). View file
|
|
|
.ruff_cache/0.14.10/3165187837348788939
ADDED
|
Binary file (171 Bytes). View file
|
|
|
.ruff_cache/0.14.10/4171122735627067383
ADDED
|
Binary file (129 Bytes). View file
|
|
|
.ruff_cache/0.14.10/8273464926453838394
ADDED
|
Binary file (1.05 kB). View file
|
|
|
.ruff_cache/0.14.10/9088412491868955099
ADDED
|
Binary file (1.05 kB). View file
|
|
|
.ruff_cache/0.14.10/9103521535542433765
ADDED
|
Binary file (132 Bytes). View file
|
|
|
.ruff_cache/0.14.10/9189204400079810969
ADDED
|
Binary file (136 Bytes). View file
|
|
|
.ruff_cache/0.14.10/9226417474992298237
ADDED
|
Binary file (132 Bytes). View file
|
|
|
.ruff_cache/0.14.10/9918913907578606062
ADDED
|
Binary file (1.05 kB). View file
|
|
|
.ruff_cache/CACHEDIR.TAG
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
ACKNOWLEDGEMENTS
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Acknowledgements
|
| 2 |
+
Portions of this CLaRa Software may utilize the following copyrighted
|
| 3 |
+
material, the use of which is hereby acknowledged.
|
| 4 |
+
|
| 5 |
+
_____________________
|
| 6 |
+
|
| 7 |
+
Naver Labs Europe (PISCO-mistral)
|
| 8 |
+
Licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).
|
| 9 |
+
|
| 10 |
+
Copyright © Naver Labs Europe
|
| 11 |
+
|
| 12 |
+
You are free to:
|
| 13 |
+
- Share — copy and redistribute the material in any medium or format
|
| 14 |
+
- Adapt — remix, transform, and build upon the material
|
| 15 |
+
|
| 16 |
+
Under the following terms:
|
| 17 |
+
- Attribution — You must give appropriate credit, provide a link to the license,
|
| 18 |
+
and indicate if changes were made.
|
| 19 |
+
- NonCommercial — You may not use the material for commercial purposes.
|
| 20 |
+
|
| 21 |
+
Full license text available at: https://creativecommons.org/licenses/by-nc/4.0/
|
| 22 |
+
|
| 23 |
+
OpenRLHF authors
|
| 24 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 25 |
+
you may not use this file except in compliance with the License.
|
| 26 |
+
You may obtain a copy of the License at
|
| 27 |
+
|
| 28 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 29 |
+
|
| 30 |
+
Unless required by applicable law or agreed to in writing, software
|
| 31 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 32 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 33 |
+
See the License for the specific language governing permissions and
|
| 34 |
+
limitations under the License.
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to making participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
## Enforcement
|
| 56 |
+
|
| 57 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 58 |
+
reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
|
| 59 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 60 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 61 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 62 |
+
Further details of specific enforcement policies may be posted separately.
|
| 63 |
+
|
| 64 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 65 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 66 |
+
members of the project's leadership.
|
| 67 |
+
|
| 68 |
+
## Attribution
|
| 69 |
+
|
| 70 |
+
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
|
| 71 |
+
available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contribution Guide
|
| 2 |
+
|
| 3 |
+
Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository.
|
| 4 |
+
|
| 5 |
+
While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
|
| 6 |
+
|
| 7 |
+
## Before you get started
|
| 8 |
+
|
| 9 |
+
By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
|
| 10 |
+
|
LICENSE
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 2 |
+
|
| 3 |
+
IMPORTANT: This Apple software is supplied to you by Apple
|
| 4 |
+
Inc. ("Apple") in consideration of your agreement to the following
|
| 5 |
+
terms, and your use, installation, modification or redistribution of
|
| 6 |
+
this Apple software constitutes acceptance of these terms. If you do
|
| 7 |
+
not agree with these terms, please do not use, install, modify or
|
| 8 |
+
redistribute this Apple software.
|
| 9 |
+
|
| 10 |
+
In consideration of your agreement to abide by the following terms, and
|
| 11 |
+
subject to these terms, Apple grants you a personal, non-exclusive
|
| 12 |
+
license, under Apple's copyrights in this original Apple software (the
|
| 13 |
+
"Apple Software"), to use, reproduce, modify and redistribute the Apple
|
| 14 |
+
Software, with or without modifications, in source and/or binary forms;
|
| 15 |
+
provided that if you redistribute the Apple Software in its entirety and
|
| 16 |
+
without modifications, you must retain this notice and the following
|
| 17 |
+
text and disclaimers in all such redistributions of the Apple Software.
|
| 18 |
+
Neither the name, trademarks, service marks or logos of Apple Inc. may
|
| 19 |
+
be used to endorse or promote products derived from the Apple Software
|
| 20 |
+
without specific prior written permission from Apple. Except as
|
| 21 |
+
expressly stated in this notice, no other rights or licenses, express or
|
| 22 |
+
implied, are granted by Apple herein, including but not limited to any
|
| 23 |
+
patent rights that may be infringed by your derivative works or by other
|
| 24 |
+
works in which the Apple Software may be incorporated.
|
| 25 |
+
|
| 26 |
+
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
|
| 27 |
+
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
|
| 28 |
+
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
|
| 29 |
+
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
|
| 30 |
+
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
|
| 31 |
+
|
| 32 |
+
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
|
| 33 |
+
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
| 34 |
+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
| 35 |
+
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
|
| 36 |
+
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
|
| 37 |
+
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
|
| 38 |
+
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
|
| 39 |
+
POSSIBILITY OF SUCH DAMAGE.
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
This CLaRa Software may utilize third party materials. Please refer to the
|
| 44 |
+
ACKNOWLEDGEMENTS file included with this software for attribution and
|
| 45 |
+
license information related to third party code that may be contained in or
|
| 46 |
+
used with this CLaRa Software.
|
README.md
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning
|
| 2 |
+
|
| 3 |
+
<div align="center">
|
| 4 |
+
<img src="figs/clara_logo.jpg" width="400"/>
|
| 5 |
+
</div>
|
| 6 |
+
|
| 7 |
+
<div align="center">
|
| 8 |
+
<a href="https://arxiv.org/abs/2511.18659"><img src="https://img.shields.io/badge/arXiv-2511.18659-b31b1b.svg" alt="arXiv"></a>
|
| 9 |
+
<a href="LICENSE"><img src="https://img.shields.io/badge/License-Apple-blue" alt="License"></a>
|
| 10 |
+
<a href="https://huggingface.co/apple/CLaRa-7B-Base"><img src="https://img.shields.io/badge/Hugging%20Face-CLaRa_Base-FFEB3B" alt="deploy"></a>
|
| 11 |
+
<a href="https://huggingface.co/apple/CLaRa-7B-Instruct"><img src="https://img.shields.io/badge/Hugging%20Face-CLaRa_Instruct-FFEB3B" alt="deploy"></a>
|
| 12 |
+
<a href="https://huggingface.co/apple/CLaRa-7B-E2E"><img src="https://img.shields.io/badge/Hugging%20Face-CLaRa_End_to_end-FFEB3B" alt="deploy"></a>
|
| 13 |
+
<a href="https://huggingface.co/datasets/apple/CLaRa_multi_stage"><img src="https://img.shields.io/badge/Hugging%20Face-CLaRa_Data-FFEB3B" alt="data"></a>
|
| 14 |
+
</div>
|
| 15 |
+
|
| 16 |
+
This is the official open-source release of CLaRa, a state-of-the-art, end-to-end Retrieval-Augmented Generation model.
|
| 17 |
+
|
| 18 |
+
### Updates
|
| 19 |
+
|
| 20 |
+
- Dec 11, 2025. All used data are available on [Huggingface](https://huggingface.co/datasets/apple/CLaRa_multi_stage).
|
| 21 |
+
- Dec 10, 2025. We are working on an MLX version of the model, to be announced soon.
|
| 22 |
+
- Dec 3, 2025. Evaluation data are available in `./evaluation/evaluation_data`.
|
| 23 |
+
- Nov 25, 2025. Models are available on Huggingface.
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
### Motivation
|
| 27 |
+
|
| 28 |
+
Retrieval-Augmented Generation (RAG) enhances large language models with external knowledge but suffers from **long contexts** and **disjoint retrieval-generation optimization**. Existing soft compression frameworks face two key limitations: (i) reconstruction-based objectives bias compressors toward surface patterns rather than semantic preservation; (ii) retrievers and compressors are trained separately, requiring double encoding despite compressed vectors being inherently retrievable.
|
| 29 |
+
|
| 30 |
+
In this work, we investigate:
|
| 31 |
+
|
| 32 |
+
- **How can we improve semantic preservation in compressed representations through better pretraining objectives?**
|
| 33 |
+
- **How can we unify retrieval and generation optimization to avoid redundant encoding and disjoint objectives?**
|
| 34 |
+
|
| 35 |
+
<div align="center">
|
| 36 |
+
|
| 37 |
+
<img src="figs/intro.png" width="100%"/>
|
| 38 |
+
|
| 39 |
+
</div>
|
| 40 |
+
|
| 41 |
+
We design a Three-stage training approach and introduce document compression techniques to improve RAG efficiency. The key findings are listed below.
|
| 42 |
+
|
| 43 |
+
### Findings
|
| 44 |
+
|
| 45 |
+
- **Efficient Compression**: CLaRa achieves significant compression rates (32x-64x) while preserving essential information for accurate answer generation.
|
| 46 |
+
|
| 47 |
+
- **Three-Stage Training**: A carefully designed Three-stage training approach (compression pretraining + compression instruction tuning + end-to-end fine-tuning) enables effective learning of both retrieval and generation.
|
| 48 |
+
|
| 49 |
+
For more interesting findings, please refer to our original paper!
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
### Three-Stage Training
|
| 54 |
+
|
| 55 |
+
CLaRa uses a carefully designed three-stage training approach:
|
| 56 |
+
|
| 57 |
+
**Stage 1: Compression Pretraining**
|
| 58 |
+
- Train the compressor using SCP framework with QA pairs and paraphrases
|
| 59 |
+
- Retain key semantics through QA-based and paraphrase-guided supervision
|
| 60 |
+
- Support compression rates of 1x-256x
|
| 61 |
+
|
| 62 |
+
**Stage 2: Compression Instruction Tuning**
|
| 63 |
+
- Fine-tune the compressor on instruction-following tasks for downstream QA
|
| 64 |
+
- Use text-based QA output to ensure compressed representations retain sufficient semantics
|
| 65 |
+
|
| 66 |
+
**Stage 3: End-to-End Fine-tuning (CLaRa)**
|
| 67 |
+
- Jointly train reranker and generator via a single language modeling loss
|
| 68 |
+
- Unify retrieval and generation in shared continuous space using differentiable top-k estimator
|
| 69 |
+
|
| 70 |
+
In this repository, we release our implementation of **CLaRa**, built upon [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF).
|
| 71 |
+
|
| 72 |
+
### Getting Started
|
| 73 |
+
|
| 74 |
+
```
|
| 75 |
+
├── scripts/ # Training and evaluation scripts
|
| 76 |
+
│ ├── train_pretraining.sh # Stage 1: Compression pretraining
|
| 77 |
+
│ ├── train_instruction_tuning.sh # Stage 2: Compression instruction tuning
|
| 78 |
+
│ ├── train_stage_end_to_end.sh # Stage 3: End-to-end training
|
| 79 |
+
│ └── evaluation_end_to_end.sh # Evaluation scripts
|
| 80 |
+
├── openrlhf/ # Core training framework
|
| 81 |
+
│ ├── models/ # Model implementations
|
| 82 |
+
│ │ └── modeling_clara.py # CLaRa model definition
|
| 83 |
+
│ ├── datasets/ # Dataset handling
|
| 84 |
+
│ │ └── sft_dataset.py # Training dataset
|
| 85 |
+
│ ├── trainer/ # Training utilities
|
| 86 |
+
│ │ └── sft_trainer.py # SFT trainer
|
| 87 |
+
│ └── cli/ # Command line interface
|
| 88 |
+
│ └── train_sft.py # Main training script
|
| 89 |
+
├── evaluation/ # Evaluation framework
|
| 90 |
+
├── example/ # Example training data
|
| 91 |
+
│ ├── pretrain_data.jsonl
|
| 92 |
+
│ ├── instruction_tuning_data.jsonl
|
| 93 |
+
│ └── end_to_end_data.jsonl
|
| 94 |
+
└── README.md # This file
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
Video instruction for installation (from @Fahd Mirza): https://youtu.be/al2VoAKn8GU?si=Q8bq7QNMaTvcArwa
|
| 98 |
+
Video digest (from @Richard Aragon): https://www.youtube.com/watch?v=yRM92mmKNH4
|
| 99 |
+
|
| 100 |
+
#### 1. Prepare code and environment
|
| 101 |
+
|
| 102 |
+
Clone the repository and set up the environment:
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
# Create conda environment
|
| 106 |
+
env=clara
|
| 107 |
+
conda create -n $env python=3.10 -y
|
| 108 |
+
conda activate $env
|
| 109 |
+
|
| 110 |
+
# Install dependencies
|
| 111 |
+
pip install -r requirements.txt
|
| 112 |
+
|
| 113 |
+
# Set up environment variables
|
| 114 |
+
export PYTHONPATH=/path/to/clara:$PYTHONPATH
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
Key dependencies include:
|
| 118 |
+
- PyTorch >= 2.0
|
| 119 |
+
- Transformers >= 4.20
|
| 120 |
+
- DeepSpeed >= 0.18
|
| 121 |
+
- Flash Attention 2
|
| 122 |
+
- Accelerate
|
| 123 |
+
|
| 124 |
+
#### 2. Data preparation
|
| 125 |
+
|
| 126 |
+
Prepare training data in JSONL format. For pretraining stage:
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
# Example data format for pretraining
|
| 130 |
+
{
|
| 131 |
+
"data_type": "qa",
|
| 132 |
+
"question": ["Question 1",],
|
| 133 |
+
"answers": ["Answer 1"],
|
| 134 |
+
"docs": ["Document 1"]
|
| 135 |
+
}
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
For end-to-end training:
|
| 139 |
+
|
| 140 |
+
```bash
|
| 141 |
+
{
|
| 142 |
+
"question": "Single question text",
|
| 143 |
+
"docs": ["Document 1", "Document 2", ...],
|
| 144 |
+
"gold_answer": "Reference answer"
|
| 145 |
+
}
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
#### 3. Start training
|
| 149 |
+
|
| 150 |
+
**Stage 1: Salient Compressor Pretraining (SCP)**
|
| 151 |
+
|
| 152 |
+
Pre-train the document compressor :
|
| 153 |
+
|
| 154 |
+
```bash
|
| 155 |
+
bash scripts/train_pretraining.sh
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
Key parameters:
|
| 159 |
+
- `--compress_rate`: Compression rate (default: 32)
|
| 160 |
+
- `--doc_max_length`: Maximum document length (default: 256)
|
| 161 |
+
- `--stage stage1`: Training stage
|
| 162 |
+
- `--mse_loss`: Use MSE loss to align compressed and original representations
|
| 163 |
+
- `--qa_loss`: Use QA loss for semantic preservation
|
| 164 |
+
|
| 165 |
+
**Stage 2: Compression Instruction Tuning**
|
| 166 |
+
|
| 167 |
+
Fine-tune the compressor on instruction-following tasks:
|
| 168 |
+
|
| 169 |
+
```bash
|
| 170 |
+
bash scripts/train_instruction_tuning.sh
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
Key parameters:
|
| 174 |
+
- `--pretrain_checkpoint`: Path to stage 1 checkpoint
|
| 175 |
+
- `--stage stage1_2`: Training stage
|
| 176 |
+
- `--generation_top_k`: Top-k sampling for generation (default: 5)
|
| 177 |
+
- `--mse_loss`: Use MSE loss for compression training
|
| 178 |
+
- `--do_eval_gen`: Enable generation evaluation
|
| 179 |
+
|
| 180 |
+
**Stage 3: End-to-End Training**
|
| 181 |
+
|
| 182 |
+
Fine-tune the model end-to-end with retrieval:
|
| 183 |
+
|
| 184 |
+
```bash
|
| 185 |
+
bash scripts/train_stage_end_to_end.sh
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
Key parameters:
|
| 189 |
+
- `--pretrain_checkpoint`: Path to stage 2 checkpoint
|
| 190 |
+
- `--stage stage2`: Training stage
|
| 191 |
+
- `--generation_top_k`: Top-k sampling for generation
|
| 192 |
+
- `--do_eval_gen`: Enable generation evaluation
|
| 193 |
+
|
| 194 |
+
#### 4. Distributed Training
|
| 195 |
+
|
| 196 |
+
The training scripts support distributed training across multiple nodes and GPUs:
|
| 197 |
+
|
| 198 |
+
- `--max_len`: Maximum sequence length (default: 2048 for stage1/stage2, 1024 for stage3)
|
| 199 |
+
- `--train_batch_size`: Training batch size
|
| 200 |
+
- `--micro_train_batch_size`: Micro batch size for gradient accumulation
|
| 201 |
+
- `--learning_rate`: Learning rate (default: 1e-4 for stage1/stage2, 5e-6 for stage3)
|
| 202 |
+
- `--max_epochs`: Maximum training epochs
|
| 203 |
+
- `--zero_stage`: ZeRO optimization stage (default: 2)
|
| 204 |
+
- `--bf16`: Use bfloat16 precision
|
| 205 |
+
- `--flash_attn`: Use Flash Attention 2
|
| 206 |
+
|
| 207 |
+
### Inference
|
| 208 |
+
|
| 209 |
+
The CLaRa models can be loaded and used for inference. We provide three models corresponding to different training stages:
|
| 210 |
+
|
| 211 |
+
<details>
|
| 212 |
+
<summary>Stage 1: Compression Pretraining model (click to expand)</summary>
|
| 213 |
+
|
| 214 |
+
```python
|
| 215 |
+
from transformers import AutoModel
|
| 216 |
+
|
| 217 |
+
model_path = "path/to/stage1/model"
|
| 218 |
+
model = AutoModel.from_pretrained(
|
| 219 |
+
model_path,
|
| 220 |
+
trust_remote_code=True
|
| 221 |
+
).to('cuda')
|
| 222 |
+
|
| 223 |
+
# Example documents
|
| 224 |
+
documents = [
|
| 225 |
+
[
|
| 226 |
+
"Document 1 content...",
|
| 227 |
+
"Document 2 content...",
|
| 228 |
+
"Document 3 content..."
|
| 229 |
+
]
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
questions = ["" for _ in range(len(documents))]
|
| 233 |
+
|
| 234 |
+
# Generate paraphrase from compressed representations
|
| 235 |
+
output = model.generate_from_paraphrase(
|
| 236 |
+
questions=questions,
|
| 237 |
+
documents=documents,
|
| 238 |
+
max_new_tokens=64
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
print('Generated paraphrase:', output[0])
|
| 242 |
+
```
|
| 243 |
+
|
| 244 |
+
</details>
|
| 245 |
+
|
| 246 |
+
<details>
|
| 247 |
+
<summary>Stage 2: Compression Instruction Tuning model (click to expand)</summary>
|
| 248 |
+
|
| 249 |
+
```python
|
| 250 |
+
from transformers import AutoModel
|
| 251 |
+
|
| 252 |
+
model_path = "path/to/stage2/model"
|
| 253 |
+
model = AutoModel.from_pretrained(
|
| 254 |
+
model_path,
|
| 255 |
+
trust_remote_code=True
|
| 256 |
+
).to('cuda')
|
| 257 |
+
|
| 258 |
+
# Example documents and question
|
| 259 |
+
documents = [
|
| 260 |
+
[
|
| 261 |
+
"Document 1 content...",
|
| 262 |
+
"Document 2 content...",
|
| 263 |
+
"Document 3 content..."
|
| 264 |
+
]
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
questions = ["Your question here"]
|
| 268 |
+
|
| 269 |
+
# Generate answer from compressed representations
|
| 270 |
+
output = model.generate_from_text(
|
| 271 |
+
questions=questions,
|
| 272 |
+
documents=documents,
|
| 273 |
+
max_new_tokens=64
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
print('Generated answer:', output[0])
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
</details>
|
| 280 |
+
|
| 281 |
+
<details>
|
| 282 |
+
<summary>Stage 3: End-to-End (CLaRa) model (click to expand)</summary>
|
| 283 |
+
|
| 284 |
+
```python
|
| 285 |
+
from transformers import AutoModel
|
| 286 |
+
|
| 287 |
+
model_path = "path/to/stage3/model"
|
| 288 |
+
model = AutoModel.from_pretrained(
|
| 289 |
+
model_path,
|
| 290 |
+
trust_remote_code=True
|
| 291 |
+
).to('cuda')
|
| 292 |
+
|
| 293 |
+
# Example documents and question
|
| 294 |
+
# Note: Stage 3 supports retrieval with multiple candidate documents
|
| 295 |
+
documents = [
|
| 296 |
+
["Document 1 content..." for _ in range(20)] # 20 candidate documents
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
questions = ["Your question here"]
|
| 300 |
+
|
| 301 |
+
# Generate answer with retrieval and reranking
|
| 302 |
+
# The top-k is decided by generation_top_k in config.json
|
| 303 |
+
output, topk_indices = model.generate_from_questions(
|
| 304 |
+
questions=questions,
|
| 305 |
+
documents=documents,
|
| 306 |
+
max_new_tokens=64
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
print('Generated answer:', output[0])
|
| 310 |
+
print('Top-k selected document indices:', topk_indices)
|
| 311 |
+
```
|
| 312 |
+
|
| 313 |
+
</details>
|
| 314 |
+
|
| 315 |
+
### Evaluation
|
| 316 |
+
|
| 317 |
+
The evaluation framework is based on standard RAG benchmarks. Run evaluation:
|
| 318 |
+
|
| 319 |
+
**End-to-end evaluation:**
|
| 320 |
+
```bash
|
| 321 |
+
bash scripts/evaluation_end_to_end.sh
|
| 322 |
+
```
|
| 323 |
+
|
| 324 |
+
**Instruction tuning evaluation:**
|
| 325 |
+
```bash
|
| 326 |
+
bash scripts/evaluation_instruction_tuning.sh
|
| 327 |
+
```
|
| 328 |
+
|
| 329 |
+
Supported datasets:
|
| 330 |
+
- **HotpotQA**: Multi-hop question answering
|
| 331 |
+
- **MuSiQue**: Multi-hop question answering with diverse reasoning
|
| 332 |
+
- **2WikiMultiHopQA**: Multi-hop question answering over Wikipedia
|
| 333 |
+
- **Natural Questions**: Open-domain question answering
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
### Results
|
| 338 |
+
|
| 339 |
+
#### Compression Performance
|
| 340 |
+
|
| 341 |
+
We evaluate our document compressor on four QA datasets (NQ, HotpotQA, MuSiQue, 2WikiMultiHopQA) under two settings: **Normal** (retrieving top-5 documents) and **Oracle** (gold document included). CLaRa consistently outperforms all baselines across different compression ratios.
|
| 342 |
+
|
| 343 |
+
<div align="center">
|
| 344 |
+
|
| 345 |
+
**Main Results (Mistral-7B, Normal Setting)**
|
| 346 |
+
|
| 347 |
+
| Model | CR | NQ | HotpotQA | MuSiQue | 2Wiki | Avg |
|
| 348 |
+
|:---|:---:|:---:|:---:|:---:|:---:|:---:|
|
| 349 |
+
| AutoCompressor | - | 17.24 | 14.61 | 3.81 | 19.89 | 13.89 |
|
| 350 |
+
| XRAG | 128 | 32.35 | 25.16 | 3.64 | 28.79 | 22.48 |
|
| 351 |
+
| COCOM | 16 | 24.12 | 21.48 | 3.52 | 24.48 | 18.40 |
|
| 352 |
+
| PCC | 16 | 31.38 | 22.29 | 3.43 | 19.47 | 19.14 |
|
| 353 |
+
| LLMLingua-2 | 4 | 47.53 | 37.05 | 9.02 | 44.35 | 34.49 |
|
| 354 |
+
| PISCO | 16 | 54.39 | 41.94 | 10.09 | 44.88 | 37.83 |
|
| 355 |
+
| Mistral-7B w/ retrieval | - | 54.58 | 42.94 | 8.94 | 44.24 | 37.67 |
|
| 356 |
+
| **CLaRa (CR=4)** | **4** | **57.05** | **45.09** | **10.34** | **46.94** | **39.86** |
|
| 357 |
+
| **CLaRa (CR=16)** | **16** | **55.56** | **43.72** | **10.55** | **46.00** | **38.96** |
|
| 358 |
+
| **CLaRa (CR=32)** | **32** | **54.64** | **43.52** | **10.55** | **46.58** | **38.82** |
|
| 359 |
+
|
| 360 |
+
**Oracle Setting Results (Mistral-7B)**
|
| 361 |
+
|
| 362 |
+
| Model | CR | NQ | HotpotQA | MuSiQue | 2Wiki | Avg |
|
| 363 |
+
|:---|:---:|:---:|:---:|:---:|:---:|:---:|
|
| 364 |
+
| PISCO | 16 | 73.44 | 66.53 | 33.80 | 60.45 | 58.55 |
|
| 365 |
+
| Mistral-7B w/ retrieval | - | 71.64 | 70.77 | 45.72 | 68.83 | 64.24 |
|
| 366 |
+
| **CLaRa (CR=4)** | **4** | **76.50** | **73.81** | **46.26** | **70.48** | **66.76** |
|
| 367 |
+
| **CLaRa (CR=16)** | **16** | **75.48** | **70.79** | **43.15** | **66.16** | **63.90** |
|
| 368 |
+
| **CLaRa (CR=32)** | **32** | **73.77** | **69.51** | **38.31** | **64.54** | **61.53** |
|
| 369 |
+
|
| 370 |
+
</div>
|
| 371 |
+
|
| 372 |
+
**Key Findings:**
|
| 373 |
+
- ✅ CLaRa outperforms PISCO by **+1.13%** (Normal) and **+5.35%** (Oracle) on average
|
| 374 |
+
- ✅ CLaRa outperforms LLMLingua-2 by **+5.37%** (Normal) on average
|
| 375 |
+
- ✅ CLaRa matches/exceeds text-based baseline with **+2.36%** average gain on Mistral-7B
|
| 376 |
+
|
| 377 |
+
#### Retrieval Performance
|
| 378 |
+
|
| 379 |
+
<div align="center">
|
| 380 |
+
|
| 381 |
+
<img src="figs/main_recall.png" width="80%"/>
|
| 382 |
+
|
| 383 |
+
</div>
|
| 384 |
+
|
| 385 |
+
For detailed experimental results and analysis, please refer to our paper.
|
| 386 |
+
|
| 387 |
+
## Acknowledgments
|
| 388 |
+
|
| 389 |
+
We sincerely appreciate the following works for CLaRa:
|
| 390 |
+
|
| 391 |
+
- Our implementation is built upon the [OpenRLHF framework](https://github.com/OpenRLHF/OpenRLHF).
|
| 392 |
+
|
| 393 |
+
- Inspired by [PISCO-mistral](https://huggingface.co/naver/pisco-mistral) for document compression techniques
|
| 394 |
+
|
| 395 |
+
## Citation
|
| 396 |
+
|
| 397 |
+
```bibtex
|
| 398 |
+
@misc{he2025clarabridgingretrievalgeneration,
|
| 399 |
+
title={CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning},
|
| 400 |
+
author={Jie He and Richard He Bai and Sinead Williamson and Jeff Z. Pan and Navdeep Jaitly and Yizhe Zhang},
|
| 401 |
+
year={2025},
|
| 402 |
+
eprint={2511.18659},
|
| 403 |
+
archivePrefix={arXiv},
|
| 404 |
+
primaryClass={cs.CL},
|
| 405 |
+
url={https://arxiv.org/abs/2511.18659},
|
| 406 |
+
}
|
| 407 |
+
```
|
docs/Gemfile
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
source "https://rubygems.org"
|
| 2 |
+
|
| 3 |
+
gem "jekyll", "~> 4.3"
|
| 4 |
+
gem "jekyll-feed", "~> 0.12"
|
| 5 |
+
gem "jekyll-seo-tag", "~> 2.8"
|
| 6 |
+
gem "jekyll-sitemap", "~> 1.4"
|
| 7 |
+
gem "jekyll-theme-cayman", "~> 0.2"
|
| 8 |
+
|
| 9 |
+
# Windows and JRuby do not include zoneinfo files, so bundle the tzinfo-data gem
|
| 10 |
+
# and associated library.
|
| 11 |
+
platforms :mingw, :x64_mingw, :mswin, :jruby do
|
| 12 |
+
gem "tzinfo", ">= 1", "< 3"
|
| 13 |
+
gem "tzinfo-data"
|
| 14 |
+
end
|
| 15 |
+
|
| 16 |
+
# Performance-booster for watching directories on Windows
|
| 17 |
+
gem "wdm", "~> 0.1.1", :platforms => [:mingw, :x64_mingw, :mswin]
|
| 18 |
+
|
| 19 |
+
# Lock `http_parser.rb` gem to `v0.6.x` on JRuby builds since newer versions of the gem
|
| 20 |
+
# do not have a Java counterpart.
|
| 21 |
+
gem "http_parser.rb", "~> 0.6.0", :platforms => [:jruby]
|
| 22 |
+
|
| 23 |
+
|
docs/_config.yml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Jekyll configuration for CLaRa Documentation
|
| 2 |
+
|
| 3 |
+
# Site settings
|
| 4 |
+
title: CLaRa Documentation
|
| 5 |
+
description: Complete documentation for CLaRa - Unified Retrieval-Augmented Generation with Compression
|
| 6 |
+
baseurl: "/CLaRa" # the subpath of your site for GitHub Pages
|
| 7 |
+
url: "https://aiml-oss.github.io" # the base hostname & protocol for GitHub Pages
|
| 8 |
+
|
| 9 |
+
# Theme
|
| 10 |
+
theme: jekyll-theme-cayman
|
| 11 |
+
# Alternative themes:
|
| 12 |
+
# theme: jekyll-theme-minimal
|
| 13 |
+
# theme: jekyll-theme-slate
|
| 14 |
+
# theme: jekyll-theme-architect
|
| 15 |
+
|
| 16 |
+
# GitHub repository info
|
| 17 |
+
repository: probe2/CLaRa
|
| 18 |
+
|
| 19 |
+
# Build settings
|
| 20 |
+
markdown: kramdown
|
| 21 |
+
kramdown:
|
| 22 |
+
input: GFM
|
| 23 |
+
syntax_highlighter: rouge
|
| 24 |
+
|
| 25 |
+
# Plugins
|
| 26 |
+
plugins:
|
| 27 |
+
- jekyll-feed
|
| 28 |
+
- jekyll-seo-tag
|
| 29 |
+
- jekyll-sitemap
|
| 30 |
+
|
| 31 |
+
# Navigation
|
| 32 |
+
header_pages:
|
| 33 |
+
- index.md
|
| 34 |
+
- getting_started.md
|
| 35 |
+
- training.md
|
| 36 |
+
- inference.md
|
| 37 |
+
|
| 38 |
+
# Exclude from processing
|
| 39 |
+
exclude:
|
| 40 |
+
- Gemfile
|
| 41 |
+
- Gemfile.lock
|
| 42 |
+
- node_modules
|
| 43 |
+
- vendor/bundle/
|
| 44 |
+
- vendor/cache/
|
| 45 |
+
- vendor/gems/
|
| 46 |
+
- vendor/ruby/
|
| 47 |
+
- "*.sh"
|
| 48 |
+
- "*.log"
|
| 49 |
+
|
| 50 |
+
# Collections and navigation
|
| 51 |
+
# collections:
|
| 52 |
+
# docs:
|
| 53 |
+
# output: true
|
| 54 |
+
# permalink: /:collection/:path/
|
| 55 |
+
|
| 56 |
+
# Defaults
|
| 57 |
+
defaults:
|
| 58 |
+
- scope:
|
| 59 |
+
path: ""
|
| 60 |
+
values:
|
| 61 |
+
layout: "default"
|
docs/getting_started.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
layout: default
|
| 3 |
+
title: Getting Started
|
| 4 |
+
permalink: /getting_started/
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
# Getting Started with CLaRa
|
| 8 |
+
|
| 9 |
+
This guide will help you get started with CLaRa, from installation to running your first training.
|
| 10 |
+
|
| 11 |
+
## Installation
|
| 12 |
+
|
| 13 |
+
### Prerequisites
|
| 14 |
+
|
| 15 |
+
- Python 3.10+
|
| 16 |
+
- CUDA-compatible GPU (recommended)
|
| 17 |
+
- PyTorch 2.0+
|
| 18 |
+
- CUDA 11.8 or 12.x
|
| 19 |
+
|
| 20 |
+
### Step 1: Create Conda Environment
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
env=clara
|
| 24 |
+
conda create -n $env python=3.10 -y
|
| 25 |
+
conda activate $env
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
### Step 2: Install Dependencies
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
pip install -r requirements.txt
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
Key dependencies include:
|
| 35 |
+
- `torch>=2.0`
|
| 36 |
+
- `transformers>=4.20`
|
| 37 |
+
- `deepspeed>=0.18`
|
| 38 |
+
- `flash-attn>=2.8.0`
|
| 39 |
+
- `accelerate>=1.10.1`
|
| 40 |
+
- `peft>=0.17.1`
|
| 41 |
+
|
| 42 |
+
### Step 3: Set Environment Variables
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
export PYTHONPATH=/path/to/clara:$PYTHONPATH
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## Quick Start
|
| 49 |
+
|
| 50 |
+
### 1. Prepare Your Data
|
| 51 |
+
|
| 52 |
+
CLaRa uses JSONL format for training data. See the [Training Guide](./training.md) for data format details.
|
| 53 |
+
|
| 54 |
+
### 2. Train Stage 1: Compression Pretraining
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
bash scripts/train_pretraining.sh
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
### 3. Train Stage 2: Instruction Tuning
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
bash scripts/train_instruction_tuning.sh
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### 4. Train Stage 3: End-to-End Training
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
bash scripts/train_stage_end_to_end.sh
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### 5. Run Inference
|
| 73 |
+
|
| 74 |
+
See the [Inference Guide](./inference.md) for examples of using all three model stages.
|
| 75 |
+
|
| 76 |
+
## Next Steps
|
| 77 |
+
|
| 78 |
+
- [Training Guide](./training.md) - Detailed training instructions and data formats
|
| 79 |
+
- [Inference Guide](./inference.md) - Inference examples for all model stages
|
| 80 |
+
|
docs/index.md
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
layout: default
|
| 3 |
+
title: CLaRa Documentation
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# CLaRa Documentation
|
| 7 |
+
|
| 8 |
+
Welcome to the CLaRa documentation! This site provides comprehensive guides and references for using CLaRa.
|
| 9 |
+
|
| 10 |
+
## What is CLaRa?
|
| 11 |
+
|
| 12 |
+
**CLaRa** (Continuous Latent Reasoning) is a unified framework for retrieval-augmented generation that performs embedding-based compression and joint optimization in a shared continuous space.
|
| 13 |
+
|
| 14 |
+
[](https://arxiv.org/abs/XXXX.XXXXX) [](../LICENSE) [](https://huggingface.co/your-org/clara-base) [](https://huggingface.co/your-org/clara-instruct) [](https://huggingface.co/your-org/clara-e)
|
| 15 |
+
|
| 16 |
+
## Documentation
|
| 17 |
+
|
| 18 |
+
- **[Getting Started](./getting_started.md)** - Installation and quick start guide
|
| 19 |
+
- **[Training Guide](./training.md)** - Detailed instructions for all three training stages including data formats
|
| 20 |
+
- **[Inference Guide](./inference.md)** - How to use CLaRa models for inference
|
| 21 |
+
|
| 22 |
+
## Quick Links
|
| 23 |
+
|
| 24 |
+
- **GitHub Repository**: [github.com/apple/ml-CLaRa](https://github.com/apple/ml-CLaRa)
|
| 25 |
+
- **Main README**: [../README.md](../README.md)
|
| 26 |
+
- **Model Checkpoints**: [Hugging Face](https://huggingface.co/your-org/clara-base) (Coming Soon)
|
| 27 |
+
|
| 28 |
+
## Overview
|
| 29 |
+
|
| 30 |
+
CLaRa uses a three-stage training approach:
|
| 31 |
+
|
| 32 |
+
1. **Stage 1: Compression Pretraining** - Learn effective document compression
|
| 33 |
+
2. **Stage 2: Compression Instruction Tuning** - Adapt for downstream QA tasks
|
| 34 |
+
3. **Stage 3: End-to-End Fine-tuning (CLaRa)** - Joint retrieval and generation optimization
|
| 35 |
+
|
| 36 |
+
For more details, see the [Training Guide](./training.md).
|
| 37 |
+
|
| 38 |
+
## Citation
|
| 39 |
+
|
| 40 |
+
If you use CLaRa in your research, please cite:
|
| 41 |
+
|
| 42 |
+
```bibtex
|
| 43 |
+
@article{clara2024,
|
| 44 |
+
title={CLaRa: Unified Retrieval-Augmented Generation with Compression},
|
| 45 |
+
author={[Authors]},
|
| 46 |
+
journal={[Journal]},
|
| 47 |
+
year={2024},
|
| 48 |
+
eprint={XXXX.XXXXX},
|
| 49 |
+
archivePrefix={arXiv},
|
| 50 |
+
primaryClass={cs.CL},
|
| 51 |
+
url={https://arxiv.org/abs/XXXX.XXXXX}
|
| 52 |
+
}
|
| 53 |
+
```
|
docs/inference.md
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
layout: default
|
| 3 |
+
title: Inference Guide
|
| 4 |
+
permalink: /inference/
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
# Inference Guide
|
| 8 |
+
|
| 9 |
+
This guide shows how to use CLaRa models for inference at different stages.
|
| 10 |
+
|
| 11 |
+
## Loading Models
|
| 12 |
+
|
| 13 |
+
CLaRa models can be loaded using the standard `AutoModel` interface:
|
| 14 |
+
|
| 15 |
+
```python
|
| 16 |
+
from transformers import AutoModel
|
| 17 |
+
|
| 18 |
+
model = AutoModel.from_pretrained(
|
| 19 |
+
"path/to/model",
|
| 20 |
+
trust_remote_code=True
|
| 21 |
+
).to('cuda')
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
## Stage 1: Compression Pretraining Model
|
| 25 |
+
|
| 26 |
+
Generate paraphrases from compressed document representations.
|
| 27 |
+
|
| 28 |
+
```python
|
| 29 |
+
from transformers import AutoModel
|
| 30 |
+
|
| 31 |
+
model = AutoModel.from_pretrained(
|
| 32 |
+
"path/to/stage1/model",
|
| 33 |
+
trust_remote_code=True
|
| 34 |
+
).to('cuda')
|
| 35 |
+
|
| 36 |
+
# Example documents
|
| 37 |
+
documents = [
|
| 38 |
+
[
|
| 39 |
+
"Document 1 content...",
|
| 40 |
+
"Document 2 content...",
|
| 41 |
+
"Document 3 content..."
|
| 42 |
+
]
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
questions = ["" for _ in range(len(documents))]
|
| 46 |
+
|
| 47 |
+
# Generate paraphrase from compressed representations
|
| 48 |
+
output = model.generate_from_paraphrase(
|
| 49 |
+
questions=questions,
|
| 50 |
+
documents=documents,
|
| 51 |
+
max_new_tokens=64
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
print('Generated paraphrase:', output[0])
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Stage 2: Compression Instruction Tuning Model
|
| 58 |
+
|
| 59 |
+
Generate answers from compressed representations for QA tasks.
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
from transformers import AutoModel
|
| 63 |
+
|
| 64 |
+
model = AutoModel.from_pretrained(
|
| 65 |
+
"path/to/stage2/model",
|
| 66 |
+
trust_remote_code=True
|
| 67 |
+
).to('cuda')
|
| 68 |
+
|
| 69 |
+
# Example documents and question
|
| 70 |
+
documents = [
|
| 71 |
+
[
|
| 72 |
+
"Document 1 content...",
|
| 73 |
+
"Document 2 content...",
|
| 74 |
+
"Document 3 content..."
|
| 75 |
+
]
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
questions = ["Your question here"]
|
| 79 |
+
|
| 80 |
+
# Generate answer from compressed representations
|
| 81 |
+
output = model.generate_from_text(
|
| 82 |
+
questions=questions,
|
| 83 |
+
documents=documents,
|
| 84 |
+
max_new_tokens=64
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
print('Generated answer:', output[0])
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## Stage 3: End-to-End (CLaRa) Model
|
| 91 |
+
|
| 92 |
+
Generate answers with retrieval and reranking using joint optimization.
|
| 93 |
+
|
| 94 |
+
```python
|
| 95 |
+
from transformers import AutoModel
|
| 96 |
+
|
| 97 |
+
model = AutoModel.from_pretrained(
|
| 98 |
+
"path/to/stage3/model",
|
| 99 |
+
trust_remote_code=True
|
| 100 |
+
).to('cuda')
|
| 101 |
+
|
| 102 |
+
# Example documents and question
|
| 103 |
+
# Note: Stage 3 supports retrieval with multiple candidate documents
|
| 104 |
+
documents = [
|
| 105 |
+
["Document 1 content..." for _ in range(20)] # 20 candidate documents
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
questions = ["Your question here"]
|
| 109 |
+
|
| 110 |
+
# Generate answer with retrieval and reranking
|
| 111 |
+
# The top-k is decided by generation_top_k in config.json
|
| 112 |
+
output, topk_indices = model.generate_from_questions(
|
| 113 |
+
questions=questions,
|
| 114 |
+
documents=documents,
|
| 115 |
+
max_new_tokens=64
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
print('Generated answer:', output[0])
|
| 119 |
+
print('Top-k selected document indices:', topk_indices)
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Key Parameters
|
| 123 |
+
|
| 124 |
+
- `max_new_tokens`: Maximum number of tokens to generate (default: 128)
|
| 125 |
+
- `generation_top_k`: Number of top documents to select (configured in model config)
|
| 126 |
+
|
| 127 |
+
## Model Methods
|
| 128 |
+
|
| 129 |
+
- `generate_from_paraphrase()` - Stage 1: Generate paraphrases
|
| 130 |
+
- `generate_from_text()` - Stage 2: Generate answers from compressed docs
|
| 131 |
+
- `generate_from_questions()` - Stage 3: Generate with retrieval and reranking
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
docs/training.md
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
layout: default
|
| 3 |
+
title: Training Guide
|
| 4 |
+
permalink: /training/
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
# Training Guide
|
| 8 |
+
|
| 9 |
+
This guide covers the three-stage training process in CLaRa.
|
| 10 |
+
|
| 11 |
+
## Overview
|
| 12 |
+
|
| 13 |
+
CLaRa uses a three-stage training approach:
|
| 14 |
+
|
| 15 |
+
1. **Stage 1**: Compression Pretraining
|
| 16 |
+
2. **Stage 2**: Compression Instruction Tuning
|
| 17 |
+
3. **Stage 3**: End-to-End Fine-tuning (CLaRa)
|
| 18 |
+
|
| 19 |
+
## Stage 1: Compression Pretraining
|
| 20 |
+
|
| 21 |
+
Train the compressor to learn effective document compression.
|
| 22 |
+
|
| 23 |
+
### Key Parameters
|
| 24 |
+
|
| 25 |
+
- `--stage stage1`: Training stage identifier
|
| 26 |
+
- `--compress_rate`: Compression rate (default: 32)
|
| 27 |
+
- `--doc_max_length`: Maximum document length (default: 256)
|
| 28 |
+
- `--mse_loss`: Use MSE loss for compression alignment
|
| 29 |
+
- `--qa_loss`: Use QA loss for semantic preservation
|
| 30 |
+
|
| 31 |
+
### Example Command
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
bash scripts/train_pretraining.sh
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### Data Format
|
| 38 |
+
|
| 39 |
+
**Stage 1 Pretraining Data:**
|
| 40 |
+
```json
|
| 41 |
+
{
|
| 42 |
+
"data_type": "qa",
|
| 43 |
+
"question": ["Question 1", "Question 2", ...],
|
| 44 |
+
"answers": ["Answer 1", "Answer 2", ...],
|
| 45 |
+
"docs": ["Document 1", "Document 2", ...]
|
| 46 |
+
}
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Stage 2: Compression Instruction Tuning
|
| 50 |
+
|
| 51 |
+
Fine-tune the compressor on instruction-following tasks.
|
| 52 |
+
|
| 53 |
+
### Key Parameters
|
| 54 |
+
|
| 55 |
+
- `--stage stage1_2`: Training stage identifier
|
| 56 |
+
- `--pretrain_checkpoint`: Path to Stage 1 checkpoint
|
| 57 |
+
- `--generation_top_k`: Top-k sampling (default: 5)
|
| 58 |
+
- `--mse_loss`: Continue using MSE loss
|
| 59 |
+
- `--do_eval_gen`: Enable generation evaluation
|
| 60 |
+
|
| 61 |
+
### Example Command
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
bash scripts/train_instruction_tuning.sh
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### Data Format
|
| 68 |
+
|
| 69 |
+
**Stage 2 Instruction Tuning Data:**
|
| 70 |
+
```json
|
| 71 |
+
{
|
| 72 |
+
"question": "Single question text",
|
| 73 |
+
"docs": ["Document 1", "Document 2", ...],
|
| 74 |
+
"gold_answer": "Reference answer",
|
| 75 |
+
"answer": "Generated answer"
|
| 76 |
+
}
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## Stage 3: End-to-End Training
|
| 80 |
+
|
| 81 |
+
Jointly train reranker and generator with retrieval.
|
| 82 |
+
|
| 83 |
+
### Key Parameters
|
| 84 |
+
|
| 85 |
+
- `--stage stage2`: Training stage identifier
|
| 86 |
+
- `--pretrain_checkpoint`: Path to Stage 2 checkpoint
|
| 87 |
+
- `--generation_top_k`: Top-k sampling for generation
|
| 88 |
+
- `--do_eval_gen`: Enable generation evaluation
|
| 89 |
+
|
| 90 |
+
### Example Command
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
bash scripts/train_stage_end_to_end.sh
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
### Data Format
|
| 97 |
+
|
| 98 |
+
**Stage 3 End-to-End Data:**
|
| 99 |
+
```json
|
| 100 |
+
{
|
| 101 |
+
"question": "Single question text",
|
| 102 |
+
"docs": ["Document 1", "Document 2", ...],
|
| 103 |
+
"gold_answer": "Reference answer"
|
| 104 |
+
}
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
## Distributed Training
|
| 108 |
+
|
| 109 |
+
All training stages support distributed training across multiple nodes and GPUs.
|
| 110 |
+
|
| 111 |
+
### Key Parameters
|
| 112 |
+
|
| 113 |
+
- `--max_len`: Maximum sequence length (2048 for stage1/stage2, 1024 for stage3)
|
| 114 |
+
- `--train_batch_size`: Training batch size
|
| 115 |
+
- `--micro_train_batch_size`: Micro batch size for gradient accumulation
|
| 116 |
+
- `--learning_rate`: Learning rate (1e-4 for stage1/stage2, 5e-6 for stage3)
|
| 117 |
+
- `--max_epochs`: Maximum training epochs
|
| 118 |
+
- `--zero_stage`: ZeRO optimization stage (default: 2)
|
| 119 |
+
- `--bf16`: Use bfloat16 precision
|
| 120 |
+
- `--flash_attn`: Use Flash Attention 2
|
| 121 |
+
|
| 122 |
+
## Monitoring Training
|
| 123 |
+
|
| 124 |
+
Training progress is logged via:
|
| 125 |
+
- Console output
|
| 126 |
+
- Wandb (if configured)
|
| 127 |
+
- Checkpoint files
|
| 128 |
+
|
| 129 |
+
Checkpoints are saved at the path specified by `--save_path`.
|
evaluation/evaluate.py
ADDED
|
@@ -0,0 +1,936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# For licensing see accompanying LICENSE file.
|
| 3 |
+
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import argparse
|
| 9 |
+
import gc
|
| 10 |
+
from datetime import timedelta
|
| 11 |
+
from collections import defaultdict, Counter
|
| 12 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import numpy as np
|
| 16 |
+
from accelerate import Accelerator, InitProcessGroupKwargs
|
| 17 |
+
from transformers import AutoModel
|
| 18 |
+
from datasets import load_dataset
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
from sklearn.manifold import TSNE
|
| 22 |
+
from sklearn.decomposition import PCA
|
| 23 |
+
try:
|
| 24 |
+
import spacy
|
| 25 |
+
SPACY_AVAILABLE = True
|
| 26 |
+
except Exception as e:
|
| 27 |
+
SPACY_AVAILABLE = False
|
| 28 |
+
print(f"Warning: spacy not available ({e}). Entity extraction will be disabled.")
|
| 29 |
+
try:
|
| 30 |
+
import evaluate as eval_lib
|
| 31 |
+
EVAL_LIB_AVAILABLE = True
|
| 32 |
+
except Exception as e:
|
| 33 |
+
EVAL_LIB_AVAILABLE = False
|
| 34 |
+
eval_lib = None
|
| 35 |
+
print(f"Warning: evaluate library not available ({e}). BERTScore and ROUGE metrics will be disabled.")
|
| 36 |
+
import re
|
| 37 |
+
import string
|
| 38 |
+
|
| 39 |
+
from openrlhf.models.modeling_clara import CLaRa
|
| 40 |
+
|
| 41 |
+
# Environment setup
|
| 42 |
+
os.environ["NCCL_TIMEOUT"] = "5400"
|
| 43 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 44 |
+
|
| 45 |
+
# Global constants
|
| 46 |
+
TARGET_ENTITY_CATEGORIES = {"PERSON", "GPE", "DATE", "CARDINAL", "ORG"}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class EvaluationMetrics:
|
| 50 |
+
"""Handles all evaluation metrics and scoring functions."""
|
| 51 |
+
|
| 52 |
+
def __init__(self):
|
| 53 |
+
if EVAL_LIB_AVAILABLE:
|
| 54 |
+
self.bertscore = eval_lib.load("bertscore")
|
| 55 |
+
self.rouge = eval_lib.load("rouge")
|
| 56 |
+
else:
|
| 57 |
+
self.bertscore = None
|
| 58 |
+
self.rouge = None
|
| 59 |
+
if SPACY_AVAILABLE:
|
| 60 |
+
self.nlp = spacy.load("en_core_web_sm")
|
| 61 |
+
else:
|
| 62 |
+
self.nlp = None
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def normalize_answer(text: str) -> str:
|
| 66 |
+
"""Normalize text for comparison."""
|
| 67 |
+
def remove_articles(text):
|
| 68 |
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
| 69 |
+
|
| 70 |
+
def white_space_fix(text):
|
| 71 |
+
return " ".join(text.split())
|
| 72 |
+
|
| 73 |
+
def remove_punc(text):
|
| 74 |
+
exclude = set(string.punctuation)
|
| 75 |
+
return "".join(ch for ch in text if ch not in exclude)
|
| 76 |
+
|
| 77 |
+
return white_space_fix(remove_articles(remove_punc(text.lower())))
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def bool_mapping(text: str) -> str:
|
| 81 |
+
"""Map boolean values to yes/no."""
|
| 82 |
+
mapping = {"True": "yes", "False": "no"}
|
| 83 |
+
return mapping.get(text, text)
|
| 84 |
+
|
| 85 |
+
def exact_match_score(self, prediction: str, ground_truth: str) -> bool:
|
| 86 |
+
"""Calculate exact match score."""
|
| 87 |
+
pred_norm = self.normalize_answer(self.bool_mapping(prediction))
|
| 88 |
+
gt_norm = self.normalize_answer(self.bool_mapping(ground_truth))
|
| 89 |
+
return pred_norm == gt_norm
|
| 90 |
+
|
| 91 |
+
def cover_exact_match_score(self, prediction: str, ground_truth: str) -> bool:
|
| 92 |
+
"""Calculate coverage exact match score."""
|
| 93 |
+
pred_tokens = self.normalize_answer(self.bool_mapping(prediction)).split()
|
| 94 |
+
gt_tokens = self.normalize_answer(self.bool_mapping(ground_truth)).split()
|
| 95 |
+
return all(token in pred_tokens for token in gt_tokens)
|
| 96 |
+
|
| 97 |
+
def f1_score(self, prediction: str, ground_truth: str) -> float:
|
| 98 |
+
"""Calculate F1 score."""
|
| 99 |
+
pred_norm = self.normalize_answer(self.bool_mapping(prediction))
|
| 100 |
+
gt_norm = self.normalize_answer(self.bool_mapping(ground_truth))
|
| 101 |
+
|
| 102 |
+
# Handle yes/no/noanswer cases
|
| 103 |
+
if pred_norm in ["yes", "no", "noanswer"] and pred_norm != gt_norm:
|
| 104 |
+
return 0.0
|
| 105 |
+
if gt_norm in ["yes", "no", "noanswer"] and pred_norm != gt_norm:
|
| 106 |
+
return 0.0
|
| 107 |
+
|
| 108 |
+
pred_tokens = pred_norm.split()
|
| 109 |
+
gt_tokens = gt_norm.split()
|
| 110 |
+
|
| 111 |
+
common = Counter(pred_tokens) & Counter(gt_tokens)
|
| 112 |
+
num_same = sum(common.values())
|
| 113 |
+
|
| 114 |
+
if num_same == 0:
|
| 115 |
+
return 0.0
|
| 116 |
+
|
| 117 |
+
precision = num_same / len(pred_tokens)
|
| 118 |
+
recall = num_same / len(gt_tokens)
|
| 119 |
+
|
| 120 |
+
return (2 * precision * recall) / (precision + recall)
|
| 121 |
+
|
| 122 |
+
def extract_entities(self, text: str) -> set:
|
| 123 |
+
"""Extract entities from text."""
|
| 124 |
+
if self.nlp is None:
|
| 125 |
+
return set() # Return empty set if spacy unavailable
|
| 126 |
+
doc = self.nlp(text)
|
| 127 |
+
return set(ent.text.lower().strip() for ent in doc.ents)
|
| 128 |
+
|
| 129 |
+
def extract_entities_by_category(self, text: str) -> Dict[str, set]:
|
| 130 |
+
"""Extract entities by category."""
|
| 131 |
+
if self.nlp is None:
|
| 132 |
+
return defaultdict(set) # Return empty dict if spacy unavailable
|
| 133 |
+
doc = self.nlp(text)
|
| 134 |
+
entities_by_category = defaultdict(set)
|
| 135 |
+
|
| 136 |
+
for ent in doc.ents:
|
| 137 |
+
if ent.label_ in TARGET_ENTITY_CATEGORIES:
|
| 138 |
+
entities_by_category[ent.label_].add(ent.text.lower().strip())
|
| 139 |
+
|
| 140 |
+
return entities_by_category
|
| 141 |
+
|
| 142 |
+
def entity_preserve_metric(self, prediction: str, reference: str) -> float:
|
| 143 |
+
"""Calculate entity preservation rate."""
|
| 144 |
+
ref_entities = self.extract_entities(reference)
|
| 145 |
+
pred_entities = self.extract_entities(prediction)
|
| 146 |
+
|
| 147 |
+
if not ref_entities:
|
| 148 |
+
return 1.0
|
| 149 |
+
|
| 150 |
+
preserved = ref_entities.intersection(pred_entities)
|
| 151 |
+
return len(preserved) / len(ref_entities)
|
| 152 |
+
|
| 153 |
+
def entity_preserve_metric_by_category(self, prediction_tokens: List[List[str]],
|
| 154 |
+
reference_docs: List[str]) -> Dict[str, float]:
|
| 155 |
+
"""Calculate entity preservation by category."""
|
| 156 |
+
# Merge prediction tokens
|
| 157 |
+
all_prediction_tokens = []
|
| 158 |
+
for tokens in prediction_tokens:
|
| 159 |
+
all_prediction_tokens.extend(tokens)
|
| 160 |
+
prediction_text = " ".join(all_prediction_tokens)
|
| 161 |
+
|
| 162 |
+
# Merge reference documents
|
| 163 |
+
reference_text = " ".join(reference_docs)
|
| 164 |
+
|
| 165 |
+
# Extract entities
|
| 166 |
+
pred_entities = self.extract_entities_by_category(prediction_text)
|
| 167 |
+
ref_entities = self.extract_entities_by_category(reference_text)
|
| 168 |
+
|
| 169 |
+
# Calculate preservation rates
|
| 170 |
+
preservation_rates = {}
|
| 171 |
+
|
| 172 |
+
for category in TARGET_ENTITY_CATEGORIES:
|
| 173 |
+
ref_ents = ref_entities.get(category, set())
|
| 174 |
+
pred_ents = pred_entities.get(category, set())
|
| 175 |
+
|
| 176 |
+
if not ref_ents:
|
| 177 |
+
preservation_rates[category] = 1.0
|
| 178 |
+
else:
|
| 179 |
+
preserved = ref_ents.intersection(pred_ents)
|
| 180 |
+
preservation_rates[category] = len(preserved) / len(ref_ents)
|
| 181 |
+
|
| 182 |
+
# Calculate overall preservation
|
| 183 |
+
all_ref_entities = set()
|
| 184 |
+
all_pred_entities = set()
|
| 185 |
+
|
| 186 |
+
for entities_set in ref_entities.values():
|
| 187 |
+
all_ref_entities.update(entities_set)
|
| 188 |
+
for entities_set in pred_entities.values():
|
| 189 |
+
all_pred_entities.update(entities_set)
|
| 190 |
+
|
| 191 |
+
if not all_ref_entities:
|
| 192 |
+
preservation_rates["overall"] = 1.0
|
| 193 |
+
else:
|
| 194 |
+
preserved_overall = all_ref_entities.intersection(all_pred_entities)
|
| 195 |
+
preservation_rates["overall"] = len(preserved_overall) / len(all_ref_entities)
|
| 196 |
+
|
| 197 |
+
return preservation_rates
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class ResultCalculator:
|
| 201 |
+
"""Handles result calculation and visualization."""
|
| 202 |
+
|
| 203 |
+
def __init__(self):
|
| 204 |
+
self.metrics = EvaluationMetrics()
|
| 205 |
+
|
| 206 |
+
def calculate_basic_metrics(self, result_list: List[Dict]) -> Dict[str, float]:
|
| 207 |
+
"""Calculate basic metrics (F1, accuracy, exact match)."""
|
| 208 |
+
f1_total = 0
|
| 209 |
+
acc_total = 0
|
| 210 |
+
em_total = 0
|
| 211 |
+
avg_output_length = 0
|
| 212 |
+
|
| 213 |
+
answer_key = "golden_answers" if "golden_answers" in result_list[0] else "answer"
|
| 214 |
+
|
| 215 |
+
for result in result_list:
|
| 216 |
+
prediction = result['CLaRa_normal_output']
|
| 217 |
+
ground_truth = result[answer_key][0] if answer_key == "golden_answers" else result[answer_key]
|
| 218 |
+
|
| 219 |
+
acc_total += self.metrics.cover_exact_match_score(prediction, ground_truth)
|
| 220 |
+
f1_total += self.metrics.f1_score(prediction, ground_truth)
|
| 221 |
+
em_total += self.metrics.exact_match_score(prediction, ground_truth)
|
| 222 |
+
avg_output_length += len(prediction.split())
|
| 223 |
+
|
| 224 |
+
n = len(result_list)
|
| 225 |
+
return {
|
| 226 |
+
"f1": f1_total / n,
|
| 227 |
+
"acc": acc_total / n,
|
| 228 |
+
"em": em_total / n,
|
| 229 |
+
"avg_output_length": avg_output_length / n
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
def calculate_stage2_metrics(self, result_list: List[Dict], k_values: List[int] = [1, 3, 5]) -> Dict[str, float]:
|
| 233 |
+
"""Calculate stage2 metrics with recall and precision."""
|
| 234 |
+
basic_metrics = self.calculate_basic_metrics(result_list)
|
| 235 |
+
|
| 236 |
+
recall = {k: 0 for k in k_values}
|
| 237 |
+
precision = {k: 0 for k in k_values}
|
| 238 |
+
|
| 239 |
+
for result in result_list:
|
| 240 |
+
scores = result['topk_idx']
|
| 241 |
+
pos_index = set(result['pos_index'])
|
| 242 |
+
|
| 243 |
+
for k in k_values:
|
| 244 |
+
top_k = set(scores[:k])
|
| 245 |
+
hit = len(top_k & pos_index)
|
| 246 |
+
|
| 247 |
+
recall[k] += hit / len(pos_index) if len(pos_index) > 0 else 0
|
| 248 |
+
precision[k] += hit / k
|
| 249 |
+
|
| 250 |
+
n = len(result_list)
|
| 251 |
+
recall_metrics = {f"recall@{k}": v / n for k, v in recall.items()}
|
| 252 |
+
precision_metrics = {f"precision@{k}": v / n for k, v in precision.items()}
|
| 253 |
+
|
| 254 |
+
return {**basic_metrics, **recall_metrics, **precision_metrics}
|
| 255 |
+
|
| 256 |
+
def calculate_paraphrase_metrics(self, result_list: List[Dict]) -> Dict[str, float]:
|
| 257 |
+
"""Calculate paraphrase metrics."""
|
| 258 |
+
seen_metrics = {'bert-score': 0, 'rouge-1': 0, 'rouge-L': 0, 'entity_preserve': 0}
|
| 259 |
+
unseen_metrics = {'bert-score': 0, 'rouge-1': 0, 'rouge-L': 0, 'entity_preserve': 0}
|
| 260 |
+
|
| 261 |
+
# Process seen data (first 2000)
|
| 262 |
+
for result in result_list[:2000]:
|
| 263 |
+
prediction = result['CLaRa_normal_output']
|
| 264 |
+
ground_truth = result['doc']
|
| 265 |
+
|
| 266 |
+
if EVAL_LIB_AVAILABLE and self.metrics.bertscore is not None:
|
| 267 |
+
bs = self.metrics.bertscore.compute(predictions=[prediction], references=[ground_truth], lang="en")
|
| 268 |
+
seen_metrics['bert-score'] += bs['f1'][0]
|
| 269 |
+
|
| 270 |
+
if EVAL_LIB_AVAILABLE and self.metrics.rouge is not None:
|
| 271 |
+
rouge_scores = self.metrics.rouge.compute(predictions=[prediction], references=[ground_truth])
|
| 272 |
+
seen_metrics['rouge-1'] += rouge_scores['rouge1']
|
| 273 |
+
seen_metrics['rouge-L'] += rouge_scores['rougeL']
|
| 274 |
+
|
| 275 |
+
seen_metrics['entity_preserve'] += self.metrics.entity_preserve_metric(prediction, ground_truth)
|
| 276 |
+
|
| 277 |
+
# Process unseen data (after 2000)
|
| 278 |
+
for result in result_list[2000:]:
|
| 279 |
+
prediction = result['CLaRa_normal_output']
|
| 280 |
+
ground_truth = result['doc']
|
| 281 |
+
|
| 282 |
+
if EVAL_LIB_AVAILABLE and self.metrics.bertscore is not None:
|
| 283 |
+
bs = self.metrics.bertscore.compute(predictions=[prediction], references=[ground_truth], lang="en")
|
| 284 |
+
unseen_metrics['bert-score'] += bs['f1'][0]
|
| 285 |
+
|
| 286 |
+
if EVAL_LIB_AVAILABLE and self.metrics.rouge is not None:
|
| 287 |
+
rouge_scores = self.metrics.rouge.compute(predictions=[prediction], references=[ground_truth])
|
| 288 |
+
unseen_metrics['rouge-1'] += rouge_scores['rouge1']
|
| 289 |
+
unseen_metrics['rouge-L'] += rouge_scores['rougeL']
|
| 290 |
+
|
| 291 |
+
unseen_metrics['entity_preserve'] += self.metrics.entity_preserve_metric(prediction, ground_truth)
|
| 292 |
+
|
| 293 |
+
# Normalize
|
| 294 |
+
n_seen = min(len(result_list[:2000]), 2000)
|
| 295 |
+
n_unseen = max(len(result_list) - 2000, 0)
|
| 296 |
+
|
| 297 |
+
final_metrics = {}
|
| 298 |
+
if n_seen > 0:
|
| 299 |
+
for key, value in seen_metrics.items():
|
| 300 |
+
final_metrics[f'seen_{key}'] = float(value / n_seen)
|
| 301 |
+
|
| 302 |
+
if n_unseen > 0:
|
| 303 |
+
for key, value in unseen_metrics.items():
|
| 304 |
+
final_metrics[f'unseen_{key}'] = float(value / n_unseen)
|
| 305 |
+
|
| 306 |
+
return final_metrics
|
| 307 |
+
|
| 308 |
+
def visualize_mse(self, result_list: List[Dict], save_path: str) -> Dict[str, Any]:
|
| 309 |
+
"""Create t-SNE visualization for MSE analysis."""
|
| 310 |
+
# Set scientific style
|
| 311 |
+
plt.rcParams.update({
|
| 312 |
+
'font.family': 'serif',
|
| 313 |
+
'font.size': 12,
|
| 314 |
+
'axes.labelsize': 14,
|
| 315 |
+
'axes.titlesize': 16,
|
| 316 |
+
'figure.titlesize': 18,
|
| 317 |
+
'axes.linewidth': 1.2,
|
| 318 |
+
'grid.alpha': 0.3,
|
| 319 |
+
})
|
| 320 |
+
|
| 321 |
+
# Collect representations
|
| 322 |
+
mem_reps = []
|
| 323 |
+
non_mem_reps = []
|
| 324 |
+
|
| 325 |
+
for result in result_list:
|
| 326 |
+
mem_rep = result['CLaRa_compressed_output']
|
| 327 |
+
non_mem_rep = result['CLaRa_normal_output']
|
| 328 |
+
|
| 329 |
+
if isinstance(mem_rep, torch.Tensor):
|
| 330 |
+
mem_rep = mem_rep.float().cpu().numpy()
|
| 331 |
+
if isinstance(non_mem_rep, torch.Tensor):
|
| 332 |
+
non_mem_rep = non_mem_rep.float().cpu().numpy()
|
| 333 |
+
|
| 334 |
+
mem_reps.append(mem_rep)
|
| 335 |
+
non_mem_reps.append(non_mem_rep)
|
| 336 |
+
|
| 337 |
+
mem_reps = np.array(mem_reps)
|
| 338 |
+
non_mem_reps = np.array(non_mem_reps)
|
| 339 |
+
|
| 340 |
+
print(f"Memory representations shape: {mem_reps.shape}")
|
| 341 |
+
print(f"Document representations shape: {non_mem_reps.shape}")
|
| 342 |
+
|
| 343 |
+
# Combine data for t-SNE
|
| 344 |
+
all_data = np.vstack([mem_reps, non_mem_reps])
|
| 345 |
+
original_dim = all_data.shape[1]
|
| 346 |
+
|
| 347 |
+
# PCA preprocessing if needed
|
| 348 |
+
if all_data.shape[1] > 50:
|
| 349 |
+
print(f"Applying PCA preprocessing from {all_data.shape[1]} to 50 dimensions...")
|
| 350 |
+
pca = PCA(n_components=50)
|
| 351 |
+
all_data = pca.fit_transform(all_data)
|
| 352 |
+
print(f"PCA explained variance ratio: {pca.explained_variance_ratio_[:5].sum():.3f}")
|
| 353 |
+
|
| 354 |
+
# Apply t-SNE
|
| 355 |
+
print("Applying t-SNE...")
|
| 356 |
+
perplexity = min(30, max(5, len(all_data) // 3))
|
| 357 |
+
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity,
|
| 358 |
+
max_iter=1000, learning_rate=200, verbose=1)
|
| 359 |
+
tsne_results = tsne.fit_transform(all_data)
|
| 360 |
+
|
| 361 |
+
# Separate results
|
| 362 |
+
mem_tsne = tsne_results[:len(mem_reps)]
|
| 363 |
+
doc_tsne = tsne_results[len(mem_reps):]
|
| 364 |
+
|
| 365 |
+
# Create visualization
|
| 366 |
+
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
|
| 367 |
+
|
| 368 |
+
# Add jitter to separate overlapping points
|
| 369 |
+
np.random.seed(42)
|
| 370 |
+
jitter_strength = 1.0
|
| 371 |
+
|
| 372 |
+
mem_jitter = mem_tsne.copy()
|
| 373 |
+
doc_jitter = doc_tsne.copy()
|
| 374 |
+
|
| 375 |
+
mem_jitter[:, 0] += np.random.normal(0.5, jitter_strength, len(mem_tsne))
|
| 376 |
+
mem_jitter[:, 1] += np.random.normal(0.5, jitter_strength, len(mem_tsne))
|
| 377 |
+
|
| 378 |
+
doc_jitter[:, 0] += np.random.normal(-0.5, jitter_strength, len(doc_tsne))
|
| 379 |
+
doc_jitter[:, 1] += np.random.normal(-0.5, jitter_strength, len(doc_tsne))
|
| 380 |
+
|
| 381 |
+
# Plot scatter points
|
| 382 |
+
ax.scatter(doc_jitter[:, 0], doc_jitter[:, 1], c='#0066CC', alpha=0.7, s=25,
|
| 383 |
+
marker='o', edgecolors='white', linewidth=0.5,
|
| 384 |
+
label='Document Representations', zorder=2)
|
| 385 |
+
|
| 386 |
+
ax.scatter(mem_jitter[:, 0], mem_jitter[:, 1], c='#FF3333', alpha=0.7, s=25,
|
| 387 |
+
marker='o', edgecolors='white', linewidth=0.5,
|
| 388 |
+
label='Memory Tokens Representations', zorder=3)
|
| 389 |
+
|
| 390 |
+
# Configure plot
|
| 391 |
+
ax.set_xlabel('')
|
| 392 |
+
ax.set_ylabel('')
|
| 393 |
+
ax.set_title('')
|
| 394 |
+
|
| 395 |
+
legend = ax.legend(frameon=True, fancybox=True, shadow=True,
|
| 396 |
+
loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=2, fontsize=14)
|
| 397 |
+
legend.get_frame().set_facecolor('white')
|
| 398 |
+
legend.get_frame().set_alpha(0.9)
|
| 399 |
+
|
| 400 |
+
ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
|
| 401 |
+
ax.set_axisbelow(True)
|
| 402 |
+
|
| 403 |
+
plt.tight_layout()
|
| 404 |
+
|
| 405 |
+
# Save visualization
|
| 406 |
+
os.makedirs(save_path, exist_ok=True)
|
| 407 |
+
plt.savefig(os.path.join(save_path, 'tsne_visualization_scientific.png'),
|
| 408 |
+
dpi=300, bbox_inches='tight', facecolor='white')
|
| 409 |
+
plt.show()
|
| 410 |
+
|
| 411 |
+
# Calculate statistics
|
| 412 |
+
distances = np.array([
|
| 413 |
+
np.linalg.norm(mem_reps[i] - non_mem_reps[i])
|
| 414 |
+
for i in range(len(mem_reps))
|
| 415 |
+
])
|
| 416 |
+
|
| 417 |
+
statistics = {
|
| 418 |
+
'mean_distance': float(np.mean(distances)),
|
| 419 |
+
'std_distance': float(np.std(distances)),
|
| 420 |
+
'median_distance': float(np.median(distances)),
|
| 421 |
+
'min_distance': float(np.min(distances)),
|
| 422 |
+
'max_distance': float(np.max(distances))
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
print("\n" + "="*60)
|
| 426 |
+
print("VISUALIZATION ANALYSIS REPORT")
|
| 427 |
+
print("="*60)
|
| 428 |
+
print(f"Dataset Statistics:")
|
| 429 |
+
print(f" • Total samples: {len(mem_reps)}")
|
| 430 |
+
print(f" • Original dimension: {original_dim}")
|
| 431 |
+
print(f" • t-SNE perplexity: {perplexity}")
|
| 432 |
+
print(f"\nDistance Analysis:")
|
| 433 |
+
for key, value in statistics.items():
|
| 434 |
+
print(f" • {key.replace('_', ' ').title()}: {value:.4f}")
|
| 435 |
+
print("="*60)
|
| 436 |
+
|
| 437 |
+
return {
|
| 438 |
+
'mem_tsne': mem_tsne,
|
| 439 |
+
'doc_tsne': doc_tsne,
|
| 440 |
+
'original_distances': distances,
|
| 441 |
+
'statistics': statistics
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class DataLoader:
|
| 446 |
+
"""Handles data loading for different datasets and stages."""
|
| 447 |
+
|
| 448 |
+
@staticmethod
|
| 449 |
+
def load_stage1_data(dataset: str, gold_retrieval: bool) -> List[Dict]:
|
| 450 |
+
"""Load stage1 evaluation data."""
|
| 451 |
+
retrieval_type = "with_pos" if gold_retrieval else "no_pos"
|
| 452 |
+
file_path = f"/mnt/conductor_data/data/compression_rag_data/generator_training_val_data/stage1_eval/{dataset}/eval_processed_{retrieval_type}.jsonl"
|
| 453 |
+
|
| 454 |
+
data = []
|
| 455 |
+
with open(file_path, 'r') as f:
|
| 456 |
+
for line in f:
|
| 457 |
+
data.append(json.loads(line))
|
| 458 |
+
|
| 459 |
+
processed_data = []
|
| 460 |
+
for index, item in enumerate(data):
|
| 461 |
+
docs = item['docs'][:5] # Take top 5 documents
|
| 462 |
+
processed_item = {
|
| 463 |
+
'original_data': item,
|
| 464 |
+
'documents': docs,
|
| 465 |
+
'question': item['question'],
|
| 466 |
+
'global_index': index
|
| 467 |
+
}
|
| 468 |
+
processed_data.append(processed_item)
|
| 469 |
+
|
| 470 |
+
return processed_data
|
| 471 |
+
|
| 472 |
+
@staticmethod
|
| 473 |
+
def load_stage2_data(dataset: str, gold_retrieval: bool) -> List[Dict]:
|
| 474 |
+
"""Load stage2 evaluation data."""
|
| 475 |
+
retrieval_type = "with_pos" if gold_retrieval else "no_pos"
|
| 476 |
+
file_path = f"/mnt/conductor_data/data/compression_rag_data/generator_training_val_data/stage2_eval/{dataset}/eval_processed_{retrieval_type}.jsonl"
|
| 477 |
+
|
| 478 |
+
processed_data = []
|
| 479 |
+
with open(file_path, 'r') as f:
|
| 480 |
+
for index, line in enumerate(f):
|
| 481 |
+
item = json.loads(line)
|
| 482 |
+
processed_item = {
|
| 483 |
+
'original_data': item,
|
| 484 |
+
'documents': item['docs'],
|
| 485 |
+
'question': item['question'],
|
| 486 |
+
'global_index': index,
|
| 487 |
+
'pos_index': item['pos_index']
|
| 488 |
+
}
|
| 489 |
+
processed_data.append(processed_item)
|
| 490 |
+
|
| 491 |
+
return processed_data
|
| 492 |
+
|
| 493 |
+
@staticmethod
|
| 494 |
+
def load_paraphrase_data(file_path: str) -> List[Dict]:
|
| 495 |
+
"""Load paraphrase data."""
|
| 496 |
+
data = []
|
| 497 |
+
with open(file_path, 'r') as f:
|
| 498 |
+
for line in f:
|
| 499 |
+
data.append(json.loads(line))
|
| 500 |
+
|
| 501 |
+
processed_data = []
|
| 502 |
+
for index, item in enumerate(data):
|
| 503 |
+
processed_item = {
|
| 504 |
+
'original_data': item,
|
| 505 |
+
'documents': [item['doc']],
|
| 506 |
+
'question': "",
|
| 507 |
+
'global_index': index
|
| 508 |
+
}
|
| 509 |
+
processed_data.append(processed_item)
|
| 510 |
+
|
| 511 |
+
return processed_data
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class AcceleratedCLaRaInference:
|
| 515 |
+
"""Main inference engine using Accelerate for distributed processing."""
|
| 516 |
+
|
| 517 |
+
def __init__(self, model_path: str, training_stage: str = None,
|
| 518 |
+
generation_top_k: int = None, args = None):
|
| 519 |
+
self.args = args
|
| 520 |
+
|
| 521 |
+
# Initialize Accelerator
|
| 522 |
+
process_group_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400))
|
| 523 |
+
self.accelerator = Accelerator(kwargs_handlers=[process_group_kwargs])
|
| 524 |
+
|
| 525 |
+
if self.accelerator.is_main_process:
|
| 526 |
+
print(f"Using {self.accelerator.num_processes} GPUs for distributed inference")
|
| 527 |
+
print(f"Current process: {self.accelerator.process_index}")
|
| 528 |
+
print("Loading CLaRa model...")
|
| 529 |
+
|
| 530 |
+
# Load model
|
| 531 |
+
self.model = CLaRa.from_pretrained(
|
| 532 |
+
model_path,
|
| 533 |
+
training_stage=training_stage,
|
| 534 |
+
generation_top_k=generation_top_k,
|
| 535 |
+
pure_inference=True
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# Prepare model with Accelerator
|
| 539 |
+
self.model = self.accelerator.prepare(self.model)
|
| 540 |
+
self.model.eval()
|
| 541 |
+
|
| 542 |
+
if self.accelerator.is_main_process:
|
| 543 |
+
print("Model preparation completed")
|
| 544 |
+
|
| 545 |
+
def _get_model(self):
|
| 546 |
+
"""Get the actual model (handles distributed vs single GPU)."""
|
| 547 |
+
return self.model.module if hasattr(self.model, 'module') else self.model
|
| 548 |
+
|
| 549 |
+
def process_batch(self, batch_questions: List[str], batch_documents: List[List[str]] = None,
|
| 550 |
+
stage2_mips: bool = False, training_stage: str = None,
|
| 551 |
+
batch_answers: List[str] = None, time_count: bool = False) -> Tuple:
|
| 552 |
+
"""Process a batch of questions and documents."""
|
| 553 |
+
model = self._get_model()
|
| 554 |
+
|
| 555 |
+
with torch.no_grad():
|
| 556 |
+
try:
|
| 557 |
+
if training_stage == 'stage2':
|
| 558 |
+
return self._process_stage2(model, batch_questions, batch_documents,
|
| 559 |
+
stage2_mips, time_count)
|
| 560 |
+
elif training_stage in ['stage1', 'stage1_2']:
|
| 561 |
+
return self._process_stage1(model, batch_questions, batch_documents)
|
| 562 |
+
elif training_stage == 'stage2_reasoning':
|
| 563 |
+
return self._process_reasoning(model, batch_questions, batch_answers)
|
| 564 |
+
elif training_stage == 'stage1_paraphrase':
|
| 565 |
+
return self._process_paraphrase(model, batch_questions, batch_documents)
|
| 566 |
+
elif training_stage == 'stage1_mse_visulize':
|
| 567 |
+
return self._process_mse_visualize(model, batch_documents)
|
| 568 |
+
else:
|
| 569 |
+
raise ValueError(f"Unknown training stage: {training_stage}")
|
| 570 |
+
|
| 571 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 572 |
+
self.accelerator.print(f"CUDA OOM error: {e}")
|
| 573 |
+
torch.cuda.empty_cache()
|
| 574 |
+
gc.collect()
|
| 575 |
+
return self._create_empty_results(batch_questions, training_stage)
|
| 576 |
+
|
| 577 |
+
def _process_stage2(self, model, batch_questions, batch_documents, stage2_mips, time_count):
|
| 578 |
+
"""Process stage2 inference."""
|
| 579 |
+
if time_count:
|
| 580 |
+
if stage2_mips:
|
| 581 |
+
results = model.generate_from_questions(
|
| 582 |
+
questions=batch_questions,
|
| 583 |
+
max_new_tokens=64,
|
| 584 |
+
stage2_mips=stage2_mips,
|
| 585 |
+
time_count=True
|
| 586 |
+
)
|
| 587 |
+
else:
|
| 588 |
+
results = model.generate_from_questions(
|
| 589 |
+
questions=batch_questions,
|
| 590 |
+
max_new_tokens=64,
|
| 591 |
+
stage2_mips=stage2_mips,
|
| 592 |
+
documents=batch_documents,
|
| 593 |
+
time_count=True
|
| 594 |
+
)
|
| 595 |
+
return results
|
| 596 |
+
else:
|
| 597 |
+
if stage2_mips:
|
| 598 |
+
batch_out_normal, topk_idx = model.generate_from_questions(
|
| 599 |
+
questions=batch_questions,
|
| 600 |
+
max_new_tokens=64,
|
| 601 |
+
stage2_mips=stage2_mips
|
| 602 |
+
)
|
| 603 |
+
else:
|
| 604 |
+
batch_out_normal, topk_idx = model.generate_from_questions(
|
| 605 |
+
questions=batch_questions,
|
| 606 |
+
max_new_tokens=64,
|
| 607 |
+
stage2_mips=stage2_mips,
|
| 608 |
+
documents=batch_documents
|
| 609 |
+
)
|
| 610 |
+
return batch_out_normal, batch_out_normal, topk_idx
|
| 611 |
+
|
| 612 |
+
def _process_stage1(self, model, batch_questions, batch_documents):
|
| 613 |
+
"""Process stage1 inference."""
|
| 614 |
+
batch_out_compressed = []
|
| 615 |
+
|
| 616 |
+
for docs, question in zip(batch_documents, batch_questions):
|
| 617 |
+
embeddings, _ = model.compress_documents(documents=docs)
|
| 618 |
+
out_compressed = model.generate_from_compressed_documents_and_questions(
|
| 619 |
+
questions=[question],
|
| 620 |
+
compressed_documents=embeddings
|
| 621 |
+
)
|
| 622 |
+
batch_out_compressed.extend(out_compressed)
|
| 623 |
+
|
| 624 |
+
del embeddings
|
| 625 |
+
torch.cuda.empty_cache()
|
| 626 |
+
|
| 627 |
+
return batch_out_compressed, batch_out_compressed, None
|
| 628 |
+
|
| 629 |
+
def _process_reasoning(self, model, batch_questions, batch_answers):
|
| 630 |
+
"""Process reasoning inference."""
|
| 631 |
+
batch_out_normal = []
|
| 632 |
+
batch_out_reasoning_list = []
|
| 633 |
+
|
| 634 |
+
for question, answer in zip(batch_questions, batch_answers):
|
| 635 |
+
temp_out, temp_out_reasoning = model.generate_from_reasoning(
|
| 636 |
+
questions=[question],
|
| 637 |
+
max_new_tokens=1024,
|
| 638 |
+
answers=[answer],
|
| 639 |
+
save_dir=self.args.model_path
|
| 640 |
+
)
|
| 641 |
+
batch_out_normal.append(temp_out[0])
|
| 642 |
+
batch_out_reasoning_list.extend(temp_out_reasoning)
|
| 643 |
+
|
| 644 |
+
return batch_out_normal, batch_out_normal, None, batch_out_reasoning_list
|
| 645 |
+
|
| 646 |
+
def _process_paraphrase(self, model, batch_questions, batch_documents):
|
| 647 |
+
"""Process paraphrase inference."""
|
| 648 |
+
batch_out_compressed = []
|
| 649 |
+
|
| 650 |
+
for docs, question in zip(batch_documents, batch_questions):
|
| 651 |
+
out_compressed = model.generate_from_paraphrase(
|
| 652 |
+
questions=["" for _ in range(len(docs))],
|
| 653 |
+
documents=[docs]
|
| 654 |
+
)
|
| 655 |
+
batch_out_compressed.extend(out_compressed)
|
| 656 |
+
torch.cuda.empty_cache()
|
| 657 |
+
|
| 658 |
+
return batch_out_compressed, batch_out_compressed, None
|
| 659 |
+
|
| 660 |
+
def _process_mse_visualize(self, model, batch_documents):
|
| 661 |
+
"""Process MSE visualization."""
|
| 662 |
+
batch_out_normal = []
|
| 663 |
+
batch_out_compressed = []
|
| 664 |
+
|
| 665 |
+
for docs in batch_documents:
|
| 666 |
+
mem_rep, non_mem_rep = model.compress_documents_mse_visulize(documents=docs)
|
| 667 |
+
batch_out_compressed.append(mem_rep[0])
|
| 668 |
+
batch_out_normal.append(non_mem_rep[0])
|
| 669 |
+
|
| 670 |
+
return batch_out_normal, batch_out_compressed
|
| 671 |
+
|
| 672 |
+
def _create_empty_results(self, batch_questions, training_stage):
|
| 673 |
+
"""Create empty results for error cases."""
|
| 674 |
+
empty_results = [""] * len(batch_questions)
|
| 675 |
+
if training_stage == 'stage2_reasoning':
|
| 676 |
+
return empty_results, empty_results, None, empty_results
|
| 677 |
+
elif training_stage == 'stage1_mse_visulize':
|
| 678 |
+
return empty_results, empty_results
|
| 679 |
+
else:
|
| 680 |
+
return empty_results, empty_results, None
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def convert_embeddings_to_list(data):
|
| 684 |
+
"""Convert tensor embeddings to lists for JSON serialization."""
|
| 685 |
+
if isinstance(data, dict):
|
| 686 |
+
return {k: convert_embeddings_to_list(v) for k, v in data.items()}
|
| 687 |
+
elif isinstance(data, list):
|
| 688 |
+
return [convert_embeddings_to_list(item) for item in data]
|
| 689 |
+
elif isinstance(data, torch.Tensor):
|
| 690 |
+
return data.cpu().to(torch.float32).numpy().tolist()
|
| 691 |
+
elif isinstance(data, np.ndarray):
|
| 692 |
+
return data.tolist()
|
| 693 |
+
else:
|
| 694 |
+
return data
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def main():
|
| 698 |
+
parser = argparse.ArgumentParser(description="CLaRa Model Inference")
|
| 699 |
+
parser.add_argument('--model_path', type=str, required=True, help='Path to model checkpoint')
|
| 700 |
+
parser.add_argument('--batch_size', type=int, default=4, help='Batch size per GPU')
|
| 701 |
+
parser.add_argument('--stage', type=str, default='stage1',
|
| 702 |
+
choices=['stage1', 'stage1_2', 'stage2', 'stage2_reasoning',
|
| 703 |
+
'stage1_paraphrase', 'stage1_mse_visulize'],
|
| 704 |
+
help='Training stage')
|
| 705 |
+
parser.add_argument('--stage2_mips', action='store_true', help='Use MIPS for stage2')
|
| 706 |
+
parser.add_argument('--dataset', type=str, default='musique',
|
| 707 |
+
help='Comma-separated list of datasets')
|
| 708 |
+
parser.add_argument('--gold_retrieval', action='store_true',
|
| 709 |
+
help='Use gold retrieval context')
|
| 710 |
+
parser.add_argument('--generation_top_k', type=int, default=5, help='Top-k for generation')
|
| 711 |
+
parser.add_argument('--paraphrase_path', type=str, help='Path to paraphrase data')
|
| 712 |
+
parser.add_argument('--mse_visulize_path', type=str, help='Path to save MSE visualization')
|
| 713 |
+
parser.add_argument('--efficient_count', action='store_true', help='Count efficiency metrics')
|
| 714 |
+
|
| 715 |
+
args = parser.parse_args()
|
| 716 |
+
|
| 717 |
+
# Process datasets
|
| 718 |
+
all_results_metrics = {}
|
| 719 |
+
datasets_list = args.dataset.split(',')
|
| 720 |
+
|
| 721 |
+
for dataset in datasets_list:
|
| 722 |
+
print(f"Processing dataset: {dataset}")
|
| 723 |
+
|
| 724 |
+
# Load data based on stage
|
| 725 |
+
if args.stage in ['stage1', 'stage1_2']:
|
| 726 |
+
processed_data = DataLoader.load_stage1_data(dataset, args.gold_retrieval)
|
| 727 |
+
elif args.stage == 'stage2':
|
| 728 |
+
processed_data = DataLoader.load_stage2_data(dataset, args.gold_retrieval)
|
| 729 |
+
elif args.stage in ['stage1_paraphrase', 'stage1_mse_visulize']:
|
| 730 |
+
if not args.paraphrase_path:
|
| 731 |
+
raise ValueError(f"--paraphrase_path required for stage {args.stage}")
|
| 732 |
+
processed_data = DataLoader.load_paraphrase_data(args.paraphrase_path)
|
| 733 |
+
else:
|
| 734 |
+
raise ValueError(f"Unsupported stage: {args.stage}")
|
| 735 |
+
|
| 736 |
+
print(f"Loaded {len(processed_data)} samples for {dataset}")
|
| 737 |
+
|
| 738 |
+
# Initialize inference engine
|
| 739 |
+
# Use model_path directly if absolute, otherwise use SageMaker path
|
| 740 |
+
if os.path.isabs(args.model_path):
|
| 741 |
+
model_path = args.model_path
|
| 742 |
+
else:
|
| 743 |
+
model_path = os.path.join('/mnt/task_wrapper/user_output/artifacts/data/train_checkpoint', args.model_path)
|
| 744 |
+
args.model_path = model_path
|
| 745 |
+
|
| 746 |
+
inference_engine = AcceleratedCLaRaInference(
|
| 747 |
+
model_path=model_path,
|
| 748 |
+
training_stage=args.stage,
|
| 749 |
+
generation_top_k=args.generation_top_k,
|
| 750 |
+
args=args
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
# Wait for all processes to be ready
|
| 754 |
+
inference_engine.accelerator.wait_for_everyone()
|
| 755 |
+
|
| 756 |
+
# Store results
|
| 757 |
+
all_results = []
|
| 758 |
+
time_count_dic = {"compress_time": 0, "query_time": 0, "generate_time": 0, "total_time": 0, "count": 0}
|
| 759 |
+
|
| 760 |
+
# Process data in batches using accelerator
|
| 761 |
+
with inference_engine.accelerator.split_between_processes(processed_data, apply_padding=False) as local_data:
|
| 762 |
+
print(f"Process {inference_engine.accelerator.process_index}: processing {len(local_data)} samples")
|
| 763 |
+
|
| 764 |
+
batch_size = args.batch_size
|
| 765 |
+
num_batches = (len(local_data) + batch_size - 1) // batch_size
|
| 766 |
+
|
| 767 |
+
for batch_idx in tqdm(range(num_batches),
|
| 768 |
+
desc=f"GPU {inference_engine.accelerator.process_index}",
|
| 769 |
+
disable=not inference_engine.accelerator.is_local_main_process):
|
| 770 |
+
|
| 771 |
+
# Get current batch
|
| 772 |
+
start_idx = batch_idx * batch_size
|
| 773 |
+
end_idx = min(start_idx + batch_size, len(local_data))
|
| 774 |
+
batch = local_data[start_idx:end_idx]
|
| 775 |
+
|
| 776 |
+
# Prepare batch data
|
| 777 |
+
batch_questions = [item['question'] for item in batch]
|
| 778 |
+
batch_documents = [item['documents'] for item in batch] if 'documents' in batch[0] else None
|
| 779 |
+
batch_answers = [item.get('answer') for item in batch] if args.stage == 'stage2_reasoning' else None
|
| 780 |
+
|
| 781 |
+
# Process batch
|
| 782 |
+
if args.efficient_count and args.stage == 'stage2':
|
| 783 |
+
results = inference_engine.process_batch(
|
| 784 |
+
batch_questions=batch_questions,
|
| 785 |
+
batch_documents=batch_documents,
|
| 786 |
+
stage2_mips=args.stage2_mips,
|
| 787 |
+
training_stage=args.stage,
|
| 788 |
+
time_count=True
|
| 789 |
+
)
|
| 790 |
+
batch_out_normal, batch_out_compressed, batch_topk_idx, compress_time, query_time, generate_time, total_time = results
|
| 791 |
+
|
| 792 |
+
time_count_dic["compress_time"] += compress_time
|
| 793 |
+
time_count_dic["query_time"] += query_time
|
| 794 |
+
time_count_dic["generate_time"] += generate_time
|
| 795 |
+
time_count_dic["total_time"] += total_time
|
| 796 |
+
time_count_dic["count"] += 1
|
| 797 |
+
else:
|
| 798 |
+
results = inference_engine.process_batch(
|
| 799 |
+
batch_questions=batch_questions,
|
| 800 |
+
batch_documents=batch_documents,
|
| 801 |
+
stage2_mips=args.stage2_mips,
|
| 802 |
+
training_stage=args.stage,
|
| 803 |
+
batch_answers=batch_answers
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
if args.stage == 'stage2_reasoning':
|
| 807 |
+
batch_out_normal, batch_out_compressed, batch_topk_idx, batch_out_reasoning = results
|
| 808 |
+
elif args.stage == 'stage1_mse_visulize':
|
| 809 |
+
batch_out_normal, batch_out_compressed = results
|
| 810 |
+
batch_topk_idx = None
|
| 811 |
+
else:
|
| 812 |
+
batch_out_normal, batch_out_compressed, batch_topk_idx = results
|
| 813 |
+
|
| 814 |
+
# Prepare results
|
| 815 |
+
batch_results = []
|
| 816 |
+
for i, (item, normal_out, compressed_out) in enumerate(zip(batch, batch_out_normal, batch_out_compressed)):
|
| 817 |
+
result_item = item['original_data'].copy()
|
| 818 |
+
result_item['CLaRa_normal_output'] = normal_out
|
| 819 |
+
result_item['CLaRa_compressed_output'] = compressed_out
|
| 820 |
+
result_item['global_index'] = item['global_index']
|
| 821 |
+
|
| 822 |
+
if args.stage == 'stage2' and batch_topk_idx is not None:
|
| 823 |
+
result_item['topk_idx'] = batch_topk_idx[i].tolist()
|
| 824 |
+
elif args.stage == 'stage2_reasoning':
|
| 825 |
+
result_item['reasoning_output'] = batch_out_reasoning[i]
|
| 826 |
+
|
| 827 |
+
batch_results.append(result_item)
|
| 828 |
+
|
| 829 |
+
all_results.extend(batch_results)
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
# Clean up memory
|
| 833 |
+
torch.cuda.empty_cache()
|
| 834 |
+
if batch_idx % 10 == 0:
|
| 835 |
+
gc.collect()
|
| 836 |
+
|
| 837 |
+
# Save efficiency metrics if requested
|
| 838 |
+
if args.efficient_count and inference_engine.accelerator.is_main_process:
|
| 839 |
+
eff_dic = {
|
| 840 |
+
"compress_time_ms": round((time_count_dic['compress_time'] / time_count_dic['count']) * 1000, 2),
|
| 841 |
+
"query_time_ms": round((time_count_dic['query_time'] / time_count_dic['count']) * 1000, 2),
|
| 842 |
+
"generate_time_ms": round((time_count_dic['generate_time'] / time_count_dic['count']) * 1000, 2),
|
| 843 |
+
"total_time_ms": round((time_count_dic['total_time'] / time_count_dic['count']) * 1000, 2),
|
| 844 |
+
"sample_count": time_count_dic['count']
|
| 845 |
+
}
|
| 846 |
+
eff_output_path = os.path.join(model_path, f"efficiency_{dataset}_{args.stage}_{args.gold_retrieval}_{args.generation_top_k}.json")
|
| 847 |
+
with open(eff_output_path, 'w') as f:
|
| 848 |
+
json.dump(eff_dic, f, indent=2)
|
| 849 |
+
|
| 850 |
+
# Wait for all processes to complete
|
| 851 |
+
inference_engine.accelerator.wait_for_everyone()
|
| 852 |
+
|
| 853 |
+
# Gather results from all processes
|
| 854 |
+
if inference_engine.accelerator.is_main_process:
|
| 855 |
+
print("Collecting results from all processes...")
|
| 856 |
+
|
| 857 |
+
all_results_gathered = inference_engine.accelerator.gather_for_metrics(all_results)
|
| 858 |
+
|
| 859 |
+
# Process and save results (main process only)
|
| 860 |
+
if inference_engine.accelerator.is_main_process:
|
| 861 |
+
print("Processing and saving results...")
|
| 862 |
+
|
| 863 |
+
# Flatten results
|
| 864 |
+
final_results = []
|
| 865 |
+
if isinstance(all_results_gathered, list):
|
| 866 |
+
for result_batch in all_results_gathered:
|
| 867 |
+
if isinstance(result_batch, list):
|
| 868 |
+
final_results.extend(result_batch)
|
| 869 |
+
else:
|
| 870 |
+
final_results.append(result_batch)
|
| 871 |
+
|
| 872 |
+
print(f"Collected {len(final_results)} results")
|
| 873 |
+
|
| 874 |
+
# Sort by global index to maintain order
|
| 875 |
+
final_results.sort(key=lambda x: x.get('global_index', 0))
|
| 876 |
+
|
| 877 |
+
# Verify data integrity
|
| 878 |
+
processed_indices = set(item.get('global_index', -1) for item in final_results)
|
| 879 |
+
expected_indices = set(range(len(processed_data)))
|
| 880 |
+
missing_indices = expected_indices - processed_indices
|
| 881 |
+
|
| 882 |
+
if missing_indices:
|
| 883 |
+
print(f"Warning: Missing indices: {sorted(list(missing_indices))}")
|
| 884 |
+
else:
|
| 885 |
+
print("✓ Data integrity verification passed")
|
| 886 |
+
|
| 887 |
+
# Remove global index for clean output
|
| 888 |
+
for item in final_results:
|
| 889 |
+
item.pop('global_index', None)
|
| 890 |
+
|
| 891 |
+
# Save results
|
| 892 |
+
output_path = os.path.join(model_path, f"{dataset}_{args.stage}_{args.gold_retrieval}_{args.generation_top_k}.jsonl")
|
| 893 |
+
with open(output_path, 'w') as f:
|
| 894 |
+
if args.stage == 'stage1_mse_visulize':
|
| 895 |
+
converted_results = convert_embeddings_to_list(final_results)
|
| 896 |
+
for item in converted_results:
|
| 897 |
+
f.write(json.dumps(item) + '\n')
|
| 898 |
+
else:
|
| 899 |
+
for item in final_results:
|
| 900 |
+
f.write(json.dumps(item) + '\n')
|
| 901 |
+
|
| 902 |
+
print(f"Results saved to: {output_path}")
|
| 903 |
+
|
| 904 |
+
# Calculate metrics
|
| 905 |
+
calculator = ResultCalculator()
|
| 906 |
+
|
| 907 |
+
if args.stage == 'stage2':
|
| 908 |
+
metrics = calculator.calculate_stage2_metrics(final_results)
|
| 909 |
+
elif args.stage == 'stage1_paraphrase':
|
| 910 |
+
metrics = calculator.calculate_paraphrase_metrics(final_results)
|
| 911 |
+
elif args.stage == 'stage1_mse_visulize':
|
| 912 |
+
if args.mse_visulize_path:
|
| 913 |
+
metrics = calculator.visualize_mse(final_results, args.mse_visulize_path)
|
| 914 |
+
else:
|
| 915 |
+
metrics = {"visualization": "completed"}
|
| 916 |
+
else:
|
| 917 |
+
metrics = calculator.calculate_basic_metrics(final_results)
|
| 918 |
+
|
| 919 |
+
print(f"Metrics for {dataset}: {metrics}")
|
| 920 |
+
all_results_metrics[dataset] = metrics
|
| 921 |
+
|
| 922 |
+
# Clean up
|
| 923 |
+
del inference_engine
|
| 924 |
+
torch.cuda.empty_cache()
|
| 925 |
+
gc.collect()
|
| 926 |
+
|
| 927 |
+
# Save final metrics
|
| 928 |
+
if len(all_results_metrics) > 0:
|
| 929 |
+
metrics_path = os.path.join(model_path, f"results_metrics_{args.stage}_{args.gold_retrieval}_{args.generation_top_k}.json")
|
| 930 |
+
with open(metrics_path, 'w') as f:
|
| 931 |
+
json.dump(all_results_metrics, f, indent=2)
|
| 932 |
+
print(f"Final metrics saved to: {metrics_path}")
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
if __name__ == '__main__':
|
| 936 |
+
main()
|
evaluation/evaluate.py.bak
ADDED
|
@@ -0,0 +1,910 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# For licensing see accompanying LICENSE file.
|
| 3 |
+
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import argparse
|
| 9 |
+
import gc
|
| 10 |
+
from datetime import timedelta
|
| 11 |
+
from collections import defaultdict, Counter
|
| 12 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import numpy as np
|
| 16 |
+
from accelerate import Accelerator, InitProcessGroupKwargs
|
| 17 |
+
from transformers import AutoModel
|
| 18 |
+
from datasets import load_dataset
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
from sklearn.manifold import TSNE
|
| 22 |
+
from sklearn.decomposition import PCA
|
| 23 |
+
import spacy
|
| 24 |
+
import evaluate
|
| 25 |
+
import re
|
| 26 |
+
import string
|
| 27 |
+
|
| 28 |
+
from openrlhf.models.modeling_clara import CLaRa
|
| 29 |
+
|
| 30 |
+
# Environment setup
|
| 31 |
+
os.environ["NCCL_TIMEOUT"] = "5400"
|
| 32 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 33 |
+
|
| 34 |
+
# Global constants
|
| 35 |
+
TARGET_ENTITY_CATEGORIES = {"PERSON", "GPE", "DATE", "CARDINAL", "ORG"}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class EvaluationMetrics:
|
| 39 |
+
"""Handles all evaluation metrics and scoring functions."""
|
| 40 |
+
|
| 41 |
+
def __init__(self):
|
| 42 |
+
self.bertscore = evaluate.load("bertscore")
|
| 43 |
+
self.rouge = evaluate.load("rouge")
|
| 44 |
+
self.nlp = spacy.load("en_core_web_sm")
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def normalize_answer(text: str) -> str:
|
| 48 |
+
"""Normalize text for comparison."""
|
| 49 |
+
def remove_articles(text):
|
| 50 |
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
| 51 |
+
|
| 52 |
+
def white_space_fix(text):
|
| 53 |
+
return " ".join(text.split())
|
| 54 |
+
|
| 55 |
+
def remove_punc(text):
|
| 56 |
+
exclude = set(string.punctuation)
|
| 57 |
+
return "".join(ch for ch in text if ch not in exclude)
|
| 58 |
+
|
| 59 |
+
return white_space_fix(remove_articles(remove_punc(text.lower())))
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def bool_mapping(text: str) -> str:
|
| 63 |
+
"""Map boolean values to yes/no."""
|
| 64 |
+
mapping = {"True": "yes", "False": "no"}
|
| 65 |
+
return mapping.get(text, text)
|
| 66 |
+
|
| 67 |
+
def exact_match_score(self, prediction: str, ground_truth: str) -> bool:
|
| 68 |
+
"""Calculate exact match score."""
|
| 69 |
+
pred_norm = self.normalize_answer(self.bool_mapping(prediction))
|
| 70 |
+
gt_norm = self.normalize_answer(self.bool_mapping(ground_truth))
|
| 71 |
+
return pred_norm == gt_norm
|
| 72 |
+
|
| 73 |
+
def cover_exact_match_score(self, prediction: str, ground_truth: str) -> bool:
|
| 74 |
+
"""Calculate coverage exact match score."""
|
| 75 |
+
pred_tokens = self.normalize_answer(self.bool_mapping(prediction)).split()
|
| 76 |
+
gt_tokens = self.normalize_answer(self.bool_mapping(ground_truth)).split()
|
| 77 |
+
return all(token in pred_tokens for token in gt_tokens)
|
| 78 |
+
|
| 79 |
+
def f1_score(self, prediction: str, ground_truth: str) -> float:
|
| 80 |
+
"""Calculate F1 score."""
|
| 81 |
+
pred_norm = self.normalize_answer(self.bool_mapping(prediction))
|
| 82 |
+
gt_norm = self.normalize_answer(self.bool_mapping(ground_truth))
|
| 83 |
+
|
| 84 |
+
# Handle yes/no/noanswer cases
|
| 85 |
+
if pred_norm in ["yes", "no", "noanswer"] and pred_norm != gt_norm:
|
| 86 |
+
return 0.0
|
| 87 |
+
if gt_norm in ["yes", "no", "noanswer"] and pred_norm != gt_norm:
|
| 88 |
+
return 0.0
|
| 89 |
+
|
| 90 |
+
pred_tokens = pred_norm.split()
|
| 91 |
+
gt_tokens = gt_norm.split()
|
| 92 |
+
|
| 93 |
+
common = Counter(pred_tokens) & Counter(gt_tokens)
|
| 94 |
+
num_same = sum(common.values())
|
| 95 |
+
|
| 96 |
+
if num_same == 0:
|
| 97 |
+
return 0.0
|
| 98 |
+
|
| 99 |
+
precision = num_same / len(pred_tokens)
|
| 100 |
+
recall = num_same / len(gt_tokens)
|
| 101 |
+
|
| 102 |
+
return (2 * precision * recall) / (precision + recall)
|
| 103 |
+
|
| 104 |
+
def extract_entities(self, text: str) -> set:
|
| 105 |
+
"""Extract entities from text."""
|
| 106 |
+
doc = self.nlp(text)
|
| 107 |
+
return set(ent.text.lower().strip() for ent in doc.ents)
|
| 108 |
+
|
| 109 |
+
def extract_entities_by_category(self, text: str) -> Dict[str, set]:
|
| 110 |
+
"""Extract entities by category."""
|
| 111 |
+
doc = self.nlp(text)
|
| 112 |
+
entities_by_category = defaultdict(set)
|
| 113 |
+
|
| 114 |
+
for ent in doc.ents:
|
| 115 |
+
if ent.label_ in TARGET_ENTITY_CATEGORIES:
|
| 116 |
+
entities_by_category[ent.label_].add(ent.text.lower().strip())
|
| 117 |
+
|
| 118 |
+
return entities_by_category
|
| 119 |
+
|
| 120 |
+
def entity_preserve_metric(self, prediction: str, reference: str) -> float:
|
| 121 |
+
"""Calculate entity preservation rate."""
|
| 122 |
+
ref_entities = self.extract_entities(reference)
|
| 123 |
+
pred_entities = self.extract_entities(prediction)
|
| 124 |
+
|
| 125 |
+
if not ref_entities:
|
| 126 |
+
return 1.0
|
| 127 |
+
|
| 128 |
+
preserved = ref_entities.intersection(pred_entities)
|
| 129 |
+
return len(preserved) / len(ref_entities)
|
| 130 |
+
|
| 131 |
+
def entity_preserve_metric_by_category(self, prediction_tokens: List[List[str]],
|
| 132 |
+
reference_docs: List[str]) -> Dict[str, float]:
|
| 133 |
+
"""Calculate entity preservation by category."""
|
| 134 |
+
# Merge prediction tokens
|
| 135 |
+
all_prediction_tokens = []
|
| 136 |
+
for tokens in prediction_tokens:
|
| 137 |
+
all_prediction_tokens.extend(tokens)
|
| 138 |
+
prediction_text = " ".join(all_prediction_tokens)
|
| 139 |
+
|
| 140 |
+
# Merge reference documents
|
| 141 |
+
reference_text = " ".join(reference_docs)
|
| 142 |
+
|
| 143 |
+
# Extract entities
|
| 144 |
+
pred_entities = self.extract_entities_by_category(prediction_text)
|
| 145 |
+
ref_entities = self.extract_entities_by_category(reference_text)
|
| 146 |
+
|
| 147 |
+
# Calculate preservation rates
|
| 148 |
+
preservation_rates = {}
|
| 149 |
+
|
| 150 |
+
for category in TARGET_ENTITY_CATEGORIES:
|
| 151 |
+
ref_ents = ref_entities.get(category, set())
|
| 152 |
+
pred_ents = pred_entities.get(category, set())
|
| 153 |
+
|
| 154 |
+
if not ref_ents:
|
| 155 |
+
preservation_rates[category] = 1.0
|
| 156 |
+
else:
|
| 157 |
+
preserved = ref_ents.intersection(pred_ents)
|
| 158 |
+
preservation_rates[category] = len(preserved) / len(ref_ents)
|
| 159 |
+
|
| 160 |
+
# Calculate overall preservation
|
| 161 |
+
all_ref_entities = set()
|
| 162 |
+
all_pred_entities = set()
|
| 163 |
+
|
| 164 |
+
for entities_set in ref_entities.values():
|
| 165 |
+
all_ref_entities.update(entities_set)
|
| 166 |
+
for entities_set in pred_entities.values():
|
| 167 |
+
all_pred_entities.update(entities_set)
|
| 168 |
+
|
| 169 |
+
if not all_ref_entities:
|
| 170 |
+
preservation_rates["overall"] = 1.0
|
| 171 |
+
else:
|
| 172 |
+
preserved_overall = all_ref_entities.intersection(all_pred_entities)
|
| 173 |
+
preservation_rates["overall"] = len(preserved_overall) / len(all_ref_entities)
|
| 174 |
+
|
| 175 |
+
return preservation_rates
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class ResultCalculator:
|
| 179 |
+
"""Handles result calculation and visualization."""
|
| 180 |
+
|
| 181 |
+
def __init__(self):
|
| 182 |
+
self.metrics = EvaluationMetrics()
|
| 183 |
+
|
| 184 |
+
def calculate_basic_metrics(self, result_list: List[Dict]) -> Dict[str, float]:
|
| 185 |
+
"""Calculate basic metrics (F1, accuracy, exact match)."""
|
| 186 |
+
f1_total = 0
|
| 187 |
+
acc_total = 0
|
| 188 |
+
em_total = 0
|
| 189 |
+
avg_output_length = 0
|
| 190 |
+
|
| 191 |
+
answer_key = "golden_answers" if "golden_answers" in result_list[0] else "answer"
|
| 192 |
+
|
| 193 |
+
for result in result_list:
|
| 194 |
+
prediction = result['CLaRa_normal_output']
|
| 195 |
+
ground_truth = result[answer_key][0] if answer_key == "golden_answers" else result[answer_key]
|
| 196 |
+
|
| 197 |
+
acc_total += self.metrics.cover_exact_match_score(prediction, ground_truth)
|
| 198 |
+
f1_total += self.metrics.f1_score(prediction, ground_truth)
|
| 199 |
+
em_total += self.metrics.exact_match_score(prediction, ground_truth)
|
| 200 |
+
avg_output_length += len(prediction.split())
|
| 201 |
+
|
| 202 |
+
n = len(result_list)
|
| 203 |
+
return {
|
| 204 |
+
"f1": f1_total / n,
|
| 205 |
+
"acc": acc_total / n,
|
| 206 |
+
"em": em_total / n,
|
| 207 |
+
"avg_output_length": avg_output_length / n
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
def calculate_stage2_metrics(self, result_list: List[Dict], k_values: List[int] = [1, 3, 5]) -> Dict[str, float]:
|
| 211 |
+
"""Calculate stage2 metrics with recall and precision."""
|
| 212 |
+
basic_metrics = self.calculate_basic_metrics(result_list)
|
| 213 |
+
|
| 214 |
+
recall = {k: 0 for k in k_values}
|
| 215 |
+
precision = {k: 0 for k in k_values}
|
| 216 |
+
|
| 217 |
+
for result in result_list:
|
| 218 |
+
scores = result['topk_idx']
|
| 219 |
+
pos_index = set(result['pos_index'])
|
| 220 |
+
|
| 221 |
+
for k in k_values:
|
| 222 |
+
top_k = set(scores[:k])
|
| 223 |
+
hit = len(top_k & pos_index)
|
| 224 |
+
|
| 225 |
+
recall[k] += hit / len(pos_index) if len(pos_index) > 0 else 0
|
| 226 |
+
precision[k] += hit / k
|
| 227 |
+
|
| 228 |
+
n = len(result_list)
|
| 229 |
+
recall_metrics = {f"recall@{k}": v / n for k, v in recall.items()}
|
| 230 |
+
precision_metrics = {f"precision@{k}": v / n for k, v in precision.items()}
|
| 231 |
+
|
| 232 |
+
return {**basic_metrics, **recall_metrics, **precision_metrics}
|
| 233 |
+
|
| 234 |
+
def calculate_paraphrase_metrics(self, result_list: List[Dict]) -> Dict[str, float]:
|
| 235 |
+
"""Calculate paraphrase metrics."""
|
| 236 |
+
seen_metrics = {'bert-score': 0, 'rouge-1': 0, 'rouge-L': 0, 'entity_preserve': 0}
|
| 237 |
+
unseen_metrics = {'bert-score': 0, 'rouge-1': 0, 'rouge-L': 0, 'entity_preserve': 0}
|
| 238 |
+
|
| 239 |
+
# Process seen data (first 2000)
|
| 240 |
+
for result in result_list[:2000]:
|
| 241 |
+
prediction = result['CLaRa_normal_output']
|
| 242 |
+
ground_truth = result['doc']
|
| 243 |
+
|
| 244 |
+
bs = self.metrics.bertscore.compute(predictions=[prediction], references=[ground_truth], lang="en")
|
| 245 |
+
seen_metrics['bert-score'] += bs['f1'][0]
|
| 246 |
+
|
| 247 |
+
rouge_scores = self.metrics.rouge.compute(predictions=[prediction], references=[ground_truth])
|
| 248 |
+
seen_metrics['rouge-1'] += rouge_scores['rouge1']
|
| 249 |
+
seen_metrics['rouge-L'] += rouge_scores['rougeL']
|
| 250 |
+
|
| 251 |
+
seen_metrics['entity_preserve'] += self.metrics.entity_preserve_metric(prediction, ground_truth)
|
| 252 |
+
|
| 253 |
+
# Process unseen data (after 2000)
|
| 254 |
+
for result in result_list[2000:]:
|
| 255 |
+
prediction = result['CLaRa_normal_output']
|
| 256 |
+
ground_truth = result['doc']
|
| 257 |
+
|
| 258 |
+
bs = self.metrics.bertscore.compute(predictions=[prediction], references=[ground_truth], lang="en")
|
| 259 |
+
unseen_metrics['bert-score'] += bs['f1'][0]
|
| 260 |
+
|
| 261 |
+
rouge_scores = self.metrics.rouge.compute(predictions=[prediction], references=[ground_truth])
|
| 262 |
+
unseen_metrics['rouge-1'] += rouge_scores['rouge1']
|
| 263 |
+
unseen_metrics['rouge-L'] += rouge_scores['rougeL']
|
| 264 |
+
|
| 265 |
+
unseen_metrics['entity_preserve'] += self.metrics.entity_preserve_metric(prediction, ground_truth)
|
| 266 |
+
|
| 267 |
+
# Normalize
|
| 268 |
+
n_seen = min(len(result_list[:2000]), 2000)
|
| 269 |
+
n_unseen = max(len(result_list) - 2000, 0)
|
| 270 |
+
|
| 271 |
+
final_metrics = {}
|
| 272 |
+
if n_seen > 0:
|
| 273 |
+
for key, value in seen_metrics.items():
|
| 274 |
+
final_metrics[f'seen_{key}'] = float(value / n_seen)
|
| 275 |
+
|
| 276 |
+
if n_unseen > 0:
|
| 277 |
+
for key, value in unseen_metrics.items():
|
| 278 |
+
final_metrics[f'unseen_{key}'] = float(value / n_unseen)
|
| 279 |
+
|
| 280 |
+
return final_metrics
|
| 281 |
+
|
| 282 |
+
def visualize_mse(self, result_list: List[Dict], save_path: str) -> Dict[str, Any]:
|
| 283 |
+
"""Create t-SNE visualization for MSE analysis."""
|
| 284 |
+
# Set scientific style
|
| 285 |
+
plt.rcParams.update({
|
| 286 |
+
'font.family': 'serif',
|
| 287 |
+
'font.size': 12,
|
| 288 |
+
'axes.labelsize': 14,
|
| 289 |
+
'axes.titlesize': 16,
|
| 290 |
+
'figure.titlesize': 18,
|
| 291 |
+
'axes.linewidth': 1.2,
|
| 292 |
+
'grid.alpha': 0.3,
|
| 293 |
+
})
|
| 294 |
+
|
| 295 |
+
# Collect representations
|
| 296 |
+
mem_reps = []
|
| 297 |
+
non_mem_reps = []
|
| 298 |
+
|
| 299 |
+
for result in result_list:
|
| 300 |
+
mem_rep = result['CLaRa_compressed_output']
|
| 301 |
+
non_mem_rep = result['CLaRa_normal_output']
|
| 302 |
+
|
| 303 |
+
if isinstance(mem_rep, torch.Tensor):
|
| 304 |
+
mem_rep = mem_rep.float().cpu().numpy()
|
| 305 |
+
if isinstance(non_mem_rep, torch.Tensor):
|
| 306 |
+
non_mem_rep = non_mem_rep.float().cpu().numpy()
|
| 307 |
+
|
| 308 |
+
mem_reps.append(mem_rep)
|
| 309 |
+
non_mem_reps.append(non_mem_rep)
|
| 310 |
+
|
| 311 |
+
mem_reps = np.array(mem_reps)
|
| 312 |
+
non_mem_reps = np.array(non_mem_reps)
|
| 313 |
+
|
| 314 |
+
print(f"Memory representations shape: {mem_reps.shape}")
|
| 315 |
+
print(f"Document representations shape: {non_mem_reps.shape}")
|
| 316 |
+
|
| 317 |
+
# Combine data for t-SNE
|
| 318 |
+
all_data = np.vstack([mem_reps, non_mem_reps])
|
| 319 |
+
original_dim = all_data.shape[1]
|
| 320 |
+
|
| 321 |
+
# PCA preprocessing if needed
|
| 322 |
+
if all_data.shape[1] > 50:
|
| 323 |
+
print(f"Applying PCA preprocessing from {all_data.shape[1]} to 50 dimensions...")
|
| 324 |
+
pca = PCA(n_components=50)
|
| 325 |
+
all_data = pca.fit_transform(all_data)
|
| 326 |
+
print(f"PCA explained variance ratio: {pca.explained_variance_ratio_[:5].sum():.3f}")
|
| 327 |
+
|
| 328 |
+
# Apply t-SNE
|
| 329 |
+
print("Applying t-SNE...")
|
| 330 |
+
perplexity = min(30, max(5, len(all_data) // 3))
|
| 331 |
+
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity,
|
| 332 |
+
max_iter=1000, learning_rate=200, verbose=1)
|
| 333 |
+
tsne_results = tsne.fit_transform(all_data)
|
| 334 |
+
|
| 335 |
+
# Separate results
|
| 336 |
+
mem_tsne = tsne_results[:len(mem_reps)]
|
| 337 |
+
doc_tsne = tsne_results[len(mem_reps):]
|
| 338 |
+
|
| 339 |
+
# Create visualization
|
| 340 |
+
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
|
| 341 |
+
|
| 342 |
+
# Add jitter to separate overlapping points
|
| 343 |
+
np.random.seed(42)
|
| 344 |
+
jitter_strength = 1.0
|
| 345 |
+
|
| 346 |
+
mem_jitter = mem_tsne.copy()
|
| 347 |
+
doc_jitter = doc_tsne.copy()
|
| 348 |
+
|
| 349 |
+
mem_jitter[:, 0] += np.random.normal(0.5, jitter_strength, len(mem_tsne))
|
| 350 |
+
mem_jitter[:, 1] += np.random.normal(0.5, jitter_strength, len(mem_tsne))
|
| 351 |
+
|
| 352 |
+
doc_jitter[:, 0] += np.random.normal(-0.5, jitter_strength, len(doc_tsne))
|
| 353 |
+
doc_jitter[:, 1] += np.random.normal(-0.5, jitter_strength, len(doc_tsne))
|
| 354 |
+
|
| 355 |
+
# Plot scatter points
|
| 356 |
+
ax.scatter(doc_jitter[:, 0], doc_jitter[:, 1], c='#0066CC', alpha=0.7, s=25,
|
| 357 |
+
marker='o', edgecolors='white', linewidth=0.5,
|
| 358 |
+
label='Document Representations', zorder=2)
|
| 359 |
+
|
| 360 |
+
ax.scatter(mem_jitter[:, 0], mem_jitter[:, 1], c='#FF3333', alpha=0.7, s=25,
|
| 361 |
+
marker='o', edgecolors='white', linewidth=0.5,
|
| 362 |
+
label='Memory Tokens Representations', zorder=3)
|
| 363 |
+
|
| 364 |
+
# Configure plot
|
| 365 |
+
ax.set_xlabel('')
|
| 366 |
+
ax.set_ylabel('')
|
| 367 |
+
ax.set_title('')
|
| 368 |
+
|
| 369 |
+
legend = ax.legend(frameon=True, fancybox=True, shadow=True,
|
| 370 |
+
loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=2, fontsize=14)
|
| 371 |
+
legend.get_frame().set_facecolor('white')
|
| 372 |
+
legend.get_frame().set_alpha(0.9)
|
| 373 |
+
|
| 374 |
+
ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
|
| 375 |
+
ax.set_axisbelow(True)
|
| 376 |
+
|
| 377 |
+
plt.tight_layout()
|
| 378 |
+
|
| 379 |
+
# Save visualization
|
| 380 |
+
os.makedirs(save_path, exist_ok=True)
|
| 381 |
+
plt.savefig(os.path.join(save_path, 'tsne_visualization_scientific.png'),
|
| 382 |
+
dpi=300, bbox_inches='tight', facecolor='white')
|
| 383 |
+
plt.show()
|
| 384 |
+
|
| 385 |
+
# Calculate statistics
|
| 386 |
+
distances = np.array([
|
| 387 |
+
np.linalg.norm(mem_reps[i] - non_mem_reps[i])
|
| 388 |
+
for i in range(len(mem_reps))
|
| 389 |
+
])
|
| 390 |
+
|
| 391 |
+
statistics = {
|
| 392 |
+
'mean_distance': float(np.mean(distances)),
|
| 393 |
+
'std_distance': float(np.std(distances)),
|
| 394 |
+
'median_distance': float(np.median(distances)),
|
| 395 |
+
'min_distance': float(np.min(distances)),
|
| 396 |
+
'max_distance': float(np.max(distances))
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
print("\n" + "="*60)
|
| 400 |
+
print("VISUALIZATION ANALYSIS REPORT")
|
| 401 |
+
print("="*60)
|
| 402 |
+
print(f"Dataset Statistics:")
|
| 403 |
+
print(f" • Total samples: {len(mem_reps)}")
|
| 404 |
+
print(f" • Original dimension: {original_dim}")
|
| 405 |
+
print(f" • t-SNE perplexity: {perplexity}")
|
| 406 |
+
print(f"\nDistance Analysis:")
|
| 407 |
+
for key, value in statistics.items():
|
| 408 |
+
print(f" • {key.replace('_', ' ').title()}: {value:.4f}")
|
| 409 |
+
print("="*60)
|
| 410 |
+
|
| 411 |
+
return {
|
| 412 |
+
'mem_tsne': mem_tsne,
|
| 413 |
+
'doc_tsne': doc_tsne,
|
| 414 |
+
'original_distances': distances,
|
| 415 |
+
'statistics': statistics
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class DataLoader:
|
| 420 |
+
"""Handles data loading for different datasets and stages."""
|
| 421 |
+
|
| 422 |
+
@staticmethod
|
| 423 |
+
def load_stage1_data(dataset: str, gold_retrieval: bool) -> List[Dict]:
|
| 424 |
+
"""Load stage1 evaluation data."""
|
| 425 |
+
retrieval_type = "with_pos" if gold_retrieval else "no_pos"
|
| 426 |
+
file_path = f"/mnt/conductor_data/data/compression_rag_data/generator_training_val_data/stage1_eval/{dataset}/eval_processed_{retrieval_type}.jsonl"
|
| 427 |
+
|
| 428 |
+
data = []
|
| 429 |
+
with open(file_path, 'r') as f:
|
| 430 |
+
for line in f:
|
| 431 |
+
data.append(json.loads(line))
|
| 432 |
+
|
| 433 |
+
processed_data = []
|
| 434 |
+
for index, item in enumerate(data):
|
| 435 |
+
docs = item['docs'][:5] # Take top 5 documents
|
| 436 |
+
processed_item = {
|
| 437 |
+
'original_data': item,
|
| 438 |
+
'documents': docs,
|
| 439 |
+
'question': item['question'],
|
| 440 |
+
'global_index': index
|
| 441 |
+
}
|
| 442 |
+
processed_data.append(processed_item)
|
| 443 |
+
|
| 444 |
+
return processed_data
|
| 445 |
+
|
| 446 |
+
@staticmethod
|
| 447 |
+
def load_stage2_data(dataset: str, gold_retrieval: bool) -> List[Dict]:
|
| 448 |
+
"""Load stage2 evaluation data."""
|
| 449 |
+
retrieval_type = "with_pos" if gold_retrieval else "no_pos"
|
| 450 |
+
file_path = f"/mnt/conductor_data/data/compression_rag_data/generator_training_val_data/stage2_eval/{dataset}/eval_processed_{retrieval_type}.jsonl"
|
| 451 |
+
|
| 452 |
+
processed_data = []
|
| 453 |
+
with open(file_path, 'r') as f:
|
| 454 |
+
for index, line in enumerate(f):
|
| 455 |
+
item = json.loads(line)
|
| 456 |
+
processed_item = {
|
| 457 |
+
'original_data': item,
|
| 458 |
+
'documents': item['docs'],
|
| 459 |
+
'question': item['question'],
|
| 460 |
+
'global_index': index,
|
| 461 |
+
'pos_index': item['pos_index']
|
| 462 |
+
}
|
| 463 |
+
processed_data.append(processed_item)
|
| 464 |
+
|
| 465 |
+
return processed_data
|
| 466 |
+
|
| 467 |
+
@staticmethod
|
| 468 |
+
def load_paraphrase_data(file_path: str) -> List[Dict]:
|
| 469 |
+
"""Load paraphrase data."""
|
| 470 |
+
data = []
|
| 471 |
+
with open(file_path, 'r') as f:
|
| 472 |
+
for line in f:
|
| 473 |
+
data.append(json.loads(line))
|
| 474 |
+
|
| 475 |
+
processed_data = []
|
| 476 |
+
for index, item in enumerate(data):
|
| 477 |
+
processed_item = {
|
| 478 |
+
'original_data': item,
|
| 479 |
+
'documents': [item['doc']],
|
| 480 |
+
'question': "",
|
| 481 |
+
'global_index': index
|
| 482 |
+
}
|
| 483 |
+
processed_data.append(processed_item)
|
| 484 |
+
|
| 485 |
+
return processed_data
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class AcceleratedCLaRaInference:
|
| 489 |
+
"""Main inference engine using Accelerate for distributed processing."""
|
| 490 |
+
|
| 491 |
+
def __init__(self, model_path: str, training_stage: str = None,
|
| 492 |
+
generation_top_k: int = None, args = None):
|
| 493 |
+
self.args = args
|
| 494 |
+
|
| 495 |
+
# Initialize Accelerator
|
| 496 |
+
process_group_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400))
|
| 497 |
+
self.accelerator = Accelerator(kwargs_handlers=[process_group_kwargs])
|
| 498 |
+
|
| 499 |
+
if self.accelerator.is_main_process:
|
| 500 |
+
print(f"Using {self.accelerator.num_processes} GPUs for distributed inference")
|
| 501 |
+
print(f"Current process: {self.accelerator.process_index}")
|
| 502 |
+
print("Loading CLaRa model...")
|
| 503 |
+
|
| 504 |
+
# Load model
|
| 505 |
+
self.model = CLaRa.from_pretrained(
|
| 506 |
+
model_path,
|
| 507 |
+
training_stage=training_stage,
|
| 508 |
+
generation_top_k=generation_top_k,
|
| 509 |
+
pure_inference=True
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# Prepare model with Accelerator
|
| 513 |
+
self.model = self.accelerator.prepare(self.model)
|
| 514 |
+
self.model.eval()
|
| 515 |
+
|
| 516 |
+
if self.accelerator.is_main_process:
|
| 517 |
+
print("Model preparation completed")
|
| 518 |
+
|
| 519 |
+
def _get_model(self):
|
| 520 |
+
"""Get the actual model (handles distributed vs single GPU)."""
|
| 521 |
+
return self.model.module if hasattr(self.model, 'module') else self.model
|
| 522 |
+
|
| 523 |
+
def process_batch(self, batch_questions: List[str], batch_documents: List[List[str]] = None,
|
| 524 |
+
stage2_mips: bool = False, training_stage: str = None,
|
| 525 |
+
batch_answers: List[str] = None, time_count: bool = False) -> Tuple:
|
| 526 |
+
"""Process a batch of questions and documents."""
|
| 527 |
+
model = self._get_model()
|
| 528 |
+
|
| 529 |
+
with torch.no_grad():
|
| 530 |
+
try:
|
| 531 |
+
if training_stage == 'stage2':
|
| 532 |
+
return self._process_stage2(model, batch_questions, batch_documents,
|
| 533 |
+
stage2_mips, time_count)
|
| 534 |
+
elif training_stage in ['stage1', 'stage1_2']:
|
| 535 |
+
return self._process_stage1(model, batch_questions, batch_documents)
|
| 536 |
+
elif training_stage == 'stage2_reasoning':
|
| 537 |
+
return self._process_reasoning(model, batch_questions, batch_answers)
|
| 538 |
+
elif training_stage == 'stage1_paraphrase':
|
| 539 |
+
return self._process_paraphrase(model, batch_questions, batch_documents)
|
| 540 |
+
elif training_stage == 'stage1_mse_visulize':
|
| 541 |
+
return self._process_mse_visualize(model, batch_documents)
|
| 542 |
+
else:
|
| 543 |
+
raise ValueError(f"Unknown training stage: {training_stage}")
|
| 544 |
+
|
| 545 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 546 |
+
self.accelerator.print(f"CUDA OOM error: {e}")
|
| 547 |
+
torch.cuda.empty_cache()
|
| 548 |
+
gc.collect()
|
| 549 |
+
return self._create_empty_results(batch_questions, training_stage)
|
| 550 |
+
|
| 551 |
+
def _process_stage2(self, model, batch_questions, batch_documents, stage2_mips, time_count):
|
| 552 |
+
"""Process stage2 inference."""
|
| 553 |
+
if time_count:
|
| 554 |
+
if stage2_mips:
|
| 555 |
+
results = model.generate_from_questions(
|
| 556 |
+
questions=batch_questions,
|
| 557 |
+
max_new_tokens=64,
|
| 558 |
+
stage2_mips=stage2_mips,
|
| 559 |
+
time_count=True
|
| 560 |
+
)
|
| 561 |
+
else:
|
| 562 |
+
results = model.generate_from_questions(
|
| 563 |
+
questions=batch_questions,
|
| 564 |
+
max_new_tokens=64,
|
| 565 |
+
stage2_mips=stage2_mips,
|
| 566 |
+
documents=batch_documents,
|
| 567 |
+
time_count=True
|
| 568 |
+
)
|
| 569 |
+
return results
|
| 570 |
+
else:
|
| 571 |
+
if stage2_mips:
|
| 572 |
+
batch_out_normal, topk_idx = model.generate_from_questions(
|
| 573 |
+
questions=batch_questions,
|
| 574 |
+
max_new_tokens=64,
|
| 575 |
+
stage2_mips=stage2_mips
|
| 576 |
+
)
|
| 577 |
+
else:
|
| 578 |
+
batch_out_normal, topk_idx = model.generate_from_questions(
|
| 579 |
+
questions=batch_questions,
|
| 580 |
+
max_new_tokens=64,
|
| 581 |
+
stage2_mips=stage2_mips,
|
| 582 |
+
documents=batch_documents
|
| 583 |
+
)
|
| 584 |
+
return batch_out_normal, batch_out_normal, topk_idx
|
| 585 |
+
|
| 586 |
+
def _process_stage1(self, model, batch_questions, batch_documents):
|
| 587 |
+
"""Process stage1 inference."""
|
| 588 |
+
batch_out_compressed = []
|
| 589 |
+
|
| 590 |
+
for docs, question in zip(batch_documents, batch_questions):
|
| 591 |
+
embeddings, _ = model.compress_documents(documents=docs)
|
| 592 |
+
out_compressed = model.generate_from_compressed_documents_and_questions(
|
| 593 |
+
questions=[question],
|
| 594 |
+
compressed_documents=embeddings
|
| 595 |
+
)
|
| 596 |
+
batch_out_compressed.extend(out_compressed)
|
| 597 |
+
|
| 598 |
+
del embeddings
|
| 599 |
+
torch.cuda.empty_cache()
|
| 600 |
+
|
| 601 |
+
return batch_out_compressed, batch_out_compressed, None
|
| 602 |
+
|
| 603 |
+
def _process_reasoning(self, model, batch_questions, batch_answers):
|
| 604 |
+
"""Process reasoning inference."""
|
| 605 |
+
batch_out_normal = []
|
| 606 |
+
batch_out_reasoning_list = []
|
| 607 |
+
|
| 608 |
+
for question, answer in zip(batch_questions, batch_answers):
|
| 609 |
+
temp_out, temp_out_reasoning = model.generate_from_reasoning(
|
| 610 |
+
questions=[question],
|
| 611 |
+
max_new_tokens=1024,
|
| 612 |
+
answers=[answer],
|
| 613 |
+
save_dir=self.args.model_path
|
| 614 |
+
)
|
| 615 |
+
batch_out_normal.append(temp_out[0])
|
| 616 |
+
batch_out_reasoning_list.extend(temp_out_reasoning)
|
| 617 |
+
|
| 618 |
+
return batch_out_normal, batch_out_normal, None, batch_out_reasoning_list
|
| 619 |
+
|
| 620 |
+
def _process_paraphrase(self, model, batch_questions, batch_documents):
|
| 621 |
+
"""Process paraphrase inference."""
|
| 622 |
+
batch_out_compressed = []
|
| 623 |
+
|
| 624 |
+
for docs, question in zip(batch_documents, batch_questions):
|
| 625 |
+
out_compressed = model.generate_from_paraphrase(
|
| 626 |
+
questions=["" for _ in range(len(docs))],
|
| 627 |
+
documents=[docs]
|
| 628 |
+
)
|
| 629 |
+
batch_out_compressed.extend(out_compressed)
|
| 630 |
+
torch.cuda.empty_cache()
|
| 631 |
+
|
| 632 |
+
return batch_out_compressed, batch_out_compressed, None
|
| 633 |
+
|
| 634 |
+
def _process_mse_visualize(self, model, batch_documents):
|
| 635 |
+
"""Process MSE visualization."""
|
| 636 |
+
batch_out_normal = []
|
| 637 |
+
batch_out_compressed = []
|
| 638 |
+
|
| 639 |
+
for docs in batch_documents:
|
| 640 |
+
mem_rep, non_mem_rep = model.compress_documents_mse_visulize(documents=docs)
|
| 641 |
+
batch_out_compressed.append(mem_rep[0])
|
| 642 |
+
batch_out_normal.append(non_mem_rep[0])
|
| 643 |
+
|
| 644 |
+
return batch_out_normal, batch_out_compressed
|
| 645 |
+
|
| 646 |
+
def _create_empty_results(self, batch_questions, training_stage):
|
| 647 |
+
"""Create empty results for error cases."""
|
| 648 |
+
empty_results = [""] * len(batch_questions)
|
| 649 |
+
if training_stage == 'stage2_reasoning':
|
| 650 |
+
return empty_results, empty_results, None, empty_results
|
| 651 |
+
elif training_stage == 'stage1_mse_visulize':
|
| 652 |
+
return empty_results, empty_results
|
| 653 |
+
else:
|
| 654 |
+
return empty_results, empty_results, None
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def convert_embeddings_to_list(data):
|
| 658 |
+
"""Convert tensor embeddings to lists for JSON serialization."""
|
| 659 |
+
if isinstance(data, dict):
|
| 660 |
+
return {k: convert_embeddings_to_list(v) for k, v in data.items()}
|
| 661 |
+
elif isinstance(data, list):
|
| 662 |
+
return [convert_embeddings_to_list(item) for item in data]
|
| 663 |
+
elif isinstance(data, torch.Tensor):
|
| 664 |
+
return data.cpu().to(torch.float32).numpy().tolist()
|
| 665 |
+
elif isinstance(data, np.ndarray):
|
| 666 |
+
return data.tolist()
|
| 667 |
+
else:
|
| 668 |
+
return data
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def main():
|
| 672 |
+
parser = argparse.ArgumentParser(description="CLaRa Model Inference")
|
| 673 |
+
parser.add_argument('--model_path', type=str, required=True, help='Path to model checkpoint')
|
| 674 |
+
parser.add_argument('--batch_size', type=int, default=4, help='Batch size per GPU')
|
| 675 |
+
parser.add_argument('--stage', type=str, default='stage1',
|
| 676 |
+
choices=['stage1', 'stage1_2', 'stage2', 'stage2_reasoning',
|
| 677 |
+
'stage1_paraphrase', 'stage1_mse_visulize'],
|
| 678 |
+
help='Training stage')
|
| 679 |
+
parser.add_argument('--stage2_mips', action='store_true', help='Use MIPS for stage2')
|
| 680 |
+
parser.add_argument('--dataset', type=str, default='musique',
|
| 681 |
+
help='Comma-separated list of datasets')
|
| 682 |
+
parser.add_argument('--gold_retrieval', action='store_true',
|
| 683 |
+
help='Use gold retrieval context')
|
| 684 |
+
parser.add_argument('--generation_top_k', type=int, default=5, help='Top-k for generation')
|
| 685 |
+
parser.add_argument('--paraphrase_path', type=str, help='Path to paraphrase data')
|
| 686 |
+
parser.add_argument('--mse_visulize_path', type=str, help='Path to save MSE visualization')
|
| 687 |
+
parser.add_argument('--efficient_count', action='store_true', help='Count efficiency metrics')
|
| 688 |
+
|
| 689 |
+
args = parser.parse_args()
|
| 690 |
+
|
| 691 |
+
# Process datasets
|
| 692 |
+
all_results_metrics = {}
|
| 693 |
+
datasets_list = args.dataset.split(',')
|
| 694 |
+
|
| 695 |
+
for dataset in datasets_list:
|
| 696 |
+
print(f"Processing dataset: {dataset}")
|
| 697 |
+
|
| 698 |
+
# Load data based on stage
|
| 699 |
+
if args.stage in ['stage1', 'stage1_2']:
|
| 700 |
+
processed_data = DataLoader.load_stage1_data(dataset, args.gold_retrieval)
|
| 701 |
+
elif args.stage == 'stage2':
|
| 702 |
+
processed_data = DataLoader.load_stage2_data(dataset, args.gold_retrieval)
|
| 703 |
+
elif args.stage in ['stage1_paraphrase', 'stage1_mse_visulize']:
|
| 704 |
+
if not args.paraphrase_path:
|
| 705 |
+
raise ValueError(f"--paraphrase_path required for stage {args.stage}")
|
| 706 |
+
processed_data = DataLoader.load_paraphrase_data(args.paraphrase_path)
|
| 707 |
+
else:
|
| 708 |
+
raise ValueError(f"Unsupported stage: {args.stage}")
|
| 709 |
+
|
| 710 |
+
print(f"Loaded {len(processed_data)} samples for {dataset}")
|
| 711 |
+
|
| 712 |
+
# Initialize inference engine
|
| 713 |
+
# Use model_path directly if absolute, otherwise use SageMaker path
|
| 714 |
+
if os.path.isabs(args.model_path):
|
| 715 |
+
model_path = args.model_path
|
| 716 |
+
else:
|
| 717 |
+
model_path = os.path.join('/mnt/task_wrapper/user_output/artifacts/data/train_checkpoint', args.model_path)
|
| 718 |
+
args.model_path = model_path
|
| 719 |
+
|
| 720 |
+
inference_engine = AcceleratedCLaRaInference(
|
| 721 |
+
model_path=model_path,
|
| 722 |
+
training_stage=args.stage,
|
| 723 |
+
generation_top_k=args.generation_top_k,
|
| 724 |
+
args=args
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
# Wait for all processes to be ready
|
| 728 |
+
inference_engine.accelerator.wait_for_everyone()
|
| 729 |
+
|
| 730 |
+
# Store results
|
| 731 |
+
all_results = []
|
| 732 |
+
time_count_dic = {"compress_time": 0, "query_time": 0, "generate_time": 0, "total_time": 0, "count": 0}
|
| 733 |
+
|
| 734 |
+
# Process data in batches using accelerator
|
| 735 |
+
with inference_engine.accelerator.split_between_processes(processed_data, apply_padding=False) as local_data:
|
| 736 |
+
print(f"Process {inference_engine.accelerator.process_index}: processing {len(local_data)} samples")
|
| 737 |
+
|
| 738 |
+
batch_size = args.batch_size
|
| 739 |
+
num_batches = (len(local_data) + batch_size - 1) // batch_size
|
| 740 |
+
|
| 741 |
+
for batch_idx in tqdm(range(num_batches),
|
| 742 |
+
desc=f"GPU {inference_engine.accelerator.process_index}",
|
| 743 |
+
disable=not inference_engine.accelerator.is_local_main_process):
|
| 744 |
+
|
| 745 |
+
# Get current batch
|
| 746 |
+
start_idx = batch_idx * batch_size
|
| 747 |
+
end_idx = min(start_idx + batch_size, len(local_data))
|
| 748 |
+
batch = local_data[start_idx:end_idx]
|
| 749 |
+
|
| 750 |
+
# Prepare batch data
|
| 751 |
+
batch_questions = [item['question'] for item in batch]
|
| 752 |
+
batch_documents = [item['documents'] for item in batch] if 'documents' in batch[0] else None
|
| 753 |
+
batch_answers = [item.get('answer') for item in batch] if args.stage == 'stage2_reasoning' else None
|
| 754 |
+
|
| 755 |
+
# Process batch
|
| 756 |
+
if args.efficient_count and args.stage == 'stage2':
|
| 757 |
+
results = inference_engine.process_batch(
|
| 758 |
+
batch_questions=batch_questions,
|
| 759 |
+
batch_documents=batch_documents,
|
| 760 |
+
stage2_mips=args.stage2_mips,
|
| 761 |
+
training_stage=args.stage,
|
| 762 |
+
time_count=True
|
| 763 |
+
)
|
| 764 |
+
batch_out_normal, batch_out_compressed, batch_topk_idx, compress_time, query_time, generate_time, total_time = results
|
| 765 |
+
|
| 766 |
+
time_count_dic["compress_time"] += compress_time
|
| 767 |
+
time_count_dic["query_time"] += query_time
|
| 768 |
+
time_count_dic["generate_time"] += generate_time
|
| 769 |
+
time_count_dic["total_time"] += total_time
|
| 770 |
+
time_count_dic["count"] += 1
|
| 771 |
+
else:
|
| 772 |
+
results = inference_engine.process_batch(
|
| 773 |
+
batch_questions=batch_questions,
|
| 774 |
+
batch_documents=batch_documents,
|
| 775 |
+
stage2_mips=args.stage2_mips,
|
| 776 |
+
training_stage=args.stage,
|
| 777 |
+
batch_answers=batch_answers
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
if args.stage == 'stage2_reasoning':
|
| 781 |
+
batch_out_normal, batch_out_compressed, batch_topk_idx, batch_out_reasoning = results
|
| 782 |
+
elif args.stage == 'stage1_mse_visulize':
|
| 783 |
+
batch_out_normal, batch_out_compressed = results
|
| 784 |
+
batch_topk_idx = None
|
| 785 |
+
else:
|
| 786 |
+
batch_out_normal, batch_out_compressed, batch_topk_idx = results
|
| 787 |
+
|
| 788 |
+
# Prepare results
|
| 789 |
+
batch_results = []
|
| 790 |
+
for i, (item, normal_out, compressed_out) in enumerate(zip(batch, batch_out_normal, batch_out_compressed)):
|
| 791 |
+
result_item = item['original_data'].copy()
|
| 792 |
+
result_item['CLaRa_normal_output'] = normal_out
|
| 793 |
+
result_item['CLaRa_compressed_output'] = compressed_out
|
| 794 |
+
result_item['global_index'] = item['global_index']
|
| 795 |
+
|
| 796 |
+
if args.stage == 'stage2' and batch_topk_idx is not None:
|
| 797 |
+
result_item['topk_idx'] = batch_topk_idx[i].tolist()
|
| 798 |
+
elif args.stage == 'stage2_reasoning':
|
| 799 |
+
result_item['reasoning_output'] = batch_out_reasoning[i]
|
| 800 |
+
|
| 801 |
+
batch_results.append(result_item)
|
| 802 |
+
|
| 803 |
+
all_results.extend(batch_results)
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
# Clean up memory
|
| 807 |
+
torch.cuda.empty_cache()
|
| 808 |
+
if batch_idx % 10 == 0:
|
| 809 |
+
gc.collect()
|
| 810 |
+
|
| 811 |
+
# Save efficiency metrics if requested
|
| 812 |
+
if args.efficient_count and inference_engine.accelerator.is_main_process:
|
| 813 |
+
eff_dic = {
|
| 814 |
+
"compress_time_ms": round((time_count_dic['compress_time'] / time_count_dic['count']) * 1000, 2),
|
| 815 |
+
"query_time_ms": round((time_count_dic['query_time'] / time_count_dic['count']) * 1000, 2),
|
| 816 |
+
"generate_time_ms": round((time_count_dic['generate_time'] / time_count_dic['count']) * 1000, 2),
|
| 817 |
+
"total_time_ms": round((time_count_dic['total_time'] / time_count_dic['count']) * 1000, 2),
|
| 818 |
+
"sample_count": time_count_dic['count']
|
| 819 |
+
}
|
| 820 |
+
eff_output_path = os.path.join(model_path, f"efficiency_{dataset}_{args.stage}_{args.gold_retrieval}_{args.generation_top_k}.json")
|
| 821 |
+
with open(eff_output_path, 'w') as f:
|
| 822 |
+
json.dump(eff_dic, f, indent=2)
|
| 823 |
+
|
| 824 |
+
# Wait for all processes to complete
|
| 825 |
+
inference_engine.accelerator.wait_for_everyone()
|
| 826 |
+
|
| 827 |
+
# Gather results from all processes
|
| 828 |
+
if inference_engine.accelerator.is_main_process:
|
| 829 |
+
print("Collecting results from all processes...")
|
| 830 |
+
|
| 831 |
+
all_results_gathered = inference_engine.accelerator.gather_for_metrics(all_results)
|
| 832 |
+
|
| 833 |
+
# Process and save results (main process only)
|
| 834 |
+
if inference_engine.accelerator.is_main_process:
|
| 835 |
+
print("Processing and saving results...")
|
| 836 |
+
|
| 837 |
+
# Flatten results
|
| 838 |
+
final_results = []
|
| 839 |
+
if isinstance(all_results_gathered, list):
|
| 840 |
+
for result_batch in all_results_gathered:
|
| 841 |
+
if isinstance(result_batch, list):
|
| 842 |
+
final_results.extend(result_batch)
|
| 843 |
+
else:
|
| 844 |
+
final_results.append(result_batch)
|
| 845 |
+
|
| 846 |
+
print(f"Collected {len(final_results)} results")
|
| 847 |
+
|
| 848 |
+
# Sort by global index to maintain order
|
| 849 |
+
final_results.sort(key=lambda x: x.get('global_index', 0))
|
| 850 |
+
|
| 851 |
+
# Verify data integrity
|
| 852 |
+
processed_indices = set(item.get('global_index', -1) for item in final_results)
|
| 853 |
+
expected_indices = set(range(len(processed_data)))
|
| 854 |
+
missing_indices = expected_indices - processed_indices
|
| 855 |
+
|
| 856 |
+
if missing_indices:
|
| 857 |
+
print(f"Warning: Missing indices: {sorted(list(missing_indices))}")
|
| 858 |
+
else:
|
| 859 |
+
print("✓ Data integrity verification passed")
|
| 860 |
+
|
| 861 |
+
# Remove global index for clean output
|
| 862 |
+
for item in final_results:
|
| 863 |
+
item.pop('global_index', None)
|
| 864 |
+
|
| 865 |
+
# Save results
|
| 866 |
+
output_path = os.path.join(model_path, f"{dataset}_{args.stage}_{args.gold_retrieval}_{args.generation_top_k}.jsonl")
|
| 867 |
+
with open(output_path, 'w') as f:
|
| 868 |
+
if args.stage == 'stage1_mse_visulize':
|
| 869 |
+
converted_results = convert_embeddings_to_list(final_results)
|
| 870 |
+
for item in converted_results:
|
| 871 |
+
f.write(json.dumps(item) + '\n')
|
| 872 |
+
else:
|
| 873 |
+
for item in final_results:
|
| 874 |
+
f.write(json.dumps(item) + '\n')
|
| 875 |
+
|
| 876 |
+
print(f"Results saved to: {output_path}")
|
| 877 |
+
|
| 878 |
+
# Calculate metrics
|
| 879 |
+
calculator = ResultCalculator()
|
| 880 |
+
|
| 881 |
+
if args.stage == 'stage2':
|
| 882 |
+
metrics = calculator.calculate_stage2_metrics(final_results)
|
| 883 |
+
elif args.stage == 'stage1_paraphrase':
|
| 884 |
+
metrics = calculator.calculate_paraphrase_metrics(final_results)
|
| 885 |
+
elif args.stage == 'stage1_mse_visulize':
|
| 886 |
+
if args.mse_visulize_path:
|
| 887 |
+
metrics = calculator.visualize_mse(final_results, args.mse_visulize_path)
|
| 888 |
+
else:
|
| 889 |
+
metrics = {"visualization": "completed"}
|
| 890 |
+
else:
|
| 891 |
+
metrics = calculator.calculate_basic_metrics(final_results)
|
| 892 |
+
|
| 893 |
+
print(f"Metrics for {dataset}: {metrics}")
|
| 894 |
+
all_results_metrics[dataset] = metrics
|
| 895 |
+
|
| 896 |
+
# Clean up
|
| 897 |
+
del inference_engine
|
| 898 |
+
torch.cuda.empty_cache()
|
| 899 |
+
gc.collect()
|
| 900 |
+
|
| 901 |
+
# Save final metrics
|
| 902 |
+
if len(all_results_metrics) > 0:
|
| 903 |
+
metrics_path = os.path.join(model_path, f"results_metrics_{args.stage}_{args.gold_retrieval}_{args.generation_top_k}.json")
|
| 904 |
+
with open(metrics_path, 'w') as f:
|
| 905 |
+
json.dump(all_results_metrics, f, indent=2)
|
| 906 |
+
print(f"Final metrics saved to: {metrics_path}")
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
if __name__ == '__main__':
|
| 910 |
+
main()
|
evaluation/evaluation_data/end_to_end_evaluation/2wiki.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:39cf44bcfa24938c40617ef5bba90235642bf02f537297f4055b3f6bc756846c
|
| 3 |
+
size 93670063
|
evaluation/evaluation_data/end_to_end_evaluation/hotpotqa.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f46d7cfc23199f6cdff5e3ce1872ff150e6d940eb83b343cf37431cd740fa4db
|
| 3 |
+
size 61751762
|
evaluation/evaluation_data/end_to_end_evaluation/musique.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:85a55afc5c6067d00eef1888e13b598039a515f787791b23fbb495c35827e264
|
| 3 |
+
size 18789210
|
evaluation/evaluation_data/end_to_end_evaluation/nq.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7d26d5c29694cd81cccfcac4fd29c16ae7f245b4c554623cbe3c6ec8c3a0ad41
|
| 3 |
+
size 60057585
|
evaluation/evaluation_data/instruction_tuning_evaluation/2wiki.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:690a91abab47ebb8e32f335b2f1e31f40a1b2e78452988c7bffd0c30cc5f5463
|
| 3 |
+
size 93670063
|
evaluation/evaluation_data/instruction_tuning_evaluation/hotpotqa.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:808713253d9a0821c43800b16e56fd76a9b046b42088afb217630c735196b4e4
|
| 3 |
+
size 61751762
|
evaluation/evaluation_data/instruction_tuning_evaluation/musique.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:171fb546901128d338de9a343080e6e41ca73c3f8246c7c0d270330f06cef0d9
|
| 3 |
+
size 18789210
|
evaluation/evaluation_data/instruction_tuning_evaluation/nq.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:225bc8fd6f5b5b3156f42952602fc296a7a390bf873de24ceb0328ceb61eabfd
|
| 3 |
+
size 60057585
|
example/end_to_end_data.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6bf9c2e07e6c833288c7041a93dfa7fd3cf41e34b22800d97aafbd93c78d3597
|
| 3 |
+
size 13128781
|
example/instruction_tuning_data.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
example/pretrain_data.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|