dl3239491 commited on
Commit
30c14cd
·
verified ·
1 Parent(s): ce477b1

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +3 -0
  3. .pytest_cache/.gitignore +2 -0
  4. .pytest_cache/CACHEDIR.TAG +4 -0
  5. .pytest_cache/README.md +8 -0
  6. .pytest_cache/v/cache/nodeids +3 -0
  7. .ruff_cache/.gitignore +2 -0
  8. .ruff_cache/0.14.10/10241894308290549172 +0 -0
  9. .ruff_cache/0.14.10/1073426088278906643 +0 -0
  10. .ruff_cache/0.14.10/13957033273656742151 +0 -0
  11. .ruff_cache/0.14.10/1442719585850318975 +0 -0
  12. .ruff_cache/0.14.10/14754177912317367819 +0 -0
  13. .ruff_cache/0.14.10/14978186029505022734 +0 -0
  14. .ruff_cache/0.14.10/15569745458013874055 +0 -0
  15. .ruff_cache/0.14.10/17608220473508725558 +0 -0
  16. .ruff_cache/0.14.10/18191902847846296179 +0 -0
  17. .ruff_cache/0.14.10/2046185769257499142 +0 -0
  18. .ruff_cache/0.14.10/3165187837348788939 +0 -0
  19. .ruff_cache/0.14.10/4171122735627067383 +0 -0
  20. .ruff_cache/0.14.10/8273464926453838394 +0 -0
  21. .ruff_cache/0.14.10/9088412491868955099 +0 -0
  22. .ruff_cache/0.14.10/9103521535542433765 +0 -0
  23. .ruff_cache/0.14.10/9189204400079810969 +0 -0
  24. .ruff_cache/0.14.10/9226417474992298237 +0 -0
  25. .ruff_cache/0.14.10/9918913907578606062 +0 -0
  26. .ruff_cache/CACHEDIR.TAG +1 -0
  27. ACKNOWLEDGEMENTS +34 -0
  28. CODE_OF_CONDUCT.md +71 -0
  29. CONTRIBUTING.md +10 -0
  30. LICENSE +46 -0
  31. README.md +407 -0
  32. docs/Gemfile +23 -0
  33. docs/_config.yml +61 -0
  34. docs/getting_started.md +80 -0
  35. docs/index.md +53 -0
  36. docs/inference.md +134 -0
  37. docs/training.md +129 -0
  38. evaluation/evaluate.py +936 -0
  39. evaluation/evaluate.py.bak +910 -0
  40. evaluation/evaluation_data/end_to_end_evaluation/2wiki.zip +3 -0
  41. evaluation/evaluation_data/end_to_end_evaluation/hotpotqa.zip +3 -0
  42. evaluation/evaluation_data/end_to_end_evaluation/musique.zip +3 -0
  43. evaluation/evaluation_data/end_to_end_evaluation/nq.zip +3 -0
  44. evaluation/evaluation_data/instruction_tuning_evaluation/2wiki.zip +3 -0
  45. evaluation/evaluation_data/instruction_tuning_evaluation/hotpotqa.zip +3 -0
  46. evaluation/evaluation_data/instruction_tuning_evaluation/musique.zip +3 -0
  47. evaluation/evaluation_data/instruction_tuning_evaluation/nq.zip +3 -0
  48. example/end_to_end_data.jsonl +3 -0
  49. example/instruction_tuning_data.jsonl +0 -0
  50. example/pretrain_data.jsonl +0 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example/end_to_end_data.jsonl filter=lfs diff=lfs merge=lfs -text
37
+ figs/intro.png filter=lfs diff=lfs merge=lfs -text
38
+ figs/sample_main.png filter=lfs diff=lfs merge=lfs -text
.pytest_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Created by pytest automatically.
2
+ *
.pytest_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
2
+ # This file is a cache directory tag created by pytest.
3
+ # For information about cache directory tags, see:
4
+ # https://bford.info/cachedir/spec.html
.pytest_cache/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # pytest cache directory #
2
+
3
+ This directory contains data from the pytest's cache plugin,
4
+ which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
5
+
6
+ **Do not** commit this to version control.
7
+
8
+ See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
.pytest_cache/v/cache/nodeids ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [
2
+ "tests/test_placeholder.py::test_placeholder"
3
+ ]
.ruff_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Automatically created by ruff.
2
+ *
.ruff_cache/0.14.10/10241894308290549172 ADDED
Binary file (136 Bytes). View file
 
.ruff_cache/0.14.10/1073426088278906643 ADDED
Binary file (136 Bytes). View file
 
.ruff_cache/0.14.10/13957033273656742151 ADDED
Binary file (1.05 kB). View file
 
.ruff_cache/0.14.10/1442719585850318975 ADDED
Binary file (171 Bytes). View file
 
.ruff_cache/0.14.10/14754177912317367819 ADDED
Binary file (132 Bytes). View file
 
.ruff_cache/0.14.10/14978186029505022734 ADDED
Binary file (129 Bytes). View file
 
.ruff_cache/0.14.10/15569745458013874055 ADDED
Binary file (136 Bytes). View file
 
.ruff_cache/0.14.10/17608220473508725558 ADDED
Binary file (132 Bytes). View file
 
.ruff_cache/0.14.10/18191902847846296179 ADDED
Binary file (129 Bytes). View file
 
.ruff_cache/0.14.10/2046185769257499142 ADDED
Binary file (129 Bytes). View file
 
.ruff_cache/0.14.10/3165187837348788939 ADDED
Binary file (171 Bytes). View file
 
.ruff_cache/0.14.10/4171122735627067383 ADDED
Binary file (129 Bytes). View file
 
.ruff_cache/0.14.10/8273464926453838394 ADDED
Binary file (1.05 kB). View file
 
.ruff_cache/0.14.10/9088412491868955099 ADDED
Binary file (1.05 kB). View file
 
.ruff_cache/0.14.10/9103521535542433765 ADDED
Binary file (132 Bytes). View file
 
.ruff_cache/0.14.10/9189204400079810969 ADDED
Binary file (136 Bytes). View file
 
.ruff_cache/0.14.10/9226417474992298237 ADDED
Binary file (132 Bytes). View file
 
.ruff_cache/0.14.10/9918913907578606062 ADDED
Binary file (1.05 kB). View file
 
.ruff_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1 @@
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
ACKNOWLEDGEMENTS ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Acknowledgements
2
+ Portions of this CLaRa Software may utilize the following copyrighted
3
+ material, the use of which is hereby acknowledged.
4
+
5
+ _____________________
6
+
7
+ Naver Labs Europe (PISCO-mistral)
8
+ Licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).
9
+
10
+ Copyright © Naver Labs Europe
11
+
12
+ You are free to:
13
+ - Share — copy and redistribute the material in any medium or format
14
+ - Adapt — remix, transform, and build upon the material
15
+
16
+ Under the following terms:
17
+ - Attribution — You must give appropriate credit, provide a link to the license,
18
+ and indicate if changes were made.
19
+ - NonCommercial — You may not use the material for commercial purposes.
20
+
21
+ Full license text available at: https://creativecommons.org/licenses/by-nc/4.0/
22
+
23
+ OpenRLHF authors
24
+ Licensed under the Apache License, Version 2.0 (the "License");
25
+ you may not use this file except in compliance with the License.
26
+ You may obtain a copy of the License at
27
+
28
+ http://www.apache.org/licenses/LICENSE-2.0
29
+
30
+ Unless required by applicable law or agreed to in writing, software
31
+ distributed under the License is distributed on an "AS IS" BASIS,
32
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33
+ See the License for the specific language governing permissions and
34
+ limitations under the License.
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to making participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71
+ available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
CONTRIBUTING.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contribution Guide
2
+
3
+ Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository.
4
+
5
+ While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6
+
7
+ ## Before you get started
8
+
9
+ By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10
+
LICENSE ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
2
+
3
+ IMPORTANT: This Apple software is supplied to you by Apple
4
+ Inc. ("Apple") in consideration of your agreement to the following
5
+ terms, and your use, installation, modification or redistribution of
6
+ this Apple software constitutes acceptance of these terms. If you do
7
+ not agree with these terms, please do not use, install, modify or
8
+ redistribute this Apple software.
9
+
10
+ In consideration of your agreement to abide by the following terms, and
11
+ subject to these terms, Apple grants you a personal, non-exclusive
12
+ license, under Apple's copyrights in this original Apple software (the
13
+ "Apple Software"), to use, reproduce, modify and redistribute the Apple
14
+ Software, with or without modifications, in source and/or binary forms;
15
+ provided that if you redistribute the Apple Software in its entirety and
16
+ without modifications, you must retain this notice and the following
17
+ text and disclaimers in all such redistributions of the Apple Software.
18
+ Neither the name, trademarks, service marks or logos of Apple Inc. may
19
+ be used to endorse or promote products derived from the Apple Software
20
+ without specific prior written permission from Apple. Except as
21
+ expressly stated in this notice, no other rights or licenses, express or
22
+ implied, are granted by Apple herein, including but not limited to any
23
+ patent rights that may be infringed by your derivative works or by other
24
+ works in which the Apple Software may be incorporated.
25
+
26
+ The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27
+ MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28
+ THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29
+ FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30
+ OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31
+
32
+ IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33
+ OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35
+ INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36
+ MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37
+ AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38
+ STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39
+ POSSIBILITY OF SUCH DAMAGE.
40
+
41
+ ---
42
+
43
+ This CLaRa Software may utilize third party materials. Please refer to the
44
+ ACKNOWLEDGEMENTS file included with this software for attribution and
45
+ license information related to third party code that may be contained in or
46
+ used with this CLaRa Software.
README.md ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning
2
+
3
+ <div align="center">
4
+ <img src="figs/clara_logo.jpg" width="400"/>
5
+ </div>
6
+
7
+ <div align="center">
8
+ <a href="https://arxiv.org/abs/2511.18659"><img src="https://img.shields.io/badge/arXiv-2511.18659-b31b1b.svg" alt="arXiv"></a>
9
+ <a href="LICENSE"><img src="https://img.shields.io/badge/License-Apple-blue" alt="License"></a>
10
+ <a href="https://huggingface.co/apple/CLaRa-7B-Base"><img src="https://img.shields.io/badge/Hugging%20Face-CLaRa_Base-FFEB3B" alt="deploy"></a>
11
+ <a href="https://huggingface.co/apple/CLaRa-7B-Instruct"><img src="https://img.shields.io/badge/Hugging%20Face-CLaRa_Instruct-FFEB3B" alt="deploy"></a>
12
+ <a href="https://huggingface.co/apple/CLaRa-7B-E2E"><img src="https://img.shields.io/badge/Hugging%20Face-CLaRa_End_to_end-FFEB3B" alt="deploy"></a>
13
+ <a href="https://huggingface.co/datasets/apple/CLaRa_multi_stage"><img src="https://img.shields.io/badge/Hugging%20Face-CLaRa_Data-FFEB3B" alt="data"></a>
14
+ </div>
15
+
16
+ This is the official open-source release of CLaRa, a state-of-the-art, end-to-end Retrieval-Augmented Generation model.
17
+
18
+ ### Updates
19
+
20
+ - Dec 11, 2025. All used data are available on [Huggingface](https://huggingface.co/datasets/apple/CLaRa_multi_stage).
21
+ - Dec 10, 2025. We are working on an MLX version of the model, to be announced soon.
22
+ - Dec 3, 2025. Evaluation data are available in `./evaluation/evaluation_data`.
23
+ - Nov 25, 2025. Models are available on Huggingface.
24
+
25
+
26
+ ### Motivation
27
+
28
+ Retrieval-Augmented Generation (RAG) enhances large language models with external knowledge but suffers from **long contexts** and **disjoint retrieval-generation optimization**. Existing soft compression frameworks face two key limitations: (i) reconstruction-based objectives bias compressors toward surface patterns rather than semantic preservation; (ii) retrievers and compressors are trained separately, requiring double encoding despite compressed vectors being inherently retrievable.
29
+
30
+ In this work, we investigate:
31
+
32
+ - **How can we improve semantic preservation in compressed representations through better pretraining objectives?**
33
+ - **How can we unify retrieval and generation optimization to avoid redundant encoding and disjoint objectives?**
34
+
35
+ <div align="center">
36
+
37
+ <img src="figs/intro.png" width="100%"/>
38
+
39
+ </div>
40
+
41
+ We design a Three-stage training approach and introduce document compression techniques to improve RAG efficiency. The key findings are listed below.
42
+
43
+ ### Findings
44
+
45
+ - **Efficient Compression**: CLaRa achieves significant compression rates (32x-64x) while preserving essential information for accurate answer generation.
46
+
47
+ - **Three-Stage Training**: A carefully designed Three-stage training approach (compression pretraining + compression instruction tuning + end-to-end fine-tuning) enables effective learning of both retrieval and generation.
48
+
49
+ For more interesting findings, please refer to our original paper!
50
+
51
+ ---
52
+
53
+ ### Three-Stage Training
54
+
55
+ CLaRa uses a carefully designed three-stage training approach:
56
+
57
+ **Stage 1: Compression Pretraining**
58
+ - Train the compressor using SCP framework with QA pairs and paraphrases
59
+ - Retain key semantics through QA-based and paraphrase-guided supervision
60
+ - Support compression rates of 1x-256x
61
+
62
+ **Stage 2: Compression Instruction Tuning**
63
+ - Fine-tune the compressor on instruction-following tasks for downstream QA
64
+ - Use text-based QA output to ensure compressed representations retain sufficient semantics
65
+
66
+ **Stage 3: End-to-End Fine-tuning (CLaRa)**
67
+ - Jointly train reranker and generator via a single language modeling loss
68
+ - Unify retrieval and generation in shared continuous space using differentiable top-k estimator
69
+
70
+ In this repository, we release our implementation of **CLaRa**, built upon [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF).
71
+
72
+ ### Getting Started
73
+
74
+ ```
75
+ ├── scripts/ # Training and evaluation scripts
76
+ │ ├── train_pretraining.sh # Stage 1: Compression pretraining
77
+ │ ├── train_instruction_tuning.sh # Stage 2: Compression instruction tuning
78
+ │ ├── train_stage_end_to_end.sh # Stage 3: End-to-end training
79
+ │ └── evaluation_end_to_end.sh # Evaluation scripts
80
+ ├── openrlhf/ # Core training framework
81
+ │ ├── models/ # Model implementations
82
+ │ │ └── modeling_clara.py # CLaRa model definition
83
+ │ ├── datasets/ # Dataset handling
84
+ │ │ └── sft_dataset.py # Training dataset
85
+ │ ├── trainer/ # Training utilities
86
+ │ │ └── sft_trainer.py # SFT trainer
87
+ │ └── cli/ # Command line interface
88
+ │ └── train_sft.py # Main training script
89
+ ├── evaluation/ # Evaluation framework
90
+ ├── example/ # Example training data
91
+ │ ├── pretrain_data.jsonl
92
+ │ ├── instruction_tuning_data.jsonl
93
+ │ └── end_to_end_data.jsonl
94
+ └── README.md # This file
95
+ ```
96
+
97
+ Video instruction for installation (from @Fahd Mirza): https://youtu.be/al2VoAKn8GU?si=Q8bq7QNMaTvcArwa
98
+ Video digest (from @Richard Aragon): https://www.youtube.com/watch?v=yRM92mmKNH4
99
+
100
+ #### 1. Prepare code and environment
101
+
102
+ Clone the repository and set up the environment:
103
+
104
+ ```bash
105
+ # Create conda environment
106
+ env=clara
107
+ conda create -n $env python=3.10 -y
108
+ conda activate $env
109
+
110
+ # Install dependencies
111
+ pip install -r requirements.txt
112
+
113
+ # Set up environment variables
114
+ export PYTHONPATH=/path/to/clara:$PYTHONPATH
115
+ ```
116
+
117
+ Key dependencies include:
118
+ - PyTorch >= 2.0
119
+ - Transformers >= 4.20
120
+ - DeepSpeed >= 0.18
121
+ - Flash Attention 2
122
+ - Accelerate
123
+
124
+ #### 2. Data preparation
125
+
126
+ Prepare training data in JSONL format. For pretraining stage:
127
+
128
+ ```bash
129
+ # Example data format for pretraining
130
+ {
131
+ "data_type": "qa",
132
+ "question": ["Question 1",],
133
+ "answers": ["Answer 1"],
134
+ "docs": ["Document 1"]
135
+ }
136
+ ```
137
+
138
+ For end-to-end training:
139
+
140
+ ```bash
141
+ {
142
+ "question": "Single question text",
143
+ "docs": ["Document 1", "Document 2", ...],
144
+ "gold_answer": "Reference answer"
145
+ }
146
+ ```
147
+
148
+ #### 3. Start training
149
+
150
+ **Stage 1: Salient Compressor Pretraining (SCP)**
151
+
152
+ Pre-train the document compressor :
153
+
154
+ ```bash
155
+ bash scripts/train_pretraining.sh
156
+ ```
157
+
158
+ Key parameters:
159
+ - `--compress_rate`: Compression rate (default: 32)
160
+ - `--doc_max_length`: Maximum document length (default: 256)
161
+ - `--stage stage1`: Training stage
162
+ - `--mse_loss`: Use MSE loss to align compressed and original representations
163
+ - `--qa_loss`: Use QA loss for semantic preservation
164
+
165
+ **Stage 2: Compression Instruction Tuning**
166
+
167
+ Fine-tune the compressor on instruction-following tasks:
168
+
169
+ ```bash
170
+ bash scripts/train_instruction_tuning.sh
171
+ ```
172
+
173
+ Key parameters:
174
+ - `--pretrain_checkpoint`: Path to stage 1 checkpoint
175
+ - `--stage stage1_2`: Training stage
176
+ - `--generation_top_k`: Top-k sampling for generation (default: 5)
177
+ - `--mse_loss`: Use MSE loss for compression training
178
+ - `--do_eval_gen`: Enable generation evaluation
179
+
180
+ **Stage 3: End-to-End Training**
181
+
182
+ Fine-tune the model end-to-end with retrieval:
183
+
184
+ ```bash
185
+ bash scripts/train_stage_end_to_end.sh
186
+ ```
187
+
188
+ Key parameters:
189
+ - `--pretrain_checkpoint`: Path to stage 2 checkpoint
190
+ - `--stage stage2`: Training stage
191
+ - `--generation_top_k`: Top-k sampling for generation
192
+ - `--do_eval_gen`: Enable generation evaluation
193
+
194
+ #### 4. Distributed Training
195
+
196
+ The training scripts support distributed training across multiple nodes and GPUs:
197
+
198
+ - `--max_len`: Maximum sequence length (default: 2048 for stage1/stage2, 1024 for stage3)
199
+ - `--train_batch_size`: Training batch size
200
+ - `--micro_train_batch_size`: Micro batch size for gradient accumulation
201
+ - `--learning_rate`: Learning rate (default: 1e-4 for stage1/stage2, 5e-6 for stage3)
202
+ - `--max_epochs`: Maximum training epochs
203
+ - `--zero_stage`: ZeRO optimization stage (default: 2)
204
+ - `--bf16`: Use bfloat16 precision
205
+ - `--flash_attn`: Use Flash Attention 2
206
+
207
+ ### Inference
208
+
209
+ The CLaRa models can be loaded and used for inference. We provide three models corresponding to different training stages:
210
+
211
+ <details>
212
+ <summary>Stage 1: Compression Pretraining model (click to expand)</summary>
213
+
214
+ ```python
215
+ from transformers import AutoModel
216
+
217
+ model_path = "path/to/stage1/model"
218
+ model = AutoModel.from_pretrained(
219
+ model_path,
220
+ trust_remote_code=True
221
+ ).to('cuda')
222
+
223
+ # Example documents
224
+ documents = [
225
+ [
226
+ "Document 1 content...",
227
+ "Document 2 content...",
228
+ "Document 3 content..."
229
+ ]
230
+ ]
231
+
232
+ questions = ["" for _ in range(len(documents))]
233
+
234
+ # Generate paraphrase from compressed representations
235
+ output = model.generate_from_paraphrase(
236
+ questions=questions,
237
+ documents=documents,
238
+ max_new_tokens=64
239
+ )
240
+
241
+ print('Generated paraphrase:', output[0])
242
+ ```
243
+
244
+ </details>
245
+
246
+ <details>
247
+ <summary>Stage 2: Compression Instruction Tuning model (click to expand)</summary>
248
+
249
+ ```python
250
+ from transformers import AutoModel
251
+
252
+ model_path = "path/to/stage2/model"
253
+ model = AutoModel.from_pretrained(
254
+ model_path,
255
+ trust_remote_code=True
256
+ ).to('cuda')
257
+
258
+ # Example documents and question
259
+ documents = [
260
+ [
261
+ "Document 1 content...",
262
+ "Document 2 content...",
263
+ "Document 3 content..."
264
+ ]
265
+ ]
266
+
267
+ questions = ["Your question here"]
268
+
269
+ # Generate answer from compressed representations
270
+ output = model.generate_from_text(
271
+ questions=questions,
272
+ documents=documents,
273
+ max_new_tokens=64
274
+ )
275
+
276
+ print('Generated answer:', output[0])
277
+ ```
278
+
279
+ </details>
280
+
281
+ <details>
282
+ <summary>Stage 3: End-to-End (CLaRa) model (click to expand)</summary>
283
+
284
+ ```python
285
+ from transformers import AutoModel
286
+
287
+ model_path = "path/to/stage3/model"
288
+ model = AutoModel.from_pretrained(
289
+ model_path,
290
+ trust_remote_code=True
291
+ ).to('cuda')
292
+
293
+ # Example documents and question
294
+ # Note: Stage 3 supports retrieval with multiple candidate documents
295
+ documents = [
296
+ ["Document 1 content..." for _ in range(20)] # 20 candidate documents
297
+ ]
298
+
299
+ questions = ["Your question here"]
300
+
301
+ # Generate answer with retrieval and reranking
302
+ # The top-k is decided by generation_top_k in config.json
303
+ output, topk_indices = model.generate_from_questions(
304
+ questions=questions,
305
+ documents=documents,
306
+ max_new_tokens=64
307
+ )
308
+
309
+ print('Generated answer:', output[0])
310
+ print('Top-k selected document indices:', topk_indices)
311
+ ```
312
+
313
+ </details>
314
+
315
+ ### Evaluation
316
+
317
+ The evaluation framework is based on standard RAG benchmarks. Run evaluation:
318
+
319
+ **End-to-end evaluation:**
320
+ ```bash
321
+ bash scripts/evaluation_end_to_end.sh
322
+ ```
323
+
324
+ **Instruction tuning evaluation:**
325
+ ```bash
326
+ bash scripts/evaluation_instruction_tuning.sh
327
+ ```
328
+
329
+ Supported datasets:
330
+ - **HotpotQA**: Multi-hop question answering
331
+ - **MuSiQue**: Multi-hop question answering with diverse reasoning
332
+ - **2WikiMultiHopQA**: Multi-hop question answering over Wikipedia
333
+ - **Natural Questions**: Open-domain question answering
334
+
335
+
336
+
337
+ ### Results
338
+
339
+ #### Compression Performance
340
+
341
+ We evaluate our document compressor on four QA datasets (NQ, HotpotQA, MuSiQue, 2WikiMultiHopQA) under two settings: **Normal** (retrieving top-5 documents) and **Oracle** (gold document included). CLaRa consistently outperforms all baselines across different compression ratios.
342
+
343
+ <div align="center">
344
+
345
+ **Main Results (Mistral-7B, Normal Setting)**
346
+
347
+ | Model | CR | NQ | HotpotQA | MuSiQue | 2Wiki | Avg |
348
+ |:---|:---:|:---:|:---:|:---:|:---:|:---:|
349
+ | AutoCompressor | - | 17.24 | 14.61 | 3.81 | 19.89 | 13.89 |
350
+ | XRAG | 128 | 32.35 | 25.16 | 3.64 | 28.79 | 22.48 |
351
+ | COCOM | 16 | 24.12 | 21.48 | 3.52 | 24.48 | 18.40 |
352
+ | PCC | 16 | 31.38 | 22.29 | 3.43 | 19.47 | 19.14 |
353
+ | LLMLingua-2 | 4 | 47.53 | 37.05 | 9.02 | 44.35 | 34.49 |
354
+ | PISCO | 16 | 54.39 | 41.94 | 10.09 | 44.88 | 37.83 |
355
+ | Mistral-7B w/ retrieval | - | 54.58 | 42.94 | 8.94 | 44.24 | 37.67 |
356
+ | **CLaRa (CR=4)** | **4** | **57.05** | **45.09** | **10.34** | **46.94** | **39.86** |
357
+ | **CLaRa (CR=16)** | **16** | **55.56** | **43.72** | **10.55** | **46.00** | **38.96** |
358
+ | **CLaRa (CR=32)** | **32** | **54.64** | **43.52** | **10.55** | **46.58** | **38.82** |
359
+
360
+ **Oracle Setting Results (Mistral-7B)**
361
+
362
+ | Model | CR | NQ | HotpotQA | MuSiQue | 2Wiki | Avg |
363
+ |:---|:---:|:---:|:---:|:---:|:---:|:---:|
364
+ | PISCO | 16 | 73.44 | 66.53 | 33.80 | 60.45 | 58.55 |
365
+ | Mistral-7B w/ retrieval | - | 71.64 | 70.77 | 45.72 | 68.83 | 64.24 |
366
+ | **CLaRa (CR=4)** | **4** | **76.50** | **73.81** | **46.26** | **70.48** | **66.76** |
367
+ | **CLaRa (CR=16)** | **16** | **75.48** | **70.79** | **43.15** | **66.16** | **63.90** |
368
+ | **CLaRa (CR=32)** | **32** | **73.77** | **69.51** | **38.31** | **64.54** | **61.53** |
369
+
370
+ </div>
371
+
372
+ **Key Findings:**
373
+ - ✅ CLaRa outperforms PISCO by **+1.13%** (Normal) and **+5.35%** (Oracle) on average
374
+ - ✅ CLaRa outperforms LLMLingua-2 by **+5.37%** (Normal) on average
375
+ - ✅ CLaRa matches/exceeds text-based baseline with **+2.36%** average gain on Mistral-7B
376
+
377
+ #### Retrieval Performance
378
+
379
+ <div align="center">
380
+
381
+ <img src="figs/main_recall.png" width="80%"/>
382
+
383
+ </div>
384
+
385
+ For detailed experimental results and analysis, please refer to our paper.
386
+
387
+ ## Acknowledgments
388
+
389
+ We sincerely appreciate the following works for CLaRa:
390
+
391
+ - Our implementation is built upon the [OpenRLHF framework](https://github.com/OpenRLHF/OpenRLHF).
392
+
393
+ - Inspired by [PISCO-mistral](https://huggingface.co/naver/pisco-mistral) for document compression techniques
394
+
395
+ ## Citation
396
+
397
+ ```bibtex
398
+ @misc{he2025clarabridgingretrievalgeneration,
399
+ title={CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning},
400
+ author={Jie He and Richard He Bai and Sinead Williamson and Jeff Z. Pan and Navdeep Jaitly and Yizhe Zhang},
401
+ year={2025},
402
+ eprint={2511.18659},
403
+ archivePrefix={arXiv},
404
+ primaryClass={cs.CL},
405
+ url={https://arxiv.org/abs/2511.18659},
406
+ }
407
+ ```
docs/Gemfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source "https://rubygems.org"
2
+
3
+ gem "jekyll", "~> 4.3"
4
+ gem "jekyll-feed", "~> 0.12"
5
+ gem "jekyll-seo-tag", "~> 2.8"
6
+ gem "jekyll-sitemap", "~> 1.4"
7
+ gem "jekyll-theme-cayman", "~> 0.2"
8
+
9
+ # Windows and JRuby do not include zoneinfo files, so bundle the tzinfo-data gem
10
+ # and associated library.
11
+ platforms :mingw, :x64_mingw, :mswin, :jruby do
12
+ gem "tzinfo", ">= 1", "< 3"
13
+ gem "tzinfo-data"
14
+ end
15
+
16
+ # Performance-booster for watching directories on Windows
17
+ gem "wdm", "~> 0.1.1", :platforms => [:mingw, :x64_mingw, :mswin]
18
+
19
+ # Lock `http_parser.rb` gem to `v0.6.x` on JRuby builds since newer versions of the gem
20
+ # do not have a Java counterpart.
21
+ gem "http_parser.rb", "~> 0.6.0", :platforms => [:jruby]
22
+
23
+
docs/_config.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Jekyll configuration for CLaRa Documentation
2
+
3
+ # Site settings
4
+ title: CLaRa Documentation
5
+ description: Complete documentation for CLaRa - Unified Retrieval-Augmented Generation with Compression
6
+ baseurl: "/CLaRa" # the subpath of your site for GitHub Pages
7
+ url: "https://aiml-oss.github.io" # the base hostname & protocol for GitHub Pages
8
+
9
+ # Theme
10
+ theme: jekyll-theme-cayman
11
+ # Alternative themes:
12
+ # theme: jekyll-theme-minimal
13
+ # theme: jekyll-theme-slate
14
+ # theme: jekyll-theme-architect
15
+
16
+ # GitHub repository info
17
+ repository: probe2/CLaRa
18
+
19
+ # Build settings
20
+ markdown: kramdown
21
+ kramdown:
22
+ input: GFM
23
+ syntax_highlighter: rouge
24
+
25
+ # Plugins
26
+ plugins:
27
+ - jekyll-feed
28
+ - jekyll-seo-tag
29
+ - jekyll-sitemap
30
+
31
+ # Navigation
32
+ header_pages:
33
+ - index.md
34
+ - getting_started.md
35
+ - training.md
36
+ - inference.md
37
+
38
+ # Exclude from processing
39
+ exclude:
40
+ - Gemfile
41
+ - Gemfile.lock
42
+ - node_modules
43
+ - vendor/bundle/
44
+ - vendor/cache/
45
+ - vendor/gems/
46
+ - vendor/ruby/
47
+ - "*.sh"
48
+ - "*.log"
49
+
50
+ # Collections and navigation
51
+ # collections:
52
+ # docs:
53
+ # output: true
54
+ # permalink: /:collection/:path/
55
+
56
+ # Defaults
57
+ defaults:
58
+ - scope:
59
+ path: ""
60
+ values:
61
+ layout: "default"
docs/getting_started.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ layout: default
3
+ title: Getting Started
4
+ permalink: /getting_started/
5
+ ---
6
+
7
+ # Getting Started with CLaRa
8
+
9
+ This guide will help you get started with CLaRa, from installation to running your first training.
10
+
11
+ ## Installation
12
+
13
+ ### Prerequisites
14
+
15
+ - Python 3.10+
16
+ - CUDA-compatible GPU (recommended)
17
+ - PyTorch 2.0+
18
+ - CUDA 11.8 or 12.x
19
+
20
+ ### Step 1: Create Conda Environment
21
+
22
+ ```bash
23
+ env=clara
24
+ conda create -n $env python=3.10 -y
25
+ conda activate $env
26
+ ```
27
+
28
+ ### Step 2: Install Dependencies
29
+
30
+ ```bash
31
+ pip install -r requirements.txt
32
+ ```
33
+
34
+ Key dependencies include:
35
+ - `torch>=2.0`
36
+ - `transformers>=4.20`
37
+ - `deepspeed>=0.18`
38
+ - `flash-attn>=2.8.0`
39
+ - `accelerate>=1.10.1`
40
+ - `peft>=0.17.1`
41
+
42
+ ### Step 3: Set Environment Variables
43
+
44
+ ```bash
45
+ export PYTHONPATH=/path/to/clara:$PYTHONPATH
46
+ ```
47
+
48
+ ## Quick Start
49
+
50
+ ### 1. Prepare Your Data
51
+
52
+ CLaRa uses JSONL format for training data. See the [Training Guide](./training.md) for data format details.
53
+
54
+ ### 2. Train Stage 1: Compression Pretraining
55
+
56
+ ```bash
57
+ bash scripts/train_pretraining.sh
58
+ ```
59
+
60
+ ### 3. Train Stage 2: Instruction Tuning
61
+
62
+ ```bash
63
+ bash scripts/train_instruction_tuning.sh
64
+ ```
65
+
66
+ ### 4. Train Stage 3: End-to-End Training
67
+
68
+ ```bash
69
+ bash scripts/train_stage_end_to_end.sh
70
+ ```
71
+
72
+ ### 5. Run Inference
73
+
74
+ See the [Inference Guide](./inference.md) for examples of using all three model stages.
75
+
76
+ ## Next Steps
77
+
78
+ - [Training Guide](./training.md) - Detailed training instructions and data formats
79
+ - [Inference Guide](./inference.md) - Inference examples for all model stages
80
+
docs/index.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ layout: default
3
+ title: CLaRa Documentation
4
+ ---
5
+
6
+ # CLaRa Documentation
7
+
8
+ Welcome to the CLaRa documentation! This site provides comprehensive guides and references for using CLaRa.
9
+
10
+ ## What is CLaRa?
11
+
12
+ **CLaRa** (Continuous Latent Reasoning) is a unified framework for retrieval-augmented generation that performs embedding-based compression and joint optimization in a shared continuous space.
13
+
14
+ [![Paper](https://img.shields.io/badge/Paper-Arxiv%20Link-green)](https://arxiv.org/abs/XXXX.XXXXX) [![License](https://img.shields.io/badge/License-Apple-blue)](../LICENSE) [![deploy](https://img.shields.io/badge/Hugging%20Face-CLaRa_Base-FFEB3B)](https://huggingface.co/your-org/clara-base) [![deploy](https://img.shields.io/badge/Hugging%20Face-CLaRa_Instruct-FFEB3B)](https://huggingface.co/your-org/clara-instruct) [![deploy](https://img.shields.io/badge/Hugging%20Face-CLaRa_End_to_end-FFEB3B)](https://huggingface.co/your-org/clara-e)
15
+
16
+ ## Documentation
17
+
18
+ - **[Getting Started](./getting_started.md)** - Installation and quick start guide
19
+ - **[Training Guide](./training.md)** - Detailed instructions for all three training stages including data formats
20
+ - **[Inference Guide](./inference.md)** - How to use CLaRa models for inference
21
+
22
+ ## Quick Links
23
+
24
+ - **GitHub Repository**: [github.com/apple/ml-CLaRa](https://github.com/apple/ml-CLaRa)
25
+ - **Main README**: [../README.md](../README.md)
26
+ - **Model Checkpoints**: [Hugging Face](https://huggingface.co/your-org/clara-base) (Coming Soon)
27
+
28
+ ## Overview
29
+
30
+ CLaRa uses a three-stage training approach:
31
+
32
+ 1. **Stage 1: Compression Pretraining** - Learn effective document compression
33
+ 2. **Stage 2: Compression Instruction Tuning** - Adapt for downstream QA tasks
34
+ 3. **Stage 3: End-to-End Fine-tuning (CLaRa)** - Joint retrieval and generation optimization
35
+
36
+ For more details, see the [Training Guide](./training.md).
37
+
38
+ ## Citation
39
+
40
+ If you use CLaRa in your research, please cite:
41
+
42
+ ```bibtex
43
+ @article{clara2024,
44
+ title={CLaRa: Unified Retrieval-Augmented Generation with Compression},
45
+ author={[Authors]},
46
+ journal={[Journal]},
47
+ year={2024},
48
+ eprint={XXXX.XXXXX},
49
+ archivePrefix={arXiv},
50
+ primaryClass={cs.CL},
51
+ url={https://arxiv.org/abs/XXXX.XXXXX}
52
+ }
53
+ ```
docs/inference.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ layout: default
3
+ title: Inference Guide
4
+ permalink: /inference/
5
+ ---
6
+
7
+ # Inference Guide
8
+
9
+ This guide shows how to use CLaRa models for inference at different stages.
10
+
11
+ ## Loading Models
12
+
13
+ CLaRa models can be loaded using the standard `AutoModel` interface:
14
+
15
+ ```python
16
+ from transformers import AutoModel
17
+
18
+ model = AutoModel.from_pretrained(
19
+ "path/to/model",
20
+ trust_remote_code=True
21
+ ).to('cuda')
22
+ ```
23
+
24
+ ## Stage 1: Compression Pretraining Model
25
+
26
+ Generate paraphrases from compressed document representations.
27
+
28
+ ```python
29
+ from transformers import AutoModel
30
+
31
+ model = AutoModel.from_pretrained(
32
+ "path/to/stage1/model",
33
+ trust_remote_code=True
34
+ ).to('cuda')
35
+
36
+ # Example documents
37
+ documents = [
38
+ [
39
+ "Document 1 content...",
40
+ "Document 2 content...",
41
+ "Document 3 content..."
42
+ ]
43
+ ]
44
+
45
+ questions = ["" for _ in range(len(documents))]
46
+
47
+ # Generate paraphrase from compressed representations
48
+ output = model.generate_from_paraphrase(
49
+ questions=questions,
50
+ documents=documents,
51
+ max_new_tokens=64
52
+ )
53
+
54
+ print('Generated paraphrase:', output[0])
55
+ ```
56
+
57
+ ## Stage 2: Compression Instruction Tuning Model
58
+
59
+ Generate answers from compressed representations for QA tasks.
60
+
61
+ ```python
62
+ from transformers import AutoModel
63
+
64
+ model = AutoModel.from_pretrained(
65
+ "path/to/stage2/model",
66
+ trust_remote_code=True
67
+ ).to('cuda')
68
+
69
+ # Example documents and question
70
+ documents = [
71
+ [
72
+ "Document 1 content...",
73
+ "Document 2 content...",
74
+ "Document 3 content..."
75
+ ]
76
+ ]
77
+
78
+ questions = ["Your question here"]
79
+
80
+ # Generate answer from compressed representations
81
+ output = model.generate_from_text(
82
+ questions=questions,
83
+ documents=documents,
84
+ max_new_tokens=64
85
+ )
86
+
87
+ print('Generated answer:', output[0])
88
+ ```
89
+
90
+ ## Stage 3: End-to-End (CLaRa) Model
91
+
92
+ Generate answers with retrieval and reranking using joint optimization.
93
+
94
+ ```python
95
+ from transformers import AutoModel
96
+
97
+ model = AutoModel.from_pretrained(
98
+ "path/to/stage3/model",
99
+ trust_remote_code=True
100
+ ).to('cuda')
101
+
102
+ # Example documents and question
103
+ # Note: Stage 3 supports retrieval with multiple candidate documents
104
+ documents = [
105
+ ["Document 1 content..." for _ in range(20)] # 20 candidate documents
106
+ ]
107
+
108
+ questions = ["Your question here"]
109
+
110
+ # Generate answer with retrieval and reranking
111
+ # The top-k is decided by generation_top_k in config.json
112
+ output, topk_indices = model.generate_from_questions(
113
+ questions=questions,
114
+ documents=documents,
115
+ max_new_tokens=64
116
+ )
117
+
118
+ print('Generated answer:', output[0])
119
+ print('Top-k selected document indices:', topk_indices)
120
+ ```
121
+
122
+ ## Key Parameters
123
+
124
+ - `max_new_tokens`: Maximum number of tokens to generate (default: 128)
125
+ - `generation_top_k`: Number of top documents to select (configured in model config)
126
+
127
+ ## Model Methods
128
+
129
+ - `generate_from_paraphrase()` - Stage 1: Generate paraphrases
130
+ - `generate_from_text()` - Stage 2: Generate answers from compressed docs
131
+ - `generate_from_questions()` - Stage 3: Generate with retrieval and reranking
132
+
133
+
134
+
docs/training.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ layout: default
3
+ title: Training Guide
4
+ permalink: /training/
5
+ ---
6
+
7
+ # Training Guide
8
+
9
+ This guide covers the three-stage training process in CLaRa.
10
+
11
+ ## Overview
12
+
13
+ CLaRa uses a three-stage training approach:
14
+
15
+ 1. **Stage 1**: Compression Pretraining
16
+ 2. **Stage 2**: Compression Instruction Tuning
17
+ 3. **Stage 3**: End-to-End Fine-tuning (CLaRa)
18
+
19
+ ## Stage 1: Compression Pretraining
20
+
21
+ Train the compressor to learn effective document compression.
22
+
23
+ ### Key Parameters
24
+
25
+ - `--stage stage1`: Training stage identifier
26
+ - `--compress_rate`: Compression rate (default: 32)
27
+ - `--doc_max_length`: Maximum document length (default: 256)
28
+ - `--mse_loss`: Use MSE loss for compression alignment
29
+ - `--qa_loss`: Use QA loss for semantic preservation
30
+
31
+ ### Example Command
32
+
33
+ ```bash
34
+ bash scripts/train_pretraining.sh
35
+ ```
36
+
37
+ ### Data Format
38
+
39
+ **Stage 1 Pretraining Data:**
40
+ ```json
41
+ {
42
+ "data_type": "qa",
43
+ "question": ["Question 1", "Question 2", ...],
44
+ "answers": ["Answer 1", "Answer 2", ...],
45
+ "docs": ["Document 1", "Document 2", ...]
46
+ }
47
+ ```
48
+
49
+ ## Stage 2: Compression Instruction Tuning
50
+
51
+ Fine-tune the compressor on instruction-following tasks.
52
+
53
+ ### Key Parameters
54
+
55
+ - `--stage stage1_2`: Training stage identifier
56
+ - `--pretrain_checkpoint`: Path to Stage 1 checkpoint
57
+ - `--generation_top_k`: Top-k sampling (default: 5)
58
+ - `--mse_loss`: Continue using MSE loss
59
+ - `--do_eval_gen`: Enable generation evaluation
60
+
61
+ ### Example Command
62
+
63
+ ```bash
64
+ bash scripts/train_instruction_tuning.sh
65
+ ```
66
+
67
+ ### Data Format
68
+
69
+ **Stage 2 Instruction Tuning Data:**
70
+ ```json
71
+ {
72
+ "question": "Single question text",
73
+ "docs": ["Document 1", "Document 2", ...],
74
+ "gold_answer": "Reference answer",
75
+ "answer": "Generated answer"
76
+ }
77
+ ```
78
+
79
+ ## Stage 3: End-to-End Training
80
+
81
+ Jointly train reranker and generator with retrieval.
82
+
83
+ ### Key Parameters
84
+
85
+ - `--stage stage2`: Training stage identifier
86
+ - `--pretrain_checkpoint`: Path to Stage 2 checkpoint
87
+ - `--generation_top_k`: Top-k sampling for generation
88
+ - `--do_eval_gen`: Enable generation evaluation
89
+
90
+ ### Example Command
91
+
92
+ ```bash
93
+ bash scripts/train_stage_end_to_end.sh
94
+ ```
95
+
96
+ ### Data Format
97
+
98
+ **Stage 3 End-to-End Data:**
99
+ ```json
100
+ {
101
+ "question": "Single question text",
102
+ "docs": ["Document 1", "Document 2", ...],
103
+ "gold_answer": "Reference answer"
104
+ }
105
+ ```
106
+
107
+ ## Distributed Training
108
+
109
+ All training stages support distributed training across multiple nodes and GPUs.
110
+
111
+ ### Key Parameters
112
+
113
+ - `--max_len`: Maximum sequence length (2048 for stage1/stage2, 1024 for stage3)
114
+ - `--train_batch_size`: Training batch size
115
+ - `--micro_train_batch_size`: Micro batch size for gradient accumulation
116
+ - `--learning_rate`: Learning rate (1e-4 for stage1/stage2, 5e-6 for stage3)
117
+ - `--max_epochs`: Maximum training epochs
118
+ - `--zero_stage`: ZeRO optimization stage (default: 2)
119
+ - `--bf16`: Use bfloat16 precision
120
+ - `--flash_attn`: Use Flash Attention 2
121
+
122
+ ## Monitoring Training
123
+
124
+ Training progress is logged via:
125
+ - Console output
126
+ - Wandb (if configured)
127
+ - Checkpoint files
128
+
129
+ Checkpoints are saved at the path specified by `--save_path`.
evaluation/evaluate.py ADDED
@@ -0,0 +1,936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4
+ #
5
+
6
+ import os
7
+ import json
8
+ import argparse
9
+ import gc
10
+ from datetime import timedelta
11
+ from collections import defaultdict, Counter
12
+ from typing import List, Dict, Any, Optional, Tuple
13
+
14
+ import torch
15
+ import numpy as np
16
+ from accelerate import Accelerator, InitProcessGroupKwargs
17
+ from transformers import AutoModel
18
+ from datasets import load_dataset
19
+ from tqdm import tqdm
20
+ import matplotlib.pyplot as plt
21
+ from sklearn.manifold import TSNE
22
+ from sklearn.decomposition import PCA
23
+ try:
24
+ import spacy
25
+ SPACY_AVAILABLE = True
26
+ except Exception as e:
27
+ SPACY_AVAILABLE = False
28
+ print(f"Warning: spacy not available ({e}). Entity extraction will be disabled.")
29
+ try:
30
+ import evaluate as eval_lib
31
+ EVAL_LIB_AVAILABLE = True
32
+ except Exception as e:
33
+ EVAL_LIB_AVAILABLE = False
34
+ eval_lib = None
35
+ print(f"Warning: evaluate library not available ({e}). BERTScore and ROUGE metrics will be disabled.")
36
+ import re
37
+ import string
38
+
39
+ from openrlhf.models.modeling_clara import CLaRa
40
+
41
+ # Environment setup
42
+ os.environ["NCCL_TIMEOUT"] = "5400"
43
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
44
+
45
+ # Global constants
46
+ TARGET_ENTITY_CATEGORIES = {"PERSON", "GPE", "DATE", "CARDINAL", "ORG"}
47
+
48
+
49
+ class EvaluationMetrics:
50
+ """Handles all evaluation metrics and scoring functions."""
51
+
52
+ def __init__(self):
53
+ if EVAL_LIB_AVAILABLE:
54
+ self.bertscore = eval_lib.load("bertscore")
55
+ self.rouge = eval_lib.load("rouge")
56
+ else:
57
+ self.bertscore = None
58
+ self.rouge = None
59
+ if SPACY_AVAILABLE:
60
+ self.nlp = spacy.load("en_core_web_sm")
61
+ else:
62
+ self.nlp = None
63
+
64
+ @staticmethod
65
+ def normalize_answer(text: str) -> str:
66
+ """Normalize text for comparison."""
67
+ def remove_articles(text):
68
+ return re.sub(r"\b(a|an|the)\b", " ", text)
69
+
70
+ def white_space_fix(text):
71
+ return " ".join(text.split())
72
+
73
+ def remove_punc(text):
74
+ exclude = set(string.punctuation)
75
+ return "".join(ch for ch in text if ch not in exclude)
76
+
77
+ return white_space_fix(remove_articles(remove_punc(text.lower())))
78
+
79
+ @staticmethod
80
+ def bool_mapping(text: str) -> str:
81
+ """Map boolean values to yes/no."""
82
+ mapping = {"True": "yes", "False": "no"}
83
+ return mapping.get(text, text)
84
+
85
+ def exact_match_score(self, prediction: str, ground_truth: str) -> bool:
86
+ """Calculate exact match score."""
87
+ pred_norm = self.normalize_answer(self.bool_mapping(prediction))
88
+ gt_norm = self.normalize_answer(self.bool_mapping(ground_truth))
89
+ return pred_norm == gt_norm
90
+
91
+ def cover_exact_match_score(self, prediction: str, ground_truth: str) -> bool:
92
+ """Calculate coverage exact match score."""
93
+ pred_tokens = self.normalize_answer(self.bool_mapping(prediction)).split()
94
+ gt_tokens = self.normalize_answer(self.bool_mapping(ground_truth)).split()
95
+ return all(token in pred_tokens for token in gt_tokens)
96
+
97
+ def f1_score(self, prediction: str, ground_truth: str) -> float:
98
+ """Calculate F1 score."""
99
+ pred_norm = self.normalize_answer(self.bool_mapping(prediction))
100
+ gt_norm = self.normalize_answer(self.bool_mapping(ground_truth))
101
+
102
+ # Handle yes/no/noanswer cases
103
+ if pred_norm in ["yes", "no", "noanswer"] and pred_norm != gt_norm:
104
+ return 0.0
105
+ if gt_norm in ["yes", "no", "noanswer"] and pred_norm != gt_norm:
106
+ return 0.0
107
+
108
+ pred_tokens = pred_norm.split()
109
+ gt_tokens = gt_norm.split()
110
+
111
+ common = Counter(pred_tokens) & Counter(gt_tokens)
112
+ num_same = sum(common.values())
113
+
114
+ if num_same == 0:
115
+ return 0.0
116
+
117
+ precision = num_same / len(pred_tokens)
118
+ recall = num_same / len(gt_tokens)
119
+
120
+ return (2 * precision * recall) / (precision + recall)
121
+
122
+ def extract_entities(self, text: str) -> set:
123
+ """Extract entities from text."""
124
+ if self.nlp is None:
125
+ return set() # Return empty set if spacy unavailable
126
+ doc = self.nlp(text)
127
+ return set(ent.text.lower().strip() for ent in doc.ents)
128
+
129
+ def extract_entities_by_category(self, text: str) -> Dict[str, set]:
130
+ """Extract entities by category."""
131
+ if self.nlp is None:
132
+ return defaultdict(set) # Return empty dict if spacy unavailable
133
+ doc = self.nlp(text)
134
+ entities_by_category = defaultdict(set)
135
+
136
+ for ent in doc.ents:
137
+ if ent.label_ in TARGET_ENTITY_CATEGORIES:
138
+ entities_by_category[ent.label_].add(ent.text.lower().strip())
139
+
140
+ return entities_by_category
141
+
142
+ def entity_preserve_metric(self, prediction: str, reference: str) -> float:
143
+ """Calculate entity preservation rate."""
144
+ ref_entities = self.extract_entities(reference)
145
+ pred_entities = self.extract_entities(prediction)
146
+
147
+ if not ref_entities:
148
+ return 1.0
149
+
150
+ preserved = ref_entities.intersection(pred_entities)
151
+ return len(preserved) / len(ref_entities)
152
+
153
+ def entity_preserve_metric_by_category(self, prediction_tokens: List[List[str]],
154
+ reference_docs: List[str]) -> Dict[str, float]:
155
+ """Calculate entity preservation by category."""
156
+ # Merge prediction tokens
157
+ all_prediction_tokens = []
158
+ for tokens in prediction_tokens:
159
+ all_prediction_tokens.extend(tokens)
160
+ prediction_text = " ".join(all_prediction_tokens)
161
+
162
+ # Merge reference documents
163
+ reference_text = " ".join(reference_docs)
164
+
165
+ # Extract entities
166
+ pred_entities = self.extract_entities_by_category(prediction_text)
167
+ ref_entities = self.extract_entities_by_category(reference_text)
168
+
169
+ # Calculate preservation rates
170
+ preservation_rates = {}
171
+
172
+ for category in TARGET_ENTITY_CATEGORIES:
173
+ ref_ents = ref_entities.get(category, set())
174
+ pred_ents = pred_entities.get(category, set())
175
+
176
+ if not ref_ents:
177
+ preservation_rates[category] = 1.0
178
+ else:
179
+ preserved = ref_ents.intersection(pred_ents)
180
+ preservation_rates[category] = len(preserved) / len(ref_ents)
181
+
182
+ # Calculate overall preservation
183
+ all_ref_entities = set()
184
+ all_pred_entities = set()
185
+
186
+ for entities_set in ref_entities.values():
187
+ all_ref_entities.update(entities_set)
188
+ for entities_set in pred_entities.values():
189
+ all_pred_entities.update(entities_set)
190
+
191
+ if not all_ref_entities:
192
+ preservation_rates["overall"] = 1.0
193
+ else:
194
+ preserved_overall = all_ref_entities.intersection(all_pred_entities)
195
+ preservation_rates["overall"] = len(preserved_overall) / len(all_ref_entities)
196
+
197
+ return preservation_rates
198
+
199
+
200
+ class ResultCalculator:
201
+ """Handles result calculation and visualization."""
202
+
203
+ def __init__(self):
204
+ self.metrics = EvaluationMetrics()
205
+
206
+ def calculate_basic_metrics(self, result_list: List[Dict]) -> Dict[str, float]:
207
+ """Calculate basic metrics (F1, accuracy, exact match)."""
208
+ f1_total = 0
209
+ acc_total = 0
210
+ em_total = 0
211
+ avg_output_length = 0
212
+
213
+ answer_key = "golden_answers" if "golden_answers" in result_list[0] else "answer"
214
+
215
+ for result in result_list:
216
+ prediction = result['CLaRa_normal_output']
217
+ ground_truth = result[answer_key][0] if answer_key == "golden_answers" else result[answer_key]
218
+
219
+ acc_total += self.metrics.cover_exact_match_score(prediction, ground_truth)
220
+ f1_total += self.metrics.f1_score(prediction, ground_truth)
221
+ em_total += self.metrics.exact_match_score(prediction, ground_truth)
222
+ avg_output_length += len(prediction.split())
223
+
224
+ n = len(result_list)
225
+ return {
226
+ "f1": f1_total / n,
227
+ "acc": acc_total / n,
228
+ "em": em_total / n,
229
+ "avg_output_length": avg_output_length / n
230
+ }
231
+
232
+ def calculate_stage2_metrics(self, result_list: List[Dict], k_values: List[int] = [1, 3, 5]) -> Dict[str, float]:
233
+ """Calculate stage2 metrics with recall and precision."""
234
+ basic_metrics = self.calculate_basic_metrics(result_list)
235
+
236
+ recall = {k: 0 for k in k_values}
237
+ precision = {k: 0 for k in k_values}
238
+
239
+ for result in result_list:
240
+ scores = result['topk_idx']
241
+ pos_index = set(result['pos_index'])
242
+
243
+ for k in k_values:
244
+ top_k = set(scores[:k])
245
+ hit = len(top_k & pos_index)
246
+
247
+ recall[k] += hit / len(pos_index) if len(pos_index) > 0 else 0
248
+ precision[k] += hit / k
249
+
250
+ n = len(result_list)
251
+ recall_metrics = {f"recall@{k}": v / n for k, v in recall.items()}
252
+ precision_metrics = {f"precision@{k}": v / n for k, v in precision.items()}
253
+
254
+ return {**basic_metrics, **recall_metrics, **precision_metrics}
255
+
256
+ def calculate_paraphrase_metrics(self, result_list: List[Dict]) -> Dict[str, float]:
257
+ """Calculate paraphrase metrics."""
258
+ seen_metrics = {'bert-score': 0, 'rouge-1': 0, 'rouge-L': 0, 'entity_preserve': 0}
259
+ unseen_metrics = {'bert-score': 0, 'rouge-1': 0, 'rouge-L': 0, 'entity_preserve': 0}
260
+
261
+ # Process seen data (first 2000)
262
+ for result in result_list[:2000]:
263
+ prediction = result['CLaRa_normal_output']
264
+ ground_truth = result['doc']
265
+
266
+ if EVAL_LIB_AVAILABLE and self.metrics.bertscore is not None:
267
+ bs = self.metrics.bertscore.compute(predictions=[prediction], references=[ground_truth], lang="en")
268
+ seen_metrics['bert-score'] += bs['f1'][0]
269
+
270
+ if EVAL_LIB_AVAILABLE and self.metrics.rouge is not None:
271
+ rouge_scores = self.metrics.rouge.compute(predictions=[prediction], references=[ground_truth])
272
+ seen_metrics['rouge-1'] += rouge_scores['rouge1']
273
+ seen_metrics['rouge-L'] += rouge_scores['rougeL']
274
+
275
+ seen_metrics['entity_preserve'] += self.metrics.entity_preserve_metric(prediction, ground_truth)
276
+
277
+ # Process unseen data (after 2000)
278
+ for result in result_list[2000:]:
279
+ prediction = result['CLaRa_normal_output']
280
+ ground_truth = result['doc']
281
+
282
+ if EVAL_LIB_AVAILABLE and self.metrics.bertscore is not None:
283
+ bs = self.metrics.bertscore.compute(predictions=[prediction], references=[ground_truth], lang="en")
284
+ unseen_metrics['bert-score'] += bs['f1'][0]
285
+
286
+ if EVAL_LIB_AVAILABLE and self.metrics.rouge is not None:
287
+ rouge_scores = self.metrics.rouge.compute(predictions=[prediction], references=[ground_truth])
288
+ unseen_metrics['rouge-1'] += rouge_scores['rouge1']
289
+ unseen_metrics['rouge-L'] += rouge_scores['rougeL']
290
+
291
+ unseen_metrics['entity_preserve'] += self.metrics.entity_preserve_metric(prediction, ground_truth)
292
+
293
+ # Normalize
294
+ n_seen = min(len(result_list[:2000]), 2000)
295
+ n_unseen = max(len(result_list) - 2000, 0)
296
+
297
+ final_metrics = {}
298
+ if n_seen > 0:
299
+ for key, value in seen_metrics.items():
300
+ final_metrics[f'seen_{key}'] = float(value / n_seen)
301
+
302
+ if n_unseen > 0:
303
+ for key, value in unseen_metrics.items():
304
+ final_metrics[f'unseen_{key}'] = float(value / n_unseen)
305
+
306
+ return final_metrics
307
+
308
+ def visualize_mse(self, result_list: List[Dict], save_path: str) -> Dict[str, Any]:
309
+ """Create t-SNE visualization for MSE analysis."""
310
+ # Set scientific style
311
+ plt.rcParams.update({
312
+ 'font.family': 'serif',
313
+ 'font.size': 12,
314
+ 'axes.labelsize': 14,
315
+ 'axes.titlesize': 16,
316
+ 'figure.titlesize': 18,
317
+ 'axes.linewidth': 1.2,
318
+ 'grid.alpha': 0.3,
319
+ })
320
+
321
+ # Collect representations
322
+ mem_reps = []
323
+ non_mem_reps = []
324
+
325
+ for result in result_list:
326
+ mem_rep = result['CLaRa_compressed_output']
327
+ non_mem_rep = result['CLaRa_normal_output']
328
+
329
+ if isinstance(mem_rep, torch.Tensor):
330
+ mem_rep = mem_rep.float().cpu().numpy()
331
+ if isinstance(non_mem_rep, torch.Tensor):
332
+ non_mem_rep = non_mem_rep.float().cpu().numpy()
333
+
334
+ mem_reps.append(mem_rep)
335
+ non_mem_reps.append(non_mem_rep)
336
+
337
+ mem_reps = np.array(mem_reps)
338
+ non_mem_reps = np.array(non_mem_reps)
339
+
340
+ print(f"Memory representations shape: {mem_reps.shape}")
341
+ print(f"Document representations shape: {non_mem_reps.shape}")
342
+
343
+ # Combine data for t-SNE
344
+ all_data = np.vstack([mem_reps, non_mem_reps])
345
+ original_dim = all_data.shape[1]
346
+
347
+ # PCA preprocessing if needed
348
+ if all_data.shape[1] > 50:
349
+ print(f"Applying PCA preprocessing from {all_data.shape[1]} to 50 dimensions...")
350
+ pca = PCA(n_components=50)
351
+ all_data = pca.fit_transform(all_data)
352
+ print(f"PCA explained variance ratio: {pca.explained_variance_ratio_[:5].sum():.3f}")
353
+
354
+ # Apply t-SNE
355
+ print("Applying t-SNE...")
356
+ perplexity = min(30, max(5, len(all_data) // 3))
357
+ tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity,
358
+ max_iter=1000, learning_rate=200, verbose=1)
359
+ tsne_results = tsne.fit_transform(all_data)
360
+
361
+ # Separate results
362
+ mem_tsne = tsne_results[:len(mem_reps)]
363
+ doc_tsne = tsne_results[len(mem_reps):]
364
+
365
+ # Create visualization
366
+ fig, ax = plt.subplots(1, 1, figsize=(10, 8))
367
+
368
+ # Add jitter to separate overlapping points
369
+ np.random.seed(42)
370
+ jitter_strength = 1.0
371
+
372
+ mem_jitter = mem_tsne.copy()
373
+ doc_jitter = doc_tsne.copy()
374
+
375
+ mem_jitter[:, 0] += np.random.normal(0.5, jitter_strength, len(mem_tsne))
376
+ mem_jitter[:, 1] += np.random.normal(0.5, jitter_strength, len(mem_tsne))
377
+
378
+ doc_jitter[:, 0] += np.random.normal(-0.5, jitter_strength, len(doc_tsne))
379
+ doc_jitter[:, 1] += np.random.normal(-0.5, jitter_strength, len(doc_tsne))
380
+
381
+ # Plot scatter points
382
+ ax.scatter(doc_jitter[:, 0], doc_jitter[:, 1], c='#0066CC', alpha=0.7, s=25,
383
+ marker='o', edgecolors='white', linewidth=0.5,
384
+ label='Document Representations', zorder=2)
385
+
386
+ ax.scatter(mem_jitter[:, 0], mem_jitter[:, 1], c='#FF3333', alpha=0.7, s=25,
387
+ marker='o', edgecolors='white', linewidth=0.5,
388
+ label='Memory Tokens Representations', zorder=3)
389
+
390
+ # Configure plot
391
+ ax.set_xlabel('')
392
+ ax.set_ylabel('')
393
+ ax.set_title('')
394
+
395
+ legend = ax.legend(frameon=True, fancybox=True, shadow=True,
396
+ loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=2, fontsize=14)
397
+ legend.get_frame().set_facecolor('white')
398
+ legend.get_frame().set_alpha(0.9)
399
+
400
+ ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
401
+ ax.set_axisbelow(True)
402
+
403
+ plt.tight_layout()
404
+
405
+ # Save visualization
406
+ os.makedirs(save_path, exist_ok=True)
407
+ plt.savefig(os.path.join(save_path, 'tsne_visualization_scientific.png'),
408
+ dpi=300, bbox_inches='tight', facecolor='white')
409
+ plt.show()
410
+
411
+ # Calculate statistics
412
+ distances = np.array([
413
+ np.linalg.norm(mem_reps[i] - non_mem_reps[i])
414
+ for i in range(len(mem_reps))
415
+ ])
416
+
417
+ statistics = {
418
+ 'mean_distance': float(np.mean(distances)),
419
+ 'std_distance': float(np.std(distances)),
420
+ 'median_distance': float(np.median(distances)),
421
+ 'min_distance': float(np.min(distances)),
422
+ 'max_distance': float(np.max(distances))
423
+ }
424
+
425
+ print("\n" + "="*60)
426
+ print("VISUALIZATION ANALYSIS REPORT")
427
+ print("="*60)
428
+ print(f"Dataset Statistics:")
429
+ print(f" • Total samples: {len(mem_reps)}")
430
+ print(f" • Original dimension: {original_dim}")
431
+ print(f" • t-SNE perplexity: {perplexity}")
432
+ print(f"\nDistance Analysis:")
433
+ for key, value in statistics.items():
434
+ print(f" • {key.replace('_', ' ').title()}: {value:.4f}")
435
+ print("="*60)
436
+
437
+ return {
438
+ 'mem_tsne': mem_tsne,
439
+ 'doc_tsne': doc_tsne,
440
+ 'original_distances': distances,
441
+ 'statistics': statistics
442
+ }
443
+
444
+
445
+ class DataLoader:
446
+ """Handles data loading for different datasets and stages."""
447
+
448
+ @staticmethod
449
+ def load_stage1_data(dataset: str, gold_retrieval: bool) -> List[Dict]:
450
+ """Load stage1 evaluation data."""
451
+ retrieval_type = "with_pos" if gold_retrieval else "no_pos"
452
+ file_path = f"/mnt/conductor_data/data/compression_rag_data/generator_training_val_data/stage1_eval/{dataset}/eval_processed_{retrieval_type}.jsonl"
453
+
454
+ data = []
455
+ with open(file_path, 'r') as f:
456
+ for line in f:
457
+ data.append(json.loads(line))
458
+
459
+ processed_data = []
460
+ for index, item in enumerate(data):
461
+ docs = item['docs'][:5] # Take top 5 documents
462
+ processed_item = {
463
+ 'original_data': item,
464
+ 'documents': docs,
465
+ 'question': item['question'],
466
+ 'global_index': index
467
+ }
468
+ processed_data.append(processed_item)
469
+
470
+ return processed_data
471
+
472
+ @staticmethod
473
+ def load_stage2_data(dataset: str, gold_retrieval: bool) -> List[Dict]:
474
+ """Load stage2 evaluation data."""
475
+ retrieval_type = "with_pos" if gold_retrieval else "no_pos"
476
+ file_path = f"/mnt/conductor_data/data/compression_rag_data/generator_training_val_data/stage2_eval/{dataset}/eval_processed_{retrieval_type}.jsonl"
477
+
478
+ processed_data = []
479
+ with open(file_path, 'r') as f:
480
+ for index, line in enumerate(f):
481
+ item = json.loads(line)
482
+ processed_item = {
483
+ 'original_data': item,
484
+ 'documents': item['docs'],
485
+ 'question': item['question'],
486
+ 'global_index': index,
487
+ 'pos_index': item['pos_index']
488
+ }
489
+ processed_data.append(processed_item)
490
+
491
+ return processed_data
492
+
493
+ @staticmethod
494
+ def load_paraphrase_data(file_path: str) -> List[Dict]:
495
+ """Load paraphrase data."""
496
+ data = []
497
+ with open(file_path, 'r') as f:
498
+ for line in f:
499
+ data.append(json.loads(line))
500
+
501
+ processed_data = []
502
+ for index, item in enumerate(data):
503
+ processed_item = {
504
+ 'original_data': item,
505
+ 'documents': [item['doc']],
506
+ 'question': "",
507
+ 'global_index': index
508
+ }
509
+ processed_data.append(processed_item)
510
+
511
+ return processed_data
512
+
513
+
514
+ class AcceleratedCLaRaInference:
515
+ """Main inference engine using Accelerate for distributed processing."""
516
+
517
+ def __init__(self, model_path: str, training_stage: str = None,
518
+ generation_top_k: int = None, args = None):
519
+ self.args = args
520
+
521
+ # Initialize Accelerator
522
+ process_group_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400))
523
+ self.accelerator = Accelerator(kwargs_handlers=[process_group_kwargs])
524
+
525
+ if self.accelerator.is_main_process:
526
+ print(f"Using {self.accelerator.num_processes} GPUs for distributed inference")
527
+ print(f"Current process: {self.accelerator.process_index}")
528
+ print("Loading CLaRa model...")
529
+
530
+ # Load model
531
+ self.model = CLaRa.from_pretrained(
532
+ model_path,
533
+ training_stage=training_stage,
534
+ generation_top_k=generation_top_k,
535
+ pure_inference=True
536
+ )
537
+
538
+ # Prepare model with Accelerator
539
+ self.model = self.accelerator.prepare(self.model)
540
+ self.model.eval()
541
+
542
+ if self.accelerator.is_main_process:
543
+ print("Model preparation completed")
544
+
545
+ def _get_model(self):
546
+ """Get the actual model (handles distributed vs single GPU)."""
547
+ return self.model.module if hasattr(self.model, 'module') else self.model
548
+
549
+ def process_batch(self, batch_questions: List[str], batch_documents: List[List[str]] = None,
550
+ stage2_mips: bool = False, training_stage: str = None,
551
+ batch_answers: List[str] = None, time_count: bool = False) -> Tuple:
552
+ """Process a batch of questions and documents."""
553
+ model = self._get_model()
554
+
555
+ with torch.no_grad():
556
+ try:
557
+ if training_stage == 'stage2':
558
+ return self._process_stage2(model, batch_questions, batch_documents,
559
+ stage2_mips, time_count)
560
+ elif training_stage in ['stage1', 'stage1_2']:
561
+ return self._process_stage1(model, batch_questions, batch_documents)
562
+ elif training_stage == 'stage2_reasoning':
563
+ return self._process_reasoning(model, batch_questions, batch_answers)
564
+ elif training_stage == 'stage1_paraphrase':
565
+ return self._process_paraphrase(model, batch_questions, batch_documents)
566
+ elif training_stage == 'stage1_mse_visulize':
567
+ return self._process_mse_visualize(model, batch_documents)
568
+ else:
569
+ raise ValueError(f"Unknown training stage: {training_stage}")
570
+
571
+ except torch.cuda.OutOfMemoryError as e:
572
+ self.accelerator.print(f"CUDA OOM error: {e}")
573
+ torch.cuda.empty_cache()
574
+ gc.collect()
575
+ return self._create_empty_results(batch_questions, training_stage)
576
+
577
+ def _process_stage2(self, model, batch_questions, batch_documents, stage2_mips, time_count):
578
+ """Process stage2 inference."""
579
+ if time_count:
580
+ if stage2_mips:
581
+ results = model.generate_from_questions(
582
+ questions=batch_questions,
583
+ max_new_tokens=64,
584
+ stage2_mips=stage2_mips,
585
+ time_count=True
586
+ )
587
+ else:
588
+ results = model.generate_from_questions(
589
+ questions=batch_questions,
590
+ max_new_tokens=64,
591
+ stage2_mips=stage2_mips,
592
+ documents=batch_documents,
593
+ time_count=True
594
+ )
595
+ return results
596
+ else:
597
+ if stage2_mips:
598
+ batch_out_normal, topk_idx = model.generate_from_questions(
599
+ questions=batch_questions,
600
+ max_new_tokens=64,
601
+ stage2_mips=stage2_mips
602
+ )
603
+ else:
604
+ batch_out_normal, topk_idx = model.generate_from_questions(
605
+ questions=batch_questions,
606
+ max_new_tokens=64,
607
+ stage2_mips=stage2_mips,
608
+ documents=batch_documents
609
+ )
610
+ return batch_out_normal, batch_out_normal, topk_idx
611
+
612
+ def _process_stage1(self, model, batch_questions, batch_documents):
613
+ """Process stage1 inference."""
614
+ batch_out_compressed = []
615
+
616
+ for docs, question in zip(batch_documents, batch_questions):
617
+ embeddings, _ = model.compress_documents(documents=docs)
618
+ out_compressed = model.generate_from_compressed_documents_and_questions(
619
+ questions=[question],
620
+ compressed_documents=embeddings
621
+ )
622
+ batch_out_compressed.extend(out_compressed)
623
+
624
+ del embeddings
625
+ torch.cuda.empty_cache()
626
+
627
+ return batch_out_compressed, batch_out_compressed, None
628
+
629
+ def _process_reasoning(self, model, batch_questions, batch_answers):
630
+ """Process reasoning inference."""
631
+ batch_out_normal = []
632
+ batch_out_reasoning_list = []
633
+
634
+ for question, answer in zip(batch_questions, batch_answers):
635
+ temp_out, temp_out_reasoning = model.generate_from_reasoning(
636
+ questions=[question],
637
+ max_new_tokens=1024,
638
+ answers=[answer],
639
+ save_dir=self.args.model_path
640
+ )
641
+ batch_out_normal.append(temp_out[0])
642
+ batch_out_reasoning_list.extend(temp_out_reasoning)
643
+
644
+ return batch_out_normal, batch_out_normal, None, batch_out_reasoning_list
645
+
646
+ def _process_paraphrase(self, model, batch_questions, batch_documents):
647
+ """Process paraphrase inference."""
648
+ batch_out_compressed = []
649
+
650
+ for docs, question in zip(batch_documents, batch_questions):
651
+ out_compressed = model.generate_from_paraphrase(
652
+ questions=["" for _ in range(len(docs))],
653
+ documents=[docs]
654
+ )
655
+ batch_out_compressed.extend(out_compressed)
656
+ torch.cuda.empty_cache()
657
+
658
+ return batch_out_compressed, batch_out_compressed, None
659
+
660
+ def _process_mse_visualize(self, model, batch_documents):
661
+ """Process MSE visualization."""
662
+ batch_out_normal = []
663
+ batch_out_compressed = []
664
+
665
+ for docs in batch_documents:
666
+ mem_rep, non_mem_rep = model.compress_documents_mse_visulize(documents=docs)
667
+ batch_out_compressed.append(mem_rep[0])
668
+ batch_out_normal.append(non_mem_rep[0])
669
+
670
+ return batch_out_normal, batch_out_compressed
671
+
672
+ def _create_empty_results(self, batch_questions, training_stage):
673
+ """Create empty results for error cases."""
674
+ empty_results = [""] * len(batch_questions)
675
+ if training_stage == 'stage2_reasoning':
676
+ return empty_results, empty_results, None, empty_results
677
+ elif training_stage == 'stage1_mse_visulize':
678
+ return empty_results, empty_results
679
+ else:
680
+ return empty_results, empty_results, None
681
+
682
+
683
+ def convert_embeddings_to_list(data):
684
+ """Convert tensor embeddings to lists for JSON serialization."""
685
+ if isinstance(data, dict):
686
+ return {k: convert_embeddings_to_list(v) for k, v in data.items()}
687
+ elif isinstance(data, list):
688
+ return [convert_embeddings_to_list(item) for item in data]
689
+ elif isinstance(data, torch.Tensor):
690
+ return data.cpu().to(torch.float32).numpy().tolist()
691
+ elif isinstance(data, np.ndarray):
692
+ return data.tolist()
693
+ else:
694
+ return data
695
+
696
+
697
+ def main():
698
+ parser = argparse.ArgumentParser(description="CLaRa Model Inference")
699
+ parser.add_argument('--model_path', type=str, required=True, help='Path to model checkpoint')
700
+ parser.add_argument('--batch_size', type=int, default=4, help='Batch size per GPU')
701
+ parser.add_argument('--stage', type=str, default='stage1',
702
+ choices=['stage1', 'stage1_2', 'stage2', 'stage2_reasoning',
703
+ 'stage1_paraphrase', 'stage1_mse_visulize'],
704
+ help='Training stage')
705
+ parser.add_argument('--stage2_mips', action='store_true', help='Use MIPS for stage2')
706
+ parser.add_argument('--dataset', type=str, default='musique',
707
+ help='Comma-separated list of datasets')
708
+ parser.add_argument('--gold_retrieval', action='store_true',
709
+ help='Use gold retrieval context')
710
+ parser.add_argument('--generation_top_k', type=int, default=5, help='Top-k for generation')
711
+ parser.add_argument('--paraphrase_path', type=str, help='Path to paraphrase data')
712
+ parser.add_argument('--mse_visulize_path', type=str, help='Path to save MSE visualization')
713
+ parser.add_argument('--efficient_count', action='store_true', help='Count efficiency metrics')
714
+
715
+ args = parser.parse_args()
716
+
717
+ # Process datasets
718
+ all_results_metrics = {}
719
+ datasets_list = args.dataset.split(',')
720
+
721
+ for dataset in datasets_list:
722
+ print(f"Processing dataset: {dataset}")
723
+
724
+ # Load data based on stage
725
+ if args.stage in ['stage1', 'stage1_2']:
726
+ processed_data = DataLoader.load_stage1_data(dataset, args.gold_retrieval)
727
+ elif args.stage == 'stage2':
728
+ processed_data = DataLoader.load_stage2_data(dataset, args.gold_retrieval)
729
+ elif args.stage in ['stage1_paraphrase', 'stage1_mse_visulize']:
730
+ if not args.paraphrase_path:
731
+ raise ValueError(f"--paraphrase_path required for stage {args.stage}")
732
+ processed_data = DataLoader.load_paraphrase_data(args.paraphrase_path)
733
+ else:
734
+ raise ValueError(f"Unsupported stage: {args.stage}")
735
+
736
+ print(f"Loaded {len(processed_data)} samples for {dataset}")
737
+
738
+ # Initialize inference engine
739
+ # Use model_path directly if absolute, otherwise use SageMaker path
740
+ if os.path.isabs(args.model_path):
741
+ model_path = args.model_path
742
+ else:
743
+ model_path = os.path.join('/mnt/task_wrapper/user_output/artifacts/data/train_checkpoint', args.model_path)
744
+ args.model_path = model_path
745
+
746
+ inference_engine = AcceleratedCLaRaInference(
747
+ model_path=model_path,
748
+ training_stage=args.stage,
749
+ generation_top_k=args.generation_top_k,
750
+ args=args
751
+ )
752
+
753
+ # Wait for all processes to be ready
754
+ inference_engine.accelerator.wait_for_everyone()
755
+
756
+ # Store results
757
+ all_results = []
758
+ time_count_dic = {"compress_time": 0, "query_time": 0, "generate_time": 0, "total_time": 0, "count": 0}
759
+
760
+ # Process data in batches using accelerator
761
+ with inference_engine.accelerator.split_between_processes(processed_data, apply_padding=False) as local_data:
762
+ print(f"Process {inference_engine.accelerator.process_index}: processing {len(local_data)} samples")
763
+
764
+ batch_size = args.batch_size
765
+ num_batches = (len(local_data) + batch_size - 1) // batch_size
766
+
767
+ for batch_idx in tqdm(range(num_batches),
768
+ desc=f"GPU {inference_engine.accelerator.process_index}",
769
+ disable=not inference_engine.accelerator.is_local_main_process):
770
+
771
+ # Get current batch
772
+ start_idx = batch_idx * batch_size
773
+ end_idx = min(start_idx + batch_size, len(local_data))
774
+ batch = local_data[start_idx:end_idx]
775
+
776
+ # Prepare batch data
777
+ batch_questions = [item['question'] for item in batch]
778
+ batch_documents = [item['documents'] for item in batch] if 'documents' in batch[0] else None
779
+ batch_answers = [item.get('answer') for item in batch] if args.stage == 'stage2_reasoning' else None
780
+
781
+ # Process batch
782
+ if args.efficient_count and args.stage == 'stage2':
783
+ results = inference_engine.process_batch(
784
+ batch_questions=batch_questions,
785
+ batch_documents=batch_documents,
786
+ stage2_mips=args.stage2_mips,
787
+ training_stage=args.stage,
788
+ time_count=True
789
+ )
790
+ batch_out_normal, batch_out_compressed, batch_topk_idx, compress_time, query_time, generate_time, total_time = results
791
+
792
+ time_count_dic["compress_time"] += compress_time
793
+ time_count_dic["query_time"] += query_time
794
+ time_count_dic["generate_time"] += generate_time
795
+ time_count_dic["total_time"] += total_time
796
+ time_count_dic["count"] += 1
797
+ else:
798
+ results = inference_engine.process_batch(
799
+ batch_questions=batch_questions,
800
+ batch_documents=batch_documents,
801
+ stage2_mips=args.stage2_mips,
802
+ training_stage=args.stage,
803
+ batch_answers=batch_answers
804
+ )
805
+
806
+ if args.stage == 'stage2_reasoning':
807
+ batch_out_normal, batch_out_compressed, batch_topk_idx, batch_out_reasoning = results
808
+ elif args.stage == 'stage1_mse_visulize':
809
+ batch_out_normal, batch_out_compressed = results
810
+ batch_topk_idx = None
811
+ else:
812
+ batch_out_normal, batch_out_compressed, batch_topk_idx = results
813
+
814
+ # Prepare results
815
+ batch_results = []
816
+ for i, (item, normal_out, compressed_out) in enumerate(zip(batch, batch_out_normal, batch_out_compressed)):
817
+ result_item = item['original_data'].copy()
818
+ result_item['CLaRa_normal_output'] = normal_out
819
+ result_item['CLaRa_compressed_output'] = compressed_out
820
+ result_item['global_index'] = item['global_index']
821
+
822
+ if args.stage == 'stage2' and batch_topk_idx is not None:
823
+ result_item['topk_idx'] = batch_topk_idx[i].tolist()
824
+ elif args.stage == 'stage2_reasoning':
825
+ result_item['reasoning_output'] = batch_out_reasoning[i]
826
+
827
+ batch_results.append(result_item)
828
+
829
+ all_results.extend(batch_results)
830
+
831
+
832
+ # Clean up memory
833
+ torch.cuda.empty_cache()
834
+ if batch_idx % 10 == 0:
835
+ gc.collect()
836
+
837
+ # Save efficiency metrics if requested
838
+ if args.efficient_count and inference_engine.accelerator.is_main_process:
839
+ eff_dic = {
840
+ "compress_time_ms": round((time_count_dic['compress_time'] / time_count_dic['count']) * 1000, 2),
841
+ "query_time_ms": round((time_count_dic['query_time'] / time_count_dic['count']) * 1000, 2),
842
+ "generate_time_ms": round((time_count_dic['generate_time'] / time_count_dic['count']) * 1000, 2),
843
+ "total_time_ms": round((time_count_dic['total_time'] / time_count_dic['count']) * 1000, 2),
844
+ "sample_count": time_count_dic['count']
845
+ }
846
+ eff_output_path = os.path.join(model_path, f"efficiency_{dataset}_{args.stage}_{args.gold_retrieval}_{args.generation_top_k}.json")
847
+ with open(eff_output_path, 'w') as f:
848
+ json.dump(eff_dic, f, indent=2)
849
+
850
+ # Wait for all processes to complete
851
+ inference_engine.accelerator.wait_for_everyone()
852
+
853
+ # Gather results from all processes
854
+ if inference_engine.accelerator.is_main_process:
855
+ print("Collecting results from all processes...")
856
+
857
+ all_results_gathered = inference_engine.accelerator.gather_for_metrics(all_results)
858
+
859
+ # Process and save results (main process only)
860
+ if inference_engine.accelerator.is_main_process:
861
+ print("Processing and saving results...")
862
+
863
+ # Flatten results
864
+ final_results = []
865
+ if isinstance(all_results_gathered, list):
866
+ for result_batch in all_results_gathered:
867
+ if isinstance(result_batch, list):
868
+ final_results.extend(result_batch)
869
+ else:
870
+ final_results.append(result_batch)
871
+
872
+ print(f"Collected {len(final_results)} results")
873
+
874
+ # Sort by global index to maintain order
875
+ final_results.sort(key=lambda x: x.get('global_index', 0))
876
+
877
+ # Verify data integrity
878
+ processed_indices = set(item.get('global_index', -1) for item in final_results)
879
+ expected_indices = set(range(len(processed_data)))
880
+ missing_indices = expected_indices - processed_indices
881
+
882
+ if missing_indices:
883
+ print(f"Warning: Missing indices: {sorted(list(missing_indices))}")
884
+ else:
885
+ print("✓ Data integrity verification passed")
886
+
887
+ # Remove global index for clean output
888
+ for item in final_results:
889
+ item.pop('global_index', None)
890
+
891
+ # Save results
892
+ output_path = os.path.join(model_path, f"{dataset}_{args.stage}_{args.gold_retrieval}_{args.generation_top_k}.jsonl")
893
+ with open(output_path, 'w') as f:
894
+ if args.stage == 'stage1_mse_visulize':
895
+ converted_results = convert_embeddings_to_list(final_results)
896
+ for item in converted_results:
897
+ f.write(json.dumps(item) + '\n')
898
+ else:
899
+ for item in final_results:
900
+ f.write(json.dumps(item) + '\n')
901
+
902
+ print(f"Results saved to: {output_path}")
903
+
904
+ # Calculate metrics
905
+ calculator = ResultCalculator()
906
+
907
+ if args.stage == 'stage2':
908
+ metrics = calculator.calculate_stage2_metrics(final_results)
909
+ elif args.stage == 'stage1_paraphrase':
910
+ metrics = calculator.calculate_paraphrase_metrics(final_results)
911
+ elif args.stage == 'stage1_mse_visulize':
912
+ if args.mse_visulize_path:
913
+ metrics = calculator.visualize_mse(final_results, args.mse_visulize_path)
914
+ else:
915
+ metrics = {"visualization": "completed"}
916
+ else:
917
+ metrics = calculator.calculate_basic_metrics(final_results)
918
+
919
+ print(f"Metrics for {dataset}: {metrics}")
920
+ all_results_metrics[dataset] = metrics
921
+
922
+ # Clean up
923
+ del inference_engine
924
+ torch.cuda.empty_cache()
925
+ gc.collect()
926
+
927
+ # Save final metrics
928
+ if len(all_results_metrics) > 0:
929
+ metrics_path = os.path.join(model_path, f"results_metrics_{args.stage}_{args.gold_retrieval}_{args.generation_top_k}.json")
930
+ with open(metrics_path, 'w') as f:
931
+ json.dump(all_results_metrics, f, indent=2)
932
+ print(f"Final metrics saved to: {metrics_path}")
933
+
934
+
935
+ if __name__ == '__main__':
936
+ main()
evaluation/evaluate.py.bak ADDED
@@ -0,0 +1,910 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4
+ #
5
+
6
+ import os
7
+ import json
8
+ import argparse
9
+ import gc
10
+ from datetime import timedelta
11
+ from collections import defaultdict, Counter
12
+ from typing import List, Dict, Any, Optional, Tuple
13
+
14
+ import torch
15
+ import numpy as np
16
+ from accelerate import Accelerator, InitProcessGroupKwargs
17
+ from transformers import AutoModel
18
+ from datasets import load_dataset
19
+ from tqdm import tqdm
20
+ import matplotlib.pyplot as plt
21
+ from sklearn.manifold import TSNE
22
+ from sklearn.decomposition import PCA
23
+ import spacy
24
+ import evaluate
25
+ import re
26
+ import string
27
+
28
+ from openrlhf.models.modeling_clara import CLaRa
29
+
30
+ # Environment setup
31
+ os.environ["NCCL_TIMEOUT"] = "5400"
32
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
33
+
34
+ # Global constants
35
+ TARGET_ENTITY_CATEGORIES = {"PERSON", "GPE", "DATE", "CARDINAL", "ORG"}
36
+
37
+
38
+ class EvaluationMetrics:
39
+ """Handles all evaluation metrics and scoring functions."""
40
+
41
+ def __init__(self):
42
+ self.bertscore = evaluate.load("bertscore")
43
+ self.rouge = evaluate.load("rouge")
44
+ self.nlp = spacy.load("en_core_web_sm")
45
+
46
+ @staticmethod
47
+ def normalize_answer(text: str) -> str:
48
+ """Normalize text for comparison."""
49
+ def remove_articles(text):
50
+ return re.sub(r"\b(a|an|the)\b", " ", text)
51
+
52
+ def white_space_fix(text):
53
+ return " ".join(text.split())
54
+
55
+ def remove_punc(text):
56
+ exclude = set(string.punctuation)
57
+ return "".join(ch for ch in text if ch not in exclude)
58
+
59
+ return white_space_fix(remove_articles(remove_punc(text.lower())))
60
+
61
+ @staticmethod
62
+ def bool_mapping(text: str) -> str:
63
+ """Map boolean values to yes/no."""
64
+ mapping = {"True": "yes", "False": "no"}
65
+ return mapping.get(text, text)
66
+
67
+ def exact_match_score(self, prediction: str, ground_truth: str) -> bool:
68
+ """Calculate exact match score."""
69
+ pred_norm = self.normalize_answer(self.bool_mapping(prediction))
70
+ gt_norm = self.normalize_answer(self.bool_mapping(ground_truth))
71
+ return pred_norm == gt_norm
72
+
73
+ def cover_exact_match_score(self, prediction: str, ground_truth: str) -> bool:
74
+ """Calculate coverage exact match score."""
75
+ pred_tokens = self.normalize_answer(self.bool_mapping(prediction)).split()
76
+ gt_tokens = self.normalize_answer(self.bool_mapping(ground_truth)).split()
77
+ return all(token in pred_tokens for token in gt_tokens)
78
+
79
+ def f1_score(self, prediction: str, ground_truth: str) -> float:
80
+ """Calculate F1 score."""
81
+ pred_norm = self.normalize_answer(self.bool_mapping(prediction))
82
+ gt_norm = self.normalize_answer(self.bool_mapping(ground_truth))
83
+
84
+ # Handle yes/no/noanswer cases
85
+ if pred_norm in ["yes", "no", "noanswer"] and pred_norm != gt_norm:
86
+ return 0.0
87
+ if gt_norm in ["yes", "no", "noanswer"] and pred_norm != gt_norm:
88
+ return 0.0
89
+
90
+ pred_tokens = pred_norm.split()
91
+ gt_tokens = gt_norm.split()
92
+
93
+ common = Counter(pred_tokens) & Counter(gt_tokens)
94
+ num_same = sum(common.values())
95
+
96
+ if num_same == 0:
97
+ return 0.0
98
+
99
+ precision = num_same / len(pred_tokens)
100
+ recall = num_same / len(gt_tokens)
101
+
102
+ return (2 * precision * recall) / (precision + recall)
103
+
104
+ def extract_entities(self, text: str) -> set:
105
+ """Extract entities from text."""
106
+ doc = self.nlp(text)
107
+ return set(ent.text.lower().strip() for ent in doc.ents)
108
+
109
+ def extract_entities_by_category(self, text: str) -> Dict[str, set]:
110
+ """Extract entities by category."""
111
+ doc = self.nlp(text)
112
+ entities_by_category = defaultdict(set)
113
+
114
+ for ent in doc.ents:
115
+ if ent.label_ in TARGET_ENTITY_CATEGORIES:
116
+ entities_by_category[ent.label_].add(ent.text.lower().strip())
117
+
118
+ return entities_by_category
119
+
120
+ def entity_preserve_metric(self, prediction: str, reference: str) -> float:
121
+ """Calculate entity preservation rate."""
122
+ ref_entities = self.extract_entities(reference)
123
+ pred_entities = self.extract_entities(prediction)
124
+
125
+ if not ref_entities:
126
+ return 1.0
127
+
128
+ preserved = ref_entities.intersection(pred_entities)
129
+ return len(preserved) / len(ref_entities)
130
+
131
+ def entity_preserve_metric_by_category(self, prediction_tokens: List[List[str]],
132
+ reference_docs: List[str]) -> Dict[str, float]:
133
+ """Calculate entity preservation by category."""
134
+ # Merge prediction tokens
135
+ all_prediction_tokens = []
136
+ for tokens in prediction_tokens:
137
+ all_prediction_tokens.extend(tokens)
138
+ prediction_text = " ".join(all_prediction_tokens)
139
+
140
+ # Merge reference documents
141
+ reference_text = " ".join(reference_docs)
142
+
143
+ # Extract entities
144
+ pred_entities = self.extract_entities_by_category(prediction_text)
145
+ ref_entities = self.extract_entities_by_category(reference_text)
146
+
147
+ # Calculate preservation rates
148
+ preservation_rates = {}
149
+
150
+ for category in TARGET_ENTITY_CATEGORIES:
151
+ ref_ents = ref_entities.get(category, set())
152
+ pred_ents = pred_entities.get(category, set())
153
+
154
+ if not ref_ents:
155
+ preservation_rates[category] = 1.0
156
+ else:
157
+ preserved = ref_ents.intersection(pred_ents)
158
+ preservation_rates[category] = len(preserved) / len(ref_ents)
159
+
160
+ # Calculate overall preservation
161
+ all_ref_entities = set()
162
+ all_pred_entities = set()
163
+
164
+ for entities_set in ref_entities.values():
165
+ all_ref_entities.update(entities_set)
166
+ for entities_set in pred_entities.values():
167
+ all_pred_entities.update(entities_set)
168
+
169
+ if not all_ref_entities:
170
+ preservation_rates["overall"] = 1.0
171
+ else:
172
+ preserved_overall = all_ref_entities.intersection(all_pred_entities)
173
+ preservation_rates["overall"] = len(preserved_overall) / len(all_ref_entities)
174
+
175
+ return preservation_rates
176
+
177
+
178
+ class ResultCalculator:
179
+ """Handles result calculation and visualization."""
180
+
181
+ def __init__(self):
182
+ self.metrics = EvaluationMetrics()
183
+
184
+ def calculate_basic_metrics(self, result_list: List[Dict]) -> Dict[str, float]:
185
+ """Calculate basic metrics (F1, accuracy, exact match)."""
186
+ f1_total = 0
187
+ acc_total = 0
188
+ em_total = 0
189
+ avg_output_length = 0
190
+
191
+ answer_key = "golden_answers" if "golden_answers" in result_list[0] else "answer"
192
+
193
+ for result in result_list:
194
+ prediction = result['CLaRa_normal_output']
195
+ ground_truth = result[answer_key][0] if answer_key == "golden_answers" else result[answer_key]
196
+
197
+ acc_total += self.metrics.cover_exact_match_score(prediction, ground_truth)
198
+ f1_total += self.metrics.f1_score(prediction, ground_truth)
199
+ em_total += self.metrics.exact_match_score(prediction, ground_truth)
200
+ avg_output_length += len(prediction.split())
201
+
202
+ n = len(result_list)
203
+ return {
204
+ "f1": f1_total / n,
205
+ "acc": acc_total / n,
206
+ "em": em_total / n,
207
+ "avg_output_length": avg_output_length / n
208
+ }
209
+
210
+ def calculate_stage2_metrics(self, result_list: List[Dict], k_values: List[int] = [1, 3, 5]) -> Dict[str, float]:
211
+ """Calculate stage2 metrics with recall and precision."""
212
+ basic_metrics = self.calculate_basic_metrics(result_list)
213
+
214
+ recall = {k: 0 for k in k_values}
215
+ precision = {k: 0 for k in k_values}
216
+
217
+ for result in result_list:
218
+ scores = result['topk_idx']
219
+ pos_index = set(result['pos_index'])
220
+
221
+ for k in k_values:
222
+ top_k = set(scores[:k])
223
+ hit = len(top_k & pos_index)
224
+
225
+ recall[k] += hit / len(pos_index) if len(pos_index) > 0 else 0
226
+ precision[k] += hit / k
227
+
228
+ n = len(result_list)
229
+ recall_metrics = {f"recall@{k}": v / n for k, v in recall.items()}
230
+ precision_metrics = {f"precision@{k}": v / n for k, v in precision.items()}
231
+
232
+ return {**basic_metrics, **recall_metrics, **precision_metrics}
233
+
234
+ def calculate_paraphrase_metrics(self, result_list: List[Dict]) -> Dict[str, float]:
235
+ """Calculate paraphrase metrics."""
236
+ seen_metrics = {'bert-score': 0, 'rouge-1': 0, 'rouge-L': 0, 'entity_preserve': 0}
237
+ unseen_metrics = {'bert-score': 0, 'rouge-1': 0, 'rouge-L': 0, 'entity_preserve': 0}
238
+
239
+ # Process seen data (first 2000)
240
+ for result in result_list[:2000]:
241
+ prediction = result['CLaRa_normal_output']
242
+ ground_truth = result['doc']
243
+
244
+ bs = self.metrics.bertscore.compute(predictions=[prediction], references=[ground_truth], lang="en")
245
+ seen_metrics['bert-score'] += bs['f1'][0]
246
+
247
+ rouge_scores = self.metrics.rouge.compute(predictions=[prediction], references=[ground_truth])
248
+ seen_metrics['rouge-1'] += rouge_scores['rouge1']
249
+ seen_metrics['rouge-L'] += rouge_scores['rougeL']
250
+
251
+ seen_metrics['entity_preserve'] += self.metrics.entity_preserve_metric(prediction, ground_truth)
252
+
253
+ # Process unseen data (after 2000)
254
+ for result in result_list[2000:]:
255
+ prediction = result['CLaRa_normal_output']
256
+ ground_truth = result['doc']
257
+
258
+ bs = self.metrics.bertscore.compute(predictions=[prediction], references=[ground_truth], lang="en")
259
+ unseen_metrics['bert-score'] += bs['f1'][0]
260
+
261
+ rouge_scores = self.metrics.rouge.compute(predictions=[prediction], references=[ground_truth])
262
+ unseen_metrics['rouge-1'] += rouge_scores['rouge1']
263
+ unseen_metrics['rouge-L'] += rouge_scores['rougeL']
264
+
265
+ unseen_metrics['entity_preserve'] += self.metrics.entity_preserve_metric(prediction, ground_truth)
266
+
267
+ # Normalize
268
+ n_seen = min(len(result_list[:2000]), 2000)
269
+ n_unseen = max(len(result_list) - 2000, 0)
270
+
271
+ final_metrics = {}
272
+ if n_seen > 0:
273
+ for key, value in seen_metrics.items():
274
+ final_metrics[f'seen_{key}'] = float(value / n_seen)
275
+
276
+ if n_unseen > 0:
277
+ for key, value in unseen_metrics.items():
278
+ final_metrics[f'unseen_{key}'] = float(value / n_unseen)
279
+
280
+ return final_metrics
281
+
282
+ def visualize_mse(self, result_list: List[Dict], save_path: str) -> Dict[str, Any]:
283
+ """Create t-SNE visualization for MSE analysis."""
284
+ # Set scientific style
285
+ plt.rcParams.update({
286
+ 'font.family': 'serif',
287
+ 'font.size': 12,
288
+ 'axes.labelsize': 14,
289
+ 'axes.titlesize': 16,
290
+ 'figure.titlesize': 18,
291
+ 'axes.linewidth': 1.2,
292
+ 'grid.alpha': 0.3,
293
+ })
294
+
295
+ # Collect representations
296
+ mem_reps = []
297
+ non_mem_reps = []
298
+
299
+ for result in result_list:
300
+ mem_rep = result['CLaRa_compressed_output']
301
+ non_mem_rep = result['CLaRa_normal_output']
302
+
303
+ if isinstance(mem_rep, torch.Tensor):
304
+ mem_rep = mem_rep.float().cpu().numpy()
305
+ if isinstance(non_mem_rep, torch.Tensor):
306
+ non_mem_rep = non_mem_rep.float().cpu().numpy()
307
+
308
+ mem_reps.append(mem_rep)
309
+ non_mem_reps.append(non_mem_rep)
310
+
311
+ mem_reps = np.array(mem_reps)
312
+ non_mem_reps = np.array(non_mem_reps)
313
+
314
+ print(f"Memory representations shape: {mem_reps.shape}")
315
+ print(f"Document representations shape: {non_mem_reps.shape}")
316
+
317
+ # Combine data for t-SNE
318
+ all_data = np.vstack([mem_reps, non_mem_reps])
319
+ original_dim = all_data.shape[1]
320
+
321
+ # PCA preprocessing if needed
322
+ if all_data.shape[1] > 50:
323
+ print(f"Applying PCA preprocessing from {all_data.shape[1]} to 50 dimensions...")
324
+ pca = PCA(n_components=50)
325
+ all_data = pca.fit_transform(all_data)
326
+ print(f"PCA explained variance ratio: {pca.explained_variance_ratio_[:5].sum():.3f}")
327
+
328
+ # Apply t-SNE
329
+ print("Applying t-SNE...")
330
+ perplexity = min(30, max(5, len(all_data) // 3))
331
+ tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity,
332
+ max_iter=1000, learning_rate=200, verbose=1)
333
+ tsne_results = tsne.fit_transform(all_data)
334
+
335
+ # Separate results
336
+ mem_tsne = tsne_results[:len(mem_reps)]
337
+ doc_tsne = tsne_results[len(mem_reps):]
338
+
339
+ # Create visualization
340
+ fig, ax = plt.subplots(1, 1, figsize=(10, 8))
341
+
342
+ # Add jitter to separate overlapping points
343
+ np.random.seed(42)
344
+ jitter_strength = 1.0
345
+
346
+ mem_jitter = mem_tsne.copy()
347
+ doc_jitter = doc_tsne.copy()
348
+
349
+ mem_jitter[:, 0] += np.random.normal(0.5, jitter_strength, len(mem_tsne))
350
+ mem_jitter[:, 1] += np.random.normal(0.5, jitter_strength, len(mem_tsne))
351
+
352
+ doc_jitter[:, 0] += np.random.normal(-0.5, jitter_strength, len(doc_tsne))
353
+ doc_jitter[:, 1] += np.random.normal(-0.5, jitter_strength, len(doc_tsne))
354
+
355
+ # Plot scatter points
356
+ ax.scatter(doc_jitter[:, 0], doc_jitter[:, 1], c='#0066CC', alpha=0.7, s=25,
357
+ marker='o', edgecolors='white', linewidth=0.5,
358
+ label='Document Representations', zorder=2)
359
+
360
+ ax.scatter(mem_jitter[:, 0], mem_jitter[:, 1], c='#FF3333', alpha=0.7, s=25,
361
+ marker='o', edgecolors='white', linewidth=0.5,
362
+ label='Memory Tokens Representations', zorder=3)
363
+
364
+ # Configure plot
365
+ ax.set_xlabel('')
366
+ ax.set_ylabel('')
367
+ ax.set_title('')
368
+
369
+ legend = ax.legend(frameon=True, fancybox=True, shadow=True,
370
+ loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=2, fontsize=14)
371
+ legend.get_frame().set_facecolor('white')
372
+ legend.get_frame().set_alpha(0.9)
373
+
374
+ ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
375
+ ax.set_axisbelow(True)
376
+
377
+ plt.tight_layout()
378
+
379
+ # Save visualization
380
+ os.makedirs(save_path, exist_ok=True)
381
+ plt.savefig(os.path.join(save_path, 'tsne_visualization_scientific.png'),
382
+ dpi=300, bbox_inches='tight', facecolor='white')
383
+ plt.show()
384
+
385
+ # Calculate statistics
386
+ distances = np.array([
387
+ np.linalg.norm(mem_reps[i] - non_mem_reps[i])
388
+ for i in range(len(mem_reps))
389
+ ])
390
+
391
+ statistics = {
392
+ 'mean_distance': float(np.mean(distances)),
393
+ 'std_distance': float(np.std(distances)),
394
+ 'median_distance': float(np.median(distances)),
395
+ 'min_distance': float(np.min(distances)),
396
+ 'max_distance': float(np.max(distances))
397
+ }
398
+
399
+ print("\n" + "="*60)
400
+ print("VISUALIZATION ANALYSIS REPORT")
401
+ print("="*60)
402
+ print(f"Dataset Statistics:")
403
+ print(f" • Total samples: {len(mem_reps)}")
404
+ print(f" • Original dimension: {original_dim}")
405
+ print(f" • t-SNE perplexity: {perplexity}")
406
+ print(f"\nDistance Analysis:")
407
+ for key, value in statistics.items():
408
+ print(f" • {key.replace('_', ' ').title()}: {value:.4f}")
409
+ print("="*60)
410
+
411
+ return {
412
+ 'mem_tsne': mem_tsne,
413
+ 'doc_tsne': doc_tsne,
414
+ 'original_distances': distances,
415
+ 'statistics': statistics
416
+ }
417
+
418
+
419
+ class DataLoader:
420
+ """Handles data loading for different datasets and stages."""
421
+
422
+ @staticmethod
423
+ def load_stage1_data(dataset: str, gold_retrieval: bool) -> List[Dict]:
424
+ """Load stage1 evaluation data."""
425
+ retrieval_type = "with_pos" if gold_retrieval else "no_pos"
426
+ file_path = f"/mnt/conductor_data/data/compression_rag_data/generator_training_val_data/stage1_eval/{dataset}/eval_processed_{retrieval_type}.jsonl"
427
+
428
+ data = []
429
+ with open(file_path, 'r') as f:
430
+ for line in f:
431
+ data.append(json.loads(line))
432
+
433
+ processed_data = []
434
+ for index, item in enumerate(data):
435
+ docs = item['docs'][:5] # Take top 5 documents
436
+ processed_item = {
437
+ 'original_data': item,
438
+ 'documents': docs,
439
+ 'question': item['question'],
440
+ 'global_index': index
441
+ }
442
+ processed_data.append(processed_item)
443
+
444
+ return processed_data
445
+
446
+ @staticmethod
447
+ def load_stage2_data(dataset: str, gold_retrieval: bool) -> List[Dict]:
448
+ """Load stage2 evaluation data."""
449
+ retrieval_type = "with_pos" if gold_retrieval else "no_pos"
450
+ file_path = f"/mnt/conductor_data/data/compression_rag_data/generator_training_val_data/stage2_eval/{dataset}/eval_processed_{retrieval_type}.jsonl"
451
+
452
+ processed_data = []
453
+ with open(file_path, 'r') as f:
454
+ for index, line in enumerate(f):
455
+ item = json.loads(line)
456
+ processed_item = {
457
+ 'original_data': item,
458
+ 'documents': item['docs'],
459
+ 'question': item['question'],
460
+ 'global_index': index,
461
+ 'pos_index': item['pos_index']
462
+ }
463
+ processed_data.append(processed_item)
464
+
465
+ return processed_data
466
+
467
+ @staticmethod
468
+ def load_paraphrase_data(file_path: str) -> List[Dict]:
469
+ """Load paraphrase data."""
470
+ data = []
471
+ with open(file_path, 'r') as f:
472
+ for line in f:
473
+ data.append(json.loads(line))
474
+
475
+ processed_data = []
476
+ for index, item in enumerate(data):
477
+ processed_item = {
478
+ 'original_data': item,
479
+ 'documents': [item['doc']],
480
+ 'question': "",
481
+ 'global_index': index
482
+ }
483
+ processed_data.append(processed_item)
484
+
485
+ return processed_data
486
+
487
+
488
+ class AcceleratedCLaRaInference:
489
+ """Main inference engine using Accelerate for distributed processing."""
490
+
491
+ def __init__(self, model_path: str, training_stage: str = None,
492
+ generation_top_k: int = None, args = None):
493
+ self.args = args
494
+
495
+ # Initialize Accelerator
496
+ process_group_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400))
497
+ self.accelerator = Accelerator(kwargs_handlers=[process_group_kwargs])
498
+
499
+ if self.accelerator.is_main_process:
500
+ print(f"Using {self.accelerator.num_processes} GPUs for distributed inference")
501
+ print(f"Current process: {self.accelerator.process_index}")
502
+ print("Loading CLaRa model...")
503
+
504
+ # Load model
505
+ self.model = CLaRa.from_pretrained(
506
+ model_path,
507
+ training_stage=training_stage,
508
+ generation_top_k=generation_top_k,
509
+ pure_inference=True
510
+ )
511
+
512
+ # Prepare model with Accelerator
513
+ self.model = self.accelerator.prepare(self.model)
514
+ self.model.eval()
515
+
516
+ if self.accelerator.is_main_process:
517
+ print("Model preparation completed")
518
+
519
+ def _get_model(self):
520
+ """Get the actual model (handles distributed vs single GPU)."""
521
+ return self.model.module if hasattr(self.model, 'module') else self.model
522
+
523
+ def process_batch(self, batch_questions: List[str], batch_documents: List[List[str]] = None,
524
+ stage2_mips: bool = False, training_stage: str = None,
525
+ batch_answers: List[str] = None, time_count: bool = False) -> Tuple:
526
+ """Process a batch of questions and documents."""
527
+ model = self._get_model()
528
+
529
+ with torch.no_grad():
530
+ try:
531
+ if training_stage == 'stage2':
532
+ return self._process_stage2(model, batch_questions, batch_documents,
533
+ stage2_mips, time_count)
534
+ elif training_stage in ['stage1', 'stage1_2']:
535
+ return self._process_stage1(model, batch_questions, batch_documents)
536
+ elif training_stage == 'stage2_reasoning':
537
+ return self._process_reasoning(model, batch_questions, batch_answers)
538
+ elif training_stage == 'stage1_paraphrase':
539
+ return self._process_paraphrase(model, batch_questions, batch_documents)
540
+ elif training_stage == 'stage1_mse_visulize':
541
+ return self._process_mse_visualize(model, batch_documents)
542
+ else:
543
+ raise ValueError(f"Unknown training stage: {training_stage}")
544
+
545
+ except torch.cuda.OutOfMemoryError as e:
546
+ self.accelerator.print(f"CUDA OOM error: {e}")
547
+ torch.cuda.empty_cache()
548
+ gc.collect()
549
+ return self._create_empty_results(batch_questions, training_stage)
550
+
551
+ def _process_stage2(self, model, batch_questions, batch_documents, stage2_mips, time_count):
552
+ """Process stage2 inference."""
553
+ if time_count:
554
+ if stage2_mips:
555
+ results = model.generate_from_questions(
556
+ questions=batch_questions,
557
+ max_new_tokens=64,
558
+ stage2_mips=stage2_mips,
559
+ time_count=True
560
+ )
561
+ else:
562
+ results = model.generate_from_questions(
563
+ questions=batch_questions,
564
+ max_new_tokens=64,
565
+ stage2_mips=stage2_mips,
566
+ documents=batch_documents,
567
+ time_count=True
568
+ )
569
+ return results
570
+ else:
571
+ if stage2_mips:
572
+ batch_out_normal, topk_idx = model.generate_from_questions(
573
+ questions=batch_questions,
574
+ max_new_tokens=64,
575
+ stage2_mips=stage2_mips
576
+ )
577
+ else:
578
+ batch_out_normal, topk_idx = model.generate_from_questions(
579
+ questions=batch_questions,
580
+ max_new_tokens=64,
581
+ stage2_mips=stage2_mips,
582
+ documents=batch_documents
583
+ )
584
+ return batch_out_normal, batch_out_normal, topk_idx
585
+
586
+ def _process_stage1(self, model, batch_questions, batch_documents):
587
+ """Process stage1 inference."""
588
+ batch_out_compressed = []
589
+
590
+ for docs, question in zip(batch_documents, batch_questions):
591
+ embeddings, _ = model.compress_documents(documents=docs)
592
+ out_compressed = model.generate_from_compressed_documents_and_questions(
593
+ questions=[question],
594
+ compressed_documents=embeddings
595
+ )
596
+ batch_out_compressed.extend(out_compressed)
597
+
598
+ del embeddings
599
+ torch.cuda.empty_cache()
600
+
601
+ return batch_out_compressed, batch_out_compressed, None
602
+
603
+ def _process_reasoning(self, model, batch_questions, batch_answers):
604
+ """Process reasoning inference."""
605
+ batch_out_normal = []
606
+ batch_out_reasoning_list = []
607
+
608
+ for question, answer in zip(batch_questions, batch_answers):
609
+ temp_out, temp_out_reasoning = model.generate_from_reasoning(
610
+ questions=[question],
611
+ max_new_tokens=1024,
612
+ answers=[answer],
613
+ save_dir=self.args.model_path
614
+ )
615
+ batch_out_normal.append(temp_out[0])
616
+ batch_out_reasoning_list.extend(temp_out_reasoning)
617
+
618
+ return batch_out_normal, batch_out_normal, None, batch_out_reasoning_list
619
+
620
+ def _process_paraphrase(self, model, batch_questions, batch_documents):
621
+ """Process paraphrase inference."""
622
+ batch_out_compressed = []
623
+
624
+ for docs, question in zip(batch_documents, batch_questions):
625
+ out_compressed = model.generate_from_paraphrase(
626
+ questions=["" for _ in range(len(docs))],
627
+ documents=[docs]
628
+ )
629
+ batch_out_compressed.extend(out_compressed)
630
+ torch.cuda.empty_cache()
631
+
632
+ return batch_out_compressed, batch_out_compressed, None
633
+
634
+ def _process_mse_visualize(self, model, batch_documents):
635
+ """Process MSE visualization."""
636
+ batch_out_normal = []
637
+ batch_out_compressed = []
638
+
639
+ for docs in batch_documents:
640
+ mem_rep, non_mem_rep = model.compress_documents_mse_visulize(documents=docs)
641
+ batch_out_compressed.append(mem_rep[0])
642
+ batch_out_normal.append(non_mem_rep[0])
643
+
644
+ return batch_out_normal, batch_out_compressed
645
+
646
+ def _create_empty_results(self, batch_questions, training_stage):
647
+ """Create empty results for error cases."""
648
+ empty_results = [""] * len(batch_questions)
649
+ if training_stage == 'stage2_reasoning':
650
+ return empty_results, empty_results, None, empty_results
651
+ elif training_stage == 'stage1_mse_visulize':
652
+ return empty_results, empty_results
653
+ else:
654
+ return empty_results, empty_results, None
655
+
656
+
657
+ def convert_embeddings_to_list(data):
658
+ """Convert tensor embeddings to lists for JSON serialization."""
659
+ if isinstance(data, dict):
660
+ return {k: convert_embeddings_to_list(v) for k, v in data.items()}
661
+ elif isinstance(data, list):
662
+ return [convert_embeddings_to_list(item) for item in data]
663
+ elif isinstance(data, torch.Tensor):
664
+ return data.cpu().to(torch.float32).numpy().tolist()
665
+ elif isinstance(data, np.ndarray):
666
+ return data.tolist()
667
+ else:
668
+ return data
669
+
670
+
671
+ def main():
672
+ parser = argparse.ArgumentParser(description="CLaRa Model Inference")
673
+ parser.add_argument('--model_path', type=str, required=True, help='Path to model checkpoint')
674
+ parser.add_argument('--batch_size', type=int, default=4, help='Batch size per GPU')
675
+ parser.add_argument('--stage', type=str, default='stage1',
676
+ choices=['stage1', 'stage1_2', 'stage2', 'stage2_reasoning',
677
+ 'stage1_paraphrase', 'stage1_mse_visulize'],
678
+ help='Training stage')
679
+ parser.add_argument('--stage2_mips', action='store_true', help='Use MIPS for stage2')
680
+ parser.add_argument('--dataset', type=str, default='musique',
681
+ help='Comma-separated list of datasets')
682
+ parser.add_argument('--gold_retrieval', action='store_true',
683
+ help='Use gold retrieval context')
684
+ parser.add_argument('--generation_top_k', type=int, default=5, help='Top-k for generation')
685
+ parser.add_argument('--paraphrase_path', type=str, help='Path to paraphrase data')
686
+ parser.add_argument('--mse_visulize_path', type=str, help='Path to save MSE visualization')
687
+ parser.add_argument('--efficient_count', action='store_true', help='Count efficiency metrics')
688
+
689
+ args = parser.parse_args()
690
+
691
+ # Process datasets
692
+ all_results_metrics = {}
693
+ datasets_list = args.dataset.split(',')
694
+
695
+ for dataset in datasets_list:
696
+ print(f"Processing dataset: {dataset}")
697
+
698
+ # Load data based on stage
699
+ if args.stage in ['stage1', 'stage1_2']:
700
+ processed_data = DataLoader.load_stage1_data(dataset, args.gold_retrieval)
701
+ elif args.stage == 'stage2':
702
+ processed_data = DataLoader.load_stage2_data(dataset, args.gold_retrieval)
703
+ elif args.stage in ['stage1_paraphrase', 'stage1_mse_visulize']:
704
+ if not args.paraphrase_path:
705
+ raise ValueError(f"--paraphrase_path required for stage {args.stage}")
706
+ processed_data = DataLoader.load_paraphrase_data(args.paraphrase_path)
707
+ else:
708
+ raise ValueError(f"Unsupported stage: {args.stage}")
709
+
710
+ print(f"Loaded {len(processed_data)} samples for {dataset}")
711
+
712
+ # Initialize inference engine
713
+ # Use model_path directly if absolute, otherwise use SageMaker path
714
+ if os.path.isabs(args.model_path):
715
+ model_path = args.model_path
716
+ else:
717
+ model_path = os.path.join('/mnt/task_wrapper/user_output/artifacts/data/train_checkpoint', args.model_path)
718
+ args.model_path = model_path
719
+
720
+ inference_engine = AcceleratedCLaRaInference(
721
+ model_path=model_path,
722
+ training_stage=args.stage,
723
+ generation_top_k=args.generation_top_k,
724
+ args=args
725
+ )
726
+
727
+ # Wait for all processes to be ready
728
+ inference_engine.accelerator.wait_for_everyone()
729
+
730
+ # Store results
731
+ all_results = []
732
+ time_count_dic = {"compress_time": 0, "query_time": 0, "generate_time": 0, "total_time": 0, "count": 0}
733
+
734
+ # Process data in batches using accelerator
735
+ with inference_engine.accelerator.split_between_processes(processed_data, apply_padding=False) as local_data:
736
+ print(f"Process {inference_engine.accelerator.process_index}: processing {len(local_data)} samples")
737
+
738
+ batch_size = args.batch_size
739
+ num_batches = (len(local_data) + batch_size - 1) // batch_size
740
+
741
+ for batch_idx in tqdm(range(num_batches),
742
+ desc=f"GPU {inference_engine.accelerator.process_index}",
743
+ disable=not inference_engine.accelerator.is_local_main_process):
744
+
745
+ # Get current batch
746
+ start_idx = batch_idx * batch_size
747
+ end_idx = min(start_idx + batch_size, len(local_data))
748
+ batch = local_data[start_idx:end_idx]
749
+
750
+ # Prepare batch data
751
+ batch_questions = [item['question'] for item in batch]
752
+ batch_documents = [item['documents'] for item in batch] if 'documents' in batch[0] else None
753
+ batch_answers = [item.get('answer') for item in batch] if args.stage == 'stage2_reasoning' else None
754
+
755
+ # Process batch
756
+ if args.efficient_count and args.stage == 'stage2':
757
+ results = inference_engine.process_batch(
758
+ batch_questions=batch_questions,
759
+ batch_documents=batch_documents,
760
+ stage2_mips=args.stage2_mips,
761
+ training_stage=args.stage,
762
+ time_count=True
763
+ )
764
+ batch_out_normal, batch_out_compressed, batch_topk_idx, compress_time, query_time, generate_time, total_time = results
765
+
766
+ time_count_dic["compress_time"] += compress_time
767
+ time_count_dic["query_time"] += query_time
768
+ time_count_dic["generate_time"] += generate_time
769
+ time_count_dic["total_time"] += total_time
770
+ time_count_dic["count"] += 1
771
+ else:
772
+ results = inference_engine.process_batch(
773
+ batch_questions=batch_questions,
774
+ batch_documents=batch_documents,
775
+ stage2_mips=args.stage2_mips,
776
+ training_stage=args.stage,
777
+ batch_answers=batch_answers
778
+ )
779
+
780
+ if args.stage == 'stage2_reasoning':
781
+ batch_out_normal, batch_out_compressed, batch_topk_idx, batch_out_reasoning = results
782
+ elif args.stage == 'stage1_mse_visulize':
783
+ batch_out_normal, batch_out_compressed = results
784
+ batch_topk_idx = None
785
+ else:
786
+ batch_out_normal, batch_out_compressed, batch_topk_idx = results
787
+
788
+ # Prepare results
789
+ batch_results = []
790
+ for i, (item, normal_out, compressed_out) in enumerate(zip(batch, batch_out_normal, batch_out_compressed)):
791
+ result_item = item['original_data'].copy()
792
+ result_item['CLaRa_normal_output'] = normal_out
793
+ result_item['CLaRa_compressed_output'] = compressed_out
794
+ result_item['global_index'] = item['global_index']
795
+
796
+ if args.stage == 'stage2' and batch_topk_idx is not None:
797
+ result_item['topk_idx'] = batch_topk_idx[i].tolist()
798
+ elif args.stage == 'stage2_reasoning':
799
+ result_item['reasoning_output'] = batch_out_reasoning[i]
800
+
801
+ batch_results.append(result_item)
802
+
803
+ all_results.extend(batch_results)
804
+
805
+
806
+ # Clean up memory
807
+ torch.cuda.empty_cache()
808
+ if batch_idx % 10 == 0:
809
+ gc.collect()
810
+
811
+ # Save efficiency metrics if requested
812
+ if args.efficient_count and inference_engine.accelerator.is_main_process:
813
+ eff_dic = {
814
+ "compress_time_ms": round((time_count_dic['compress_time'] / time_count_dic['count']) * 1000, 2),
815
+ "query_time_ms": round((time_count_dic['query_time'] / time_count_dic['count']) * 1000, 2),
816
+ "generate_time_ms": round((time_count_dic['generate_time'] / time_count_dic['count']) * 1000, 2),
817
+ "total_time_ms": round((time_count_dic['total_time'] / time_count_dic['count']) * 1000, 2),
818
+ "sample_count": time_count_dic['count']
819
+ }
820
+ eff_output_path = os.path.join(model_path, f"efficiency_{dataset}_{args.stage}_{args.gold_retrieval}_{args.generation_top_k}.json")
821
+ with open(eff_output_path, 'w') as f:
822
+ json.dump(eff_dic, f, indent=2)
823
+
824
+ # Wait for all processes to complete
825
+ inference_engine.accelerator.wait_for_everyone()
826
+
827
+ # Gather results from all processes
828
+ if inference_engine.accelerator.is_main_process:
829
+ print("Collecting results from all processes...")
830
+
831
+ all_results_gathered = inference_engine.accelerator.gather_for_metrics(all_results)
832
+
833
+ # Process and save results (main process only)
834
+ if inference_engine.accelerator.is_main_process:
835
+ print("Processing and saving results...")
836
+
837
+ # Flatten results
838
+ final_results = []
839
+ if isinstance(all_results_gathered, list):
840
+ for result_batch in all_results_gathered:
841
+ if isinstance(result_batch, list):
842
+ final_results.extend(result_batch)
843
+ else:
844
+ final_results.append(result_batch)
845
+
846
+ print(f"Collected {len(final_results)} results")
847
+
848
+ # Sort by global index to maintain order
849
+ final_results.sort(key=lambda x: x.get('global_index', 0))
850
+
851
+ # Verify data integrity
852
+ processed_indices = set(item.get('global_index', -1) for item in final_results)
853
+ expected_indices = set(range(len(processed_data)))
854
+ missing_indices = expected_indices - processed_indices
855
+
856
+ if missing_indices:
857
+ print(f"Warning: Missing indices: {sorted(list(missing_indices))}")
858
+ else:
859
+ print("✓ Data integrity verification passed")
860
+
861
+ # Remove global index for clean output
862
+ for item in final_results:
863
+ item.pop('global_index', None)
864
+
865
+ # Save results
866
+ output_path = os.path.join(model_path, f"{dataset}_{args.stage}_{args.gold_retrieval}_{args.generation_top_k}.jsonl")
867
+ with open(output_path, 'w') as f:
868
+ if args.stage == 'stage1_mse_visulize':
869
+ converted_results = convert_embeddings_to_list(final_results)
870
+ for item in converted_results:
871
+ f.write(json.dumps(item) + '\n')
872
+ else:
873
+ for item in final_results:
874
+ f.write(json.dumps(item) + '\n')
875
+
876
+ print(f"Results saved to: {output_path}")
877
+
878
+ # Calculate metrics
879
+ calculator = ResultCalculator()
880
+
881
+ if args.stage == 'stage2':
882
+ metrics = calculator.calculate_stage2_metrics(final_results)
883
+ elif args.stage == 'stage1_paraphrase':
884
+ metrics = calculator.calculate_paraphrase_metrics(final_results)
885
+ elif args.stage == 'stage1_mse_visulize':
886
+ if args.mse_visulize_path:
887
+ metrics = calculator.visualize_mse(final_results, args.mse_visulize_path)
888
+ else:
889
+ metrics = {"visualization": "completed"}
890
+ else:
891
+ metrics = calculator.calculate_basic_metrics(final_results)
892
+
893
+ print(f"Metrics for {dataset}: {metrics}")
894
+ all_results_metrics[dataset] = metrics
895
+
896
+ # Clean up
897
+ del inference_engine
898
+ torch.cuda.empty_cache()
899
+ gc.collect()
900
+
901
+ # Save final metrics
902
+ if len(all_results_metrics) > 0:
903
+ metrics_path = os.path.join(model_path, f"results_metrics_{args.stage}_{args.gold_retrieval}_{args.generation_top_k}.json")
904
+ with open(metrics_path, 'w') as f:
905
+ json.dump(all_results_metrics, f, indent=2)
906
+ print(f"Final metrics saved to: {metrics_path}")
907
+
908
+
909
+ if __name__ == '__main__':
910
+ main()
evaluation/evaluation_data/end_to_end_evaluation/2wiki.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39cf44bcfa24938c40617ef5bba90235642bf02f537297f4055b3f6bc756846c
3
+ size 93670063
evaluation/evaluation_data/end_to_end_evaluation/hotpotqa.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f46d7cfc23199f6cdff5e3ce1872ff150e6d940eb83b343cf37431cd740fa4db
3
+ size 61751762
evaluation/evaluation_data/end_to_end_evaluation/musique.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85a55afc5c6067d00eef1888e13b598039a515f787791b23fbb495c35827e264
3
+ size 18789210
evaluation/evaluation_data/end_to_end_evaluation/nq.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d26d5c29694cd81cccfcac4fd29c16ae7f245b4c554623cbe3c6ec8c3a0ad41
3
+ size 60057585
evaluation/evaluation_data/instruction_tuning_evaluation/2wiki.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:690a91abab47ebb8e32f335b2f1e31f40a1b2e78452988c7bffd0c30cc5f5463
3
+ size 93670063
evaluation/evaluation_data/instruction_tuning_evaluation/hotpotqa.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:808713253d9a0821c43800b16e56fd76a9b046b42088afb217630c735196b4e4
3
+ size 61751762
evaluation/evaluation_data/instruction_tuning_evaluation/musique.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:171fb546901128d338de9a343080e6e41ca73c3f8246c7c0d270330f06cef0d9
3
+ size 18789210
evaluation/evaluation_data/instruction_tuning_evaluation/nq.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:225bc8fd6f5b5b3156f42952602fc296a7a390bf873de24ceb0328ceb61eabfd
3
+ size 60057585
example/end_to_end_data.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bf9c2e07e6c833288c7041a93dfa7fd3cf41e34b22800d97aafbd93c78d3597
3
+ size 13128781
example/instruction_tuning_data.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
example/pretrain_data.jsonl ADDED
The diff for this file is too large to render. See raw diff