Upload 47 files
Browse files- MACI-main/LICENSE +21 -0
- MACI-main/README.md +46 -0
- MACI-main/conditional-conformal/conditionalconformal/__init__.py +1 -0
- MACI-main/conditional-conformal/conditionalconformal/condconf.py +877 -0
- MACI-main/conditional-conformal/conditionalconformal/experiment_utils.py +182 -0
- MACI-main/conditional-conformal/conditionalconformal/synthetic_data.py +55 -0
- MACI-main/conditional-conformal/src/atomizer.py +347 -0
- MACI-main/conditional-conformal/src/aws_utils.py +15 -0
- MACI-main/conditional-conformal/src/client.py +89 -0
- MACI-main/conditional-conformal/src/config.py +7 -0
- MACI-main/conditional-conformal/src/conformal.py +68 -0
- MACI-main/conditional-conformal/src/data_utils/sample_names.py +86 -0
- MACI-main/conditional-conformal/src/dataset.py +279 -0
- MACI-main/conditional-conformal/src/featurizer.py +352 -0
- MACI-main/conditional-conformal/src/gpt.py +58 -0
- MACI-main/conditional-conformal/src/llm_utils.py +111 -0
- MACI-main/conditional-conformal/src/postprocess_factscore.py +34 -0
- MACI-main/conditional-conformal/src/prob_model.py +101 -0
- MACI-main/conditional-conformal/src/query.py +112 -0
- MACI-main/conditional-conformal/src/ray_data.py +192 -0
- MACI-main/conditional-conformal/src/retrieval.py +268 -0
- MACI-main/conditional-conformal/src/retrieve_data.py +86 -0
- MACI-main/conditional-conformal/src/run.py +119 -0
- MACI-main/conditional-conformal/src/scorer.py +202 -0
- MACI-main/conformal/__pycache__/adaptive_conformal.cpython-39.pyc +0 -0
- MACI-main/conformal/__pycache__/basic_conformal.cpython-39.pyc +0 -0
- MACI-main/conformal/__pycache__/conditional_conformal.cpython-39.pyc +0 -0
- MACI-main/conformal/adaptive_conformal.py +403 -0
- MACI-main/conformal/basic_conformal.py +189 -0
- MACI-main/conformal/conditional_conformal.py +489 -0
- MACI-main/data/med_scores/medlfqa_frequencies.npz +3 -0
- MACI-main/data/med_scores/medlfqa_logprobs.npz +3 -0
- MACI-main/data/med_scores/medlfqa_scores_deepseek_deepseek-chat-v3-0324.npz +3 -0
- MACI-main/data/med_scores/medlfqa_scores_meta-llama_llama-3.3-70b-instruct.npz +3 -0
- MACI-main/data/med_scores/medlfqa_scores_qwen_qwen-2.5-72b-instruct.npz +3 -0
- MACI-main/data/med_scores/medlfqa_selfevals.npz +3 -0
- MACI-main/data/wiki_scores/wikibio_final.csv +0 -0
- MACI-main/data/wiki_scores/wikibio_final_dataset.pkl +3 -0
- MACI-main/data/wiki_scores/wikibio_final_frequencies.npz +3 -0
- MACI-main/data/wiki_scores/wikibio_final_logprobs.npz +3 -0
- MACI-main/data/wiki_scores/wikibio_final_self_evals.npz +3 -0
- MACI-main/data/wiki_scores/wikibio_scores_deepseek-chat-v3-0324.npz +3 -0
- MACI-main/data/wiki_scores/wikibio_scores_meta-llama_llama-3.3-70b-instruct.npz +3 -0
- MACI-main/data/wiki_scores/wikibio_scores_qwen_qwen-2.5-72b-instruct.npz +3 -0
- MACI-main/experiments/conditional_groupers.py +542 -0
- MACI-main/experiments/run_experiment.py +1127 -0
- MACI-main/requirements.txt +12 -0
MACI-main/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Anonymous2026conf
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
MACI-main/README.md
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MACI
|
| 2 |
+
This repository contains an anonymized version of our Multi-LLM Adaptive Conformal Inference experiments. The entry point is `experiments/run_experiment.py`.
|
| 3 |
+
|
| 4 |
+
## Abstract
|
| 5 |
+
|
| 6 |
+
Ensuring factuality is essential for the safe use of Large Language Models (LLMs) in high-stakes domains such as medicine and law. Conformal inference provides distribution-free guarantees, but existing approaches are either overly conservative, discarding many true-claims, or rely on adaptive error rates and simple linear models that fail to capture complex group structures. To address these challenges, we reformulate conformal inference in a multiplicative filtering setting, modeling factuality as a product of claim-level scores. Our method, Multi-LLM Adaptive Conformal Inference (MACI), leverages ensembles to produce more accurate factuality-scores, which in our experiments led to higher retention, while validity is preserved through group-conditional calibration. Experiments show that MACI consistently achieves user-specified coverage with substantially higher retention and lower time cost than baselines.
|
| 7 |
+
|
| 8 |
+
## Running
|
| 9 |
+
|
| 10 |
+
Step 1) Create a fresh Conda environment (Python 3.9)
|
| 11 |
+
|
| 12 |
+
```bash
|
| 13 |
+
conda create -y -n maci python=3.9
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
Step 2) Install dependencies from requirements.txt
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
conda run -n maci \
|
| 20 |
+
python -m pip install -r requirements.txt --no-input
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
Step 3) Prepare data layout (repo-relative defaults)
|
| 24 |
+
|
| 25 |
+
- Place data under `data/` in the repository root (or pass `--data-dir`).
|
| 26 |
+
- For MedLFQA: put files under `data/med_scores/`.
|
| 27 |
+
- For WikiBio: put files under `data/wiki_scores/`.
|
| 28 |
+
|
| 29 |
+
Step 4) Run a quick experiment (MedLFQA example)
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
conda run -n maci \
|
| 33 |
+
python experiments/run_experiment.py \
|
| 34 |
+
--dataset-type medlfqa \
|
| 35 |
+
--conditional-groups false_claim_risk \
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
Step 5) Where outputs go
|
| 39 |
+
|
| 40 |
+
- Logs: `logs/` (repo-root-relative by default)
|
| 41 |
+
- Results JSON: `analysis/experiment_results/`
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
## CCI Attribution
|
| 46 |
+
Our implementation of the Conditional Conformal Inference (CCI) baseline is a direct adoption of the work from the [conformal-safety](https://github.com/jjcherian/conformal-safety.git) repository. To ensure full reproducibility, we have included a local copy of the necessary modules in the conditional-conformal/ directory. We explicitly state that the code within this directory is not the work of the MACI project. For all details, please refer to the original repository: [conformal-safety](https://github.com/jjcherian/conformal-safety.git)
|
MACI-main/conditional-conformal/conditionalconformal/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .condconf import CondConf
|
MACI-main/conditional-conformal/conditionalconformal/condconf.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cvxpy as cp
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from functools import partial, lru_cache
|
| 5 |
+
from scipy.optimize import linprog
|
| 6 |
+
from sklearn.metrics.pairwise import pairwise_kernels
|
| 7 |
+
from typing import Callable
|
| 8 |
+
|
| 9 |
+
FUNCTION_DEFAULTS = {"kernel": None, "gamma" : 1, "lambda": 1}
|
| 10 |
+
|
| 11 |
+
class CondConf:
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
score_fn : Callable,
|
| 15 |
+
Phi_fn : Callable,
|
| 16 |
+
quantile_fn : Callable = None,
|
| 17 |
+
infinite_params : dict = {},
|
| 18 |
+
seed : int = 0
|
| 19 |
+
):
|
| 20 |
+
"""
|
| 21 |
+
Constructs the CondConf object that caches relevant information for
|
| 22 |
+
generating conditionally valid prediction sets.
|
| 23 |
+
|
| 24 |
+
We define the score function and set of conditional guarantees
|
| 25 |
+
that we care about in this function.
|
| 26 |
+
|
| 27 |
+
Parameters
|
| 28 |
+
---------
|
| 29 |
+
score_fn : Callable[np.ndarray, np.ndarray] -> np.ndarray
|
| 30 |
+
Fixed (vectorized) conformity score function that takes in
|
| 31 |
+
X and Y as inputs and returns S as output
|
| 32 |
+
|
| 33 |
+
Phi_fn : Callable[np.ndarray] -> np.ndarray
|
| 34 |
+
Function that defines finite basis set that we provide
|
| 35 |
+
exact conditional guarantees over
|
| 36 |
+
|
| 37 |
+
infinite_params : dict = {}
|
| 38 |
+
Dictionary containing parameters for the RKHS component of the fit
|
| 39 |
+
Valid keys are ('kernel', 'gamma', 'lambda')
|
| 40 |
+
'kernel' should be a valid kernel name for sklearn.metrics.pairwise_kernels
|
| 41 |
+
'gamma' is a hyperparameter for certain kernels
|
| 42 |
+
'lambda' is the regularization penalty applied to the RKHS component
|
| 43 |
+
"""
|
| 44 |
+
self.score_fn = score_fn
|
| 45 |
+
self.Phi_fn = Phi_fn
|
| 46 |
+
self.quantile_fn = quantile_fn
|
| 47 |
+
self.infinite_params = infinite_params
|
| 48 |
+
self.rng = np.random.default_rng(seed=seed)
|
| 49 |
+
|
| 50 |
+
def setup_problem(
|
| 51 |
+
self,
|
| 52 |
+
x_calib : np.ndarray,
|
| 53 |
+
y_calib : np.ndarray
|
| 54 |
+
):
|
| 55 |
+
"""
|
| 56 |
+
setup_problem sets up the final fitting problem for a
|
| 57 |
+
particular calibration set
|
| 58 |
+
|
| 59 |
+
The resulting cvxpy Problem object is stored inside the CondConf parent.
|
| 60 |
+
|
| 61 |
+
Arguments
|
| 62 |
+
---------
|
| 63 |
+
x_calib : np.ndarray
|
| 64 |
+
Covariate data for the calibration set
|
| 65 |
+
|
| 66 |
+
y_calib : np.ndarray
|
| 67 |
+
Labels for the calibration set
|
| 68 |
+
"""
|
| 69 |
+
self.x_calib = x_calib
|
| 70 |
+
self.y_calib = y_calib
|
| 71 |
+
phi_calib = self.Phi_fn(x_calib)
|
| 72 |
+
|
| 73 |
+
_, s, Vt = np.linalg.svd(phi_calib, full_matrices=False)
|
| 74 |
+
|
| 75 |
+
# Set a tolerance to decide which singular values are nonzero
|
| 76 |
+
tol = 1e-10
|
| 77 |
+
r = np.sum(s > tol)
|
| 78 |
+
|
| 79 |
+
if r < len(s):
|
| 80 |
+
self.Phi_fn_orig = self.Phi_fn
|
| 81 |
+
T = Vt.T[:, :r]
|
| 82 |
+
self.Phi_fn = lambda x: (self.Phi_fn_orig(x) @ T)
|
| 83 |
+
phi_calib = self.Phi_fn(x_calib)
|
| 84 |
+
|
| 85 |
+
self.phi_calib = phi_calib
|
| 86 |
+
self.scores_calib = self.score_fn(x_calib, y_calib)
|
| 87 |
+
|
| 88 |
+
if self.quantile_fn is not None:
|
| 89 |
+
self.quantile_calib = self.quantile_fn(x_calib).reshape(-1,1)
|
| 90 |
+
|
| 91 |
+
self.cvx_problem = setup_cvx_problem(
|
| 92 |
+
self.x_calib,
|
| 93 |
+
self.scores_calib,
|
| 94 |
+
self.phi_calib,
|
| 95 |
+
self.infinite_params
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@lru_cache()
|
| 100 |
+
def _get_calibration_solution(
|
| 101 |
+
self,
|
| 102 |
+
quantile : float
|
| 103 |
+
):
|
| 104 |
+
S = self.scores_calib.reshape(-1,1)
|
| 105 |
+
Phi = self.phi_calib.astype(float)
|
| 106 |
+
zeros = np.zeros((Phi.shape[1],))
|
| 107 |
+
|
| 108 |
+
if quantile is None:
|
| 109 |
+
bounds = np.concatenate((self.quantile_calib - 1, self.quantile_calib), axis=1)
|
| 110 |
+
else:
|
| 111 |
+
bounds = np.asarray([quantile - 1, quantile])
|
| 112 |
+
bounds = np.tile(bounds.reshape(1,-1), (len(S), 1))
|
| 113 |
+
|
| 114 |
+
res = linprog(-1 * S, A_eq=Phi.T, b_eq=zeros, bounds=bounds, method='highs')
|
| 115 |
+
primal_vars = -1 * res.eqlin.marginals.reshape(-1,1)
|
| 116 |
+
dual_vars = res.x.reshape(-1,1)
|
| 117 |
+
|
| 118 |
+
residuals = S - (Phi @ primal_vars)
|
| 119 |
+
interpolated_pts = np.isclose(residuals, 0)
|
| 120 |
+
|
| 121 |
+
# if I didn't converge to a solution that interpolates at least Phi.shape[1] pts,
|
| 122 |
+
# I need to manually find one via a modified simplex iteration
|
| 123 |
+
if interpolated_pts.sum() < Phi.shape[1]:
|
| 124 |
+
num_to_add = Phi.shape[1] - interpolated_pts.sum()
|
| 125 |
+
for _ in range(num_to_add):
|
| 126 |
+
candidate_pts = interpolated_pts.copy().flatten()
|
| 127 |
+
|
| 128 |
+
# find candidate idx for interpolation, e.g., new covariate that is
|
| 129 |
+
# linearly independent of the previously interpolated points
|
| 130 |
+
Q, _ = np.linalg.qr(Phi[candidate_pts].T)
|
| 131 |
+
projections = Phi @ Q @ Q.T
|
| 132 |
+
norms = np.linalg.norm(Phi - projections, axis=1)
|
| 133 |
+
candidate_idx = np.where(norms > 1e-5)[0][0]
|
| 134 |
+
candidate_pts[candidate_idx] = True
|
| 135 |
+
|
| 136 |
+
# find direction to solution that would interpolate the new point
|
| 137 |
+
gamma, _, _, _ = np.linalg.lstsq(Phi[candidate_pts], S[candidate_pts], rcond=None)
|
| 138 |
+
direction = gamma.reshape(-1,1) - primal_vars
|
| 139 |
+
step_sizes = residuals / (Phi @ direction)
|
| 140 |
+
|
| 141 |
+
# check the non-basic indices for which a step in this direction could have led to interpolation
|
| 142 |
+
# e.g., those for which the step size is positive and the point is not already interpolated
|
| 143 |
+
positive_indices = np.where((step_sizes > 0) & ~interpolated_pts)[0]
|
| 144 |
+
|
| 145 |
+
# take smallest possible step that would lead to interpolation
|
| 146 |
+
primal_vars += np.min(step_sizes[positive_indices]) * direction
|
| 147 |
+
|
| 148 |
+
residuals = S - (Phi @ primal_vars)
|
| 149 |
+
interpolated_pts = np.isclose(residuals, 0)
|
| 150 |
+
|
| 151 |
+
return dual_vars, primal_vars
|
| 152 |
+
|
| 153 |
+
def _compute_exact_cutoff(
|
| 154 |
+
self,
|
| 155 |
+
quantiles,
|
| 156 |
+
primals,
|
| 157 |
+
duals,
|
| 158 |
+
phi_test,
|
| 159 |
+
dual_threshold
|
| 160 |
+
):
|
| 161 |
+
def get_current_basis(primals, duals, Phi, S, quantiles):
|
| 162 |
+
interp_bools = np.logical_and(~np.isclose(duals, quantiles - 1), ~np.isclose(duals, quantiles))
|
| 163 |
+
if np.sum(interp_bools) == Phi.shape[1]:
|
| 164 |
+
return interp_bools
|
| 165 |
+
preds = (Phi @ primals).flatten()
|
| 166 |
+
active_indices = np.where(interp_bools)[0]
|
| 167 |
+
interp_indices = np.where(np.isclose(np.abs(S - preds), 0))[0]
|
| 168 |
+
diff_indices = np.setdiff1d(interp_indices, active_indices)
|
| 169 |
+
num_missing = Phi.shape[1] - np.sum(interp_bools)
|
| 170 |
+
if num_missing < len(diff_indices):
|
| 171 |
+
from itertools import combinations
|
| 172 |
+
for cand_indices in combinations(diff_indices, num_missing):
|
| 173 |
+
cand_phi = Phi[np.concatenate((active_indices, cand_indices))]
|
| 174 |
+
if np.isfinite(np.linalg.cond(cand_phi)):
|
| 175 |
+
interp_bools[np.asarray(cand_indices)] = True
|
| 176 |
+
break
|
| 177 |
+
else:
|
| 178 |
+
interp_bools[diff_indices] = True
|
| 179 |
+
if np.sum(interp_bools) != Phi.shape[1]:
|
| 180 |
+
raise ValueError("Initial basis could not be found - retry with exact=False.")
|
| 181 |
+
return interp_bools
|
| 182 |
+
|
| 183 |
+
if np.allclose(phi_test, 0):
|
| 184 |
+
return np.inf if quantiles[-1] >= 0.5 else -np.inf
|
| 185 |
+
|
| 186 |
+
basis = get_current_basis(primals, duals, self.phi_calib, self.scores_calib, quantiles[:-1])
|
| 187 |
+
S_test = phi_test @ primals
|
| 188 |
+
|
| 189 |
+
duals = np.concatenate((duals.flatten(), [0]))
|
| 190 |
+
basis = np.concatenate((basis.flatten(), [False]))
|
| 191 |
+
phi = np.concatenate((self.phi_calib, phi_test.reshape(1,-1)), axis=0)
|
| 192 |
+
S = np.concatenate((self.scores_calib.reshape(-1,1), S_test.reshape(-1,1)), axis=0)
|
| 193 |
+
|
| 194 |
+
candidate_idx = phi.shape[0] - 1
|
| 195 |
+
num_iters = 0
|
| 196 |
+
while True:
|
| 197 |
+
# get direction vector for dual variable step
|
| 198 |
+
direction = -1 * np.linalg.solve(phi[basis].T, phi[candidate_idx].reshape(-1,1)).flatten()
|
| 199 |
+
|
| 200 |
+
# only consider non-zero entries of the direction vector
|
| 201 |
+
active_indices = ~np.isclose(direction, 0)
|
| 202 |
+
active_direction = direction[active_indices]
|
| 203 |
+
active_basis = basis.copy()
|
| 204 |
+
active_basis[np.where(basis)[0][~active_indices]] = False
|
| 205 |
+
|
| 206 |
+
positive_step = True if duals[candidate_idx] <= 0 else False
|
| 207 |
+
if candidate_idx == phi.shape[0] - 1:
|
| 208 |
+
positive_step = True if dual_threshold >= 0 else False
|
| 209 |
+
|
| 210 |
+
if positive_step:
|
| 211 |
+
gap_to_bounds = np.maximum(
|
| 212 |
+
(quantiles[active_basis].flatten() - duals[active_basis]) / active_direction,
|
| 213 |
+
((quantiles[active_basis].flatten() - 1) - duals[active_basis]) / active_direction
|
| 214 |
+
)
|
| 215 |
+
step_size = np.min(gap_to_bounds)
|
| 216 |
+
departing_idx = np.where(active_basis)[0][np.argmin(gap_to_bounds)]
|
| 217 |
+
else:
|
| 218 |
+
gap_to_bounds = np.minimum(
|
| 219 |
+
(quantiles[active_basis].flatten() - duals[active_basis]) / active_direction,
|
| 220 |
+
((quantiles[active_basis].flatten() - 1) - duals[active_basis]) / active_direction
|
| 221 |
+
)
|
| 222 |
+
step_size = np.max(gap_to_bounds)
|
| 223 |
+
departing_idx = np.where(active_basis)[0][np.argmax(gap_to_bounds)]
|
| 224 |
+
step_size_clip = np.clip(
|
| 225 |
+
step_size,
|
| 226 |
+
a_max=quantiles[candidate_idx] - duals[candidate_idx],
|
| 227 |
+
a_min=(quantiles[candidate_idx] - 1) - duals[candidate_idx]
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
duals[basis] += step_size_clip * direction
|
| 231 |
+
duals[candidate_idx] += step_size_clip
|
| 232 |
+
# print("Current value of final dual", duals[-1], "target threshold", dual_threshold)
|
| 233 |
+
|
| 234 |
+
if dual_threshold > 0 and duals[-1] > dual_threshold:
|
| 235 |
+
break
|
| 236 |
+
|
| 237 |
+
if dual_threshold < 0 and duals[-1] < dual_threshold:
|
| 238 |
+
break
|
| 239 |
+
|
| 240 |
+
if step_size_clip == step_size:
|
| 241 |
+
basis[departing_idx] = False
|
| 242 |
+
basis[candidate_idx] = True
|
| 243 |
+
|
| 244 |
+
if np.isclose(duals[-1], dual_threshold):
|
| 245 |
+
break
|
| 246 |
+
|
| 247 |
+
# TODO: make this a SMW update and reuse in the direction vector calc...
|
| 248 |
+
reduced_A = np.linalg.solve(phi[basis].T, phi[~basis].T)
|
| 249 |
+
reduced_costs = (S[~basis].T - S[basis].T @ reduced_A).flatten()
|
| 250 |
+
bottom = reduced_A[-1]
|
| 251 |
+
bottom[np.isclose(bottom, 0)] = np.inf
|
| 252 |
+
req_change = reduced_costs / bottom
|
| 253 |
+
if dual_threshold >= 0:
|
| 254 |
+
ignore_entries = (np.isclose(bottom, 0) | np.asarray(req_change <= 1e-5))
|
| 255 |
+
else:
|
| 256 |
+
ignore_entries = (np.isclose(bottom, 0) | np.asarray(req_change >= -1e-5))
|
| 257 |
+
if np.sum(~ignore_entries) == 0:
|
| 258 |
+
S[-1] = np.inf if quantiles[-1] >= 0.5 else -np.inf
|
| 259 |
+
break
|
| 260 |
+
if dual_threshold >= 0:
|
| 261 |
+
candidate_idx = np.where(~basis)[0][np.where(~ignore_entries, req_change, np.inf).argmin()]
|
| 262 |
+
S[-1] += np.min(req_change[~ignore_entries])
|
| 263 |
+
else:
|
| 264 |
+
candidate_idx = np.where(~basis)[0][np.where(~ignore_entries, req_change, -np.inf).argmax()]
|
| 265 |
+
S[-1] += np.max(req_change[~ignore_entries])
|
| 266 |
+
num_iters += 1
|
| 267 |
+
if num_iters > 10000:
|
| 268 |
+
S[-1] = np.inf if dual_threshold > 0 else -1 * np.inf
|
| 269 |
+
return S[-1]
|
| 270 |
+
|
| 271 |
+
def predict(
|
| 272 |
+
self,
|
| 273 |
+
quantile : float,
|
| 274 |
+
x_test : np.ndarray,
|
| 275 |
+
score_inv_fn : Callable,
|
| 276 |
+
S_min : float = None,
|
| 277 |
+
S_max : float = None,
|
| 278 |
+
randomize : bool = False,
|
| 279 |
+
exact : bool = True,
|
| 280 |
+
threshold : float = None
|
| 281 |
+
):
|
| 282 |
+
"""
|
| 283 |
+
Returns the (conditionally valid) prediction set for a given
|
| 284 |
+
test point
|
| 285 |
+
|
| 286 |
+
Arguments
|
| 287 |
+
---------
|
| 288 |
+
quantile : float
|
| 289 |
+
Nominal quantile level
|
| 290 |
+
x_test : np.ndarray
|
| 291 |
+
Single test point
|
| 292 |
+
score_inv_fn : Callable[float, np.ndarray] -> .
|
| 293 |
+
Function that takes in a score threshold S^* and test point x and
|
| 294 |
+
outputs all values of y such that S(x, y) <= S^*
|
| 295 |
+
S_min : float = None
|
| 296 |
+
Lower bound (if available) on the conformity scores
|
| 297 |
+
S_max : float = None
|
| 298 |
+
Upper bound (if available) on the conformity scores
|
| 299 |
+
randomize : bool = False
|
| 300 |
+
Randomize prediction set for exact coverage
|
| 301 |
+
exact : bool = True
|
| 302 |
+
Avoid binary search and compute threshold exactly
|
| 303 |
+
|
| 304 |
+
Returns
|
| 305 |
+
-------
|
| 306 |
+
prediction_set
|
| 307 |
+
"""
|
| 308 |
+
if quantile is None:
|
| 309 |
+
quantile_test = self.quantile_fn(x_test).reshape(-1,1)
|
| 310 |
+
quantiles = np.concatenate((self.quantile_calib, quantile_test), axis=0)
|
| 311 |
+
else:
|
| 312 |
+
quantile_test = quantile
|
| 313 |
+
quantiles = np.ones((len(self.scores_calib) + 1,1)) * quantile
|
| 314 |
+
if threshold is None:
|
| 315 |
+
if randomize:
|
| 316 |
+
threshold = self.rng.uniform(low=quantile_test - 1, high=quantile_test)
|
| 317 |
+
else:
|
| 318 |
+
if quantile_test < 0.5:
|
| 319 |
+
threshold = quantile_test - 1
|
| 320 |
+
else:
|
| 321 |
+
threshold = quantile_test
|
| 322 |
+
|
| 323 |
+
if exact:
|
| 324 |
+
if self.infinite_params.get('kernel', FUNCTION_DEFAULTS['kernel']):
|
| 325 |
+
raise ValueError("Exact computation doesn't support RKHS quantile regression for now.")
|
| 326 |
+
if np.allclose(quantiles[0], quantiles):
|
| 327 |
+
naive_duals, naive_primals = self._get_calibration_solution(
|
| 328 |
+
quantiles.flatten()[0]
|
| 329 |
+
)
|
| 330 |
+
else:
|
| 331 |
+
naive_duals, naive_primals = self._get_calibration_solution(
|
| 332 |
+
None
|
| 333 |
+
)
|
| 334 |
+
score_cutoff = self._compute_exact_cutoff(
|
| 335 |
+
quantiles,
|
| 336 |
+
naive_primals,
|
| 337 |
+
naive_duals,
|
| 338 |
+
self.Phi_fn(x_test),
|
| 339 |
+
threshold
|
| 340 |
+
)
|
| 341 |
+
else:
|
| 342 |
+
_solve = partial(_solve_dual, gcc=self, x_test=x_test, quantiles=quantiles, threshold=threshold)
|
| 343 |
+
|
| 344 |
+
if S_min is None:
|
| 345 |
+
S_min = np.min(self.scores_calib)
|
| 346 |
+
if S_max is None:
|
| 347 |
+
S_max = np.max(self.scores_calib)
|
| 348 |
+
lower, upper = binary_search(_solve, S_min, S_max * 2)
|
| 349 |
+
|
| 350 |
+
if quantile < 0.5:
|
| 351 |
+
score_cutoff = self._get_threshold(lower, x_test, quantiles)
|
| 352 |
+
else:
|
| 353 |
+
score_cutoff = self._get_threshold(upper, x_test, quantiles)
|
| 354 |
+
return score_inv_fn(score_cutoff, x_test.reshape(-1,1))
|
| 355 |
+
|
| 356 |
+
def estimate_coverage(
|
| 357 |
+
self,
|
| 358 |
+
quantile : float,
|
| 359 |
+
weights : np.ndarray,
|
| 360 |
+
x : np.ndarray = None
|
| 361 |
+
):
|
| 362 |
+
"""
|
| 363 |
+
estimate_coverage estimates the true percentile of the issued estimate of the
|
| 364 |
+
conditional quantile under the covariate shift induced by 'weights'
|
| 365 |
+
|
| 366 |
+
If we are ostensibly estimating the 0.95-quantile using an RKHS fit, we may
|
| 367 |
+
determine using our theory that the true percentile of this estimate is only 0.93
|
| 368 |
+
|
| 369 |
+
Arguments
|
| 370 |
+
---------
|
| 371 |
+
quantile : float
|
| 372 |
+
Nominal quantile level
|
| 373 |
+
weights : np.ndarray
|
| 374 |
+
RKHS weights for tilt under which the coverage is estimated
|
| 375 |
+
x : np.ndarray = None
|
| 376 |
+
Points for which the RKHS weights are defined. If None, we assume
|
| 377 |
+
that weights corresponds to x_calib
|
| 378 |
+
|
| 379 |
+
Returns
|
| 380 |
+
-------
|
| 381 |
+
estimated_alpha : float
|
| 382 |
+
Our estimate for the realized quantile level
|
| 383 |
+
"""
|
| 384 |
+
weights = weights.reshape(-1,1)
|
| 385 |
+
prob = setup_cvx_problem_calib(
|
| 386 |
+
quantile,
|
| 387 |
+
self.x_calib,
|
| 388 |
+
self.scores_calib,
|
| 389 |
+
self.phi_calib,
|
| 390 |
+
self.infinite_params
|
| 391 |
+
)
|
| 392 |
+
if "MOSEK" in cp.installed_solvers():
|
| 393 |
+
prob.solve(solver="MOSEK")
|
| 394 |
+
else:
|
| 395 |
+
prob.solve()
|
| 396 |
+
|
| 397 |
+
fitted_weights = prob.var_dict['weights'].value
|
| 398 |
+
if x is not None:
|
| 399 |
+
K = pairwise_kernels(
|
| 400 |
+
X=x,
|
| 401 |
+
Y=self.x_calib,
|
| 402 |
+
metric=self.infinite_params.get("kernel", FUNCTION_DEFAULTS["kernel"]),
|
| 403 |
+
gamma=self.infinite_params.get("gamma", FUNCTION_DEFAULTS["gamma"])
|
| 404 |
+
)
|
| 405 |
+
else:
|
| 406 |
+
K = pairwise_kernels(
|
| 407 |
+
X=self.x_calib,
|
| 408 |
+
metric=self.infinite_params.get("kernel", FUNCTION_DEFAULTS["kernel"]),
|
| 409 |
+
gamma=self.infinite_params.get("gamma", FUNCTION_DEFAULTS["gamma"])
|
| 410 |
+
)
|
| 411 |
+
inner_prod = weights.T @ K @ fitted_weights
|
| 412 |
+
expectation = np.mean(weights.T @ K)
|
| 413 |
+
#penalty = self.infinite_params['lambda'] * (inner_prod / expectation)
|
| 414 |
+
penalty = (1/(len(self.x_calib) + 1))*(inner_prod / expectation)
|
| 415 |
+
return quantile - penalty
|
| 416 |
+
|
| 417 |
+
def predict_naive(
|
| 418 |
+
self,
|
| 419 |
+
quantile : float,
|
| 420 |
+
x : np.ndarray,
|
| 421 |
+
score_inv_fn : Callable
|
| 422 |
+
):
|
| 423 |
+
"""
|
| 424 |
+
If we do not wish to include the imputed data point, we can sanity check that
|
| 425 |
+
the regression is appropriately adaptive to the conditional variability in the data
|
| 426 |
+
by running a quantile regression on the calibration set without any imputation.
|
| 427 |
+
When n_calib is large and the fit is stable, we expect these two sets to nearly coincide.
|
| 428 |
+
|
| 429 |
+
Arguments
|
| 430 |
+
---------
|
| 431 |
+
quantile : float
|
| 432 |
+
Nominal quantile level
|
| 433 |
+
x : np.ndarray
|
| 434 |
+
Set of points for which we are issuing prediction sets
|
| 435 |
+
score_inv_fn : Callable[np.ndarray, np.ndarray] -> np.ndarray
|
| 436 |
+
Vectorized function that takes in a score threshold S^* and test point x and
|
| 437 |
+
outputs all values of y such that S(x, y) <= S^*
|
| 438 |
+
|
| 439 |
+
Returns
|
| 440 |
+
-------
|
| 441 |
+
prediction_sets
|
| 442 |
+
|
| 443 |
+
"""
|
| 444 |
+
if len(x.shape) < 2:
|
| 445 |
+
raise ValueError("x needs to have shape (m, n), not {x_test.shape}.")
|
| 446 |
+
|
| 447 |
+
if self.infinite_params.get('kernel', FUNCTION_DEFAULTS['kernel']):
|
| 448 |
+
prob = setup_cvx_problem_calib(
|
| 449 |
+
quantile,
|
| 450 |
+
self.x_calib,
|
| 451 |
+
self.scores_calib,
|
| 452 |
+
self.phi_calib,
|
| 453 |
+
self.infinite_params
|
| 454 |
+
)
|
| 455 |
+
if "MOSEK" in cp.installed_solvers():
|
| 456 |
+
prob.solve(solver="MOSEK", verbose=False)
|
| 457 |
+
else:
|
| 458 |
+
prob.solve()
|
| 459 |
+
|
| 460 |
+
weights = prob.var_dict['weights'].value
|
| 461 |
+
beta = prob.constraints[-1].dual_value
|
| 462 |
+
K = pairwise_kernels(
|
| 463 |
+
X=x,
|
| 464 |
+
Y=self.x_calib,
|
| 465 |
+
metric=self.infinite_params.get("kernel", FUNCTION_DEFAULTS["kernel"]),
|
| 466 |
+
gamma=self.infinite_params.get("gamma", FUNCTION_DEFAULTS["gamma"])
|
| 467 |
+
)
|
| 468 |
+
threshold = K @ weights + self.Phi_fn(x) @ beta
|
| 469 |
+
else:
|
| 470 |
+
S = np.concatenate([self.scores_calib, [S]], dtype=float)
|
| 471 |
+
Phi = self.phi_calib.astype(float)
|
| 472 |
+
zeros = np.zeros((Phi.shape[1],))
|
| 473 |
+
|
| 474 |
+
if quantile is None:
|
| 475 |
+
bounds = np.concatenate((self.quantile_calib - 1, self.quantile_calib), axis=1)
|
| 476 |
+
else:
|
| 477 |
+
bounds = [(quantile - 1, quantile)] * (len(self.scores_calib) + 1)
|
| 478 |
+
res = linprog(-1 * S, A_eq=Phi.T, b_eq=zeros, bounds=bounds, method='highs')
|
| 479 |
+
beta = -1 * res.eqlin.marginals
|
| 480 |
+
threshold = self.Phi_fn(x) @ beta
|
| 481 |
+
|
| 482 |
+
return score_inv_fn(threshold, x)
|
| 483 |
+
|
| 484 |
+
def verify_coverage(
|
| 485 |
+
self,
|
| 486 |
+
x : np.ndarray,
|
| 487 |
+
y : np.ndarray,
|
| 488 |
+
quantile : float,
|
| 489 |
+
randomize : bool = False,
|
| 490 |
+
resolve : bool = False,
|
| 491 |
+
return_dual : bool = False,
|
| 492 |
+
eps : float = 0.001
|
| 493 |
+
):
|
| 494 |
+
"""
|
| 495 |
+
In some experiments, we may simply be interested in verifying the coverage of our method.
|
| 496 |
+
In this case, we do not need to binary search for the threshold S^*, but only need to verify that
|
| 497 |
+
S <= f_S(x) for the true value of S. This function implements this check for test points
|
| 498 |
+
denoted by x and y
|
| 499 |
+
|
| 500 |
+
Arguments
|
| 501 |
+
---------
|
| 502 |
+
x : np.ndarray
|
| 503 |
+
A vector of test covariates
|
| 504 |
+
y : np.ndarray
|
| 505 |
+
A vector of test labels
|
| 506 |
+
quantile : float
|
| 507 |
+
Nominal quantile level
|
| 508 |
+
resolve : bool
|
| 509 |
+
Resolve LP/QP with posited value to determine coverage
|
| 510 |
+
|
| 511 |
+
Returns
|
| 512 |
+
-------
|
| 513 |
+
coverage_booleans : np.ndarray
|
| 514 |
+
"""
|
| 515 |
+
covers = []
|
| 516 |
+
duals = []
|
| 517 |
+
|
| 518 |
+
if quantile is None:
|
| 519 |
+
quantiles = np.concatenate((self.quantile_calib, [[0.]]), axis=0).flatten()
|
| 520 |
+
else:
|
| 521 |
+
quantiles = quantile * np.ones((len(self.scores_calib) + 1, 1))
|
| 522 |
+
|
| 523 |
+
if self.infinite_params.get('kernel', FUNCTION_DEFAULTS['kernel']):
|
| 524 |
+
for x_val, y_val in zip(x, y):
|
| 525 |
+
S_true = self.score_fn(x_val.reshape(1,-1), y_val)
|
| 526 |
+
eta = self._get_dual_solution(S_true[0], x_val.reshape(1,-1), quantiles) # no need to recompute quantiles
|
| 527 |
+
if randomize:
|
| 528 |
+
threshold = self.rng.uniform(low=quantile - 1, high=quantile)
|
| 529 |
+
elif quantile > 0.5:
|
| 530 |
+
threshold = quantile - eps
|
| 531 |
+
else:
|
| 532 |
+
threshold = quantile - 1 + eps
|
| 533 |
+
if quantile > 0.5:
|
| 534 |
+
covers.append(eta[-1] < threshold)
|
| 535 |
+
else:
|
| 536 |
+
covers.append(eta[-1] > threshold)
|
| 537 |
+
duals.append(eta[-1])
|
| 538 |
+
|
| 539 |
+
else:
|
| 540 |
+
for x_val, y_val in zip(x, y):
|
| 541 |
+
if randomize:
|
| 542 |
+
threshold = self.rng.uniform(low=quantiles[-1] - 1, high=quantiles[-1])
|
| 543 |
+
elif quantiles[-1] > 0.5:
|
| 544 |
+
threshold = quantiles[-1]
|
| 545 |
+
else:
|
| 546 |
+
threshold = quantiles[-1] - 1
|
| 547 |
+
|
| 548 |
+
S_true = self.score_fn(x_val.reshape(1,-1), y_val)
|
| 549 |
+
if resolve:
|
| 550 |
+
eta = self._get_dual_solution(S_true[0], x_val.reshape(1,-1), quantile)
|
| 551 |
+
if quantile > 0.5:
|
| 552 |
+
covers.append(eta[-1] < threshold)
|
| 553 |
+
else:
|
| 554 |
+
covers.append(eta[-1] > threshold)
|
| 555 |
+
duals.append(eta[-1])
|
| 556 |
+
else:
|
| 557 |
+
naive_duals, naive_primals = self._get_calibration_solution(
|
| 558 |
+
quantile
|
| 559 |
+
)
|
| 560 |
+
score_cutoff = self._compute_exact_cutoff(
|
| 561 |
+
quantiles,
|
| 562 |
+
naive_primals,
|
| 563 |
+
naive_duals,
|
| 564 |
+
self.Phi_fn(x_val),
|
| 565 |
+
threshold
|
| 566 |
+
)
|
| 567 |
+
if quantile > 0.5:
|
| 568 |
+
covers.append(S_true < score_cutoff)
|
| 569 |
+
else:
|
| 570 |
+
covers.append(S_true > score_cutoff)
|
| 571 |
+
duals.append(np.nan)
|
| 572 |
+
if return_dual:
|
| 573 |
+
return np.asarray(covers), np.asarray(duals)
|
| 574 |
+
return np.asarray(covers)
|
| 575 |
+
|
| 576 |
+
def _get_dual_solution(
|
| 577 |
+
self,
|
| 578 |
+
S : float,
|
| 579 |
+
x : np.ndarray,
|
| 580 |
+
quantiles : np.ndarray
|
| 581 |
+
):
|
| 582 |
+
if self.infinite_params.get("kernel", FUNCTION_DEFAULTS['kernel']):
|
| 583 |
+
prob = finish_dual_setup(
|
| 584 |
+
self.cvx_problem,
|
| 585 |
+
S,
|
| 586 |
+
x,
|
| 587 |
+
quantiles[-1][0],
|
| 588 |
+
self.Phi_fn(x),
|
| 589 |
+
self.x_calib,
|
| 590 |
+
self.infinite_params
|
| 591 |
+
)
|
| 592 |
+
if "MOSEK" in cp.installed_solvers():
|
| 593 |
+
prob.solve(solver="MOSEK")
|
| 594 |
+
else:
|
| 595 |
+
prob.solve()
|
| 596 |
+
# TODO: THIS IS WRONG
|
| 597 |
+
#raise ValueError("need to get variable out of problem and return its value")
|
| 598 |
+
return prob.var_dict['weights'].value
|
| 599 |
+
else:
|
| 600 |
+
S = np.concatenate([self.scores_calib, [S]])
|
| 601 |
+
Phi = np.concatenate([self.phi_calib, self.Phi_fn(x)], axis=0)
|
| 602 |
+
zeros = np.zeros((Phi.shape[1],))
|
| 603 |
+
bounds = np.concatenate((quantiles - 1, quantiles), axis=1)
|
| 604 |
+
res = linprog(-1 * S, A_eq=Phi.T, b_eq=zeros, bounds=bounds,
|
| 605 |
+
method='highs-ds', options={'presolve': False})
|
| 606 |
+
eta = res.x
|
| 607 |
+
return eta
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def _get_primal_solution(
|
| 611 |
+
self,
|
| 612 |
+
S : float,
|
| 613 |
+
x : np.ndarray,
|
| 614 |
+
quantiles : np.ndarray
|
| 615 |
+
):
|
| 616 |
+
if self.infinite_params.get("kernel", FUNCTION_DEFAULTS['kernel']):
|
| 617 |
+
prob = finish_dual_setup(
|
| 618 |
+
self.cvx_problem,
|
| 619 |
+
S,
|
| 620 |
+
x,
|
| 621 |
+
quantiles[-1][0],
|
| 622 |
+
self.Phi_fn(x),
|
| 623 |
+
self.x_calib,
|
| 624 |
+
self.infinite_params
|
| 625 |
+
)
|
| 626 |
+
if "MOSEK" in cp.installed_solvers():
|
| 627 |
+
prob.solve(solver="MOSEK")
|
| 628 |
+
else:
|
| 629 |
+
prob.solve()
|
| 630 |
+
|
| 631 |
+
weights = prob.var_dict['weights'].value
|
| 632 |
+
beta = prob.constraints[-1].dual_value
|
| 633 |
+
else:
|
| 634 |
+
S = np.concatenate([self.scores_calib, [S]])
|
| 635 |
+
Phi = np.concatenate([self.phi_calib, self.Phi_fn(x)], axis=0)
|
| 636 |
+
zeros = np.zeros((Phi.shape[1],))
|
| 637 |
+
bounds = np.concatenate((quantiles - 1, quantiles), axis=1)
|
| 638 |
+
res = linprog(-1 * S, A_eq=Phi.T, b_eq=zeros, bounds=bounds,
|
| 639 |
+
method='highs-ds', options={'presolve': False})
|
| 640 |
+
beta = -1 * res.eqlin.marginals
|
| 641 |
+
weights = None
|
| 642 |
+
return beta, weights
|
| 643 |
+
|
| 644 |
+
def _get_threshold(
|
| 645 |
+
self,
|
| 646 |
+
S : float,
|
| 647 |
+
x : np.ndarray,
|
| 648 |
+
quantiles : np.ndarray
|
| 649 |
+
):
|
| 650 |
+
beta, weights = self._get_primal_solution(S, x, quantiles)
|
| 651 |
+
|
| 652 |
+
threshold = self.Phi_fn(x) @ beta
|
| 653 |
+
if self.infinite_params.get('kernel', FUNCTION_DEFAULTS['kernel']):
|
| 654 |
+
K = pairwise_kernels(
|
| 655 |
+
X=np.concatenate([self.x_calib, x.reshape(1,-1)], axis=0),
|
| 656 |
+
Y=np.concatenate([self.x_calib, x.reshape(1,-1)], axis=0),
|
| 657 |
+
metric=self.infinite_params.get("kernel", FUNCTION_DEFAULTS["kernel"]),
|
| 658 |
+
gamma=self.infinite_params.get("gamma", FUNCTION_DEFAULTS["gamma"])
|
| 659 |
+
)
|
| 660 |
+
threshold = (K @ weights)[-1] + threshold
|
| 661 |
+
return threshold
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def binary_search(func, min, max, tol=1e-3):
|
| 665 |
+
min, max = float(min), float(max)
|
| 666 |
+
assert (max + tol) > max
|
| 667 |
+
while (max - min) > tol:
|
| 668 |
+
mid = (min + max) / 2
|
| 669 |
+
if func(mid) > 0:
|
| 670 |
+
max = mid
|
| 671 |
+
else:
|
| 672 |
+
min = mid
|
| 673 |
+
return min, max
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def _solve_dual(S, gcc, x_test, quantiles, threshold=None):
|
| 677 |
+
if gcc.infinite_params.get('kernel', None):
|
| 678 |
+
prob = finish_dual_setup(
|
| 679 |
+
gcc.cvx_problem,
|
| 680 |
+
S,
|
| 681 |
+
x_test,
|
| 682 |
+
quantiles[-1][0],
|
| 683 |
+
gcc.Phi_fn(x_test),
|
| 684 |
+
gcc.x_calib,
|
| 685 |
+
gcc.infinite_params
|
| 686 |
+
)
|
| 687 |
+
if "MOSEK" in cp.installed_solvers():
|
| 688 |
+
prob.solve(solver="MOSEK")
|
| 689 |
+
else:
|
| 690 |
+
prob.solve(solver="OSQP")
|
| 691 |
+
weights = prob.var_dict['weights'].value
|
| 692 |
+
else:
|
| 693 |
+
S = np.concatenate([gcc.scores_calib, [S]], dtype=float)
|
| 694 |
+
Phi = np.concatenate([gcc.phi_calib, gcc.Phi_fn(x_test)], axis=0, dtype=float)
|
| 695 |
+
zeros = np.zeros((Phi.shape[1],))
|
| 696 |
+
|
| 697 |
+
bounds = np.concatenate((quantiles - 1, quantiles), axis=1)
|
| 698 |
+
res = linprog(-1 * S, A_eq=Phi.T, b_eq=zeros, bounds=bounds,
|
| 699 |
+
method='highs', options={'presolve': False})
|
| 700 |
+
weights = res.x
|
| 701 |
+
|
| 702 |
+
if threshold is None:
|
| 703 |
+
if quantiles[-1] < 0.5:
|
| 704 |
+
threshold = quantiles[-1] - 1
|
| 705 |
+
else:
|
| 706 |
+
threshold = quantiles[-1]
|
| 707 |
+
# if quantile < 0.5:
|
| 708 |
+
# return weights[-1] + (1 - quantile)
|
| 709 |
+
return weights[-1] - threshold
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
def setup_cvx_problem(
|
| 713 |
+
x_calib,
|
| 714 |
+
scores_calib,
|
| 715 |
+
phi_calib,
|
| 716 |
+
infinite_params = {}
|
| 717 |
+
):
|
| 718 |
+
n_calib = len(scores_calib)
|
| 719 |
+
if phi_calib is None:
|
| 720 |
+
phi_calib = np.ones((n_calib,1))
|
| 721 |
+
|
| 722 |
+
eta = cp.Variable(name="weights", shape=n_calib + 1)
|
| 723 |
+
|
| 724 |
+
quantile = cp.Parameter(name="quantile")
|
| 725 |
+
|
| 726 |
+
scores_const = cp.Constant(scores_calib.reshape(-1,1))
|
| 727 |
+
scores_param = cp.Parameter(name="S_test", shape=(1,1))
|
| 728 |
+
scores = cp.vstack([scores_const, scores_param])
|
| 729 |
+
|
| 730 |
+
Phi_calibration = cp.Constant(phi_calib)
|
| 731 |
+
Phi_test = cp.Parameter(name="Phi_test", shape=(1, phi_calib.shape[1]))
|
| 732 |
+
Phi = cp.vstack([Phi_calibration, Phi_test])
|
| 733 |
+
|
| 734 |
+
kernel = infinite_params.get("kernel", FUNCTION_DEFAULTS["kernel"])
|
| 735 |
+
gamma = infinite_params.get("gamma", FUNCTION_DEFAULTS["gamma"])
|
| 736 |
+
|
| 737 |
+
if kernel is None: # no RKHS fitting
|
| 738 |
+
constraints = [
|
| 739 |
+
(quantile - 1) <= eta,
|
| 740 |
+
quantile >= eta,
|
| 741 |
+
eta.T @ Phi == 0
|
| 742 |
+
]
|
| 743 |
+
prob = cp.Problem(
|
| 744 |
+
cp.Minimize(-1 * cp.sum(cp.multiply(eta, cp.vec(scores)))),
|
| 745 |
+
constraints
|
| 746 |
+
)
|
| 747 |
+
else: # RKHS fitting
|
| 748 |
+
radius = cp.Parameter(name="radius", nonneg=True)
|
| 749 |
+
|
| 750 |
+
_, L_11 = _get_kernel_matrix(x_calib, kernel, gamma)
|
| 751 |
+
|
| 752 |
+
L_11_const = cp.Constant(
|
| 753 |
+
np.hstack([L_11, np.zeros((L_11.shape[0], 1))])
|
| 754 |
+
)
|
| 755 |
+
L_21_22_param = cp.Parameter(name="L_21_22", shape=(1, n_calib + 1))
|
| 756 |
+
L = cp.vstack([L_11_const, L_21_22_param])
|
| 757 |
+
|
| 758 |
+
C = radius / (n_calib + 1)
|
| 759 |
+
|
| 760 |
+
# this is really C * (quantile - 1) and C * quantile
|
| 761 |
+
constraints = [
|
| 762 |
+
(quantile - 1) <= eta,
|
| 763 |
+
quantile >= eta,
|
| 764 |
+
eta.T @ Phi == 0]
|
| 765 |
+
prob = cp.Problem(
|
| 766 |
+
cp.Minimize(0.5 * C * cp.sum_squares(L.T @ eta) - cp.sum(cp.multiply(eta, cp.vec(scores)))),
|
| 767 |
+
constraints
|
| 768 |
+
)
|
| 769 |
+
return prob
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def _get_kernel_matrix(x_calib, kernel, gamma):
|
| 773 |
+
K = pairwise_kernels(
|
| 774 |
+
X=x_calib,
|
| 775 |
+
metric=kernel,
|
| 776 |
+
gamma=gamma
|
| 777 |
+
) + 1e-5 * np.eye(len(x_calib))
|
| 778 |
+
|
| 779 |
+
K_chol = np.linalg.cholesky(K)
|
| 780 |
+
return K, K_chol
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
def finish_dual_setup(
|
| 784 |
+
prob : cp.Problem,
|
| 785 |
+
S : np.ndarray,
|
| 786 |
+
X : np.ndarray,
|
| 787 |
+
quantile : float,
|
| 788 |
+
Phi : np.ndarray,
|
| 789 |
+
x_calib : np.ndarray,
|
| 790 |
+
infinite_params = {}
|
| 791 |
+
):
|
| 792 |
+
prob.param_dict['S_test'].value = np.asarray([[S]])
|
| 793 |
+
prob.param_dict['Phi_test'].value = Phi.reshape(1,-1)
|
| 794 |
+
prob.param_dict['quantile'].value = quantile
|
| 795 |
+
|
| 796 |
+
kernel = infinite_params.get('kernel', FUNCTION_DEFAULTS['kernel'])
|
| 797 |
+
gamma = infinite_params.get('gamma', FUNCTION_DEFAULTS['gamma'])
|
| 798 |
+
radius = 1 / infinite_params.get('lambda', FUNCTION_DEFAULTS['lambda'])
|
| 799 |
+
|
| 800 |
+
if kernel is not None:
|
| 801 |
+
K_12 = pairwise_kernels(
|
| 802 |
+
X=np.concatenate([x_calib, X.reshape(1,-1)], axis=0),
|
| 803 |
+
Y=X.reshape(1,-1),
|
| 804 |
+
metric=kernel,
|
| 805 |
+
gamma=gamma
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
if 'K_12' in prob.param_dict:
|
| 809 |
+
prob.param_dict['K_12'].value = K_12[:-1]
|
| 810 |
+
prob.param_dict['K_21'].value = K_12.T
|
| 811 |
+
|
| 812 |
+
_, L_11 = _get_kernel_matrix(x_calib, kernel, gamma)
|
| 813 |
+
K_22 = pairwise_kernels(
|
| 814 |
+
X=X.reshape(1,-1),
|
| 815 |
+
metric=kernel,
|
| 816 |
+
gamma=gamma
|
| 817 |
+
)
|
| 818 |
+
L_21 = np.linalg.solve(L_11, K_12[:-1]).T
|
| 819 |
+
L_22 = K_22 - L_21 @ L_21.T
|
| 820 |
+
L_22[L_22 < 0] = 0
|
| 821 |
+
L_22 = np.sqrt(L_22)
|
| 822 |
+
prob.param_dict['L_21_22'].value = np.hstack([L_21, L_22])
|
| 823 |
+
|
| 824 |
+
prob.param_dict['radius'].value = radius
|
| 825 |
+
|
| 826 |
+
# update quantile definition for silly cvxpy reasons
|
| 827 |
+
prob.param_dict['quantile'].value = quantile
|
| 828 |
+
#prob.param_dict['quantile'].value *= radius / (len(x_calib) + 1)
|
| 829 |
+
|
| 830 |
+
return prob
|
| 831 |
+
|
| 832 |
+
def setup_cvx_problem_calib(
|
| 833 |
+
quantile,
|
| 834 |
+
x_calib,
|
| 835 |
+
scores_calib,
|
| 836 |
+
phi_calib,
|
| 837 |
+
infinite_params = {}
|
| 838 |
+
):
|
| 839 |
+
n_calib = len(scores_calib)
|
| 840 |
+
if phi_calib is None:
|
| 841 |
+
phi_calib = np.ones((n_calib,1))
|
| 842 |
+
|
| 843 |
+
eta = cp.Variable(name="weights", shape=n_calib)
|
| 844 |
+
|
| 845 |
+
scores = cp.Constant(scores_calib.reshape(-1,1))
|
| 846 |
+
|
| 847 |
+
Phi = cp.Constant(phi_calib)
|
| 848 |
+
|
| 849 |
+
kernel = infinite_params.get("kernel", FUNCTION_DEFAULTS["kernel"])
|
| 850 |
+
gamma = infinite_params.get("gamma", FUNCTION_DEFAULTS["gamma"])
|
| 851 |
+
|
| 852 |
+
if kernel is None: # no RKHS fitting
|
| 853 |
+
constraints = [
|
| 854 |
+
(quantile - 1) <= eta,
|
| 855 |
+
quantile >= eta,
|
| 856 |
+
eta.T @ Phi == 0
|
| 857 |
+
]
|
| 858 |
+
prob = cp.Problem(
|
| 859 |
+
cp.Minimize(-1 * cp.sum(cp.multiply(eta, cp.vec(scores)))),
|
| 860 |
+
constraints
|
| 861 |
+
)
|
| 862 |
+
else: # RKHS fitting
|
| 863 |
+
radius = 1 / infinite_params.get('lambda', FUNCTION_DEFAULTS['lambda'])
|
| 864 |
+
|
| 865 |
+
_, L = _get_kernel_matrix(x_calib, kernel, gamma)
|
| 866 |
+
|
| 867 |
+
C = radius / (n_calib + 1)
|
| 868 |
+
|
| 869 |
+
constraints = [
|
| 870 |
+
(quantile - 1) <= eta,
|
| 871 |
+
quantile >= eta,
|
| 872 |
+
eta.T @ Phi == 0]
|
| 873 |
+
prob = cp.Problem(
|
| 874 |
+
cp.Minimize(0.5 * C * cp.sum_squares(L.T @ eta) - cp.sum(cp.multiply(eta, cp.vec(scores)))),
|
| 875 |
+
constraints
|
| 876 |
+
)
|
| 877 |
+
return prob
|
MACI-main/conditional-conformal/conditionalconformal/experiment_utils.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
from sklearn.linear_model import LinearRegression
|
| 5 |
+
from quantile_forest import RandomForestQuantileRegressor
|
| 6 |
+
|
| 7 |
+
from conditionalconformal import CondConf
|
| 8 |
+
|
| 9 |
+
## get base model for constructing scores
|
| 10 |
+
def fit_model(data_train, base_model):
|
| 11 |
+
x_train, y_train = data_train
|
| 12 |
+
if base_model == "ols":
|
| 13 |
+
reg = LinearRegression().fit(x_train, y_train)
|
| 14 |
+
elif base_model == "qrf":
|
| 15 |
+
reg = RandomForestQuantileRegressor()
|
| 16 |
+
reg.fit(x_train, y_train)
|
| 17 |
+
elif base_model == "qr":
|
| 18 |
+
reg = CondConf(score_fn = lambda x, y: y, Phi_fn = lambda x: x)
|
| 19 |
+
reg.setup_problem(x_train, y_train)
|
| 20 |
+
# overwrite prediction function so it looks like a regression object
|
| 21 |
+
reg.predict = lambda x, q: (x @ reg._get_calibration_solution(q)[1]).flatten() # expects x to be of form (n_points, n_feats)
|
| 22 |
+
return reg
|
| 23 |
+
|
| 24 |
+
# helper function for splitting dataset
|
| 25 |
+
def split_dataset(dataset, n_test, n_calib, rng):
|
| 26 |
+
X, Y = dataset
|
| 27 |
+
data_indices = np.arange(len(X))
|
| 28 |
+
rng.shuffle(data_indices)
|
| 29 |
+
test_indices, calib_indices, train_indices = np.array_split(
|
| 30 |
+
data_indices,
|
| 31 |
+
np.cumsum([n_test, n_calib])
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
X_test = X[test_indices]
|
| 35 |
+
Y_test = Y[test_indices]
|
| 36 |
+
|
| 37 |
+
X_calib = X[calib_indices]
|
| 38 |
+
Y_calib = Y[calib_indices]
|
| 39 |
+
|
| 40 |
+
X_train = X[train_indices]
|
| 41 |
+
Y_train = Y[train_indices]
|
| 42 |
+
return (X_train, Y_train), (X_calib, Y_calib), (X_test, Y_test)
|
| 43 |
+
|
| 44 |
+
# get coverages for each method type...
|
| 45 |
+
def get_coverage(dataset_calib, dataset_test, score_fn, method, quantile):
|
| 46 |
+
if method == "split":
|
| 47 |
+
scores_calib = score_fn(*dataset_calib)
|
| 48 |
+
scores_test = score_fn(*dataset_test)
|
| 49 |
+
|
| 50 |
+
score_cutoff = np.quantile(
|
| 51 |
+
scores_calib,
|
| 52 |
+
[quantile * (1 + 1/len(scores_calib))]
|
| 53 |
+
)
|
| 54 |
+
if quantile >= 0.5:
|
| 55 |
+
covs = scores_test <= score_cutoff
|
| 56 |
+
else:
|
| 57 |
+
covs = scores_test >= score_cutoff
|
| 58 |
+
elif "rand" in method:
|
| 59 |
+
condcalib = CondConf(score_fn, lambda x: x)
|
| 60 |
+
condcalib.setup_problem(*dataset_calib)
|
| 61 |
+
X_test, Y_test = dataset_test
|
| 62 |
+
covs = condcalib.verify_coverage(X_test, Y_test, quantile, resolve=True, randomize=True)
|
| 63 |
+
else:
|
| 64 |
+
condcalib = CondConf(score_fn, lambda x: x)
|
| 65 |
+
condcalib.setup_problem(*dataset_calib)
|
| 66 |
+
X_test, Y_test = dataset_test
|
| 67 |
+
covs = condcalib.verify_coverage(X_test, Y_test, quantile, resolve=True, randomize=False)
|
| 68 |
+
return covs
|
| 69 |
+
|
| 70 |
+
# get coverages for each method type...
|
| 71 |
+
def get_cutoff(dataset_calib, dataset_test, score_fn, method, quantile):
|
| 72 |
+
print(method, quantile)
|
| 73 |
+
scores_test = score_fn(*dataset_test)
|
| 74 |
+
if method == "split":
|
| 75 |
+
scores_calib = score_fn(*dataset_calib)
|
| 76 |
+
score_cutoff = np.quantile(
|
| 77 |
+
scores_calib,
|
| 78 |
+
[quantile * (1 + 1/len(scores_calib))]
|
| 79 |
+
)
|
| 80 |
+
cutoffs = np.ones((len(scores_test,))) * score_cutoff
|
| 81 |
+
elif "rand" in method:
|
| 82 |
+
condcalib = CondConf(score_fn, lambda x: x)
|
| 83 |
+
condcalib.setup_problem(*dataset_calib)
|
| 84 |
+
cutoffs = []
|
| 85 |
+
for x in dataset_test[0]:
|
| 86 |
+
cutoff = condcalib.predict(quantile, x, lambda c, x: c, randomize=True)
|
| 87 |
+
cutoffs.append(cutoff)
|
| 88 |
+
cutoffs = np.asarray(cutoffs)
|
| 89 |
+
else:
|
| 90 |
+
condcalib = CondConf(score_fn, lambda x: x)
|
| 91 |
+
condcalib.setup_problem(*dataset_calib)
|
| 92 |
+
cutoffs = []
|
| 93 |
+
for x in dataset_test[0]:
|
| 94 |
+
cutoff = condcalib.predict(quantile, x, lambda c, x: c, randomize=False)
|
| 95 |
+
cutoffs.append(cutoff)
|
| 96 |
+
cutoffs = np.asarray(cutoffs)
|
| 97 |
+
if quantile > 0.5:
|
| 98 |
+
coverages = scores_test <= cutoffs.flatten()
|
| 99 |
+
else:
|
| 100 |
+
coverages = scores_test >= cutoffs.flatten()
|
| 101 |
+
return cutoffs, coverages
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def run_coverage_experiment(dataset, n_test, n_calib, alpha, methods = [], seed = 0):
|
| 105 |
+
rng = np.random.default_rng(seed=seed)
|
| 106 |
+
|
| 107 |
+
dataset_train, dataset_calib, dataset_test = split_dataset(
|
| 108 |
+
dataset,
|
| 109 |
+
n_test,
|
| 110 |
+
n_calib,
|
| 111 |
+
rng
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
### Compute conformity scores
|
| 115 |
+
base_methods = set([m.split('-')[0] for m in methods])
|
| 116 |
+
base_model = {base : fit_model(dataset_train, base) for base in base_methods}
|
| 117 |
+
|
| 118 |
+
coverages = []
|
| 119 |
+
# example methods: (BASE_METHOD)-(CONFORMAL_METHOD)
|
| 120 |
+
# BASE_METHOD valid choices: "ols", "qr", "qrf"
|
| 121 |
+
# CONFORMAL_METHOD valid choices: "split", "cc", "ccrand", "lcp", "rlcp" (todo on last two)
|
| 122 |
+
for method in methods:
|
| 123 |
+
base_method, conformal_method = method.split('-')
|
| 124 |
+
reg = base_model[base_method]
|
| 125 |
+
if "q" in base_method: # if a quantile regression score needs to specify quantile
|
| 126 |
+
score_fn_upper = lambda x, y: y - reg.predict(x, 1 - alpha/2)
|
| 127 |
+
score_fn_lower = lambda x, y: y - reg.predict(x, alpha/2)
|
| 128 |
+
else:
|
| 129 |
+
score_fn_upper = lambda x, y: y - reg.predict(x)
|
| 130 |
+
score_fn_lower = lambda x, y: y - reg.predict(x)
|
| 131 |
+
covers_upper = get_coverage(dataset_calib, dataset_test, score_fn_upper, conformal_method, 1 - alpha/2)
|
| 132 |
+
covers_lower = get_coverage(dataset_calib, dataset_test, score_fn_lower, conformal_method, alpha/2)
|
| 133 |
+
covers = np.logical_and(covers_upper, covers_lower)
|
| 134 |
+
coverages.append(covers)
|
| 135 |
+
|
| 136 |
+
return dataset_test[0], coverages
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def run_experiment(dataset, n_test, n_calib, alpha, methods = [], seed = 0):
|
| 140 |
+
rng = np.random.default_rng(seed=seed)
|
| 141 |
+
|
| 142 |
+
dataset_train, dataset_calib, dataset_test = split_dataset(
|
| 143 |
+
dataset,
|
| 144 |
+
n_test,
|
| 145 |
+
n_calib,
|
| 146 |
+
rng
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
### Compute conformity scores
|
| 150 |
+
base_model = {base : fit_model(dataset_train, base) for base in ["ols", "qrf", "qr"]}
|
| 151 |
+
|
| 152 |
+
all_lengths = []
|
| 153 |
+
all_coverages = []
|
| 154 |
+
# example methods: (BASE_METHOD)-(CONFORMAL_METHOD)
|
| 155 |
+
# BASE_METHOD valid choices: "ols", "qr", "qrf"
|
| 156 |
+
# CONFORMAL_METHOD valid choices: "split", "cc", "ccrand", "lcp", "ccqp"
|
| 157 |
+
for method in methods:
|
| 158 |
+
base_method, conformal_method = method.split('-')
|
| 159 |
+
reg = base_model[base_method]
|
| 160 |
+
if "qrf" in base_method: # if a quantile regression score needs to specify quantile
|
| 161 |
+
score_fn_upper = lambda x, y: y - reg.predict(x, 1 - alpha/2) + rng.uniform(0, 1e-5, size=len(x))
|
| 162 |
+
score_fn_lower = lambda x, y: y - reg.predict(x, alpha/2) + rng.uniform(0, 1e-5, size=len(x))
|
| 163 |
+
elif "q" in base_method:
|
| 164 |
+
score_fn_upper = lambda x, y: y - reg.predict(x, 1 - alpha/2)
|
| 165 |
+
score_fn_lower = lambda x, y: y - reg.predict(x, alpha/2)
|
| 166 |
+
else:
|
| 167 |
+
score_fn_upper = lambda x, y: y - reg.predict(x)
|
| 168 |
+
score_fn_lower = lambda x, y: y - reg.predict(x)
|
| 169 |
+
cutoffs_upper, cov_upper = get_cutoff(dataset_calib, dataset_test, score_fn_upper, conformal_method, 1 - alpha/2)
|
| 170 |
+
cutoffs_lower, cov_lower = get_cutoff(dataset_calib, dataset_test, score_fn_lower, conformal_method, alpha/2)
|
| 171 |
+
if "q" in base_method:
|
| 172 |
+
pred_upper = reg.predict(dataset_test[0], 1 - alpha/2)
|
| 173 |
+
pred_lower = reg.predict(dataset_test[0], alpha/2)
|
| 174 |
+
pred_gap = pred_upper - pred_lower
|
| 175 |
+
else:
|
| 176 |
+
pred_gap = 0
|
| 177 |
+
lengths = cutoffs_upper - cutoffs_lower + pred_gap
|
| 178 |
+
coverage = np.logical_and(cov_upper, cov_lower)
|
| 179 |
+
all_lengths.append(lengths)
|
| 180 |
+
all_coverages.append(coverage)
|
| 181 |
+
|
| 182 |
+
return dataset_test[0], (all_lengths, all_coverages)
|
MACI-main/conditional-conformal/conditionalconformal/synthetic_data.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
def generate_cqr_data(seed,n_train=2000,n_calib=1000,n_test=500):
|
| 4 |
+
np.random.seed(seed)
|
| 5 |
+
|
| 6 |
+
n_train = n_train + n_calib
|
| 7 |
+
|
| 8 |
+
def f(x):
|
| 9 |
+
''' Construct data (1D example)
|
| 10 |
+
'''
|
| 11 |
+
ax = 0*x
|
| 12 |
+
for i in range(len(x)):
|
| 13 |
+
ax[i] = np.random.poisson(np.sin(x[i])**2+0.1) + 0.03*x[i]*np.random.randn(1)
|
| 14 |
+
ax[i] += 25*(np.random.uniform(0,1,1)<0.01)*np.random.randn(1)
|
| 15 |
+
return ax.astype(np.float32)
|
| 16 |
+
|
| 17 |
+
# training features
|
| 18 |
+
x_train = np.random.uniform(0, 5.0, size=n_train).astype(np.float32)
|
| 19 |
+
|
| 20 |
+
# test features
|
| 21 |
+
x_test = np.random.uniform(0, 5.0, size=n_test).astype(np.float32)
|
| 22 |
+
|
| 23 |
+
# generate labels
|
| 24 |
+
y_train = f(x_train)
|
| 25 |
+
y_test = f(x_test)
|
| 26 |
+
|
| 27 |
+
# reshape the features
|
| 28 |
+
x_train = np.reshape(x_train,(n_train,1))
|
| 29 |
+
x_test = np.reshape(x_test,(n_test,1))
|
| 30 |
+
|
| 31 |
+
train_set_size = len(y_train) - n_calib
|
| 32 |
+
x_train_final = x_train[ : train_set_size]
|
| 33 |
+
x_calib = x_train[train_set_size : ]
|
| 34 |
+
y_train_final = y_train[ : train_set_size]
|
| 35 |
+
y_calib = y_train[train_set_size : ]
|
| 36 |
+
|
| 37 |
+
return x_train_final, y_train_final, x_calib, y_calib, x_test, y_test
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def indicator_matrix(scalar_values, disc):
|
| 41 |
+
scalar_values = np.array(scalar_values)
|
| 42 |
+
|
| 43 |
+
# Create all possible intervals
|
| 44 |
+
intervals = [(disc[i], disc[i + 1]) for i in range(len(disc) - 1)]
|
| 45 |
+
|
| 46 |
+
# Initialize the indicator matrix
|
| 47 |
+
matrix = np.zeros((len(scalar_values), len(intervals)))
|
| 48 |
+
|
| 49 |
+
# Fill in the indicator matrix
|
| 50 |
+
for i, value in enumerate(scalar_values):
|
| 51 |
+
for j, (a, b) in enumerate(intervals):
|
| 52 |
+
if a <= value < b:
|
| 53 |
+
matrix[i, j] = 1
|
| 54 |
+
|
| 55 |
+
return matrix
|
MACI-main/conditional-conformal/src/atomizer.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import numpy as np
|
| 3 |
+
import re
|
| 4 |
+
import string
|
| 5 |
+
import spacy
|
| 6 |
+
import nltk
|
| 7 |
+
from rank_bm25 import BM25Okapi
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 11 |
+
from nltk.tokenize import sent_tokenize
|
| 12 |
+
|
| 13 |
+
nltk.download("punkt")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Atomizer(object):
|
| 17 |
+
def __init__(self, client, demo_dir):
|
| 18 |
+
self.nlp = spacy.load("en_core_web_sm")
|
| 19 |
+
self.is_bio = True
|
| 20 |
+
self.demo_path = os.path.join(demo_dir, "demos.json" if self.is_bio else "demos_complex.json")
|
| 21 |
+
|
| 22 |
+
self.client = client
|
| 23 |
+
|
| 24 |
+
# get the demos
|
| 25 |
+
with open(self.demo_path, 'r') as f:
|
| 26 |
+
self.demos = json.load(f)
|
| 27 |
+
|
| 28 |
+
tokenized_corpus = [doc.split(" ") for doc in self.demos.keys()]
|
| 29 |
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
| 30 |
+
|
| 31 |
+
def save_cache(self):
|
| 32 |
+
self.client.save_cache()
|
| 33 |
+
|
| 34 |
+
def run(self, generation, cost_estimate=None):
|
| 35 |
+
"""Convert the generation into a set of atomic facts. Return a total words cost if cost_estimate != None."""
|
| 36 |
+
assert isinstance(generation, str), "generation must be a string"
|
| 37 |
+
paragraphs = [para.strip() for para in generation.split("\n") if len(para.strip()) > 0]
|
| 38 |
+
return self.get_atomic_facts_from_paragraph(paragraphs, cost_estimate=cost_estimate)
|
| 39 |
+
|
| 40 |
+
def get_atomic_facts_from_paragraph(self, paragraphs, cost_estimate=None):
|
| 41 |
+
sentences = []
|
| 42 |
+
para_breaks = []
|
| 43 |
+
for para_idx, paragraph in enumerate(paragraphs):
|
| 44 |
+
if para_idx > 0 :
|
| 45 |
+
para_breaks.append(len(sentences))
|
| 46 |
+
|
| 47 |
+
initials = detect_initials(paragraph)
|
| 48 |
+
|
| 49 |
+
curr_sentences = sent_tokenize(paragraph)
|
| 50 |
+
curr_sentences_2 = sent_tokenize(paragraph)
|
| 51 |
+
|
| 52 |
+
curr_sentences = fix_sentence_splitter(curr_sentences, initials)
|
| 53 |
+
curr_sentences_2 = fix_sentence_splitter(curr_sentences_2, initials)
|
| 54 |
+
|
| 55 |
+
# checking this, just to ensure the crediability of the sentence splitter fixing algorithm
|
| 56 |
+
assert curr_sentences == curr_sentences_2, (paragraph, curr_sentences, curr_sentences_2)
|
| 57 |
+
|
| 58 |
+
sentences += curr_sentences
|
| 59 |
+
|
| 60 |
+
atoms_or_estimate = self.get_init_atomic_facts_from_sentence([sent for i, sent in enumerate(sentences) if not (not self.is_bio and ( \
|
| 61 |
+
(i==0 and (sent.startswith("Sure") or sent.startswith("Here are"))) or \
|
| 62 |
+
(i==len(sentences)-1 and (sent.startswith("Please") or sent.startswith("I hope") or sent.startswith("Here are")))))], cost_estimate=cost_estimate)
|
| 63 |
+
|
| 64 |
+
if cost_estimate:
|
| 65 |
+
return atoms_or_estimate
|
| 66 |
+
else:
|
| 67 |
+
atoms = atoms_or_estimate
|
| 68 |
+
atomic_facts_pairs = []
|
| 69 |
+
for i, sent in enumerate(sentences):
|
| 70 |
+
if not self.is_bio and ( \
|
| 71 |
+
(i==0 and (sent.startswith("Sure") or sent.startswith("Here are"))) or \
|
| 72 |
+
(i==len(sentences)-1 and (sent.startswith("Please") or sent.startswith("I hope") or sent.startswith("Here are")))):
|
| 73 |
+
atomic_facts_pairs.append((sent, []))
|
| 74 |
+
elif self.is_bio and sent.startswith("This sentence does not contain any facts"):
|
| 75 |
+
atomic_facts_pairs.append((sent, []))
|
| 76 |
+
elif sent.startswith("Sure") or sent.startswith("Please") or (i==0 and sent.startswith("Here are")):
|
| 77 |
+
atomic_facts_pairs.append((sent, []))
|
| 78 |
+
else:
|
| 79 |
+
atomic_facts_pairs.append((sent, atoms[sent]))
|
| 80 |
+
|
| 81 |
+
# postprocess_atomic_facts will fix minor issues from InstructGPT
|
| 82 |
+
# it is supposed to handle sentence splitter issue too, but since here
|
| 83 |
+
# we fixed sentence splitter issue already,
|
| 84 |
+
# the new para_breaks should be identical to the original para_breaks
|
| 85 |
+
if self.is_bio:
|
| 86 |
+
atomic_facts_pairs, para_breaks = postprocess_atomic_facts(atomic_facts_pairs, list(para_breaks), self.nlp)
|
| 87 |
+
|
| 88 |
+
return atomic_facts_pairs, para_breaks
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_init_atomic_facts_from_sentence(self, sentences, cost_estimate=None):
|
| 92 |
+
"""Get the initial atomic facts from the sentences. Return a total words cost if cost_estimate != None."""
|
| 93 |
+
|
| 94 |
+
is_bio = self.is_bio
|
| 95 |
+
demos = self.demos
|
| 96 |
+
|
| 97 |
+
k = 1 if is_bio else 0
|
| 98 |
+
n = 7 if is_bio else 8
|
| 99 |
+
|
| 100 |
+
prompts = []
|
| 101 |
+
prompt_to_sent = {}
|
| 102 |
+
atoms = {}
|
| 103 |
+
for sentence in sentences:
|
| 104 |
+
if sentence in atoms:
|
| 105 |
+
continue
|
| 106 |
+
top_matchings = best_demos(sentence, self.bm25, list(demos.keys()), k)
|
| 107 |
+
prompt = ""
|
| 108 |
+
|
| 109 |
+
for i in range(n):
|
| 110 |
+
prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(list(demos.keys())[i])
|
| 111 |
+
for fact in demos[list(demos.keys())[i]]:
|
| 112 |
+
prompt = prompt + "- {}\n".format(fact)
|
| 113 |
+
prompt = prompt + "\n"
|
| 114 |
+
|
| 115 |
+
for match in top_matchings:
|
| 116 |
+
prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(match)
|
| 117 |
+
for fact in demos[match]:
|
| 118 |
+
prompt = prompt + "- {}\n".format(fact)
|
| 119 |
+
prompt = prompt + "\n"
|
| 120 |
+
prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(sentence)
|
| 121 |
+
prompts.append(prompt)
|
| 122 |
+
prompt_to_sent[prompt] = sentence
|
| 123 |
+
|
| 124 |
+
if cost_estimate:
|
| 125 |
+
total_words_estimate = 0
|
| 126 |
+
for prompt in prompts:
|
| 127 |
+
if cost_estimate == "consider_cache" and (prompt.strip() + "_0") in self.client.cache_dict:
|
| 128 |
+
continue
|
| 129 |
+
total_words_estimate += len(prompt.split())
|
| 130 |
+
return total_words_estimate
|
| 131 |
+
else:
|
| 132 |
+
outputs = []
|
| 133 |
+
|
| 134 |
+
with ThreadPoolExecutor(max_workers=len(prompts)) as executor:
|
| 135 |
+
outputs = list(
|
| 136 |
+
executor.map(
|
| 137 |
+
lambda x : self.client.query(x),
|
| 138 |
+
prompts
|
| 139 |
+
)
|
| 140 |
+
)
|
| 141 |
+
for prompt, output in zip(prompts, outputs):
|
| 142 |
+
atoms[prompt_to_sent[prompt]] = text_to_sentences(output[0]['message'])
|
| 143 |
+
# for prompt in prompts:
|
| 144 |
+
# output = self.client.query(prompt)
|
| 145 |
+
# outputs.append(output)
|
| 146 |
+
# atoms[prompt_to_sent[prompt]] = text_to_sentences(output[0]['message'])
|
| 147 |
+
|
| 148 |
+
self.client.cache_outputs(
|
| 149 |
+
prompts=prompts,
|
| 150 |
+
sample_indices=np.zeros((len(prompts),), dtype=int),
|
| 151 |
+
outputs=outputs
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
for key, value in demos.items():
|
| 155 |
+
if key not in atoms:
|
| 156 |
+
atoms[key] = value
|
| 157 |
+
|
| 158 |
+
return atoms
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def best_demos(query, bm25, demos_sents, k):
|
| 162 |
+
tokenized_query = query.split(" ")
|
| 163 |
+
top_matchings = bm25.get_top_n(tokenized_query, demos_sents, k)
|
| 164 |
+
return top_matchings
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# transform InstructGPT output into sentences
|
| 168 |
+
def text_to_sentences(text):
|
| 169 |
+
sentences = text.split("- ")[1:]
|
| 170 |
+
sentences = [sent.strip()[:-1] if sent.strip()[-1] == '\n' else sent.strip() for sent in sentences]
|
| 171 |
+
if len(sentences) > 0:
|
| 172 |
+
if sentences[-1][-1] != '.':
|
| 173 |
+
sentences[-1] = sentences[-1] + '.'
|
| 174 |
+
else:
|
| 175 |
+
sentences = []
|
| 176 |
+
return sentences
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def normalize_answer(s):
|
| 180 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
| 181 |
+
def remove_articles(text):
|
| 182 |
+
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
| 183 |
+
return re.sub(regex, ' ', text)
|
| 184 |
+
def white_space_fix(text):
|
| 185 |
+
return ' '.join(text.split())
|
| 186 |
+
def remove_punc(text):
|
| 187 |
+
exclude = set(string.punctuation)
|
| 188 |
+
return ''.join(ch for ch in text if ch not in exclude)
|
| 189 |
+
def lower(text):
|
| 190 |
+
return text.lower()
|
| 191 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
| 192 |
+
|
| 193 |
+
MONTHS = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]
|
| 194 |
+
MONTHS = [m.lower() for m in MONTHS]
|
| 195 |
+
|
| 196 |
+
def is_num(text):
|
| 197 |
+
try:
|
| 198 |
+
text = int(text)
|
| 199 |
+
return True
|
| 200 |
+
except Exception:
|
| 201 |
+
return False
|
| 202 |
+
|
| 203 |
+
def is_date(text):
|
| 204 |
+
text = normalize_answer(text)
|
| 205 |
+
for token in text.split(" "):
|
| 206 |
+
if (not is_num(token)) and token not in MONTHS:
|
| 207 |
+
return False
|
| 208 |
+
return True
|
| 209 |
+
|
| 210 |
+
def extract_numeric_values(text):
|
| 211 |
+
pattern = r'\b\d+\b' # regular expression pattern for integers
|
| 212 |
+
numeric_values = re.findall(pattern, text) # find all numeric values in the text
|
| 213 |
+
return set([value for value in numeric_values]) # convert the values to float and return as a list
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def detect_entities(text, nlp):
|
| 217 |
+
doc = nlp(text)
|
| 218 |
+
entities = set()
|
| 219 |
+
|
| 220 |
+
def _add_to_entities(text):
|
| 221 |
+
if "-" in text:
|
| 222 |
+
for _text in text.split("-"):
|
| 223 |
+
entities.add(_text.strip())
|
| 224 |
+
else:
|
| 225 |
+
entities.add(text)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
for ent in doc.ents:
|
| 229 |
+
# spacy often has errors with other types of entities
|
| 230 |
+
if ent.label_ in ["DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"]:
|
| 231 |
+
|
| 232 |
+
if is_date(ent.text):
|
| 233 |
+
_add_to_entities(ent.text)
|
| 234 |
+
else:
|
| 235 |
+
for token in ent.text.split():
|
| 236 |
+
if is_date(token):
|
| 237 |
+
_add_to_entities(token)
|
| 238 |
+
|
| 239 |
+
for new_ent in extract_numeric_values(text):
|
| 240 |
+
if not np.any([new_ent in ent for ent in entities]):
|
| 241 |
+
entities.add(new_ent)
|
| 242 |
+
|
| 243 |
+
return entities
|
| 244 |
+
|
| 245 |
+
def postprocess_atomic_facts(_atomic_facts, para_breaks, nlp):
|
| 246 |
+
|
| 247 |
+
verbs = ["born.", " appointed.", " characterized.", " described.", " known.", " member.", " advocate.", "served.", "elected."]
|
| 248 |
+
permitted_verbs = ["founding member."]
|
| 249 |
+
|
| 250 |
+
atomic_facts = []
|
| 251 |
+
new_atomic_facts = []
|
| 252 |
+
new_para_breaks = []
|
| 253 |
+
|
| 254 |
+
for i, (sent, facts) in enumerate(_atomic_facts):
|
| 255 |
+
sent = sent.strip()
|
| 256 |
+
if len(sent.split())==1 and i not in para_breaks and i > 0:
|
| 257 |
+
assert i not in para_breaks
|
| 258 |
+
atomic_facts[-1][0] += " " + sent
|
| 259 |
+
atomic_facts[-1][1] += facts
|
| 260 |
+
else:
|
| 261 |
+
if i in para_breaks:
|
| 262 |
+
new_para_breaks.append(len(atomic_facts))
|
| 263 |
+
atomic_facts.append([sent, facts])
|
| 264 |
+
|
| 265 |
+
for i, (sent, facts) in enumerate(atomic_facts):
|
| 266 |
+
entities = detect_entities(sent, nlp)
|
| 267 |
+
covered_entities = set()
|
| 268 |
+
# print (entities)
|
| 269 |
+
new_facts = []
|
| 270 |
+
for i, fact in enumerate(facts):
|
| 271 |
+
if any([fact.endswith(verb) for verb in verbs]) and not any([fact.endswith(verb) for verb in permitted_verbs]):
|
| 272 |
+
if any([fact[:-1] in other_fact for j, other_fact in enumerate(facts) if j != i]):
|
| 273 |
+
continue
|
| 274 |
+
sent_entities = detect_entities(fact, nlp)
|
| 275 |
+
covered_entities |= set([e for e in sent_entities if e in entities])
|
| 276 |
+
new_entities = sent_entities - entities
|
| 277 |
+
if len(new_entities) > 0:
|
| 278 |
+
do_pass = False
|
| 279 |
+
for new_ent in new_entities:
|
| 280 |
+
pre_ent = None
|
| 281 |
+
for ent in entities:
|
| 282 |
+
if ent.startswith(new_ent):
|
| 283 |
+
pre_ent = ent
|
| 284 |
+
break
|
| 285 |
+
if pre_ent is None:
|
| 286 |
+
do_pass = True
|
| 287 |
+
break
|
| 288 |
+
fact = fact.replace(new_ent, pre_ent)
|
| 289 |
+
covered_entities.add(pre_ent)
|
| 290 |
+
if do_pass:
|
| 291 |
+
continue
|
| 292 |
+
if fact in new_facts:
|
| 293 |
+
continue
|
| 294 |
+
new_facts.append(fact)
|
| 295 |
+
try:
|
| 296 |
+
assert entities==covered_entities
|
| 297 |
+
except Exception:
|
| 298 |
+
new_facts = facts # there is a bug in spacy entity linker, so just go with the previous facts
|
| 299 |
+
|
| 300 |
+
new_atomic_facts.append((sent, new_facts))
|
| 301 |
+
|
| 302 |
+
return new_atomic_facts, new_para_breaks
|
| 303 |
+
|
| 304 |
+
def is_integer(s):
|
| 305 |
+
try:
|
| 306 |
+
s = int(s)
|
| 307 |
+
return True
|
| 308 |
+
except Exception:
|
| 309 |
+
return False
|
| 310 |
+
|
| 311 |
+
def detect_initials(text):
|
| 312 |
+
pattern = r"[A-Z]\. ?[A-Z]\."
|
| 313 |
+
match = re.findall(pattern, text)
|
| 314 |
+
return [m for m in match]
|
| 315 |
+
|
| 316 |
+
def fix_sentence_splitter(curr_sentences, initials):
|
| 317 |
+
for initial in initials:
|
| 318 |
+
if not np.any([initial in sent for sent in curr_sentences]):
|
| 319 |
+
alpha1, alpha2 = [t.strip() for t in initial.split(".") if len(t.strip())>0]
|
| 320 |
+
for i, (sent1, sent2) in enumerate(zip(curr_sentences, curr_sentences[1:])):
|
| 321 |
+
if sent1.endswith(alpha1 + ".") and sent2.startswith(alpha2 + "."):
|
| 322 |
+
# merge sentence i and i+1
|
| 323 |
+
curr_sentences = curr_sentences[:i] + [curr_sentences[i] + " " + curr_sentences[i+1]] + curr_sentences[i+2:]
|
| 324 |
+
break
|
| 325 |
+
sentences = []
|
| 326 |
+
combine_with_previous = None
|
| 327 |
+
for sent_idx, sent in enumerate(curr_sentences):
|
| 328 |
+
if len(sent.split())<=1 and sent_idx==0:
|
| 329 |
+
assert not combine_with_previous
|
| 330 |
+
combine_with_previous = True
|
| 331 |
+
sentences.append(sent)
|
| 332 |
+
elif len(sent.split())<=1:
|
| 333 |
+
assert sent_idx > 0
|
| 334 |
+
sentences[-1] += " " + sent
|
| 335 |
+
combined_with_previous = False
|
| 336 |
+
elif sent[0].isalpha() and not sent[0].isupper() and sent_idx > 0:
|
| 337 |
+
assert sent_idx > 0, curr_sentences
|
| 338 |
+
sentences[-1] += " " + sent
|
| 339 |
+
combine_with_previous = False
|
| 340 |
+
elif combine_with_previous:
|
| 341 |
+
assert sent_idx > 0
|
| 342 |
+
sentences[-1] += " " + sent
|
| 343 |
+
combine_with_previous = False
|
| 344 |
+
else:
|
| 345 |
+
assert not combine_with_previous
|
| 346 |
+
sentences.append(sent)
|
| 347 |
+
return sentences
|
MACI-main/conditional-conformal/src/aws_utils.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import boto3
|
| 2 |
+
import io
|
| 3 |
+
|
| 4 |
+
def s3_open(bucket_name, key):
|
| 5 |
+
# Create a session using your AWS credentials
|
| 6 |
+
session = boto3.Session()
|
| 7 |
+
# Create an S3 client
|
| 8 |
+
s3 = session.client('s3')
|
| 9 |
+
|
| 10 |
+
# Download the file object
|
| 11 |
+
response = s3.get_object(Bucket=bucket_name, Key=key)
|
| 12 |
+
file_content = response['Body'].read()
|
| 13 |
+
|
| 14 |
+
# Return a BytesIO object to mimic a file object
|
| 15 |
+
return io.BytesIO(file_content)
|
MACI-main/conditional-conformal/src/client.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
from typing import Any, List
|
| 6 |
+
|
| 7 |
+
class Client:
|
| 8 |
+
"""
|
| 9 |
+
Wrapper class for language models that we query. It keeps a cache of prompts and
|
| 10 |
+
responses so that we don't have to requery things in experiments.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, cache_file, model : str = 'gpt-3.5-turbo'):
|
| 14 |
+
self.cache_file = cache_file
|
| 15 |
+
self.cache_dict = self.load_cache()
|
| 16 |
+
self.model = model
|
| 17 |
+
self.modified_cache = False
|
| 18 |
+
|
| 19 |
+
def load_model(self):
|
| 20 |
+
# load the model and put it as self.model
|
| 21 |
+
raise NotImplementedError()
|
| 22 |
+
|
| 23 |
+
def query(
|
| 24 |
+
self,
|
| 25 |
+
prompt : str,
|
| 26 |
+
sample_idx : int = 0,
|
| 27 |
+
**kwargs
|
| 28 |
+
):
|
| 29 |
+
prompt = prompt.strip() # it's important not to end with a whitespace
|
| 30 |
+
cache_key = f"{prompt}_{sample_idx}"
|
| 31 |
+
|
| 32 |
+
if cache_key in self.cache_dict:
|
| 33 |
+
return self.cache_dict[cache_key]
|
| 34 |
+
|
| 35 |
+
if self.model is None:
|
| 36 |
+
self.load_model()
|
| 37 |
+
# print("I didn't find a cached copy!")
|
| 38 |
+
output = self._query(prompt, **kwargs)
|
| 39 |
+
|
| 40 |
+
return output
|
| 41 |
+
|
| 42 |
+
def cache_outputs(
|
| 43 |
+
self,
|
| 44 |
+
prompts : List[str],
|
| 45 |
+
sample_indices : List[int],
|
| 46 |
+
outputs : List[Any]
|
| 47 |
+
):
|
| 48 |
+
for prompt, sample_idx, output in zip(prompts, sample_indices, outputs):
|
| 49 |
+
prompt = prompt.strip()
|
| 50 |
+
cache_key = f"{prompt}_{sample_idx}"
|
| 51 |
+
self.cache_dict[cache_key] = output
|
| 52 |
+
self.modified_cache = True
|
| 53 |
+
|
| 54 |
+
def save_cache(self):
|
| 55 |
+
if self.modified_cache == False:
|
| 56 |
+
return
|
| 57 |
+
|
| 58 |
+
# load the latest cache first, since if there were other processes running in parallel, cache might have been updated
|
| 59 |
+
for k, v in self.load_cache().items():
|
| 60 |
+
self.cache_dict[k] = v
|
| 61 |
+
|
| 62 |
+
with open(self.cache_file, "wb") as f:
|
| 63 |
+
pickle.dump(self.cache_dict, f)
|
| 64 |
+
|
| 65 |
+
def load_cache(self, allow_retry=True):
|
| 66 |
+
if os.path.exists(self.cache_file):
|
| 67 |
+
while True:
|
| 68 |
+
try:
|
| 69 |
+
with open(self.cache_file, "rb") as f:
|
| 70 |
+
cache = pickle.load(f)
|
| 71 |
+
break
|
| 72 |
+
except Exception: # if there are concurent processes, things can fail
|
| 73 |
+
if not allow_retry:
|
| 74 |
+
assert False
|
| 75 |
+
print ("Pickle Error: Retry in 5sec...")
|
| 76 |
+
time.sleep(5)
|
| 77 |
+
elif 's3' in self.cache_file:
|
| 78 |
+
from aws_utils import s3_open
|
| 79 |
+
s3_path = self.cache_file.removeprefix('s3://')
|
| 80 |
+
bucket_name = s3_path.split('/')[0]
|
| 81 |
+
path_to_file = '/'.join(s3_path.split('/')[1:])
|
| 82 |
+
with s3_open(bucket_name, path_to_file) as fp:
|
| 83 |
+
cache = pickle.load(fp)
|
| 84 |
+
else:
|
| 85 |
+
cache = {}
|
| 86 |
+
return cache
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
MACI-main/conditional-conformal/src/config.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import munch
|
| 2 |
+
import toml
|
| 3 |
+
|
| 4 |
+
def get_config(filepath: str = 'configs/default.toml'):
|
| 5 |
+
return munch.munchify(
|
| 6 |
+
toml.load(filepath)
|
| 7 |
+
)
|
MACI-main/conditional-conformal/src/conformal.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from typing import Callable, List
|
| 4 |
+
|
| 5 |
+
def compute_conformity_scores(
|
| 6 |
+
dataset : List,
|
| 7 |
+
scores_list : List,
|
| 8 |
+
):
|
| 9 |
+
annotations_list = [
|
| 10 |
+
np.asarray([c['is_supported'] for c in unit['atomic_facts']])
|
| 11 |
+
for unit in dataset
|
| 12 |
+
]
|
| 13 |
+
conf_scores = [np.max(scores[~annotes], initial=0) for scores, annotes in zip(scores_list, annotations_list)]
|
| 14 |
+
return conf_scores
|
| 15 |
+
|
| 16 |
+
def calibrate_thresholds(
|
| 17 |
+
feats_test : List,
|
| 18 |
+
feats_valid : List,
|
| 19 |
+
scores_valid : List,
|
| 20 |
+
alpha_fn : Callable
|
| 21 |
+
) -> List[float]:
|
| 22 |
+
alpha_valid = alpha_fn(feats_valid)
|
| 23 |
+
quantile = np.ceil((1 - alpha_valid[0]) * (len(feats_valid) + 1)) / len(feats_valid)
|
| 24 |
+
return [np.quantile(
|
| 25 |
+
scores_valid,
|
| 26 |
+
q=quantile
|
| 27 |
+
)] * len(feats_test)
|
| 28 |
+
|
| 29 |
+
def conformal_filter(
|
| 30 |
+
dataset : List,
|
| 31 |
+
scores_list : List,
|
| 32 |
+
thresholds : List
|
| 33 |
+
) -> List:
|
| 34 |
+
for unit, scores, t in zip(dataset, scores_list, thresholds):
|
| 35 |
+
filtered_claims = [
|
| 36 |
+
c for c, s in zip(unit['atomic_facts'], scores) if s >= t
|
| 37 |
+
]
|
| 38 |
+
unit['filtered_claims'] = filtered_claims
|
| 39 |
+
return dataset
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def assess_factscore_coverage(
|
| 43 |
+
dataset : List,
|
| 44 |
+
nominal_alpha : float
|
| 45 |
+
) -> None:
|
| 46 |
+
nonfactual_list = []
|
| 47 |
+
nonfactual_grps = {}
|
| 48 |
+
for d in dataset:
|
| 49 |
+
nonfactual = 'F' in [c['is_supported'] for c in d['filtered_claims']]
|
| 50 |
+
nonfactual_list.append(nonfactual)
|
| 51 |
+
|
| 52 |
+
# right now metadata is only *two* strings...TODO this needs to be more flexible
|
| 53 |
+
# if tuple(d['metadata']) not in nonfactual_grps:
|
| 54 |
+
# nonfactual_grps[tuple(d['metadata'])] = [nonfactual]
|
| 55 |
+
# else:
|
| 56 |
+
# nonfactual_grps[tuple(d['metadata'])].append(nonfactual)
|
| 57 |
+
# if d['metadata'][0] not in nonfactual_grps:
|
| 58 |
+
# nonfactual_grps[d['metadata'][0]] = [nonfactual]
|
| 59 |
+
# else:
|
| 60 |
+
# nonfactual_grps[d['metadata'][0]].append(nonfactual)
|
| 61 |
+
# if d['metadata'][1] not in nonfactual_grps:
|
| 62 |
+
# nonfactual_grps[d['metadata'][1]] = [nonfactual]
|
| 63 |
+
# else:
|
| 64 |
+
# nonfactual_grps[d['metadata'][1]].append(nonfactual)
|
| 65 |
+
print(f"Nominal coverage: {nominal_alpha}")
|
| 66 |
+
print(f"Realized marginal coverage: {np.mean(nonfactual_list)}")
|
| 67 |
+
# for grp, nonfactuals in nonfactual_grps.items():
|
| 68 |
+
# print(f"Realized {grp} coverage: {np.mean(nonfactuals)}")
|
MACI-main/conditional-conformal/src/data_utils/sample_names.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import numpy as np
|
| 3 |
+
import requests
|
| 4 |
+
|
| 5 |
+
from typing import Dict
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 9 |
+
|
| 10 |
+
ENTITY_PATH = '/data/jcherian/wikipedia_entity_map.npz'
|
| 11 |
+
WIKIDATA_URL = "https://www.wikidata.org/w/api.php"
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
logging.basicConfig(filename='human.log', level=logging.INFO)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_id(response : Dict) -> str:
|
| 17 |
+
if response.get("entities", None) is None:
|
| 18 |
+
return None
|
| 19 |
+
wikidata_codes = list(response['entities'].keys())
|
| 20 |
+
assert len(wikidata_codes) == 1
|
| 21 |
+
return wikidata_codes[0]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def is_human(response : Dict, id: str) -> bool:
|
| 25 |
+
instances = response['entities'][id]['claims'].get('P31', [])
|
| 26 |
+
for inst in instances:
|
| 27 |
+
if inst['mainsnak']['datavalue']['value']['id'] == 'Q5':
|
| 28 |
+
return True
|
| 29 |
+
return False
|
| 30 |
+
|
| 31 |
+
def validate_entity(k):
|
| 32 |
+
name = k.split('/')[-1]
|
| 33 |
+
adapter = requests.adapters.HTTPAdapter(max_retries=10)
|
| 34 |
+
with requests.session() as s:
|
| 35 |
+
s.mount("https://", adapter)
|
| 36 |
+
response = s.get(url=WIKIDATA_URL, params={"action" : "wbgetentities",
|
| 37 |
+
"sites" : "enwiki",
|
| 38 |
+
"titles" : name,
|
| 39 |
+
"normalize": "1",
|
| 40 |
+
"languages": "en",
|
| 41 |
+
"format": "json",
|
| 42 |
+
"props": "claims"})
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
response = response.json()
|
| 46 |
+
except:
|
| 47 |
+
print(response.text)
|
| 48 |
+
|
| 49 |
+
wiki_id = get_id(response)
|
| 50 |
+
|
| 51 |
+
if wiki_id is None:
|
| 52 |
+
return name, False
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
human = is_human(response, wiki_id)
|
| 56 |
+
except:
|
| 57 |
+
return name, False
|
| 58 |
+
logger.info(f"{name}, {human}")
|
| 59 |
+
return name, human
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
wiki_entities = np.load(ENTITY_PATH)
|
| 64 |
+
entity_names = list(wiki_entities.keys())
|
| 65 |
+
try:
|
| 66 |
+
with ThreadPoolExecutor(max_workers=5) as executor:
|
| 67 |
+
res = list(
|
| 68 |
+
tqdm(
|
| 69 |
+
executor.map(
|
| 70 |
+
lambda k : validate_entity(k),
|
| 71 |
+
entity_names
|
| 72 |
+
),
|
| 73 |
+
total=len(entity_names)
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
except:
|
| 77 |
+
import pickle
|
| 78 |
+
with open('human.pkl', 'wb') as fp:
|
| 79 |
+
pickle.dump(res, fp)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
import pickle
|
| 83 |
+
with open('human.pkl', 'wb') as fp:
|
| 84 |
+
pickle.dump(res, fp)
|
| 85 |
+
|
| 86 |
+
import IPython; IPython.embed()
|
MACI-main/conditional-conformal/src/dataset.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from typing import List, Tuple
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
from atomizer import Atomizer, text_to_sentences
|
| 12 |
+
from gpt import GPTClient
|
| 13 |
+
from scorer import Scorer
|
| 14 |
+
|
| 15 |
+
def get_prompts(
|
| 16 |
+
dataset : str,
|
| 17 |
+
data_path : str = None
|
| 18 |
+
) -> List:
|
| 19 |
+
if dataset.lower() == "factscore":
|
| 20 |
+
with open('data/factscore_names.txt', 'r') as fp:
|
| 21 |
+
names = fp.readlines()
|
| 22 |
+
names = [name.strip() for name in names]
|
| 23 |
+
prompts = [
|
| 24 |
+
f"Please write one biographical paragraph about {name.strip()}."
|
| 25 |
+
for name in names
|
| 26 |
+
]
|
| 27 |
+
return names, prompts
|
| 28 |
+
if dataset.lower() == "factscore_v2":
|
| 29 |
+
with open('data/factscore_v2_names.txt', 'r') as fp:
|
| 30 |
+
names = fp.readlines()
|
| 31 |
+
names = [name.strip() for name in names]
|
| 32 |
+
prompts = [
|
| 33 |
+
f"Please write one biographical paragraph about {name.strip()}."
|
| 34 |
+
for name in names
|
| 35 |
+
]
|
| 36 |
+
return names, prompts
|
| 37 |
+
if dataset.lower() == "factscore_v3":
|
| 38 |
+
with open('data/factscore_v3_names.txt', 'r') as fp:
|
| 39 |
+
names = fp.readlines()
|
| 40 |
+
names = [name.strip() for name in names]
|
| 41 |
+
prompts = [
|
| 42 |
+
f"Please write one biographical paragraph about {name.strip()}."
|
| 43 |
+
for name in names
|
| 44 |
+
]
|
| 45 |
+
return names, prompts
|
| 46 |
+
|
| 47 |
+
if dataset.lower() == "factscore_final":
|
| 48 |
+
df = pd.read_csv(data_path, index_col=0)
|
| 49 |
+
names = set([n.strip() for n in df['Name']])
|
| 50 |
+
prompts = [
|
| 51 |
+
f"Please write one biographical paragraph about {name.strip()}."
|
| 52 |
+
for name in names
|
| 53 |
+
]
|
| 54 |
+
return names, prompts
|
| 55 |
+
|
| 56 |
+
if dataset.lower() == "medlfqa":
|
| 57 |
+
datasets = {}
|
| 58 |
+
|
| 59 |
+
suffix = "_test_MedLFQA.jsonl"
|
| 60 |
+
|
| 61 |
+
dataset_dir = "/Users/cherian/Projects/OLAPH/MedLFQA"
|
| 62 |
+
for path in os.listdir(dataset_dir):
|
| 63 |
+
if "MedLFQA" not in path:
|
| 64 |
+
continue
|
| 65 |
+
dataset_name = path[:-len(suffix)]
|
| 66 |
+
with open(os.path.join(dataset_dir, path), 'r') as fp:
|
| 67 |
+
datasets[dataset_name] = [json.loads(line) for line in fp.readlines()]
|
| 68 |
+
|
| 69 |
+
prompts = []
|
| 70 |
+
for _, dataset in datasets.items():
|
| 71 |
+
prompts += [pt['Question'] for pt in dataset]
|
| 72 |
+
prompts = list(set(prompts))
|
| 73 |
+
return prompts, prompts
|
| 74 |
+
|
| 75 |
+
if dataset.lower() == "medlfqav2":
|
| 76 |
+
datasets = {}
|
| 77 |
+
|
| 78 |
+
suffix = ".jsonl"
|
| 79 |
+
|
| 80 |
+
for filename in os.listdir(data_path):
|
| 81 |
+
dataset_name = filename[:-len(suffix)]
|
| 82 |
+
with open(os.path.join(data_path, filename), 'r') as fp:
|
| 83 |
+
datasets[dataset_name] = [json.loads(line) for line in fp.readlines()]
|
| 84 |
+
|
| 85 |
+
prompts = []
|
| 86 |
+
for _, dataset in datasets.items():
|
| 87 |
+
prompts += [pt['Question'] for pt in dataset]
|
| 88 |
+
|
| 89 |
+
return prompts, prompts
|
| 90 |
+
|
| 91 |
+
else:
|
| 92 |
+
raise ValueError("Unsupported data set.")
|
| 93 |
+
|
| 94 |
+
def find_unique_element(lst, condition, approx_index):
|
| 95 |
+
# Check the approximate index first
|
| 96 |
+
if condition(lst[approx_index]):
|
| 97 |
+
return approx_index
|
| 98 |
+
|
| 99 |
+
# Initialize left and right pointers
|
| 100 |
+
left = approx_index - 1
|
| 101 |
+
right = approx_index + 1
|
| 102 |
+
|
| 103 |
+
# Expand outwards from the approximate index
|
| 104 |
+
while left >= 0 or right < len(lst):
|
| 105 |
+
if left >= 0 and condition(lst[left]):
|
| 106 |
+
return left
|
| 107 |
+
if right < len(lst) and condition(lst[right]):
|
| 108 |
+
return right
|
| 109 |
+
left -= 1
|
| 110 |
+
right += 1
|
| 111 |
+
|
| 112 |
+
# If no element satisfies the condition, return None or raise an exception
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
def load_dataset(
|
| 116 |
+
config : dict
|
| 117 |
+
) -> List:
|
| 118 |
+
|
| 119 |
+
print("Loading responder.")
|
| 120 |
+
responder = GPTClient(config.model.responder.cache_path)
|
| 121 |
+
|
| 122 |
+
topics, prompts = get_prompts(config.dataset.name, config.dataset.path)
|
| 123 |
+
|
| 124 |
+
with ThreadPoolExecutor(max_workers=25) as executor:
|
| 125 |
+
responses = list(
|
| 126 |
+
tqdm(
|
| 127 |
+
executor.map(
|
| 128 |
+
lambda x : responder.query(x),
|
| 129 |
+
prompts
|
| 130 |
+
),
|
| 131 |
+
total=len(prompts)
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# TODO: Uncomment me if I want to run fresh dataset...
|
| 136 |
+
|
| 137 |
+
responder.cache_outputs(
|
| 138 |
+
prompts,
|
| 139 |
+
np.zeros((len(responses),), dtype=int),
|
| 140 |
+
responses
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
responder.save_cache()
|
| 144 |
+
|
| 145 |
+
responses = [r[0] for r in responses]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
outputs = [{'prompt': p, 'response': o['message']}
|
| 149 |
+
for p, o in zip(prompts, responses)] # first output is the response we will filter
|
| 150 |
+
|
| 151 |
+
import IPython; IPython.embed()
|
| 152 |
+
print("Loading atomizer.")
|
| 153 |
+
atomizer_client = GPTClient(config.model.parser.cache_path, model=config.model.parser.name)
|
| 154 |
+
|
| 155 |
+
atomizer = Atomizer(atomizer_client, demo_dir='data/demos')
|
| 156 |
+
|
| 157 |
+
CACHE_EXISTS = True
|
| 158 |
+
|
| 159 |
+
if CACHE_EXISTS: # TODO: dumb hard-coded variable to side step the slow retrieval
|
| 160 |
+
ordered_messages = [r['message'] for r in responses]
|
| 161 |
+
|
| 162 |
+
responder_cache = responder.cache_dict
|
| 163 |
+
messages = []
|
| 164 |
+
for val in responder_cache.values():
|
| 165 |
+
messages.append(val[0]['message'])
|
| 166 |
+
|
| 167 |
+
atomizer_cache = atomizer_client.cache_dict
|
| 168 |
+
idx_guess = 0
|
| 169 |
+
atomic_facts = [[] for _ in range(len(messages))]
|
| 170 |
+
atomic_facts_ph = [[] for _ in range(len(messages))]
|
| 171 |
+
|
| 172 |
+
sentences = defaultdict(int)
|
| 173 |
+
for k in tqdm(atomizer_cache.keys()):
|
| 174 |
+
atomized_msg = atomizer_cache[k][0]['message']
|
| 175 |
+
atomized_facts = text_to_sentences(atomized_msg)
|
| 176 |
+
sentence = k.split('\n')[-1].split('facts:')[-1].strip()[:-2]
|
| 177 |
+
cur_idx = -1
|
| 178 |
+
sentences[sentence] += 1
|
| 179 |
+
# if the sentence has appeared more than once we need to find the appropriate match...
|
| 180 |
+
for i in range(sentences[sentence]):
|
| 181 |
+
cur_idx = find_unique_element(messages[cur_idx + 1:], lambda x: sentence in x, approx_index=idx_guess)
|
| 182 |
+
if cur_idx is None: # TODO: TERRIBLE SPECIAL CASING that I looked at by hand...
|
| 183 |
+
raise ValueError()
|
| 184 |
+
if idx_guess in (4148, 4149, 4150):
|
| 185 |
+
cur_idx = 4149
|
| 186 |
+
elif cur_idx == 993:
|
| 187 |
+
cur_idx = 993
|
| 188 |
+
else:
|
| 189 |
+
continue
|
| 190 |
+
idx_guess = cur_idx
|
| 191 |
+
atomic_facts[cur_idx].extend(atomized_facts)
|
| 192 |
+
|
| 193 |
+
for af, msg in zip(atomic_facts, messages):
|
| 194 |
+
if len(af) == 0:
|
| 195 |
+
continue
|
| 196 |
+
new_idx = ordered_messages.index(msg)
|
| 197 |
+
atomic_facts_ph[new_idx] = af
|
| 198 |
+
atomic_facts = atomic_facts_ph
|
| 199 |
+
|
| 200 |
+
else:
|
| 201 |
+
with ThreadPoolExecutor(max_workers=10) as executor:
|
| 202 |
+
atoms = list(
|
| 203 |
+
tqdm(
|
| 204 |
+
executor.map(
|
| 205 |
+
lambda x : atomizer.run(*x),
|
| 206 |
+
[(o['response'],) for o in outputs]
|
| 207 |
+
),
|
| 208 |
+
total=len(outputs)
|
| 209 |
+
)
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
atomizer.save_cache()
|
| 213 |
+
atomic_facts = [[fact for _, facts in atom[0] for fact in facts] for atom in atoms]
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
dataset = []
|
| 217 |
+
|
| 218 |
+
for p, r, af in zip(prompts, responses, atomic_facts):
|
| 219 |
+
atoms = [{'atom': fact} for fact in af]
|
| 220 |
+
data_pt = {'prompt': p, 'response': r, 'atomic_facts': atoms}
|
| 221 |
+
dataset.append(data_pt)
|
| 222 |
+
|
| 223 |
+
# time to annotate responses using factscore code
|
| 224 |
+
print("Loading annotator.")
|
| 225 |
+
scorer_client = GPTClient(config.model.annotator.cache_path, model=config.model.annotator.name)
|
| 226 |
+
scorer = Scorer(scorer_client, config, model_name="retrieval")
|
| 227 |
+
|
| 228 |
+
scorer_inputs = [(topic, output['response'], fact) for topic, output, fact in zip(topics, outputs, atomic_facts)]
|
| 229 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
| 230 |
+
scores = list(
|
| 231 |
+
tqdm(
|
| 232 |
+
executor.map(
|
| 233 |
+
lambda x : scorer.get_score(*x, knowledge_source='medlfqa'),
|
| 234 |
+
scorer_inputs
|
| 235 |
+
),
|
| 236 |
+
total=len(scorer_inputs)
|
| 237 |
+
)
|
| 238 |
+
)
|
| 239 |
+
# scorer.save_cache()
|
| 240 |
+
|
| 241 |
+
dataset = []
|
| 242 |
+
|
| 243 |
+
for p, r, s in zip(prompts, responses, scores):
|
| 244 |
+
data_pt = {
|
| 245 |
+
'prompt': p,
|
| 246 |
+
'response': r,
|
| 247 |
+
'atomic_facts': s['decisions'][0]
|
| 248 |
+
}
|
| 249 |
+
dataset.append(data_pt)
|
| 250 |
+
|
| 251 |
+
import IPython; IPython.embed()
|
| 252 |
+
|
| 253 |
+
return dataset
|
| 254 |
+
|
| 255 |
+
def split_dataset(
|
| 256 |
+
dataset : List,
|
| 257 |
+
train_perc : float = 0.33,
|
| 258 |
+
valid_perc : float = 0.33,
|
| 259 |
+
rng : np.random.Generator = None
|
| 260 |
+
) -> Tuple[List, List, List]:
|
| 261 |
+
"""
|
| 262 |
+
Splits dataset into three parts. Split into training and validation is specified here.
|
| 263 |
+
"""
|
| 264 |
+
total_length = len(dataset)
|
| 265 |
+
|
| 266 |
+
# Calculate lengths of each part based on percentages
|
| 267 |
+
len1 = int(total_length * train_perc)
|
| 268 |
+
len2 = int(total_length * valid_perc)
|
| 269 |
+
|
| 270 |
+
# if rng passed in, shuffle the dataset
|
| 271 |
+
if rng is not None:
|
| 272 |
+
rng.shuffle(dataset)
|
| 273 |
+
|
| 274 |
+
# Split the list using slicing
|
| 275 |
+
train_data = dataset[:len1]
|
| 276 |
+
valid_data = dataset[len1:len1+len2]
|
| 277 |
+
test_data = dataset[len1+len2:]
|
| 278 |
+
|
| 279 |
+
return train_data, valid_data, test_data
|
MACI-main/conditional-conformal/src/featurizer.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
from functools import partial
|
| 5 |
+
from tqdm.contrib.concurrent import process_map
|
| 6 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 7 |
+
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
import client
|
| 11 |
+
|
| 12 |
+
from typing import Dict, List, Tuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
SELF_ASSESS_PROMPT = 'You will get a list of claims and the original prompt that motivated these claims. For each claim, assess the probability of correctness. Directly return a jsonl, where each line is {"id":[CLAIM_ID], "gpt-score":[SCORE]}. Directly return the jsonl with NO explanation or ANY other formatting. For the [SCORE], return the esimated probability of correctness to three significant figures.\n'
|
| 16 |
+
|
| 17 |
+
SELF_BOOL_PROMPT = 'You will get a list of claims and the original prompt that motivated these claims. For each claim, assess the correctness. Directly return a jsonl, where each line is {"id":[CLAIM_ID], "gpt-bool":[BOOL]}. Directly return the jsonl with NO explanation or ANY other formatting. For the [BOOL], return "T" or "F" in quotes so that it is valid json.\n'
|
| 18 |
+
|
| 19 |
+
MAX_WORKERS = 20
|
| 20 |
+
|
| 21 |
+
def get_features(
|
| 22 |
+
dataset: List[Dict],
|
| 23 |
+
config : Dict
|
| 24 |
+
) -> np.ndarray:
|
| 25 |
+
from gpt import GPTClient
|
| 26 |
+
feature_names = config.model.prob.features
|
| 27 |
+
all_features = []
|
| 28 |
+
if 'frequency' in feature_names:
|
| 29 |
+
client = GPTClient(f'.cache/{config.dataset.name}_frequency.pkl')
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
with ThreadPoolExecutor(max_workers=5) as executor:
|
| 33 |
+
frequencies = list(
|
| 34 |
+
tqdm(
|
| 35 |
+
executor.map(
|
| 36 |
+
lambda x: get_frequency(client, [af['atom'] for af in x['atomic_facts']], x['prompt'], config.model.prob.frequency.model),
|
| 37 |
+
dataset
|
| 38 |
+
),
|
| 39 |
+
total=len(dataset)
|
| 40 |
+
)
|
| 41 |
+
)
|
| 42 |
+
client.save_cache()
|
| 43 |
+
all_features.append(np.concatenate(frequencies).reshape(-1,1))
|
| 44 |
+
|
| 45 |
+
if 'selfeval' in feature_names:
|
| 46 |
+
|
| 47 |
+
eval_client = GPTClient(f'.cache/{config.dataset.name}_self_evals.pkl')
|
| 48 |
+
|
| 49 |
+
with ThreadPoolExecutor(max_workers=25) as executor:
|
| 50 |
+
self_evals = list(
|
| 51 |
+
tqdm(
|
| 52 |
+
executor.map(
|
| 53 |
+
lambda x: get_self_eval(x['prompt'], [af['atom'] for af in x['atomic_facts']], eval_client),
|
| 54 |
+
dataset
|
| 55 |
+
),
|
| 56 |
+
total=len(dataset)
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
eval_client.save_cache()
|
| 60 |
+
all_features.append(np.concatenate(self_evals).reshape(-1,1))
|
| 61 |
+
|
| 62 |
+
features = np.concatenate(
|
| 63 |
+
all_features,
|
| 64 |
+
axis=1
|
| 65 |
+
)
|
| 66 |
+
return features
|
| 67 |
+
|
| 68 |
+
# def get_features(
|
| 69 |
+
# dataset : List[Dict],
|
| 70 |
+
# config : Dict
|
| 71 |
+
# ) -> np.ndarray:
|
| 72 |
+
# feature_names = config.features
|
| 73 |
+
# num_claims = np.sum([len(dat['claims']) for dat in dataset])
|
| 74 |
+
# all_features = []
|
| 75 |
+
# for feat in feature_names:
|
| 76 |
+
# if feat == "embedding":
|
| 77 |
+
# embeds = np.zeros((num_claims, int(config.embedding.n_dimensions)))
|
| 78 |
+
# print("Fetching embeddings.")
|
| 79 |
+
# embedding_func = partial(get_embedding, model=config.embedding.model, n_dim=config.embedding.n_dimensions)
|
| 80 |
+
# res = process_map(embedding_func, [dat['claims'] for dat in dataset], max_workers=MAX_WORKERS)
|
| 81 |
+
# i = 0
|
| 82 |
+
# for dat in tqdm(dataset):
|
| 83 |
+
# len_dat = len(dat['claims'])
|
| 84 |
+
# embeds[i:(i + len_dat)] = get_embedding(dat['claims'], config.embedding.model, config.embedding.n_dimensions)
|
| 85 |
+
# i += len_dat
|
| 86 |
+
# all_features.append(embeds)
|
| 87 |
+
|
| 88 |
+
# elif feat == "selfeval":
|
| 89 |
+
# print("Fetching selfevals.")
|
| 90 |
+
# evals = np.zeros((num_claims, 1))
|
| 91 |
+
# selfeval_func = partial(get_self_eval, model=config.selfeval.model.name)
|
| 92 |
+
# res = process_map(selfeval_func, dataset, max_workers=MAX_WORKERS)
|
| 93 |
+
# i = 0
|
| 94 |
+
# for dat in tqdm(dataset):
|
| 95 |
+
# len_dat = len(dat['claims'])
|
| 96 |
+
# evals[i:(i + len_dat)] = get_self_eval(dat['claims'], dat['prompt'], config.selfeval.model.name)
|
| 97 |
+
# i += len_dat
|
| 98 |
+
# all_features.append(evals)
|
| 99 |
+
# elif feat == "frequency":
|
| 100 |
+
# print("Fetching frequency.")
|
| 101 |
+
# freqs = np.zeros(((num_claims), 1))
|
| 102 |
+
# i = 0
|
| 103 |
+
# for dat in tqdm(dataset):
|
| 104 |
+
# len_dat = len(dat['claims'])
|
| 105 |
+
# freqs[i:(i + len_dat)] = get_frequency(dat['claims'], dat['prompt'], config.frequency.model.n_samples, config.frequency.model.name)
|
| 106 |
+
# i += len_dat
|
| 107 |
+
# all_features.append(freqs)
|
| 108 |
+
# else:
|
| 109 |
+
# raise ValueError(f"{feat} not supported.")
|
| 110 |
+
# return np.concatenate(all_features, axis=1)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def get_embedding(
|
| 114 |
+
subclaims : List[str],
|
| 115 |
+
client : client.Client, # needs to be embedding client not *GPT* client
|
| 116 |
+
n_dim : int = 8
|
| 117 |
+
) -> np.ndarray:
|
| 118 |
+
raise ValueError("not supported yet")
|
| 119 |
+
embeddings = []
|
| 120 |
+
for claim in subclaims:
|
| 121 |
+
msg = claim['message'].replace('\n', ' ')
|
| 122 |
+
embed = client.query(msg)
|
| 123 |
+
embeddings.append(embed[:n_dim])
|
| 124 |
+
return np.asarray(embeddings)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _eval_self(
|
| 128 |
+
prompt : str,
|
| 129 |
+
subclaims : List,
|
| 130 |
+
client : client.Client,
|
| 131 |
+
err_msg : str = None
|
| 132 |
+
) -> Tuple[Tuple[str, List], np.ndarray]:
|
| 133 |
+
claim_string = "\n".join(
|
| 134 |
+
[str(i) + ": " + fact for i, fact in enumerate(subclaims)]
|
| 135 |
+
)
|
| 136 |
+
self_eval_prompt = SELF_ASSESS_PROMPT
|
| 137 |
+
self_eval_prompt += f"The original prompt is: {prompt}.\n"
|
| 138 |
+
self_eval_prompt += f"The claims are: {claim_string}.\n"
|
| 139 |
+
|
| 140 |
+
if err_msg is not None:
|
| 141 |
+
self_eval_prompt += "\n" + err_msg
|
| 142 |
+
|
| 143 |
+
self_evals = client.query(self_eval_prompt)
|
| 144 |
+
parsed_evals = self_evals[0]['message']
|
| 145 |
+
parsed_evals = parsed_evals.replace("```jsonl\n", "")
|
| 146 |
+
parsed_evals = parsed_evals.replace("```", "")
|
| 147 |
+
final_evals = np.zeros((len(parsed_evals.splitlines()),))
|
| 148 |
+
try:
|
| 149 |
+
assert len(final_evals) == len(subclaims)
|
| 150 |
+
except AssertionError:
|
| 151 |
+
if err_msg is not None and 'exactly' in err_msg:
|
| 152 |
+
print(f"I'm giving up on {claim_string} and {parsed_evals}, since I already retried this.")
|
| 153 |
+
return (None, None), None
|
| 154 |
+
err_msg = f"IMPORTANT: This is a retry. Make sure you return exactly {len(subclaims)} lines of JSON."
|
| 155 |
+
print(err_msg)
|
| 156 |
+
return _eval_self(prompt, subclaims, client, err_msg=err_msg)
|
| 157 |
+
try:
|
| 158 |
+
for line in parsed_evals.splitlines():
|
| 159 |
+
eval = json.loads(line)
|
| 160 |
+
idx = int(eval["id"])
|
| 161 |
+
final_evals[idx] += float(eval["gpt-score"])
|
| 162 |
+
except Exception as ex:
|
| 163 |
+
if err_msg is not None and 'requested' in err_msg:
|
| 164 |
+
print(f"I'm giving up on {claim_string} and {parsed_evals}, since I already retried this.")
|
| 165 |
+
return (None, None), None
|
| 166 |
+
err_msg = f"IMPORTANT: This is a retry. Make sure you return the lines in the requested JSON format with NO additional formatting."
|
| 167 |
+
print(err_msg)
|
| 168 |
+
return _eval_self(prompt, subclaims, client, err_msg=err_msg)
|
| 169 |
+
return (self_eval_prompt, self_evals), final_evals
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_self_eval(
|
| 173 |
+
prompt : str,
|
| 174 |
+
subclaims : List[str],
|
| 175 |
+
client : client.Client
|
| 176 |
+
) -> np.ndarray:
|
| 177 |
+
all_evals = _eval_self(
|
| 178 |
+
prompt,
|
| 179 |
+
subclaims,
|
| 180 |
+
client
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
to_cache = all_evals[0]
|
| 184 |
+
|
| 185 |
+
if to_cache[0] is None:
|
| 186 |
+
return -1 * np.ones((len(subclaims),)) # -1 prob is error
|
| 187 |
+
|
| 188 |
+
client.cache_outputs(
|
| 189 |
+
[to_cache[0]],
|
| 190 |
+
np.zeros((1,), dtype=int),
|
| 191 |
+
[to_cache[1]]
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
return all_evals[1]
|
| 195 |
+
|
| 196 |
+
def _bool_self(
|
| 197 |
+
prompt : str,
|
| 198 |
+
subclaims : List,
|
| 199 |
+
client : client.Client,
|
| 200 |
+
err_msg : str = None
|
| 201 |
+
) -> Tuple[Tuple[str, List], np.ndarray]:
|
| 202 |
+
claim_string = "\n".join(
|
| 203 |
+
[str(i) + ": " + fact for i, fact in enumerate(subclaims)]
|
| 204 |
+
)
|
| 205 |
+
self_eval_prompt = SELF_BOOL_PROMPT
|
| 206 |
+
self_eval_prompt += f"The original prompt is: {prompt}.\n"
|
| 207 |
+
self_eval_prompt += f"The claims are: {claim_string}.\n"
|
| 208 |
+
|
| 209 |
+
if err_msg is not None:
|
| 210 |
+
self_eval_prompt += "\n" + err_msg
|
| 211 |
+
|
| 212 |
+
self_evals = client.query(self_eval_prompt)
|
| 213 |
+
parsed_evals = self_evals[0]['message']
|
| 214 |
+
parsed_evals = parsed_evals.replace("```jsonl\n", "")
|
| 215 |
+
parsed_evals = parsed_evals.replace("```", "")
|
| 216 |
+
final_evals = ['T' for i in range(len(parsed_evals.splitlines()))]
|
| 217 |
+
try:
|
| 218 |
+
assert len(final_evals) == len(subclaims)
|
| 219 |
+
except AssertionError:
|
| 220 |
+
if err_msg is not None and 'exactly' in err_msg:
|
| 221 |
+
print(f"I'm giving up on {claim_string} and {parsed_evals}, since I already retried this.")
|
| 222 |
+
return (None, None), None
|
| 223 |
+
err_msg = f"IMPORTANT: This is a retry. Make sure you return exactly {len(subclaims)} lines of JSON."
|
| 224 |
+
print(err_msg)
|
| 225 |
+
return _bool_self(prompt, subclaims, client, err_msg=err_msg)
|
| 226 |
+
try:
|
| 227 |
+
for line in parsed_evals.splitlines():
|
| 228 |
+
eval = json.loads(line)
|
| 229 |
+
idx = int(eval["id"])
|
| 230 |
+
final_evals[idx] = eval["gpt-bool"]
|
| 231 |
+
except Exception as ex:
|
| 232 |
+
if err_msg is not None and 'requested' in err_msg:
|
| 233 |
+
print(f"I'm giving up on {claim_string} and {parsed_evals}, since I already retried this.")
|
| 234 |
+
return (None, None), None
|
| 235 |
+
err_msg = f"IMPORTANT: This is a retry. Make sure you return the lines in the requested JSON format with NO additional formatting."
|
| 236 |
+
print(err_msg)
|
| 237 |
+
return _bool_self(prompt, subclaims, client, err_msg=err_msg)
|
| 238 |
+
return (self_eval_prompt, self_evals), final_evals
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def get_bool_eval(
|
| 242 |
+
prompt : str,
|
| 243 |
+
subclaims : List[str],
|
| 244 |
+
client : client.Client
|
| 245 |
+
) -> np.ndarray:
|
| 246 |
+
all_evals = _bool_self(
|
| 247 |
+
prompt,
|
| 248 |
+
subclaims,
|
| 249 |
+
client
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
to_cache = all_evals[0]
|
| 253 |
+
|
| 254 |
+
if to_cache[0] is None:
|
| 255 |
+
return -1 * np.ones((len(subclaims),)) # -1 prob is error
|
| 256 |
+
client.cache_outputs(
|
| 257 |
+
[to_cache[0]],
|
| 258 |
+
np.zeros((1,), dtype=int),
|
| 259 |
+
[to_cache[1]]
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return all_evals[1]
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _eval_support(
|
| 266 |
+
output : str,
|
| 267 |
+
subclaims : List,
|
| 268 |
+
client : client.Client,
|
| 269 |
+
err_msg : str = None
|
| 270 |
+
) -> Tuple[Tuple[str, List], np.ndarray]:
|
| 271 |
+
claim_string = "\n".join(
|
| 272 |
+
[str(i) + ": " + fact for i, fact in enumerate(subclaims)]
|
| 273 |
+
)
|
| 274 |
+
counting_prompt = (
|
| 275 |
+
'You will get a list of claims and piece of text. For each claim, score whether the text supports, contradicts, or is unrelated to the claim. Directly return a jsonl, where each line is {"id":[CLAIM_ID], "score":[SCORE]}. Directly return the jsonl with NO explanation or ANY other formatting. For the [SCORE], return 1 for supports, -1 for contradicts, and 0 for unrelated. The claims are:\n'
|
| 276 |
+
+ claim_string
|
| 277 |
+
+ "\n\nThe text is:\n"
|
| 278 |
+
+ output
|
| 279 |
+
)
|
| 280 |
+
if err_msg is not None:
|
| 281 |
+
counting_prompt += "\n" + err_msg
|
| 282 |
+
|
| 283 |
+
support_scores = client.query(counting_prompt)
|
| 284 |
+
parsed_scores = support_scores[0]['message']
|
| 285 |
+
parsed_scores = parsed_scores.replace("```jsonl\n", "")
|
| 286 |
+
parsed_scores = parsed_scores.replace("```", "")
|
| 287 |
+
final_scores = np.zeros((len(parsed_scores.splitlines()),))
|
| 288 |
+
try:
|
| 289 |
+
assert len(final_scores) == len(subclaims)
|
| 290 |
+
except AssertionError:
|
| 291 |
+
if err_msg is not None and 'exactly' in err_msg:
|
| 292 |
+
print(f"I'm giving up on {claim_string} and {parsed_scores}, since I already retried this.")
|
| 293 |
+
return (None, None), None
|
| 294 |
+
err_msg = f"IMPORTANT: This is a retry. Make sure you return exactly {len(subclaims)} lines of JSON."
|
| 295 |
+
print(err_msg)
|
| 296 |
+
return _eval_support(output, subclaims, client, err_msg=err_msg)
|
| 297 |
+
try:
|
| 298 |
+
for line in parsed_scores.splitlines():
|
| 299 |
+
score = json.loads(line)
|
| 300 |
+
idx = int(score["id"])
|
| 301 |
+
final_scores[idx] += float(score["score"])
|
| 302 |
+
except Exception as ex:
|
| 303 |
+
if err_msg is not None and 'requested' in err_msg:
|
| 304 |
+
print(f"I'm giving up on {claim_string} and {parsed_scores}, since I already retried this.")
|
| 305 |
+
return (None, None), None
|
| 306 |
+
err_msg = f"IMPORTANT: This is a retry. Make sure you return the lines in the requested JSON format with NO additional formatting."
|
| 307 |
+
print(err_msg)
|
| 308 |
+
return _eval_support(output, subclaims, client, err_msg=err_msg)
|
| 309 |
+
return (counting_prompt, support_scores), final_scores
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def get_frequency(
|
| 313 |
+
client : client.Client,
|
| 314 |
+
subclaims : List,
|
| 315 |
+
prompt : str,
|
| 316 |
+
config : dict
|
| 317 |
+
) -> np.ndarray:
|
| 318 |
+
"""
|
| 319 |
+
Returns a vector of (frequency) scores corresponding to each entry of the subclaims list.
|
| 320 |
+
"""
|
| 321 |
+
# Generate n_samples alternate outputs with temperature 1.0.
|
| 322 |
+
alternate_outputs = client.query(
|
| 323 |
+
prompt, 1, n_samples=config.n_samples, temperature=config.temperature
|
| 324 |
+
)
|
| 325 |
+
client.cache_outputs(
|
| 326 |
+
[prompt],
|
| 327 |
+
[int(1)],
|
| 328 |
+
[alternate_outputs]
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
alternate_outputs = [o['message'] for o in alternate_outputs]
|
| 332 |
+
|
| 333 |
+
with ThreadPoolExecutor(max_workers=config.n_samples) as executor:
|
| 334 |
+
all_scores = list(
|
| 335 |
+
executor.map(
|
| 336 |
+
lambda x : _eval_support(x, subclaims, client),
|
| 337 |
+
alternate_outputs
|
| 338 |
+
)
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# to_cache = [s[0] for s in all_scores if s[0][0] is not None]
|
| 342 |
+
|
| 343 |
+
# client.cache_outputs(
|
| 344 |
+
# [c[0] for c in to_cache],
|
| 345 |
+
# np.zeros((len(to_cache),), dtype=int),
|
| 346 |
+
# [c[1] for c in to_cache]
|
| 347 |
+
# )
|
| 348 |
+
|
| 349 |
+
# TODO: error handling if this is all empty?
|
| 350 |
+
parsed_scores = np.mean([s[1] for s in all_scores if s[1] is not None], axis=0)
|
| 351 |
+
|
| 352 |
+
return parsed_scores
|
MACI-main/conditional-conformal/src/gpt.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import openai
|
| 2 |
+
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from client import Client
|
| 6 |
+
|
| 7 |
+
from tenacity import (
|
| 8 |
+
retry,
|
| 9 |
+
stop_after_attempt,
|
| 10 |
+
wait_random_exponential,
|
| 11 |
+
) # for exponential backoff
|
| 12 |
+
|
| 13 |
+
class GPTClient(Client):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
cache_file : str,
|
| 17 |
+
model : str = 'gpt-3.5-turbo'
|
| 18 |
+
):
|
| 19 |
+
super(GPTClient, self).__init__(cache_file, model)
|
| 20 |
+
self.client = openai.Client()
|
| 21 |
+
self.tokens_used = 0
|
| 22 |
+
self.requests_made = 0
|
| 23 |
+
|
| 24 |
+
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
| 25 |
+
def _query(
|
| 26 |
+
self,
|
| 27 |
+
prompt : List[str],
|
| 28 |
+
role : List[str] = None,
|
| 29 |
+
max_tokens : int = 1000,
|
| 30 |
+
temperature: float = 0,
|
| 31 |
+
response_format : str = None,
|
| 32 |
+
n_samples: int = 1
|
| 33 |
+
):
|
| 34 |
+
if role is None:
|
| 35 |
+
messages = [{"role": "user", "content": prompt}]
|
| 36 |
+
else:
|
| 37 |
+
messages = [{"role": role, "content": prompt}]
|
| 38 |
+
|
| 39 |
+
completion = self.client.chat.completions.create(
|
| 40 |
+
model=self.model,
|
| 41 |
+
messages=messages,
|
| 42 |
+
response_format=response_format,
|
| 43 |
+
max_tokens=max_tokens,
|
| 44 |
+
temperature=temperature,
|
| 45 |
+
n=n_samples,
|
| 46 |
+
logprobs=True
|
| 47 |
+
)
|
| 48 |
+
self.tokens_used += completion.usage.total_tokens
|
| 49 |
+
self.requests_made += 1
|
| 50 |
+
# print(self.tokens_used, self.requests_made)
|
| 51 |
+
outputs = []
|
| 52 |
+
for choice in completion.choices:
|
| 53 |
+
output_dict = {
|
| 54 |
+
'logprobs': choice.logprobs.content,
|
| 55 |
+
'message': choice.message.content
|
| 56 |
+
}
|
| 57 |
+
outputs.append(output_dict)
|
| 58 |
+
return outputs
|
MACI-main/conditional-conformal/src/llm_utils.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from typing import Dict, List
|
| 7 |
+
|
| 8 |
+
from query import (
|
| 9 |
+
generate_subclaim_prompt, generate_annotation_prompt,
|
| 10 |
+
generate_merge_prompt, query_llm
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
import client
|
| 14 |
+
|
| 15 |
+
MERGE_PROMPT = "You will get an instruction and a set of facts that are true. Construct an answer using ONLY the facts provided, and use ALL of the facts provided. If no facts are given, reply and say that you don't know enough to respond.\n"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def parse_responses(
|
| 19 |
+
outputs : List[Dict],
|
| 20 |
+
parser_config : str,
|
| 21 |
+
annotate : bool = False,
|
| 22 |
+
annotator_config : str = None
|
| 23 |
+
):
|
| 24 |
+
for output in tqdm(outputs):
|
| 25 |
+
prompt, response = output["prompt"], output["response"]
|
| 26 |
+
subclaims = get_subclaims(prompt, response, parser_config)
|
| 27 |
+
if annotate:
|
| 28 |
+
subclaims = add_annotations(prompt, subclaims, annotator_config)
|
| 29 |
+
output["claims"] = subclaims
|
| 30 |
+
return outputs
|
| 31 |
+
|
| 32 |
+
def get_subclaims(
|
| 33 |
+
prompt : str,
|
| 34 |
+
response : str,
|
| 35 |
+
parser_config : str
|
| 36 |
+
) -> List[Dict]:
|
| 37 |
+
subclaim_prompt = generate_subclaim_prompt(prompt, response)
|
| 38 |
+
subclaims = query_llm([subclaim_prompt], parser_config)[0] # get the first output
|
| 39 |
+
subclaims = [{'message': c} for c in subclaims['message'].splitlines()]
|
| 40 |
+
return subclaims
|
| 41 |
+
|
| 42 |
+
def add_annotations(
|
| 43 |
+
prompt : str,
|
| 44 |
+
subclaims : List[Dict],
|
| 45 |
+
annotator_config : str
|
| 46 |
+
) -> List[Dict]:
|
| 47 |
+
annotation_prompt = generate_annotation_prompt(prompt, subclaims)
|
| 48 |
+
annotations = query_llm([annotation_prompt], annotator_config)[0]
|
| 49 |
+
annotations = annotations['message'].splitlines()
|
| 50 |
+
num_retries = 0
|
| 51 |
+
while len(annotations) != len(subclaims):
|
| 52 |
+
print(f"Annotation length does not match subclaims for {prompt}. Retrying query.")
|
| 53 |
+
annotations = query_llm([annotation_prompt], annotator_config)[0]
|
| 54 |
+
annotations = annotations['message'].splitlines()
|
| 55 |
+
num_retries += 1
|
| 56 |
+
if num_retries > 5:
|
| 57 |
+
print("Giving up and assigning False to all subclaims.")
|
| 58 |
+
annotations = ['F' for _ in subclaims]
|
| 59 |
+
for a, subclaim in zip(annotations, subclaims):
|
| 60 |
+
try:
|
| 61 |
+
subclaim['annotation'] = json.loads(a)['value']
|
| 62 |
+
except:
|
| 63 |
+
import IPython; IPython.embed()
|
| 64 |
+
return subclaims
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _concat_claims(
|
| 68 |
+
subclaims : List[str]
|
| 69 |
+
) -> str:
|
| 70 |
+
return "\n".join(
|
| 71 |
+
f"{i}: {subclaim}" for i, subclaim in enumerate(subclaims)
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def _get_merged_output(
|
| 75 |
+
prompt : str,
|
| 76 |
+
subclaims : List[str],
|
| 77 |
+
client : client.Client
|
| 78 |
+
) -> str:
|
| 79 |
+
final_prompt = MERGE_PROMPT + f"The original instruction was: {prompt}\n"
|
| 80 |
+
|
| 81 |
+
final_prompt += f"The facts are: {_concat_claims(subclaims)}"
|
| 82 |
+
|
| 83 |
+
output = client.query(final_prompt)
|
| 84 |
+
|
| 85 |
+
return (final_prompt, output), output[0]['message']
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def merge_claims(
|
| 89 |
+
dataset : List,
|
| 90 |
+
client : client.Client
|
| 91 |
+
) -> List:
|
| 92 |
+
with ThreadPoolExecutor(max_workers=25) as executor:
|
| 93 |
+
responses = list(
|
| 94 |
+
tqdm(
|
| 95 |
+
executor.map(
|
| 96 |
+
lambda x : _get_merged_output(x['prompt'], x['filtered_claims'], client),
|
| 97 |
+
dataset
|
| 98 |
+
),
|
| 99 |
+
total=len(dataset)
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
to_cache = [r[0] for r in responses]
|
| 104 |
+
|
| 105 |
+
client.cache_outputs(
|
| 106 |
+
[c[0] for c in to_cache],
|
| 107 |
+
np.zeros((len(to_cache),), dtype=int),
|
| 108 |
+
[c[1] for c in to_cache]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return [r[1] for r in responses]
|
MACI-main/conditional-conformal/src/postprocess_factscore.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
output = []
|
| 4 |
+
prompt_to_idx = {}
|
| 5 |
+
idx = 0
|
| 6 |
+
with open("/Users/cherian/Downloads/factscore-unlabeled-predictions/ChatGPT.jsonl") as fp:
|
| 7 |
+
for line in fp:
|
| 8 |
+
res = json.loads(line)
|
| 9 |
+
new_res = {}
|
| 10 |
+
new_res['prompt'] = res['prompt']
|
| 11 |
+
new_res['claims'] = []
|
| 12 |
+
annotator = 'ChatGPT_Labels' if 'ChatGPT_Labels' in res else 'LLAMA+NP_Labels'
|
| 13 |
+
for fact, annotation in zip(res['facts'], res[annotator]):
|
| 14 |
+
a = 'T' if annotation == 'S' else 'F'
|
| 15 |
+
new_res['claims'].append(
|
| 16 |
+
{'message': fact, 'annotation': a}
|
| 17 |
+
)
|
| 18 |
+
output.append(new_res)
|
| 19 |
+
prompt_to_idx[res['prompt']] = idx
|
| 20 |
+
idx += 1
|
| 21 |
+
|
| 22 |
+
with open("/Users/cherian/Projects/FActScore/factscore/data/unlabeled/ChatGPT.jsonl", 'r') as fp:
|
| 23 |
+
for line in fp:
|
| 24 |
+
res = json.loads(line)
|
| 25 |
+
idx = prompt_to_idx.get(res['input'], None)
|
| 26 |
+
if idx is None:
|
| 27 |
+
continue
|
| 28 |
+
else:
|
| 29 |
+
output[idx]['response'] = res['output']
|
| 30 |
+
output[idx]['topic'] = res['topic']
|
| 31 |
+
output[idx]['metadata'] = res['cat']
|
| 32 |
+
|
| 33 |
+
with open("data/factscore_processed.json", 'w') as fp:
|
| 34 |
+
fp.write(json.dumps(output) + "\n")
|
MACI-main/conditional-conformal/src/prob_model.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from sklearn.linear_model import LogisticRegressionCV
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
|
| 10 |
+
from typing import List
|
| 11 |
+
|
| 12 |
+
from conformal import compute_conformity_scores
|
| 13 |
+
|
| 14 |
+
def fit_model(
|
| 15 |
+
features : np.ndarray,
|
| 16 |
+
labels : np.ndarray,
|
| 17 |
+
config : dict,
|
| 18 |
+
dataset_train : List = None,
|
| 19 |
+
eval_dict : dict = None
|
| 20 |
+
):
|
| 21 |
+
name = config.model.prob.name
|
| 22 |
+
if name == "logistic":
|
| 23 |
+
model = LogisticRegressionCV()
|
| 24 |
+
model.fit(X=features, y=labels)
|
| 25 |
+
return model
|
| 26 |
+
elif name == "XGBoost":
|
| 27 |
+
raise ValueError("not implemented yet")
|
| 28 |
+
elif name == "torch":
|
| 29 |
+
# no data splitting for now when constructing conformal loss
|
| 30 |
+
model = LogisticRegression(features.shape[1])
|
| 31 |
+
|
| 32 |
+
optimizer = optim.Adam(model.parameters(), lr=1)
|
| 33 |
+
x = torch.tensor(features, requires_grad=True, dtype=torch.float32)
|
| 34 |
+
|
| 35 |
+
for i in range(500):
|
| 36 |
+
optimizer.zero_grad()
|
| 37 |
+
probs = model.forward(x)
|
| 38 |
+
|
| 39 |
+
loss, avg_train = get_conformal_loss(probs, labels, dataset_train, config.conformal.alpha)
|
| 40 |
+
if i % 100 == 0:
|
| 41 |
+
probs_valid = model.forward(torch.tensor(eval_dict['X_valid'], dtype=torch.float32)).detach().numpy()
|
| 42 |
+
probs_split = np.array_split(probs_valid, eval_dict['splits_valid'])
|
| 43 |
+
threshold = np.quantile(compute_conformity_scores(eval_dict['dataset_valid'], probs_split), 1 - config.conformal.alpha)
|
| 44 |
+
probs_test = model.forward(torch.tensor(eval_dict['X_test'], dtype=torch.float32)).detach().numpy()
|
| 45 |
+
probs_split = np.array_split(probs_test, eval_dict['splits_test'])
|
| 46 |
+
avg = 0
|
| 47 |
+
for prob in probs_split:
|
| 48 |
+
avg_retain = np.mean(prob > threshold.item())
|
| 49 |
+
avg += avg_retain
|
| 50 |
+
print(f"Average % of train claims retained: {avg_train}")
|
| 51 |
+
print(f"Average % of test claims retained: {avg / len(probs_split)}")
|
| 52 |
+
print(f"Loss at iteration {i}: {loss.item()}")
|
| 53 |
+
|
| 54 |
+
loss.backward()
|
| 55 |
+
optimizer.step()
|
| 56 |
+
return model
|
| 57 |
+
|
| 58 |
+
else:
|
| 59 |
+
return ValueError(f"{name} not available.")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_conformal_loss(probs, labels, dataset_train, alpha):
|
| 63 |
+
claim_splits = torch.tensor(
|
| 64 |
+
np.cumsum([len(dat['atomic_facts']) for dat in dataset_train])[:-1]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
claim_probs = torch.tensor_split(probs, claim_splits)
|
| 68 |
+
claim_labels = np.array_split(1 - labels, claim_splits.numpy())
|
| 69 |
+
scores = []
|
| 70 |
+
for c_prob, c_label in zip(claim_probs, claim_labels):
|
| 71 |
+
scores.append(c_prob[c_label].max()) # could replace this with element-wise multiplication and make annotations softer?
|
| 72 |
+
|
| 73 |
+
# use random set of scores to calibrate
|
| 74 |
+
random_indices = np.random.permutation(len(scores))
|
| 75 |
+
threshold_indices = random_indices[:25]
|
| 76 |
+
loss_indices = random_indices[25:]
|
| 77 |
+
|
| 78 |
+
threshold_scores = [scores[i] for i in range(len(scores)) if i in threshold_indices]
|
| 79 |
+
|
| 80 |
+
threshold = torch.quantile(torch.stack(threshold_scores), 1 - alpha)
|
| 81 |
+
loss = 0
|
| 82 |
+
avg = 0
|
| 83 |
+
for idx, c_prob in enumerate(claim_probs):
|
| 84 |
+
if idx in loss_indices:
|
| 85 |
+
loss += torch.sigmoid((threshold - c_prob)).mean()
|
| 86 |
+
avg_retain = (c_prob > threshold).float().mean()
|
| 87 |
+
avg += avg_retain
|
| 88 |
+
if np.isnan(loss.item()):
|
| 89 |
+
raise ValueError(claim_probs[0])
|
| 90 |
+
return loss, avg / len(loss_indices)
|
| 91 |
+
|
| 92 |
+
class LogisticRegression(nn.Module):
|
| 93 |
+
|
| 94 |
+
def __init__(self, n_features):
|
| 95 |
+
super(LogisticRegression, self).__init__()
|
| 96 |
+
self.linear = nn.Linear(n_features, 1)
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
return F.sigmoid(self.linear(x))
|
| 100 |
+
|
| 101 |
+
|
MACI-main/conditional-conformal/src/query.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
import openai
|
| 3 |
+
|
| 4 |
+
SUBCLAIM_PROMPT = 'Please breakdown the following response to a prompt into a set of small, independent claims. Return each subclaim (with no other characters) on a new line. \n'
|
| 5 |
+
|
| 6 |
+
MERGE_PROMPT = "You will get an instruction and a set of facts that are true. Construct an answer using ONLY the facts provided, and use ALL of the facts provided. If no facts are given, reply and say that you don't know enough to respond.\n"
|
| 7 |
+
|
| 8 |
+
ANNOTATION_PROMPT = 'You will get an instruction and a set of claims made in response to that instruction. Determine whether each claim is true, subjective, or false. Each returned determination should be {"claim_id": ID, "value": TRUTH_VALUE} and be on its own line with NO other characters. The truth value should be in quotes and it should be T for Factual, S for Subjective, and F for False.\n'
|
| 9 |
+
|
| 10 |
+
FREQUENCY_PROMPT = 'You will get a list of claims and piece of text. For each claim, score whether the text supports, contradicts, or is unrelated to the claim. Directly return a jsonl, where each line is {"id":[CLAIM_ID], "score":[SCORE]}. Directly return the jsonl with no explanation or other formatting. For the [SCORE], return 1 for supports, -1 for contradicts, and 0 for unrelated.\n'
|
| 11 |
+
|
| 12 |
+
def _concat_claims(
|
| 13 |
+
subclaims : List[str]
|
| 14 |
+
) -> str:
|
| 15 |
+
return "\n".join(
|
| 16 |
+
f"{i}: {subclaim['message']}" for i, subclaim in enumerate(subclaims)
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def generate_subclaim_prompt(
|
| 20 |
+
prompt : str,
|
| 21 |
+
response : str
|
| 22 |
+
) -> str:
|
| 23 |
+
final_output = SUBCLAIM_PROMPT + f"The original instruction was: {prompt}\n"
|
| 24 |
+
final_output += f"The response to be broken down into subclaims is: {response}"
|
| 25 |
+
|
| 26 |
+
return final_output
|
| 27 |
+
|
| 28 |
+
def generate_merge_prompt(
|
| 29 |
+
prompt : str,
|
| 30 |
+
subclaims : List[str]
|
| 31 |
+
) -> str:
|
| 32 |
+
final_output = MERGE_PROMPT + f"The original instruction was: {prompt}\n"
|
| 33 |
+
|
| 34 |
+
final_output += f"The facts are: {_concat_claims(subclaims)}"
|
| 35 |
+
|
| 36 |
+
return final_output
|
| 37 |
+
|
| 38 |
+
def generate_annotation_prompt(
|
| 39 |
+
prompt : str,
|
| 40 |
+
subclaims : List[str]
|
| 41 |
+
) -> str:
|
| 42 |
+
final_output = ANNOTATION_PROMPT + f"The original instruction was: {prompt}\n"
|
| 43 |
+
final_output += f"The claims are: \n{_concat_claims(subclaims)}"
|
| 44 |
+
|
| 45 |
+
return final_output
|
| 46 |
+
|
| 47 |
+
def generate_frequency_prompt(
|
| 48 |
+
subclaims : List[str],
|
| 49 |
+
output : str,
|
| 50 |
+
) -> str:
|
| 51 |
+
final_output = FREQUENCY_PROMPT + f"The claims are: {_concat_claims(subclaims)}\n"
|
| 52 |
+
final_output += f"The text is: {output}"
|
| 53 |
+
return final_output
|
| 54 |
+
|
| 55 |
+
def query_gpt(
|
| 56 |
+
client : openai.Client,
|
| 57 |
+
prompts : List[str],
|
| 58 |
+
model : str = "gpt-3.5-turbo",
|
| 59 |
+
roles : List[str] = None,
|
| 60 |
+
max_tokens : int = 1000,
|
| 61 |
+
temperature: float = 0,
|
| 62 |
+
response_format : str = None,
|
| 63 |
+
n_samples: int = 1
|
| 64 |
+
):
|
| 65 |
+
if roles is None:
|
| 66 |
+
messages = [{"role": "user", "content": prompt} for prompt in prompts]
|
| 67 |
+
else:
|
| 68 |
+
messages = [{"role": role, "content": prompt} for role, prompt in zip(roles, prompts)]
|
| 69 |
+
|
| 70 |
+
completion = client.chat.completions.create(
|
| 71 |
+
model=model,
|
| 72 |
+
messages=messages,
|
| 73 |
+
response_format=response_format,
|
| 74 |
+
max_tokens=max_tokens,
|
| 75 |
+
temperature=temperature,
|
| 76 |
+
n=n_samples,
|
| 77 |
+
logprobs=True
|
| 78 |
+
)
|
| 79 |
+
return completion
|
| 80 |
+
|
| 81 |
+
def query_embedding(
|
| 82 |
+
client : openai.Client,
|
| 83 |
+
prompts : List[str],
|
| 84 |
+
model : str = "text-embedding-3-small",
|
| 85 |
+
**kwargs
|
| 86 |
+
):
|
| 87 |
+
embed = client.embeddings.create(input = prompts, model = model, **kwargs).data[0].embedding
|
| 88 |
+
return embed
|
| 89 |
+
|
| 90 |
+
def query_llm(
|
| 91 |
+
prompts : List[str],
|
| 92 |
+
model : str,
|
| 93 |
+
**kwargs
|
| 94 |
+
) -> Dict:
|
| 95 |
+
if 'gpt' in model:
|
| 96 |
+
client = openai.Client() # OPENAI_API_KEY should be set as an environment variable
|
| 97 |
+
completion = query_gpt(client, prompts, model, **kwargs)
|
| 98 |
+
outputs = []
|
| 99 |
+
for choice in completion.choices:
|
| 100 |
+
output_dict = {
|
| 101 |
+
'logprobs': choice.logprobs.content,
|
| 102 |
+
'message': choice.message.content
|
| 103 |
+
}
|
| 104 |
+
outputs.append(output_dict)
|
| 105 |
+
return outputs
|
| 106 |
+
elif 'embedding' in model:
|
| 107 |
+
client = openai.Client()
|
| 108 |
+
output = query_embedding(client, prompts, model, **kwargs)
|
| 109 |
+
return output
|
| 110 |
+
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Model {model} is not supported in query.")
|
MACI-main/conditional-conformal/src/ray_data.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from config import get_config
|
| 8 |
+
|
| 9 |
+
from featurizer import get_frequency, get_self_eval
|
| 10 |
+
from gpt import GPTClient
|
| 11 |
+
|
| 12 |
+
from atomizer import text_to_sentences
|
| 13 |
+
from dataset import get_prompts
|
| 14 |
+
from scorer import Scorer
|
| 15 |
+
|
| 16 |
+
import ray
|
| 17 |
+
|
| 18 |
+
def parse_args():
|
| 19 |
+
parser = argparse.ArgumentParser(
|
| 20 |
+
prog="conformal-safety",
|
| 21 |
+
description="Auto-filter claims from LLM to meet accuracy and safety guarantees.",
|
| 22 |
+
)
|
| 23 |
+
parser.add_argument('-config_path', '-c', default='configs/default.toml', help="Config for construction.")
|
| 24 |
+
args = parser.parse_args()
|
| 25 |
+
return args
|
| 26 |
+
|
| 27 |
+
def find_unique_element(lst, condition, approx_index):
|
| 28 |
+
# Check the approximate index first
|
| 29 |
+
if condition(lst[approx_index]):
|
| 30 |
+
return approx_index
|
| 31 |
+
|
| 32 |
+
# Initialize left and right pointers
|
| 33 |
+
left = approx_index - 1
|
| 34 |
+
right = approx_index + 1
|
| 35 |
+
|
| 36 |
+
# Expand outwards from the approximate index
|
| 37 |
+
while left >= 0 or right < len(lst):
|
| 38 |
+
if left >= 0 and condition(lst[left]):
|
| 39 |
+
return left
|
| 40 |
+
if right < len(lst) and condition(lst[right]):
|
| 41 |
+
return right
|
| 42 |
+
left -= 1
|
| 43 |
+
right += 1
|
| 44 |
+
|
| 45 |
+
# If no element satisfies the condition, return None or raise an exception
|
| 46 |
+
return None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@ray.remote
|
| 50 |
+
def parallel_scorer(*args, **kwargs):
|
| 51 |
+
return None
|
| 52 |
+
return run_experiment(*args, **kwargs)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
args = parse_args()
|
| 57 |
+
config = get_config(args.config_path)
|
| 58 |
+
|
| 59 |
+
import IPython; IPython.embed()
|
| 60 |
+
responder = GPTClient(config.model.responder.cache_path)
|
| 61 |
+
|
| 62 |
+
topics, prompts = get_prompts(config.dataset.name)
|
| 63 |
+
|
| 64 |
+
with ThreadPoolExecutor(max_workers=25) as executor:
|
| 65 |
+
responses = list(
|
| 66 |
+
tqdm(
|
| 67 |
+
executor.map(
|
| 68 |
+
lambda x : responder.query(x),
|
| 69 |
+
prompts
|
| 70 |
+
),
|
| 71 |
+
total=len(prompts)
|
| 72 |
+
)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
responses = [r[0] for r in responses]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
outputs = [{'prompt': p, 'response': o['message']}
|
| 79 |
+
for p, o in zip(prompts, responses)] # first output is the response we will filter
|
| 80 |
+
|
| 81 |
+
print("Loading atomizer.")
|
| 82 |
+
atomizer_client = GPTClient(config.model.parser.cache_path, model=config.model.parser.name)
|
| 83 |
+
|
| 84 |
+
responder_cache = responder.cache_dict
|
| 85 |
+
messages = []
|
| 86 |
+
for val in responder_cache.values():
|
| 87 |
+
messages.append(val[0]['message'])
|
| 88 |
+
|
| 89 |
+
atomizer_cache = atomizer_client.cache_dict
|
| 90 |
+
idx_guess = 0
|
| 91 |
+
atomic_facts = [[] for _ in range(len(messages))]
|
| 92 |
+
for k in tqdm(atomizer_cache.keys()):
|
| 93 |
+
atomized_msg = atomizer_cache[k][0]['message']
|
| 94 |
+
atomized_facts = text_to_sentences(atomized_msg)
|
| 95 |
+
sentence = k.split('\n')[-1].split('facts:')[-1].strip()[:-2]
|
| 96 |
+
cur_idx = find_unique_element(messages, lambda x: sentence in x, approx_index=idx_guess)
|
| 97 |
+
if cur_idx is None: # TODO: TERRIBLE SPECIAL CASING that I looked at by hand...
|
| 98 |
+
if idx_guess == 4151:
|
| 99 |
+
cur_idx = 4152
|
| 100 |
+
else:
|
| 101 |
+
cur_idx = idx_guess
|
| 102 |
+
idx_guess = cur_idx
|
| 103 |
+
atomic_facts[cur_idx].extend(atomized_facts)
|
| 104 |
+
|
| 105 |
+
# time to annotate responses using factscore code
|
| 106 |
+
print("Loading annotator.")
|
| 107 |
+
scorer_client = GPTClient(config.model.annotator.cache_path, model=config.model.annotator.name)
|
| 108 |
+
scorer = Scorer(scorer_client, config, model_name="retrieval")
|
| 109 |
+
|
| 110 |
+
scorer_inputs = [(topic, output['response'], fact) for topic, output, fact in zip(topics, outputs, atomic_facts)]
|
| 111 |
+
|
| 112 |
+
import IPython; IPython.embed()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# connect to cluster
|
| 116 |
+
ray.init(address="auto")
|
| 117 |
+
|
| 118 |
+
results = []
|
| 119 |
+
|
| 120 |
+
for seed in range(args.seed, args.seed + args.n_trials):
|
| 121 |
+
if args.type == 'coverage':
|
| 122 |
+
result = parallel_coverage_experiment.remote(
|
| 123 |
+
(X, Y), n_test, n_calib, alpha, methods=args.methods, seed=seed
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
result = parallel_experiment.remote(
|
| 127 |
+
(X, Y), n_test, n_calib, alpha, methods=args.methods, seed=seed
|
| 128 |
+
)
|
| 129 |
+
results.append(result)
|
| 130 |
+
|
| 131 |
+
trial_results = ray.get(results)
|
| 132 |
+
|
| 133 |
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
| 134 |
+
scores = list(
|
| 135 |
+
tqdm(
|
| 136 |
+
executor.map(
|
| 137 |
+
lambda x : scorer.get_score(*x),
|
| 138 |
+
scorer_inputs
|
| 139 |
+
),
|
| 140 |
+
total=len(scorer_inputs)
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
scorer.save_cache()
|
| 144 |
+
|
| 145 |
+
dataset = []
|
| 146 |
+
|
| 147 |
+
for p, r, s in zip(prompts, responses, scores):
|
| 148 |
+
data_pt = {
|
| 149 |
+
'prompt': p,
|
| 150 |
+
'response': r,
|
| 151 |
+
'atomic_facts': s['decisions'][0]
|
| 152 |
+
}
|
| 153 |
+
dataset.append(data_pt)
|
| 154 |
+
|
| 155 |
+
import IPython
|
| 156 |
+
IPython.embed()
|
| 157 |
+
|
| 158 |
+
# client = GPTClient(f'.cache/{config.dataset.name}_frequency.pkl')
|
| 159 |
+
|
| 160 |
+
# with ThreadPoolExecutor(max_workers=5) as executor:
|
| 161 |
+
# frequencies = list(
|
| 162 |
+
# tqdm(
|
| 163 |
+
# executor.map(
|
| 164 |
+
# lambda x: get_frequency(client, [af['atom'] for af in x['atomic_facts']], x['prompt'], config.model.prob.frequency.model),
|
| 165 |
+
# dataset
|
| 166 |
+
# ),
|
| 167 |
+
# total=len(dataset)
|
| 168 |
+
# )
|
| 169 |
+
# )
|
| 170 |
+
# client.save_cache()
|
| 171 |
+
|
| 172 |
+
# eval_client = GPTClient(f'.cache/{config.dataset.name}_self_evals.pkl')
|
| 173 |
+
|
| 174 |
+
# with ThreadPoolExecutor(max_workers=25) as executor:
|
| 175 |
+
# self_evals = list(
|
| 176 |
+
# tqdm(
|
| 177 |
+
# executor.map(
|
| 178 |
+
# lambda x: get_self_eval(x['prompt'], [af['atom'] for af in x['atomic_facts']], eval_client),
|
| 179 |
+
# dataset
|
| 180 |
+
# ),
|
| 181 |
+
# total=len(dataset)
|
| 182 |
+
# )
|
| 183 |
+
# )
|
| 184 |
+
# eval_client.save_cache()
|
| 185 |
+
|
| 186 |
+
# features = np.concatenate(
|
| 187 |
+
# [
|
| 188 |
+
# np.concatenate(frequencies).reshape(-1,1),
|
| 189 |
+
# np.concatenate(self_evals).reshape(-1,1)
|
| 190 |
+
# ],
|
| 191 |
+
# axis=1
|
| 192 |
+
# )
|
MACI-main/conditional-conformal/src/retrieval.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import sqlite3
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pickle as pkl
|
| 8 |
+
|
| 9 |
+
from rank_bm25 import BM25Okapi
|
| 10 |
+
|
| 11 |
+
SPECIAL_SEPARATOR = "####SPECIAL####SEPARATOR####"
|
| 12 |
+
MAX_LENGTH = 256
|
| 13 |
+
|
| 14 |
+
class DocDB(object):
|
| 15 |
+
"""Sqlite backed document storage.
|
| 16 |
+
|
| 17 |
+
Implements get_doc_text(doc_id).
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, db_path=None, data_path=None, cache_path=None):
|
| 21 |
+
self.db_path = db_path
|
| 22 |
+
self.cache_file = cache_path
|
| 23 |
+
self.connection = sqlite3.connect(self.db_path, check_same_thread=False)
|
| 24 |
+
|
| 25 |
+
self.cache_dict = self.load_cache()
|
| 26 |
+
|
| 27 |
+
cursor = self.connection.cursor()
|
| 28 |
+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
| 29 |
+
|
| 30 |
+
if len(cursor.fetchall())==0:
|
| 31 |
+
assert data_path is not None, f"{self.db_path} is empty. Specify `data_path` in order to create a DB."
|
| 32 |
+
print (f"{self.db_path} is empty. start building DB from {data_path}...")
|
| 33 |
+
self.build_db(self.db_path, data_path)
|
| 34 |
+
|
| 35 |
+
def load_cache(self, allow_retry=True):
|
| 36 |
+
if os.path.exists(self.cache_file):
|
| 37 |
+
while True:
|
| 38 |
+
try:
|
| 39 |
+
with open(self.cache_file, "rb") as f:
|
| 40 |
+
cache = pkl.load(f)
|
| 41 |
+
break
|
| 42 |
+
except Exception: # if there are concurent processes, things can fail
|
| 43 |
+
if not allow_retry:
|
| 44 |
+
assert False
|
| 45 |
+
print ("Pickle Error: Retry in 5sec...")
|
| 46 |
+
time.sleep(5)
|
| 47 |
+
elif 's3' in self.cache_file:
|
| 48 |
+
from aws_utils import s3_open
|
| 49 |
+
s3_path = self.cache_file.removeprefix('s3://')
|
| 50 |
+
bucket_name = s3_path.split('/')[0]
|
| 51 |
+
path_to_file = '/'.join(s3_path.split('/')[1:])
|
| 52 |
+
with s3_open(bucket_name, path_to_file) as fp:
|
| 53 |
+
cache = pkl.load(fp)
|
| 54 |
+
else:
|
| 55 |
+
cache = {}
|
| 56 |
+
return cache
|
| 57 |
+
|
| 58 |
+
def save_cache(self):
|
| 59 |
+
# load the latest cache first, since if there were other processes running in parallel, cache might have been updated
|
| 60 |
+
for k, v in self.load_cache().items():
|
| 61 |
+
self.cache_dict[k] = v
|
| 62 |
+
|
| 63 |
+
with open(self.cache_file, "wb") as f:
|
| 64 |
+
pkl.dump(self.cache_dict, f)
|
| 65 |
+
|
| 66 |
+
def __enter__(self):
|
| 67 |
+
return self
|
| 68 |
+
|
| 69 |
+
def __exit__(self, *args):
|
| 70 |
+
self.close()
|
| 71 |
+
|
| 72 |
+
def path(self):
|
| 73 |
+
"""Return the path to the file that backs this database."""
|
| 74 |
+
return self.path
|
| 75 |
+
|
| 76 |
+
def close(self):
|
| 77 |
+
"""Close the connection to the database."""
|
| 78 |
+
self.connection.close()
|
| 79 |
+
|
| 80 |
+
def build_db(self, db_path, data_path):
|
| 81 |
+
from transformers import RobertaTokenizer
|
| 82 |
+
tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
|
| 83 |
+
|
| 84 |
+
titles = set()
|
| 85 |
+
output_lines = []
|
| 86 |
+
tot = 0
|
| 87 |
+
start_time = time.time()
|
| 88 |
+
c = self.connection.cursor()
|
| 89 |
+
c.execute("CREATE TABLE documents (title PRIMARY KEY, text);")
|
| 90 |
+
|
| 91 |
+
with open(data_path, "r") as f:
|
| 92 |
+
for line in f:
|
| 93 |
+
dp = json.loads(line)
|
| 94 |
+
title = dp["title"]
|
| 95 |
+
text = dp["text"]
|
| 96 |
+
if title in titles:
|
| 97 |
+
continue
|
| 98 |
+
titles.add(title)
|
| 99 |
+
if type(text)==str:
|
| 100 |
+
text = [text]
|
| 101 |
+
passages = [[]]
|
| 102 |
+
for sent_idx, sent in enumerate(text):
|
| 103 |
+
assert len(sent.strip())>0
|
| 104 |
+
tokens = tokenizer(sent)["input_ids"]
|
| 105 |
+
max_length = MAX_LENGTH - len(passages[-1])
|
| 106 |
+
if len(tokens) <= max_length:
|
| 107 |
+
passages[-1].extend(tokens)
|
| 108 |
+
else:
|
| 109 |
+
passages[-1].extend(tokens[:max_length])
|
| 110 |
+
offset = max_length
|
| 111 |
+
while offset < len(tokens):
|
| 112 |
+
passages.append(tokens[offset:offset+MAX_LENGTH])
|
| 113 |
+
offset += MAX_LENGTH
|
| 114 |
+
|
| 115 |
+
psgs = [tokenizer.decode(tokens) for tokens in passages if np.sum([t not in [0, 2] for t in tokens])>0]
|
| 116 |
+
text = SPECIAL_SEPARATOR.join(psgs)
|
| 117 |
+
output_lines.append((title, text))
|
| 118 |
+
tot += 1
|
| 119 |
+
|
| 120 |
+
if len(output_lines) == 1000000:
|
| 121 |
+
c.executemany("INSERT INTO documents VALUES (?,?)", output_lines)
|
| 122 |
+
output_lines = []
|
| 123 |
+
print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60))
|
| 124 |
+
|
| 125 |
+
if len(output_lines) > 0:
|
| 126 |
+
c.executemany("INSERT INTO documents VALUES (?,?)", output_lines)
|
| 127 |
+
print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60))
|
| 128 |
+
|
| 129 |
+
self.connection.commit()
|
| 130 |
+
self.connection.close()
|
| 131 |
+
|
| 132 |
+
def get_text_from_title(self, title):
|
| 133 |
+
"""Fetch the raw text of the doc for 'doc_id'."""
|
| 134 |
+
with open('data/wiki_corrections.txt') as fp:
|
| 135 |
+
all_names = fp.readlines()
|
| 136 |
+
all_names = [n.strip() for n in all_names]
|
| 137 |
+
name_converter = {names.split('=')[0]:names.split('=')[1] for names in all_names}
|
| 138 |
+
if title in name_converter:
|
| 139 |
+
title = name_converter[title]
|
| 140 |
+
|
| 141 |
+
if title in self.cache_dict:
|
| 142 |
+
results = self.cache_dict[title]
|
| 143 |
+
else:
|
| 144 |
+
print("I SHOULD NOT BE HERE.")
|
| 145 |
+
cursor = self.connection.cursor()
|
| 146 |
+
cursor.execute("SELECT text FROM documents WHERE title = ?", (title,))
|
| 147 |
+
results = cursor.fetchall()
|
| 148 |
+
results = [r for r in results]
|
| 149 |
+
cursor.close()
|
| 150 |
+
try:
|
| 151 |
+
assert results is not None and len(results)==1, f"`topic` in your data ({title}) is likely to be not a valid title in the DB."
|
| 152 |
+
except Exception: # if there are concurent processes, things can fail
|
| 153 |
+
print (f"Retrieval error for {title}: Retry in 5sec...")
|
| 154 |
+
# time.sleep(5)
|
| 155 |
+
cursor = self.connection.cursor()
|
| 156 |
+
cursor.execute("SELECT text FROM documents WHERE title = ?", (title,))
|
| 157 |
+
results = cursor.fetchall()
|
| 158 |
+
results = [r for r in results]
|
| 159 |
+
results = [['blah blah blah']]
|
| 160 |
+
cursor.close()
|
| 161 |
+
results = [{"title": title, "text": para} for para in results[0][0].split(SPECIAL_SEPARATOR)]
|
| 162 |
+
assert len(results)>0, f"`topic` in your data ({title}) is likely to be not a valid title in the DB."
|
| 163 |
+
self.cache_dict[title] = results
|
| 164 |
+
return results
|
| 165 |
+
|
| 166 |
+
class Retrieval(object):
|
| 167 |
+
|
| 168 |
+
def __init__(self, db, cache_path, embed_cache_path,
|
| 169 |
+
retrieval_type="gtr-t5-large", batch_size=None):
|
| 170 |
+
self.db = db
|
| 171 |
+
self.cache_path = cache_path
|
| 172 |
+
self.embed_cache_path = embed_cache_path
|
| 173 |
+
self.retrieval_type = retrieval_type
|
| 174 |
+
self.batch_size = batch_size
|
| 175 |
+
assert retrieval_type=="bm25" or retrieval_type.startswith("gtr-")
|
| 176 |
+
|
| 177 |
+
self.encoder = None
|
| 178 |
+
self.load_cache()
|
| 179 |
+
self.add_n = 0
|
| 180 |
+
self.add_n_embed = 0
|
| 181 |
+
|
| 182 |
+
def load_encoder(self):
|
| 183 |
+
from sentence_transformers import SentenceTransformer
|
| 184 |
+
encoder = SentenceTransformer("sentence-transformers/" + self.retrieval_type)
|
| 185 |
+
encoder = encoder.cuda()
|
| 186 |
+
encoder = encoder.eval()
|
| 187 |
+
self.encoder = encoder
|
| 188 |
+
assert self.batch_size is not None
|
| 189 |
+
|
| 190 |
+
def load_cache(self):
|
| 191 |
+
if os.path.exists(self.cache_path):
|
| 192 |
+
with open(self.cache_path, "r") as f:
|
| 193 |
+
self.cache = json.load(f)
|
| 194 |
+
else:
|
| 195 |
+
self.cache = {}
|
| 196 |
+
if os.path.exists(self.embed_cache_path):
|
| 197 |
+
with open(self.embed_cache_path, "rb") as f:
|
| 198 |
+
self.embed_cache = pkl.load(f)
|
| 199 |
+
else:
|
| 200 |
+
self.embed_cache = {}
|
| 201 |
+
|
| 202 |
+
def save_cache(self):
|
| 203 |
+
if self.add_n > 0:
|
| 204 |
+
if os.path.exists(self.cache_path):
|
| 205 |
+
with open(self.cache_path, "r") as f:
|
| 206 |
+
new_cache = json.load(f)
|
| 207 |
+
self.cache.update(new_cache)
|
| 208 |
+
|
| 209 |
+
with open(self.cache_path, "w") as f:
|
| 210 |
+
json.dump(self.cache, f)
|
| 211 |
+
|
| 212 |
+
if self.add_n_embed > 0:
|
| 213 |
+
if os.path.exists(self.embed_cache_path):
|
| 214 |
+
with open(self.embed_cache_path, "rb") as f:
|
| 215 |
+
new_cache = pkl.load(f)
|
| 216 |
+
self.embed_cache.update(new_cache)
|
| 217 |
+
|
| 218 |
+
with open(self.embed_cache_path, "wb") as f:
|
| 219 |
+
pkl.dump(self.embed_cache, f)
|
| 220 |
+
|
| 221 |
+
def get_bm25_passages(self, topic, query, passages, k):
|
| 222 |
+
if topic in self.embed_cache:
|
| 223 |
+
bm25 = self.embed_cache[topic]
|
| 224 |
+
else:
|
| 225 |
+
bm25 = BM25Okapi([psg["text"].replace("<s>", "").replace("</s>", "").split() for psg in passages])
|
| 226 |
+
self.embed_cache[topic] = bm25
|
| 227 |
+
self.add_n_embed += 1
|
| 228 |
+
scores = bm25.get_scores(query.split())
|
| 229 |
+
indices = np.argsort(-scores)[:k]
|
| 230 |
+
return [passages[i] for i in indices]
|
| 231 |
+
|
| 232 |
+
def get_gtr_passages(self, topic, retrieval_query, passages, k):
|
| 233 |
+
if self.encoder is None:
|
| 234 |
+
self.load_encoder()
|
| 235 |
+
if topic in self.embed_cache:
|
| 236 |
+
passage_vectors = self.embed_cache[topic]
|
| 237 |
+
else:
|
| 238 |
+
inputs = [psg["title"] + " " + psg["text"].replace("<s>", "").replace("</s>", "") for psg in passages]
|
| 239 |
+
passage_vectors = self.encoder.encode(inputs, batch_size=self.batch_size, device=self.encoder.device)
|
| 240 |
+
self.embed_cache[topic] = passage_vectors
|
| 241 |
+
self.add_n_embed += 1
|
| 242 |
+
query_vectors = self.encoder.encode([retrieval_query],
|
| 243 |
+
batch_size=self.batch_size,
|
| 244 |
+
device=self.encoder.device)[0]
|
| 245 |
+
scores = np.inner(query_vectors, passage_vectors)
|
| 246 |
+
indices = np.argsort(-scores)[:k]
|
| 247 |
+
return [passages[i] for i in indices]
|
| 248 |
+
|
| 249 |
+
def get_passages(self, topic, question, k):
|
| 250 |
+
retrieval_query = topic + " " + question.strip()
|
| 251 |
+
cache_key = topic + "#" + retrieval_query
|
| 252 |
+
|
| 253 |
+
if cache_key not in self.cache:
|
| 254 |
+
passages = self.db.get_text_from_title(topic)
|
| 255 |
+
if self.retrieval_type=="bm25":
|
| 256 |
+
self.cache[cache_key] = self.get_bm25_passages(topic, retrieval_query, passages, k)
|
| 257 |
+
else:
|
| 258 |
+
self.cache[cache_key] = self.get_gtr_passages(topic, retrieval_query, passages, k)
|
| 259 |
+
assert len(self.cache[cache_key]) in [k, len(passages)]
|
| 260 |
+
self.add_n += 1
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
return self.cache[cache_key]
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
|
MACI-main/conditional-conformal/src/retrieve_data.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from dataset import load_dataset
|
| 8 |
+
from config import get_config
|
| 9 |
+
|
| 10 |
+
from featurizer import get_frequency, get_self_eval, get_bool_eval
|
| 11 |
+
from gpt import GPTClient
|
| 12 |
+
|
| 13 |
+
def parse_args():
|
| 14 |
+
parser = argparse.ArgumentParser(
|
| 15 |
+
prog="conformal-safety",
|
| 16 |
+
description="Auto-filter claims from LLM to meet accuracy and safety guarantees.",
|
| 17 |
+
)
|
| 18 |
+
parser.add_argument('-config_path', '-c', default='configs/default.toml', help="Config for construction.")
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
return args
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
args = parse_args()
|
| 24 |
+
config = get_config(args.config_path)
|
| 25 |
+
|
| 26 |
+
dataset = load_dataset(config)
|
| 27 |
+
|
| 28 |
+
# client = GPTClient(f'.cache/{config.dataset.name}_frequency.pkl')
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# with ThreadPoolExecutor(max_workers=8) as executor:
|
| 32 |
+
# frequencies = list(
|
| 33 |
+
# tqdm(
|
| 34 |
+
# executor.map(
|
| 35 |
+
# lambda x: get_frequency(client, [af['atom'] for af in x['atomic_facts']], x['prompt'], config.model.prob.frequency.model),
|
| 36 |
+
# dataset
|
| 37 |
+
# ),
|
| 38 |
+
# total=len(dataset)
|
| 39 |
+
# )
|
| 40 |
+
# )
|
| 41 |
+
# client.save_cache()
|
| 42 |
+
|
| 43 |
+
# eval_client = GPTClient(f'.cache/{config.dataset.name}_self_evals.pkl')
|
| 44 |
+
|
| 45 |
+
# with ThreadPoolExecutor(max_workers=25) as executor:
|
| 46 |
+
# self_evals = list(
|
| 47 |
+
# tqdm(
|
| 48 |
+
# executor.map(
|
| 49 |
+
# lambda x: get_self_eval(x['prompt'], [af['atom'] for af in x['atomic_facts']], eval_client),
|
| 50 |
+
# dataset
|
| 51 |
+
# ),
|
| 52 |
+
# total=len(dataset)
|
| 53 |
+
# )
|
| 54 |
+
# )
|
| 55 |
+
# eval_client.save_cache()
|
| 56 |
+
|
| 57 |
+
# bool_client = GPTClient(f'.cache/{config.dataset.name}_bool_evals.pkl')
|
| 58 |
+
|
| 59 |
+
# with ThreadPoolExecutor(max_workers=25) as executor:
|
| 60 |
+
# self_bools = list(
|
| 61 |
+
# tqdm(
|
| 62 |
+
# executor.map(
|
| 63 |
+
# lambda x: get_bool_eval(x['prompt'], [af['atom'] for af in x['atomic_facts']], bool_client),
|
| 64 |
+
# dataset
|
| 65 |
+
# ),
|
| 66 |
+
# total=len(dataset)
|
| 67 |
+
# )
|
| 68 |
+
# )
|
| 69 |
+
# bool_client.save_cache()
|
| 70 |
+
|
| 71 |
+
# features = np.concatenate(
|
| 72 |
+
# [
|
| 73 |
+
# np.concatenate(frequencies).reshape(-1,1),
|
| 74 |
+
# np.concatenate(self_evals).reshape(-1,1)
|
| 75 |
+
# ],
|
| 76 |
+
# axis=1
|
| 77 |
+
# )
|
| 78 |
+
|
| 79 |
+
import IPython; IPython.embed()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
MACI-main/conditional-conformal/src/run.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from config import get_config
|
| 5 |
+
from conformal import compute_conformity_scores, calibrate_thresholds, conformal_filter, assess_factscore_coverage
|
| 6 |
+
from dataset import load_dataset, split_dataset
|
| 7 |
+
from featurizer import get_features
|
| 8 |
+
from llm_utils import merge_claims
|
| 9 |
+
from prob_model import fit_model
|
| 10 |
+
from gpt import GPTClient
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse_args():
|
| 14 |
+
parser = argparse.ArgumentParser(
|
| 15 |
+
prog="conformal-safety",
|
| 16 |
+
description="Auto-filter claims from LLM to meet accuracy and safety guarantees.",
|
| 17 |
+
)
|
| 18 |
+
parser.add_argument('-config_path', '-c', default='configs/default.toml', help="Config for construction.")
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
return args
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
if __name__ == "__main__":
|
| 24 |
+
args = parse_args()
|
| 25 |
+
|
| 26 |
+
config = get_config(args.config_path)
|
| 27 |
+
|
| 28 |
+
rng = np.random.default_rng(seed=config.dataset.seed)
|
| 29 |
+
|
| 30 |
+
# annotate dataset
|
| 31 |
+
dataset = load_dataset(config)
|
| 32 |
+
|
| 33 |
+
# split dataset into train / validation / test
|
| 34 |
+
dataset_train, dataset_valid, dataset_test = split_dataset(
|
| 35 |
+
dataset,
|
| 36 |
+
train_perc=config.dataset.train_percent,
|
| 37 |
+
valid_perc=config.dataset.valid_percent,
|
| 38 |
+
rng=rng if config.dataset.randomize else None
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
X_train = get_features(dataset_train, config)
|
| 42 |
+
|
| 43 |
+
y_train = np.concatenate([[c['is_supported'] for c in dat['atomic_facts']] for dat in dataset_train])
|
| 44 |
+
y_train[y_train == True] = 1
|
| 45 |
+
y_train[y_train == False] = 0
|
| 46 |
+
y_train = y_train.astype(np.int8)
|
| 47 |
+
|
| 48 |
+
X_valid = get_features(dataset_valid, config)
|
| 49 |
+
y_valid = np.concatenate([[c['is_supported'] for c in dat['atomic_facts']] for dat in dataset_valid])
|
| 50 |
+
y_valid[y_valid == True] = 1
|
| 51 |
+
y_valid[y_valid == False] = 0
|
| 52 |
+
y_valid = y_valid.astype(np.int8)
|
| 53 |
+
splits_valid = np.cumsum([len(dat['atomic_facts']) for dat in dataset_valid])[:-1]
|
| 54 |
+
|
| 55 |
+
X_test = get_features(dataset_test, config)
|
| 56 |
+
y_test = np.concatenate([[c['is_supported'] for c in dat['atomic_facts']] for dat in dataset_test])
|
| 57 |
+
y_test[y_test == True] = 1
|
| 58 |
+
y_test[y_test == False] = 0
|
| 59 |
+
y_test = y_test.astype(np.int8)
|
| 60 |
+
splits_test = np.cumsum([len(dat['atomic_facts']) for dat in dataset_test])[:-1]
|
| 61 |
+
|
| 62 |
+
model = fit_model(X_train, y_train, config, dataset_train,
|
| 63 |
+
eval_dict={'X_valid': X_valid, 'X_test': X_test, 'dataset_valid': dataset_valid, 'splits_valid': splits_valid, 'splits_test': splits_test})
|
| 64 |
+
|
| 65 |
+
scores_valid = model.predict_proba(X_valid)[:,1]
|
| 66 |
+
scores_valid = np.array_split(scores_valid, splits_valid)
|
| 67 |
+
|
| 68 |
+
scores_test = model.predict_proba(X_test)[:,1]
|
| 69 |
+
scores_test = np.array_split(scores_test, splits_test)
|
| 70 |
+
# identify features for scoring
|
| 71 |
+
score_features_v = [np.zeros((len(u['atomic_facts']), 1)) for u in dataset_valid]
|
| 72 |
+
score_features_te = [np.zeros((len(u['atomic_facts']), 1)) for u in dataset_test]
|
| 73 |
+
|
| 74 |
+
conf_scores_valid = compute_conformity_scores(dataset_valid, scores_valid)
|
| 75 |
+
|
| 76 |
+
# fit error probability function using training set (or just define it?)
|
| 77 |
+
# we want to be more sure about correctness on more sensitive prompts
|
| 78 |
+
alpha_fn = lambda x: [config.conformal.alpha] * len(x) # TODO: dumb one for now.
|
| 79 |
+
|
| 80 |
+
# identify features for conditional calibration
|
| 81 |
+
conf_features_v = np.zeros((len(dataset_valid),1))
|
| 82 |
+
conf_features_te = np.zeros((len(dataset_test),1))
|
| 83 |
+
|
| 84 |
+
# calibrate a threshold on the validation set
|
| 85 |
+
thresholds = calibrate_thresholds(
|
| 86 |
+
conf_features_te,
|
| 87 |
+
conf_features_v,
|
| 88 |
+
conf_scores_valid,
|
| 89 |
+
alpha_fn
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
dataset_test = conformal_filter(
|
| 93 |
+
dataset_test,
|
| 94 |
+
scores_test,
|
| 95 |
+
thresholds
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if config.dataset.name.lower() == "factscore":
|
| 99 |
+
assess_factscore_coverage(dataset_test, config.conformal.alpha)
|
| 100 |
+
|
| 101 |
+
print("Merging filtered responses.")
|
| 102 |
+
|
| 103 |
+
merge_client = GPTClient(cache_file = config.model.merger.cache_path)
|
| 104 |
+
merged_responses = merge_claims(
|
| 105 |
+
dataset_test,
|
| 106 |
+
merge_client
|
| 107 |
+
)
|
| 108 |
+
merge_client.save_cache()
|
| 109 |
+
|
| 110 |
+
rand_idx = rng.integers(0, len(dataset_test))
|
| 111 |
+
print(dataset_test[rand_idx]['response']['message'] + "\n")
|
| 112 |
+
print(merged_responses[rand_idx])
|
| 113 |
+
|
| 114 |
+
import IPython; IPython.embed()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
MACI-main/conditional-conformal/src/scorer.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import string
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 7 |
+
# import logging
|
| 8 |
+
|
| 9 |
+
# from tqdm import tqdm
|
| 10 |
+
# from factscore.abstain_detection import is_response_abstained
|
| 11 |
+
from retrieval import DocDB, Retrieval
|
| 12 |
+
|
| 13 |
+
class Scorer(object):
|
| 14 |
+
|
| 15 |
+
def __init__(self,
|
| 16 |
+
client,
|
| 17 |
+
config,
|
| 18 |
+
model_name="retrieval+ChatGPT",
|
| 19 |
+
batch_size=256):
|
| 20 |
+
assert model_name in ["retrieval+llama", "retrieval+llama+npm", "retrieval+ChatGPT", "npm", "retrieval+ChatGPT+npm", "retrieval"]
|
| 21 |
+
self.model_name = model_name
|
| 22 |
+
self.client = client
|
| 23 |
+
self.config = config
|
| 24 |
+
|
| 25 |
+
self.data_dir = config.model.annotator.data_path
|
| 26 |
+
self.cache_dir = config.model.annotator.retrieval_cache_path
|
| 27 |
+
|
| 28 |
+
self.db = {}
|
| 29 |
+
self.retrieval = {}
|
| 30 |
+
self.npm = {}
|
| 31 |
+
self.batch_size = batch_size # batch size for retrieval
|
| 32 |
+
# self.abstain_detection_type = abstain_detection_type
|
| 33 |
+
|
| 34 |
+
# self.data_dir = data_dir
|
| 35 |
+
# self.cache_dir = cache_dir
|
| 36 |
+
# if not os.path.exists(cache_dir):
|
| 37 |
+
# os.makedirs(cache_dir)
|
| 38 |
+
|
| 39 |
+
self.af_generator = None
|
| 40 |
+
|
| 41 |
+
def save_cache(self):
|
| 42 |
+
self.client.save_cache()
|
| 43 |
+
if "npm" in self.model_name:
|
| 44 |
+
for k, v in self.npm.items():
|
| 45 |
+
v.save_cache()
|
| 46 |
+
for k, v in self.retrieval.items():
|
| 47 |
+
v.save_cache()
|
| 48 |
+
for k, v in self.db:
|
| 49 |
+
v.save_cache()
|
| 50 |
+
|
| 51 |
+
def register_knowledge_source(self, name="enwiki-20230401", db_path=None, data_path=None):
|
| 52 |
+
assert name not in self.retrieval, f"{name} already registered"
|
| 53 |
+
|
| 54 |
+
if db_path is None:
|
| 55 |
+
db_path = os.path.join(self.data_dir, f"{name}.db")
|
| 56 |
+
|
| 57 |
+
if data_path is None:
|
| 58 |
+
data_path = os.path.join(self.data_dir, f"{name}.jsonl")
|
| 59 |
+
|
| 60 |
+
if name == "medlfqa":
|
| 61 |
+
datasets = {}
|
| 62 |
+
suffix = "_test_MedLFQA.jsonl"
|
| 63 |
+
|
| 64 |
+
# dataset_dir = "/Users/cherian/Projects/OLAPH/MedLFQA"
|
| 65 |
+
for path in os.listdir(self.data_dir):
|
| 66 |
+
if "MedLFQA" not in path:
|
| 67 |
+
continue
|
| 68 |
+
dataset_name = path[:-len(suffix)]
|
| 69 |
+
with open(os.path.join(self.data_dir, path), 'r') as fp:
|
| 70 |
+
datasets[dataset_name] = [json.loads(line) for line in fp.readlines()]
|
| 71 |
+
retrieval = {}
|
| 72 |
+
for _, dataset in datasets.items():
|
| 73 |
+
for pt in dataset:
|
| 74 |
+
retrieval[pt['Question']] = {
|
| 75 |
+
'context': pt['Free_form_answer'],
|
| 76 |
+
'must_have': pt['Must_have'],
|
| 77 |
+
'nice_to_have': pt['Nice_to_have']
|
| 78 |
+
}
|
| 79 |
+
self.retrieval[name] = retrieval
|
| 80 |
+
|
| 81 |
+
else:
|
| 82 |
+
db_cache_path = os.path.join(self.cache_dir, f"db-{name}.pkl")
|
| 83 |
+
cache_path = os.path.join(self.cache_dir, f"retrieval-{name}.json")
|
| 84 |
+
embed_cache_path = os.path.join(self.cache_dir, f"retrieval-{name}.pkl")
|
| 85 |
+
|
| 86 |
+
self.db[name] = DocDB(db_path=db_path, data_path=data_path, cache_path=db_cache_path)
|
| 87 |
+
self.retrieval[name] = Retrieval(self.db[name], cache_path, embed_cache_path, retrieval_type="bm25", batch_size=self.batch_size)
|
| 88 |
+
# if "npm" in self.model_name:
|
| 89 |
+
# cache_path = os.path.join(self.cache_dir, f"bm25-{name}.json")
|
| 90 |
+
# embed_cache_path = os.path.join(self.cache_dir, f"bm25-{name}.pkl")
|
| 91 |
+
# self.npm[name] = NPM(Retrieval(self.db[name], cache_path, embed_cache_path, "bm25"),
|
| 92 |
+
# "npm-single",
|
| 93 |
+
# cache_file=os.path.join(self.cache_dir, f"npm-{name}.pkl"))
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_score(self,
|
| 97 |
+
topics,
|
| 98 |
+
generations,
|
| 99 |
+
atomic_facts,
|
| 100 |
+
gamma=10,
|
| 101 |
+
knowledge_source=None):
|
| 102 |
+
if knowledge_source is None:
|
| 103 |
+
# use the default knowledge source
|
| 104 |
+
knowledge_source = "enwiki-20230401"
|
| 105 |
+
|
| 106 |
+
if knowledge_source not in self.retrieval:
|
| 107 |
+
self.register_knowledge_source(knowledge_source)
|
| 108 |
+
|
| 109 |
+
if type(topics)==type(generations)==str:
|
| 110 |
+
topics = [topics]
|
| 111 |
+
generations = [generations]
|
| 112 |
+
atomic_facts = [atomic_facts]
|
| 113 |
+
else:
|
| 114 |
+
assert type(topics)==type(generations)==list, "`topics` and `generations` should be lists."
|
| 115 |
+
assert len(topics)==len(generations), "`topics` and `generations` should have the same length"
|
| 116 |
+
assert len(topics)==len(atomic_facts), "`topics` and `atomic_facts` should have the same length"
|
| 117 |
+
|
| 118 |
+
respond_ratio = np.mean([facts is not None for facts in atomic_facts])
|
| 119 |
+
|
| 120 |
+
scores = []
|
| 121 |
+
init_scores = []
|
| 122 |
+
decisions = []
|
| 123 |
+
for topic, generation, facts in zip(topics, generations, atomic_facts):
|
| 124 |
+
if facts is None:
|
| 125 |
+
decisions.append(None)
|
| 126 |
+
else:
|
| 127 |
+
decision = []
|
| 128 |
+
for fact in facts:
|
| 129 |
+
decision.append(
|
| 130 |
+
self._get_score(topic, generation, fact, knowledge_source, decision)
|
| 131 |
+
)
|
| 132 |
+
score = np.mean([d["is_supported"] for d in decision])
|
| 133 |
+
|
| 134 |
+
if gamma:
|
| 135 |
+
init_scores.append(score)
|
| 136 |
+
penalty = 1.0 if len(facts)>gamma else np.exp(1-gamma/max(len(facts), 1))
|
| 137 |
+
score = penalty * score
|
| 138 |
+
|
| 139 |
+
decisions.append(decision)
|
| 140 |
+
scores.append(score)
|
| 141 |
+
# if len(scores) % 10 == 0:
|
| 142 |
+
# self.save_cache()
|
| 143 |
+
|
| 144 |
+
out = {"score": np.mean(scores),
|
| 145 |
+
"respond_ratio": respond_ratio,
|
| 146 |
+
"decisions": decisions,
|
| 147 |
+
"num_facts_per_response": np.mean([len(d) for d in decisions if d is not None])}
|
| 148 |
+
|
| 149 |
+
if gamma:
|
| 150 |
+
out["init_score"] = np.mean(init_scores)
|
| 151 |
+
|
| 152 |
+
return out
|
| 153 |
+
|
| 154 |
+
def _get_score(self, topic, generation, atom, knowledge_source, prev_decisions = []):
|
| 155 |
+
definition = f"Answer the question about {topic} based on the given context and your previous answers.\n\n"
|
| 156 |
+
atom = atom.strip()
|
| 157 |
+
if knowledge_source == "medlfqa":
|
| 158 |
+
context = self.retrieval[knowledge_source][topic]['context']
|
| 159 |
+
else:
|
| 160 |
+
passages = self.retrieval[knowledge_source].get_passages(topic, atom, k=5)
|
| 161 |
+
context = ""
|
| 162 |
+
for psg in reversed(passages):
|
| 163 |
+
context += "Title: {}\nText: {}\n\n".format(psg["title"], psg["text"].replace("<s>", "").replace("</s>", ""))
|
| 164 |
+
definition += context.strip()
|
| 165 |
+
if not definition[-1] in string.punctuation:
|
| 166 |
+
definition += "."
|
| 167 |
+
prompt = f"{definition.strip()}\n\n"
|
| 168 |
+
for prev_decision in prev_decisions:
|
| 169 |
+
prev_score = "True" if prev_decision["is_supported"] else "False"
|
| 170 |
+
prompt += f"Previous input: {prev_decision['atom']}\nTrue or False? Output: {prev_score}\n"
|
| 171 |
+
|
| 172 |
+
prompt += f"Input: {atom.strip()} True or False?\nOutput:"
|
| 173 |
+
# output = [{'message': 'blah blah blah'}]
|
| 174 |
+
output = self.client.query(prompt)
|
| 175 |
+
|
| 176 |
+
# if type(output[1])==np.ndarray:
|
| 177 |
+
# # when logits are available
|
| 178 |
+
# logits = np.array(output[1])
|
| 179 |
+
# assert logits.shape[0] in [32000, 32001]
|
| 180 |
+
# true_score = logits[5852]
|
| 181 |
+
# false_score = logits[7700]
|
| 182 |
+
# is_supported = true_score > false_score
|
| 183 |
+
# else:
|
| 184 |
+
# when logits are unavailable
|
| 185 |
+
generated_answer = output[0]['message'].lower()
|
| 186 |
+
if "true" in generated_answer or "false" in generated_answer:
|
| 187 |
+
if "true" in generated_answer and "false" not in generated_answer:
|
| 188 |
+
is_supported = True
|
| 189 |
+
elif "false" in generated_answer and "true" not in generated_answer:
|
| 190 |
+
is_supported = False
|
| 191 |
+
else:
|
| 192 |
+
is_supported = generated_answer.index("true") > generated_answer.index("false")
|
| 193 |
+
else:
|
| 194 |
+
is_supported = all([keyword not in generated_answer.lower().translate(str.maketrans("", "", string.punctuation)).split() for keyword in ["not", "cannot", "unknown", "information"]])
|
| 195 |
+
|
| 196 |
+
if is_supported and "npm" in self.model_name:
|
| 197 |
+
npprob = self.npm[knowledge_source].get_probabilty(topic, atom)
|
| 198 |
+
is_supported = npprob > 0.3
|
| 199 |
+
|
| 200 |
+
decision = {"atom": atom, "is_supported": is_supported}
|
| 201 |
+
|
| 202 |
+
return decision
|
MACI-main/conformal/__pycache__/adaptive_conformal.cpython-39.pyc
ADDED
|
Binary file (19.1 kB). View file
|
|
|
MACI-main/conformal/__pycache__/basic_conformal.cpython-39.pyc
ADDED
|
Binary file (5.87 kB). View file
|
|
|
MACI-main/conformal/__pycache__/conditional_conformal.cpython-39.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
MACI-main/conformal/adaptive_conformal.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import logging
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
from sklearn.metrics import roc_auc_score
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from scipy.optimize import minimize
|
| 7 |
+
from typing import Callable, List, Dict, Any, Optional, Tuple
|
| 8 |
+
import cvxpy as cp
|
| 9 |
+
|
| 10 |
+
class MACIAdaptiveConformal:
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
score_function: Callable,
|
| 14 |
+
random_state: Optional[int] = None,
|
| 15 |
+
eps: float = 1e-6,
|
| 16 |
+
**kwargs,
|
| 17 |
+
) -> None:
|
| 18 |
+
self.score_function = score_function
|
| 19 |
+
self.random_state = random_state
|
| 20 |
+
self.eps = float(eps)
|
| 21 |
+
self.tau_hat: Optional[float] = None
|
| 22 |
+
self._rng = np.random.default_rng(self.random_state)
|
| 23 |
+
|
| 24 |
+
def _process_raw_scores(self, raw_scores: List, data: List[Dict]) -> List[np.ndarray]:
|
| 25 |
+
if raw_scores and isinstance(raw_scores[0], np.ndarray):
|
| 26 |
+
return [np.asarray(s, dtype=float) for s in raw_scores]
|
| 27 |
+
per_sample_scores: List[np.ndarray] = []
|
| 28 |
+
samples = [d.get('sample', d) for d in data]
|
| 29 |
+
for i, s_i in enumerate(raw_scores):
|
| 30 |
+
n_claims = len(samples[i].get("atomic_facts", []))
|
| 31 |
+
s_arr = np.asarray(list(s_i), dtype=float)[:n_claims]
|
| 32 |
+
per_sample_scores.append(np.nan_to_num(s_arr, nan=0.0))
|
| 33 |
+
return per_sample_scores
|
| 34 |
+
|
| 35 |
+
def _compute_nonconformity_score(self, sample: dict, scores_i: np.ndarray) -> float:
|
| 36 |
+
atomic_facts = sample.get("atomic_facts", [])
|
| 37 |
+
if not atomic_facts or scores_i.size == 0: return 0.0
|
| 38 |
+
labels = np.asarray([af.get("is_supported", False) for af in atomic_facts], dtype=bool)
|
| 39 |
+
s_raw = np.asarray(scores_i, dtype=float)
|
| 40 |
+
s_raw = np.nan_to_num(s_raw, nan=0.0, posinf=1.0, neginf=0.0)
|
| 41 |
+
s = np.clip(s_raw, 0.0, 1.0 - self.eps)
|
| 42 |
+
idx = np.argsort(s, kind='mergesort')
|
| 43 |
+
s_sorted_asc, labels_asc = s[idx], labels[idx]
|
| 44 |
+
false_positions = np.where(~labels_asc)[0]
|
| 45 |
+
if not false_positions.size: return 0.0
|
| 46 |
+
k_star = int(false_positions.max())
|
| 47 |
+
costs = -np.log(1.0 - s_sorted_asc)
|
| 48 |
+
return float(np.sum(costs[:k_star + 1]))
|
| 49 |
+
|
| 50 |
+
def fit_on_calib(self, calib_data: List[dict], alpha: float = 0.1) -> "MACIAdaptiveConformal":
|
| 51 |
+
raw_scores = self.score_function(calib_data)
|
| 52 |
+
per_sample_scores = self._process_raw_scores(raw_scores, calib_data)
|
| 53 |
+
calib_samples = [entry.get('sample', entry) for entry in calib_data]
|
| 54 |
+
s_values = [self._compute_nonconformity_score(s, sc) for s, sc in zip(calib_samples, per_sample_scores)]
|
| 55 |
+
|
| 56 |
+
logging.info(f" - Calibration set size: {len(calib_data)} samples")
|
| 57 |
+
if not s_values:
|
| 58 |
+
raise ValueError("Cannot compute scores from calibration data.")
|
| 59 |
+
|
| 60 |
+
logging.info(f" - Nonconformity stats: min={min(s_values):.4f}, max={max(s_values):.4f}, mean={np.mean(s_values):.4f}")
|
| 61 |
+
|
| 62 |
+
n = len(s_values)
|
| 63 |
+
quantile_index = int(np.ceil((1.0 - alpha) * (n + 1))) - 1
|
| 64 |
+
quantile_index = min(quantile_index, n - 1)
|
| 65 |
+
|
| 66 |
+
sorted_s_values = np.sort(s_values)
|
| 67 |
+
self.tau_hat = sorted_s_values[quantile_index]
|
| 68 |
+
|
| 69 |
+
logging.info(f" - Assigned tau_hat: {self.tau_hat:.4f}")
|
| 70 |
+
return self
|
| 71 |
+
|
| 72 |
+
def predict(self, data: List[dict]) -> Tuple[List[dict], List[float]]:
|
| 73 |
+
if self.tau_hat is None: raise ValueError("Model is not calibrated.")
|
| 74 |
+
raw_scores = self.score_function(data)
|
| 75 |
+
per_sample_scores = self._process_raw_scores(raw_scores, data)
|
| 76 |
+
samples = [d.get('sample', d) for d in data]
|
| 77 |
+
|
| 78 |
+
filtered_data, retention_rates = [], []
|
| 79 |
+
for sample, s_raw in zip(samples, per_sample_scores):
|
| 80 |
+
atomic_facts = sample.get("atomic_facts", [])
|
| 81 |
+
new_sample = dict(sample)
|
| 82 |
+
if not atomic_facts or s_raw.size == 0:
|
| 83 |
+
new_sample["filtered_claims"] = []
|
| 84 |
+
retention_rates.append(1.0 if not atomic_facts else 0.0)
|
| 85 |
+
else:
|
| 86 |
+
s_tmp = np.asarray(s_raw, dtype=float)
|
| 87 |
+
s_tmp = np.nan_to_num(s_tmp, nan=0.0, posinf=1.0, neginf=0.0)
|
| 88 |
+
s = np.clip(s_tmp, 0.0, 1.0 - self.eps)
|
| 89 |
+
indexed_items = sorted(list(zip(s, atomic_facts)), key=lambda x: x[0])
|
| 90 |
+
s_sorted_asc = np.array([item[0] for item in indexed_items])
|
| 91 |
+
costs = -np.log(1.0 - s_sorted_asc)
|
| 92 |
+
cumulative_costs = np.concatenate(([0.0], np.cumsum(costs)))
|
| 93 |
+
possible_K_indices = np.where(cumulative_costs <= self.tau_hat)[0]
|
| 94 |
+
K = int(possible_K_indices.max()) if possible_K_indices.size > 0 else 0
|
| 95 |
+
# Boundary randomization: with probability proportional to leftover budget,
|
| 96 |
+
# include one more boundary item (i.e., increase K by 1) if feasible.
|
| 97 |
+
# This randomization reduces discretization bias at the threshold.
|
| 98 |
+
if K < len(costs):
|
| 99 |
+
leftover = float(self.tau_hat - cumulative_costs[K])
|
| 100 |
+
next_cost = float(costs[K]) # cost of the (K)-th item in sorted order
|
| 101 |
+
if np.isfinite(next_cost) and next_cost > 0.0 and leftover > 0.0:
|
| 102 |
+
p = float(np.clip(leftover / next_cost, 0.0, 1.0))
|
| 103 |
+
if self._rng.uniform(0.0, 1.0) < p:
|
| 104 |
+
K = K + 1
|
| 105 |
+
new_sample["filtered_claims"] = [item[1] for item in indexed_items[K:]]
|
| 106 |
+
retention_rates.append(len(new_sample["filtered_claims"]) / len(atomic_facts))
|
| 107 |
+
filtered_data.append(new_sample)
|
| 108 |
+
return filtered_data, retention_rates
|
| 109 |
+
|
| 110 |
+
class SubgroupOptimizedMACI:
|
| 111 |
+
def __init__(self, model_names: List[str], grouper: Any, n_bins: int = 3, **kwargs):
|
| 112 |
+
self.model_names, self.grouper, self.n_bins, self.kwargs = model_names, grouper, n_bins, kwargs
|
| 113 |
+
self.weights, self.conformal_models = {}, {}
|
| 114 |
+
self.fallback_weights, self.bin_edges = None, None
|
| 115 |
+
self.bin_labels = ['low', 'medium', 'high'] if n_bins == 3 else [f'group_{i}' for i in range(n_bins)]
|
| 116 |
+
# Timing accumulators
|
| 117 |
+
self._timing: Dict[str, float] = {
|
| 118 |
+
'weight_optimization_s': 0.0,
|
| 119 |
+
'calibration_s': 0.0
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
def _get_subgroup_label(self, value: float) -> str:
|
| 123 |
+
if self.bin_edges is None or not np.isfinite(value):
|
| 124 |
+
return self.bin_labels[0]
|
| 125 |
+
bin_index = np.digitize(value, self.bin_edges)
|
| 126 |
+
return self.bin_labels[min(bin_index, len(self.bin_labels) - 1)]
|
| 127 |
+
|
| 128 |
+
def _group_data_by_bins(self, data: List[Dict], bin_edges: np.ndarray) -> Dict[str, List[Dict]]:
|
| 129 |
+
grouped_data = defaultdict(list)
|
| 130 |
+
values = self.grouper.compute_values([d['sample'] for d in data])
|
| 131 |
+
for item, value in zip(data, values):
|
| 132 |
+
label = self._get_subgroup_label(value)
|
| 133 |
+
grouped_data[label].append(item)
|
| 134 |
+
return grouped_data
|
| 135 |
+
def _learn_robust_weights_by_retention(self, training_data: List[Dict], target_tpr: float = 0.95) -> np.ndarray:
|
| 136 |
+
"""
|
| 137 |
+
Stable convex program for learning ensemble weights on the probability simplex.
|
| 138 |
+
|
| 139 |
+
Uses an epigraph reformulation with explicit nonnegative slack variables and
|
| 140 |
+
Tikhonov regularization to improve numerical stability across solvers.
|
| 141 |
+
"""
|
| 142 |
+
all_scores, all_labels = [], []
|
| 143 |
+
for entry in training_data:
|
| 144 |
+
sample, scores_dict = entry.get('sample', {}), entry.get('scores', {})
|
| 145 |
+
labels = [af.get("is_supported", False) for af in sample.get("atomic_facts", [])]
|
| 146 |
+
scores_per_model = [scores_dict.get(m, []) for m in self.model_names]
|
| 147 |
+
min_len = min(len(labels), *[len(s) for s in scores_per_model])
|
| 148 |
+
if min_len == 0:
|
| 149 |
+
continue
|
| 150 |
+
for i in range(min_len):
|
| 151 |
+
all_labels.append(labels[i])
|
| 152 |
+
all_scores.append([s[i] for s in scores_per_model])
|
| 153 |
+
|
| 154 |
+
if len(all_labels) < 2 or len(np.unique(all_labels)) < 2:
|
| 155 |
+
logging.warning("Skipping weight optimization: insufficient or single-class labels.")
|
| 156 |
+
return np.ones(len(self.model_names)) / len(self.model_names)
|
| 157 |
+
|
| 158 |
+
scores_matrix = np.nan_to_num(np.array(all_scores, dtype=float))
|
| 159 |
+
labels_array = np.array(all_labels, dtype=int)
|
| 160 |
+
n_models = scores_matrix.shape[1]
|
| 161 |
+
|
| 162 |
+
pos = scores_matrix[labels_array == 1]
|
| 163 |
+
neg = scores_matrix[labels_array == 0]
|
| 164 |
+
if pos.shape[0] == 0 or neg.shape[0] == 0:
|
| 165 |
+
logging.warning("Skipping weight optimization: missing positive or negative samples.")
|
| 166 |
+
return np.ones(len(self.model_names)) / len(self.model_names)
|
| 167 |
+
|
| 168 |
+
neg_proxy = np.mean(neg, axis=1)
|
| 169 |
+
neg_w = np.clip(neg_proxy, 0.0, 1.0) ** 2
|
| 170 |
+
neg_w = neg_w / (np.mean(neg_w) + 1e-12)
|
| 171 |
+
|
| 172 |
+
pos_w = np.ones(pos.shape[0], dtype=float)
|
| 173 |
+
sum_pos = np.sum(pos_w)
|
| 174 |
+
sum_neg = np.sum(neg_w)
|
| 175 |
+
if sum_pos > 0 and sum_neg > 0:
|
| 176 |
+
scale = sum_pos / sum_neg
|
| 177 |
+
neg_w = neg_w * scale
|
| 178 |
+
|
| 179 |
+
alpha = 1.0
|
| 180 |
+
beta = 5.0 * (target_tpr / max(1.0 - target_tpr, 1e-6))
|
| 181 |
+
|
| 182 |
+
def solve_with(ridge: float, eps_w: float, solver_name: str) -> Optional[np.ndarray]:
|
| 183 |
+
try:
|
| 184 |
+
w = cp.Variable(n_models)
|
| 185 |
+
t = cp.Variable()
|
| 186 |
+
slack_neg = cp.Variable(neg.shape[0], nonneg=True)
|
| 187 |
+
slack_pos = cp.Variable(pos.shape[0], nonneg=True)
|
| 188 |
+
|
| 189 |
+
constraints = [
|
| 190 |
+
neg @ w - t <= slack_neg,
|
| 191 |
+
t - pos @ w <= slack_pos,
|
| 192 |
+
w >= eps_w,
|
| 193 |
+
cp.sum(w) == 1,
|
| 194 |
+
t >= 0,
|
| 195 |
+
t <= 1
|
| 196 |
+
]
|
| 197 |
+
objective = (
|
| 198 |
+
alpha * cp.sum(cp.multiply(neg_w, slack_neg)) +
|
| 199 |
+
beta * cp.sum(cp.multiply(pos_w, slack_pos)) +
|
| 200 |
+
ridge * cp.sum_squares(w)
|
| 201 |
+
)
|
| 202 |
+
prob = cp.Problem(cp.Minimize(objective), constraints)
|
| 203 |
+
|
| 204 |
+
if solver_name == 'osqp':
|
| 205 |
+
prob.solve(solver=cp.OSQP, verbose=False, eps_abs=1e-6, eps_rel=1e-6, max_iter=20000, polishing=True, linsys_solver='qdldl')
|
| 206 |
+
elif solver_name == 'ecos':
|
| 207 |
+
prob.solve(solver=cp.ECOS, verbose=False, max_iters=200000, abstol=1e-7, reltol=1e-7, feastol=1e-7)
|
| 208 |
+
elif solver_name == 'scs':
|
| 209 |
+
prob.solve(solver=cp.SCS, verbose=False, max_iters=300000, eps=2e-5, acceleration_lookback=20)
|
| 210 |
+
else:
|
| 211 |
+
return None
|
| 212 |
+
|
| 213 |
+
if w.value is None:
|
| 214 |
+
return None
|
| 215 |
+
|
| 216 |
+
w_val = np.array(w.value, dtype=float).reshape(-1)
|
| 217 |
+
if not np.all(np.isfinite(w_val)):
|
| 218 |
+
return None
|
| 219 |
+
w_val = np.clip(w_val, 0.0, None)
|
| 220 |
+
s = np.sum(w_val)
|
| 221 |
+
if s <= 1e-12:
|
| 222 |
+
return None
|
| 223 |
+
w_val = w_val / s
|
| 224 |
+
logging.info(" - Weight optimization completed")
|
| 225 |
+
return w_val
|
| 226 |
+
except Exception as e:
|
| 227 |
+
logging.debug(f"{solver_name.upper()} attempt failed (ridge={ridge}, eps_w={eps_w}): {e}")
|
| 228 |
+
return None
|
| 229 |
+
|
| 230 |
+
solver_order = []
|
| 231 |
+
solver_pref = (self.kwargs or {}).get('solver', 'auto')
|
| 232 |
+
if solver_pref in ('osqp', 'ecos', 'scs'):
|
| 233 |
+
solver_order = [solver_pref] + [s for s in ('osqp', 'ecos', 'scs') if s != solver_pref]
|
| 234 |
+
else:
|
| 235 |
+
solver_order = ['osqp', 'ecos', 'scs']
|
| 236 |
+
|
| 237 |
+
for ridge in (5e-3, 5e-2, 1e-1, 5e-1):
|
| 238 |
+
for eps_w in (0.0, 1e-6, 1e-4):
|
| 239 |
+
for slv in solver_order:
|
| 240 |
+
sol = solve_with(ridge=ridge, eps_w=eps_w, solver_name=slv)
|
| 241 |
+
if sol is not None:
|
| 242 |
+
return sol
|
| 243 |
+
|
| 244 |
+
logging.warning("CVXPY solvers failed repeatedly; falling back to AUC-based SLSQP optimizer as last resort.")
|
| 245 |
+
return self._learn_robust_weights(training_data)
|
| 246 |
+
|
| 247 |
+
def _learn_robust_weights(self, training_data: List[Dict]) -> np.ndarray:
|
| 248 |
+
all_scores, all_labels = [], []
|
| 249 |
+
for entry in training_data:
|
| 250 |
+
sample, scores_dict = entry.get('sample', {}), entry.get('scores', {})
|
| 251 |
+
labels = [af.get("is_supported", False) for af in sample.get("atomic_facts", [])]
|
| 252 |
+
if not all(m in scores_dict for m in self.model_names): continue
|
| 253 |
+
scores_per_model = [scores_dict.get(m, []) for m in self.model_names]
|
| 254 |
+
min_len = min(len(labels), *[len(s) for s in scores_per_model])
|
| 255 |
+
if min_len == 0: continue
|
| 256 |
+
for i in range(min_len):
|
| 257 |
+
all_labels.append(labels[i])
|
| 258 |
+
all_scores.append([s[i] for s in scores_per_model])
|
| 259 |
+
|
| 260 |
+
if len(all_labels) < 2 or len(np.unique(all_labels)) < 2:
|
| 261 |
+
return np.ones(len(self.model_names)) / len(self.model_names)
|
| 262 |
+
|
| 263 |
+
scores_matrix = np.nan_to_num(np.array(all_scores, dtype=float))
|
| 264 |
+
labels_array = np.array(all_labels, dtype=int)
|
| 265 |
+
n_models = scores_matrix.shape[1]
|
| 266 |
+
|
| 267 |
+
def objective_fn(weights: np.ndarray) -> float:
|
| 268 |
+
w = weights / np.sum(weights) if np.sum(weights) > 0 else weights
|
| 269 |
+
ensemble_scores = scores_matrix @ w
|
| 270 |
+
try: return -roc_auc_score(labels_array, ensemble_scores)
|
| 271 |
+
except ValueError: return 0.0
|
| 272 |
+
|
| 273 |
+
best_score, best_weights = -1.0, np.ones(n_models) / n_models
|
| 274 |
+
for _ in range(10):
|
| 275 |
+
w0 = np.random.dirichlet(np.ones(n_models))
|
| 276 |
+
res = minimize(objective_fn, w0, method='SLSQP', bounds=[(0, 1)] * n_models, constraints=({'type': 'eq', 'fun': lambda w: np.sum(w) - 1.0}))
|
| 277 |
+
if res.success and -res.fun > best_score:
|
| 278 |
+
best_score, best_weights = -res.fun, res.x / np.sum(res.x)
|
| 279 |
+
return best_weights
|
| 280 |
+
|
| 281 |
+
def get_budgets(self):
|
| 282 |
+
return {subgroup: model.tau_hat for subgroup, model in self.conformal_models.items()}
|
| 283 |
+
|
| 284 |
+
def get_weights(self):
|
| 285 |
+
return {
|
| 286 |
+
'subgroup_weights': self.weights,
|
| 287 |
+
'fallback_weights': self.fallback_weights,
|
| 288 |
+
'bin_edges': None if self.bin_edges is None else np.asarray(self.bin_edges).tolist(),
|
| 289 |
+
'bin_labels': list(self.bin_labels) if self.bin_labels is not None else None,
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
def _compute_ensemble_scores(self, data: List[Dict], subgroup_label: str) -> List[np.ndarray]:
|
| 293 |
+
subgroup_weights = self.weights.get(subgroup_label, self.fallback_weights)
|
| 294 |
+
if subgroup_weights is None:
|
| 295 |
+
raise RuntimeError(f"Weights not learned for subgroup '{subgroup_label}'.")
|
| 296 |
+
|
| 297 |
+
final_scores = []
|
| 298 |
+
for entry in data:
|
| 299 |
+
scores_dict = entry.get('scores', {})
|
| 300 |
+
scores_per_model = [scores_dict.get(m, []) for m in self.model_names]
|
| 301 |
+
min_len = min(len(entry['sample']['atomic_facts']), *[len(s) for s in scores_per_model])
|
| 302 |
+
if min_len == 0:
|
| 303 |
+
final_scores.append(np.array([]))
|
| 304 |
+
else:
|
| 305 |
+
scores_matrix = np.array([np.nan_to_num(s[:min_len]) for s in scores_per_model]).T
|
| 306 |
+
final_scores.append(scores_matrix @ subgroup_weights)
|
| 307 |
+
return final_scores
|
| 308 |
+
|
| 309 |
+
def fit(self, data: List[dict], alpha: float = 0.1, ensemble_train_ratio: float = 0.5, target_tpr: float = 0.95):
|
| 310 |
+
"""Learn subgroup-specific ensemble weights and conformal thresholds."""
|
| 311 |
+
random_state = self.kwargs.get("random_state")
|
| 312 |
+
grouper_name = self.grouper.__class__.__name__
|
| 313 |
+
logging.info(f"SubgroupOptimizedMACI training started (grouper: '{grouper_name}')")
|
| 314 |
+
|
| 315 |
+
ensemble_train_data, calib_data = train_test_split(
|
| 316 |
+
data,
|
| 317 |
+
test_size=1.0 - ensemble_train_ratio,
|
| 318 |
+
random_state=random_state
|
| 319 |
+
)
|
| 320 |
+
logging.info(f" - Data split: ensemble training {len(ensemble_train_data)} / conformal calibration {len(calib_data)}")
|
| 321 |
+
|
| 322 |
+
logging.info(f" - Learning bin edges by '{grouper_name}' values...")
|
| 323 |
+
train_values = self.grouper.compute_values([d['sample'] for d in ensemble_train_data])
|
| 324 |
+
finite_train_values = train_values[np.isfinite(train_values)]
|
| 325 |
+
quantiles = np.linspace(0, 1, self.n_bins + 1)[1:-1]
|
| 326 |
+
self.bin_edges = np.quantile(finite_train_values, quantiles) if len(finite_train_values) > 0 else np.array([])
|
| 327 |
+
logging.info(f" - Learned bin edges: {self.bin_edges}")
|
| 328 |
+
|
| 329 |
+
grouped_ensemble_data = self._group_data_by_bins(ensemble_train_data, self.bin_edges)
|
| 330 |
+
grouped_calib_data = self._group_data_by_bins(calib_data, self.bin_edges)
|
| 331 |
+
|
| 332 |
+
for label in self.bin_labels:
|
| 333 |
+
logging.info(f"--- Processing group '{label}' ---")
|
| 334 |
+
sub_ensemble_data = grouped_ensemble_data.get(label, [])
|
| 335 |
+
sub_calib_data = grouped_calib_data.get(label, [])
|
| 336 |
+
|
| 337 |
+
if not sub_ensemble_data or not sub_calib_data:
|
| 338 |
+
logging.warning(f"Skipping group '{label}' due to insufficient data.")
|
| 339 |
+
continue
|
| 340 |
+
|
| 341 |
+
logging.info(f" - Learning ensemble weights (n={len(sub_ensemble_data)})...")
|
| 342 |
+
_t0 = __import__('time').perf_counter()
|
| 343 |
+
self.weights[label] = self._learn_robust_weights_by_retention(sub_ensemble_data, target_tpr=target_tpr)
|
| 344 |
+
self._timing['weight_optimization_s'] += __import__('time').perf_counter() - _t0
|
| 345 |
+
|
| 346 |
+
logging.info(f" - Calibrating threshold (n={len(sub_calib_data)})...")
|
| 347 |
+
score_func = lambda data, l=label: self._compute_ensemble_scores(data, l)
|
| 348 |
+
|
| 349 |
+
conformal_model = MACIAdaptiveConformal(score_function=score_func, **self.kwargs)
|
| 350 |
+
_t1 = __import__('time').perf_counter()
|
| 351 |
+
conformal_model.fit_on_calib(sub_calib_data, alpha)
|
| 352 |
+
self._timing['calibration_s'] += __import__('time').perf_counter() - _t1
|
| 353 |
+
self.conformal_models[label] = conformal_model
|
| 354 |
+
|
| 355 |
+
logging.info("--- Training fallback model on all data ---")
|
| 356 |
+
self.fallback_weights = self._learn_robust_weights_by_retention(ensemble_train_data, target_tpr=target_tpr)
|
| 357 |
+
|
| 358 |
+
logging.info("✅ Training complete.")
|
| 359 |
+
return self
|
| 360 |
+
|
| 361 |
+
def get_timing(self) -> Dict[str, float]:
|
| 362 |
+
return dict(self._timing)
|
| 363 |
+
|
| 364 |
+
def predict(self, data: List[dict]) -> Tuple[List[dict], List[float]]:
|
| 365 |
+
if not self.conformal_models: raise ValueError("모델이 학습되지 않았습니다.")
|
| 366 |
+
|
| 367 |
+
grouped_data_with_indices = defaultdict(list)
|
| 368 |
+
values = self.grouper.compute_values([d['sample'] for d in data])
|
| 369 |
+
for i, (item, value) in enumerate(zip(data, values)):
|
| 370 |
+
label = self._get_subgroup_label(value)
|
| 371 |
+
grouped_data_with_indices[label].append((i, item))
|
| 372 |
+
|
| 373 |
+
results_placeholder = [None] * len(data)
|
| 374 |
+
rates_placeholder = [None] * len(data)
|
| 375 |
+
|
| 376 |
+
for label, indexed_subgroup_data in grouped_data_with_indices.items():
|
| 377 |
+
if not indexed_subgroup_data: continue
|
| 378 |
+
original_indices = [item[0] for item in indexed_subgroup_data]
|
| 379 |
+
subgroup_data = [item[1] for item in indexed_subgroup_data]
|
| 380 |
+
model = self.conformal_models.get(label)
|
| 381 |
+
|
| 382 |
+
if model:
|
| 383 |
+
logging.info(f" - Predicting for group '{label}' (n={len(subgroup_data)})...")
|
| 384 |
+
predicted_samples, rates = model.predict(subgroup_data)
|
| 385 |
+
|
| 386 |
+
for i, original_item, predicted_sample, rate in zip(original_indices, subgroup_data, predicted_samples, rates):
|
| 387 |
+
new_result_item = original_item.copy()
|
| 388 |
+
new_result_item['sample'] = predicted_sample
|
| 389 |
+
results_placeholder[i] = new_result_item
|
| 390 |
+
rates_placeholder[i] = rate
|
| 391 |
+
else:
|
| 392 |
+
logging.warning(f"No trained model for group '{label}'. Using fallback weights for prediction.")
|
| 393 |
+
fallback_score_func = lambda data_list: self._compute_ensemble_scores(data_list, label)
|
| 394 |
+
fallback_model = MACIAdaptiveConformal(score_function=fallback_score_func, **self.kwargs)
|
| 395 |
+
fallback_model.tau_hat = 0.0
|
| 396 |
+
predicted_samples, rates = fallback_model.predict(subgroup_data)
|
| 397 |
+
for i, original_item, predicted_sample, rate in zip(original_indices, subgroup_data, predicted_samples, rates):
|
| 398 |
+
new_result_item = original_item.copy()
|
| 399 |
+
new_result_item['sample'] = predicted_sample
|
| 400 |
+
results_placeholder[i] = new_result_item
|
| 401 |
+
rates_placeholder[i] = rate
|
| 402 |
+
|
| 403 |
+
return results_placeholder, rates_placeholder
|
MACI-main/conformal/basic_conformal.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Basic Conformal Implementation for Factuality Assessment
|
| 3 |
+
|
| 4 |
+
This module implements a basic conformal prediction method for assessing
|
| 5 |
+
the factuality of generated text by filtering claims based on conformity scores.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import List, Tuple, Optional, Callable
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BasicConformal:
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
score_function: Callable,
|
| 16 |
+
random_state: Optional[int] = None
|
| 17 |
+
):
|
| 18 |
+
self.score_function = score_function
|
| 19 |
+
self.random_state = random_state
|
| 20 |
+
self.calibration_scores = None
|
| 21 |
+
self.threshold = None
|
| 22 |
+
self._rng = np.random.default_rng(random_state)
|
| 23 |
+
self._tie_gamma_keep: float = 1.0
|
| 24 |
+
|
| 25 |
+
def fit_on_calib(self, calib_data: List, alpha: float = 0.1) -> 'BasicConformal':
|
| 26 |
+
if not 0 < alpha < 1:
|
| 27 |
+
raise ValueError("alpha must be between 0 and 1")
|
| 28 |
+
|
| 29 |
+
raw_scores = self.score_function(calib_data)
|
| 30 |
+
per_sample_scores: List[List[float]] = []
|
| 31 |
+
if len(raw_scores) == len(calib_data) and hasattr(raw_scores[0], "__iter__") and not isinstance(raw_scores[0], (str, bytes)):
|
| 32 |
+
for i, sample in enumerate(calib_data):
|
| 33 |
+
if 'atomic_facts' in sample:
|
| 34 |
+
s_i = np.asarray(list(raw_scores[i]), dtype=float)
|
| 35 |
+
else:
|
| 36 |
+
s_i = np.asarray([float(raw_scores[i])], dtype=float)
|
| 37 |
+
s_i = np.where(np.isnan(s_i), -np.inf, s_i)
|
| 38 |
+
per_sample_scores.append(s_i.tolist())
|
| 39 |
+
else:
|
| 40 |
+
if len(raw_scores) != len(calib_data):
|
| 41 |
+
raise ValueError("score_function must return one score per sample or a per-claim score list per sample")
|
| 42 |
+
for i, sample in enumerate(calib_data):
|
| 43 |
+
if 'atomic_facts' in sample and len(sample['atomic_facts']) > 0:
|
| 44 |
+
s_i = np.asarray([float(raw_scores[i])] * len(sample['atomic_facts']), dtype=float)
|
| 45 |
+
else:
|
| 46 |
+
s_i = np.asarray([float(raw_scores[i])], dtype=float)
|
| 47 |
+
s_i = np.where(np.isnan(s_i), -np.inf, s_i)
|
| 48 |
+
per_sample_scores.append(s_i.tolist())
|
| 49 |
+
|
| 50 |
+
S_values: List[float] = []
|
| 51 |
+
for sample, scores_i in zip(calib_data, per_sample_scores):
|
| 52 |
+
if 'atomic_facts' in sample and len(sample['atomic_facts']) > 0:
|
| 53 |
+
false_scores = [s for s, fact in zip(scores_i, sample['atomic_facts']) if not fact.get('is_supported', False)]
|
| 54 |
+
if len(false_scores) == 0:
|
| 55 |
+
S_values.append(float('-inf'))
|
| 56 |
+
else:
|
| 57 |
+
vals = np.asarray(false_scores, dtype=float)
|
| 58 |
+
S_values.append(float(np.nanmax(vals)) if vals.size > 0 else float('-inf'))
|
| 59 |
+
else:
|
| 60 |
+
vals = np.asarray(scores_i, dtype=float)
|
| 61 |
+
if vals.size == 0:
|
| 62 |
+
S_values.append(float('-inf'))
|
| 63 |
+
else:
|
| 64 |
+
S_values.append(float(np.nanmax(vals)))
|
| 65 |
+
|
| 66 |
+
self.calibration_scores = np.array(S_values, dtype=float)
|
| 67 |
+
n = len(self.calibration_scores)
|
| 68 |
+
if n == 0:
|
| 69 |
+
raise ValueError("No calibration samples available to compute threshold")
|
| 70 |
+
quantile = 1 - alpha
|
| 71 |
+
try:
|
| 72 |
+
self.threshold = np.quantile(self.calibration_scores, quantile, method='higher')
|
| 73 |
+
except TypeError:
|
| 74 |
+
self.threshold = np.quantile(self.calibration_scores, quantile)
|
| 75 |
+
|
| 76 |
+
sorted_scores = np.sort(self.calibration_scores)
|
| 77 |
+
k = int(np.ceil((1.0 - alpha) * (n + 1))) - 1
|
| 78 |
+
k = min(max(k, 0), n - 1)
|
| 79 |
+
t_star = float(sorted_scores[k])
|
| 80 |
+
n_lt = int(np.sum(self.calibration_scores < t_star))
|
| 81 |
+
n_eq = int(np.sum(np.isclose(self.calibration_scores, t_star)))
|
| 82 |
+
if n_eq <= 0:
|
| 83 |
+
gamma_standard = 0.0
|
| 84 |
+
else:
|
| 85 |
+
gamma_standard = ((1.0 - alpha) * (n + 1) - n_lt) / n_eq
|
| 86 |
+
gamma_standard = float(np.clip(gamma_standard, 0.0, 1.0))
|
| 87 |
+
self._tie_gamma_keep = 1.0 - gamma_standard
|
| 88 |
+
return self
|
| 89 |
+
|
| 90 |
+
def predict(self, data: List) -> Tuple[List, List]:
|
| 91 |
+
if self.threshold is None:
|
| 92 |
+
raise ValueError("Model must be fitted before prediction")
|
| 93 |
+
raw_scores = self.score_function(data)
|
| 94 |
+
per_sample_scores: List[List[float]] = []
|
| 95 |
+
if len(raw_scores) == len(data) and hasattr(raw_scores[0], "__iter__") and not isinstance(raw_scores[0], (str, bytes)):
|
| 96 |
+
for i, sample in enumerate(data):
|
| 97 |
+
if 'atomic_facts' in sample:
|
| 98 |
+
s_i = np.asarray(list(raw_scores[i]), dtype=float)
|
| 99 |
+
else:
|
| 100 |
+
s_i = np.asarray([float(raw_scores[i])], dtype=float)
|
| 101 |
+
s_i = np.where(np.isnan(s_i), -np.inf, s_i)
|
| 102 |
+
per_sample_scores.append(s_i.tolist())
|
| 103 |
+
else:
|
| 104 |
+
if len(raw_scores) != len(data):
|
| 105 |
+
raise ValueError("score_function must return one score per sample or per-claim score lists per sample")
|
| 106 |
+
for i, sample in enumerate(data):
|
| 107 |
+
if 'atomic_facts' in sample and len(sample['atomic_facts']) > 0:
|
| 108 |
+
s_i = np.asarray([float(raw_scores[i])] * len(sample['atomic_facts']), dtype=float)
|
| 109 |
+
else:
|
| 110 |
+
s_i = np.asarray([float(raw_scores[i])], dtype=float)
|
| 111 |
+
s_i = np.where(np.isnan(s_i), -np.inf, s_i)
|
| 112 |
+
per_sample_scores.append(s_i.tolist())
|
| 113 |
+
|
| 114 |
+
filtered_data: List = []
|
| 115 |
+
retention_rates: List[float] = []
|
| 116 |
+
for sample, scores_i in zip(data, per_sample_scores):
|
| 117 |
+
if 'atomic_facts' in sample and len(sample['atomic_facts']) > 0:
|
| 118 |
+
filtered_claims = []
|
| 119 |
+
for claim, s in zip(sample['atomic_facts'], scores_i):
|
| 120 |
+
if s > self.threshold:
|
| 121 |
+
filtered_claims.append(claim)
|
| 122 |
+
elif np.isclose(s, self.threshold):
|
| 123 |
+
if self._rng.uniform() < self._tie_gamma_keep:
|
| 124 |
+
filtered_claims.append(claim)
|
| 125 |
+
sample = dict(sample)
|
| 126 |
+
sample['filtered_claims'] = filtered_claims
|
| 127 |
+
retention_rate = len(filtered_claims) / len(sample['atomic_facts'])
|
| 128 |
+
elif 'atomic_facts' in sample and len(sample['atomic_facts']) == 0:
|
| 129 |
+
sample = dict(sample)
|
| 130 |
+
sample['filtered_claims'] = []
|
| 131 |
+
retention_rate = 0.0
|
| 132 |
+
else:
|
| 133 |
+
sample = dict(sample)
|
| 134 |
+
if len(scores_i) == 0:
|
| 135 |
+
sample['is_retained'] = False
|
| 136 |
+
retention_rate = 0.0
|
| 137 |
+
else:
|
| 138 |
+
s = float(scores_i[0])
|
| 139 |
+
sample['is_retained'] = (s > self.threshold) or (np.isclose(s, self.threshold) and self._rng.uniform() < self._tie_gamma_keep)
|
| 140 |
+
retention_rate = 1.0 if sample['is_retained'] else 0.0
|
| 141 |
+
filtered_data.append(sample)
|
| 142 |
+
retention_rates.append(retention_rate)
|
| 143 |
+
return filtered_data, retention_rates
|
| 144 |
+
|
| 145 |
+
def get_coverage(self, data: List) -> float:
|
| 146 |
+
if self.threshold is None:
|
| 147 |
+
raise ValueError("Model must be fitted before computing coverage")
|
| 148 |
+
raw_scores = self.score_function(data)
|
| 149 |
+
per_sample_scores: List[List[float]] = []
|
| 150 |
+
if len(raw_scores) == len(data) and hasattr(raw_scores[0], "__iter__") and not isinstance(raw_scores[0], (str, bytes)):
|
| 151 |
+
for i, sample in enumerate(data):
|
| 152 |
+
if 'atomic_facts' in sample:
|
| 153 |
+
s_i = np.asarray(list(raw_scores[i]), dtype=float)
|
| 154 |
+
else:
|
| 155 |
+
s_i = np.asarray([float(raw_scores[i])], dtype=float)
|
| 156 |
+
s_i = np.where(np.isnan(s_i), -np.inf, s_i)
|
| 157 |
+
per_sample_scores.append(s_i.tolist())
|
| 158 |
+
else:
|
| 159 |
+
if len(raw_scores) != len(data):
|
| 160 |
+
raise ValueError("score_function must return one score per sample or per-claim score lists per sample")
|
| 161 |
+
for i, sample in enumerate(data):
|
| 162 |
+
if 'atomic_facts' in sample and len(sample['atomic_facts']) > 0:
|
| 163 |
+
s_i = np.asarray([float(raw_scores[i])] * len(sample['atomic_facts']), dtype=float)
|
| 164 |
+
else:
|
| 165 |
+
s_i = np.asarray([float(raw_scores[i])], dtype=float)
|
| 166 |
+
s_i = np.where(np.isnan(s_i), -np.inf, s_i)
|
| 167 |
+
per_sample_scores.append(s_i.tolist())
|
| 168 |
+
indicators = []
|
| 169 |
+
for sample, scores_i in zip(data, per_sample_scores):
|
| 170 |
+
if 'atomic_facts' in sample and len(sample['atomic_facts']) > 0:
|
| 171 |
+
false_scores = [s for s, fact in zip(scores_i, sample['atomic_facts']) if not fact.get('is_supported', False)]
|
| 172 |
+
if len(false_scores) == 0:
|
| 173 |
+
indicators.append(1.0)
|
| 174 |
+
else:
|
| 175 |
+
vals = np.asarray(false_scores, dtype=float)
|
| 176 |
+
max_false = float(np.nanmax(vals)) if vals.size > 0 else float('-inf')
|
| 177 |
+
indicators.append(1.0 if max_false <= self.threshold else 0.0)
|
| 178 |
+
else:
|
| 179 |
+
vals = np.asarray(scores_i, dtype=float)
|
| 180 |
+
if vals.size == 0:
|
| 181 |
+
indicators.append(1.0)
|
| 182 |
+
else:
|
| 183 |
+
indicators.append(1.0 if float(np.nanmax(vals)) <= self.threshold else 0.0)
|
| 184 |
+
return float(np.mean(indicators))
|
| 185 |
+
|
| 186 |
+
def get_threshold(self) -> float:
|
| 187 |
+
return self.threshold
|
| 188 |
+
|
| 189 |
+
|
MACI-main/conformal/conditional_conformal.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from typing import List, Tuple, Optional, Callable
|
| 3 |
+
import torch
|
| 4 |
+
from scipy.optimize import linprog
|
| 5 |
+
from sklearn.model_selection import train_test_split
|
| 6 |
+
from sklearn.metrics import roc_auc_score, roc_curve
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
# Add conditional-conformal path to Python path (local vendor copy) using repo-relative path
|
| 12 |
+
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 13 |
+
vendor_path = os.path.join(repo_root, 'conditional-conformal', 'src')
|
| 14 |
+
if vendor_path not in sys.path:
|
| 15 |
+
sys.path.append(vendor_path)
|
| 16 |
+
from conditionalconformal import CondConf
|
| 17 |
+
|
| 18 |
+
# ==============================================================================
|
| 19 |
+
# === Step 1: Classes and Helper Functions for Boosting ===
|
| 20 |
+
# ==============================================================================
|
| 21 |
+
|
| 22 |
+
def as_tensor(x, dtype, requires_grad=False):
|
| 23 |
+
return torch.tensor(x, dtype=dtype, requires_grad=requires_grad)
|
| 24 |
+
|
| 25 |
+
def get_current_basis(primals, duals, Phi, S, quantile):
|
| 26 |
+
"""Helper function to find a stable basis from LP solution"""
|
| 27 |
+
interp_bools = np.logical_and(~np.isclose(duals, quantile - 1), ~np.isclose(duals, quantile))
|
| 28 |
+
if np.sum(interp_bools) == Phi.shape[1]:
|
| 29 |
+
return interp_bools
|
| 30 |
+
preds = (Phi @ primals).flatten()
|
| 31 |
+
active_indices = np.where(interp_bools)[0]
|
| 32 |
+
interp_indices = np.argsort(np.abs(S - preds))[:Phi.shape[1]]
|
| 33 |
+
diff_indices = np.setdiff1d(interp_indices, active_indices)
|
| 34 |
+
num_missing = Phi.shape[1] - np.sum(interp_bools)
|
| 35 |
+
|
| 36 |
+
if num_missing < len(diff_indices):
|
| 37 |
+
from itertools import combinations
|
| 38 |
+
for cand_indices in combinations(diff_indices, num_missing):
|
| 39 |
+
cand_phi = Phi[np.concatenate((active_indices, cand_indices))]
|
| 40 |
+
if np.isfinite(np.linalg.cond(cand_phi)):
|
| 41 |
+
interp_bools[np.asarray(cand_indices)] = True
|
| 42 |
+
break
|
| 43 |
+
else:
|
| 44 |
+
interp_bools[diff_indices] = True
|
| 45 |
+
return interp_bools
|
| 46 |
+
|
| 47 |
+
def _choose_full_rank_rows(Phi: np.ndarray) -> np.ndarray:
|
| 48 |
+
"""Greedy row selection for full-rank basis"""
|
| 49 |
+
d = Phi.shape[1]
|
| 50 |
+
chosen = []
|
| 51 |
+
cur = np.empty((0, d))
|
| 52 |
+
for i in range(Phi.shape[0]):
|
| 53 |
+
cand = np.vstack([cur, Phi[i:i+1]])
|
| 54 |
+
if np.linalg.matrix_rank(cand) > np.linalg.matrix_rank(cur):
|
| 55 |
+
chosen.append(i)
|
| 56 |
+
cur = cand
|
| 57 |
+
if len(chosen) == d:
|
| 58 |
+
break
|
| 59 |
+
if len(chosen) < d:
|
| 60 |
+
chosen = list(range(Phi.shape[0]-d, Phi.shape[0]))
|
| 61 |
+
return np.asarray(chosen, dtype=int)
|
| 62 |
+
|
| 63 |
+
def solve_qr_for_boosting(Phi: np.ndarray, s: torch.Tensor, q: float, dtype: torch.dtype) -> torch.Tensor:
|
| 64 |
+
"""Differentiable tau calculation function for boosting - robust fallback included"""
|
| 65 |
+
S_np = s.detach().cpu().numpy().reshape(-1)
|
| 66 |
+
assert Phi.shape[0] == S_np.shape[0], "Phi rows must match len(s)"
|
| 67 |
+
assert 0.0 < q < 1.0, "q must be in (0,1)"
|
| 68 |
+
|
| 69 |
+
b_eq = np.zeros(Phi.shape[1])
|
| 70 |
+
bounds = [(q - 1.0, q)] * len(S_np)
|
| 71 |
+
|
| 72 |
+
res = None
|
| 73 |
+
try:
|
| 74 |
+
res = linprog(-S_np, A_eq=Phi.T, b_eq=b_eq, bounds=bounds, method='highs')
|
| 75 |
+
except Exception:
|
| 76 |
+
res = None
|
| 77 |
+
|
| 78 |
+
tau_initial = None
|
| 79 |
+
duals = None
|
| 80 |
+
if res is not None and getattr(res, "success", False):
|
| 81 |
+
marg = None
|
| 82 |
+
if hasattr(res, "eqlin") and res.eqlin is not None and hasattr(res.eqlin, "marginals") and res.eqlin.marginals is not None:
|
| 83 |
+
marg = res.eqlin.marginals
|
| 84 |
+
elif hasattr(res, "dual_eq") and res.dual_eq is not None:
|
| 85 |
+
marg = res.dual_eq
|
| 86 |
+
|
| 87 |
+
if marg is not None:
|
| 88 |
+
tau_initial = -np.asarray(marg, dtype=float)
|
| 89 |
+
if hasattr(res, "x") and res.x is not None:
|
| 90 |
+
duals = np.asarray(res.x, dtype=float)
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
if tau_initial is not None and duals is not None:
|
| 94 |
+
basis_mask = get_current_basis(tau_initial, duals, Phi, S_np, q)
|
| 95 |
+
basis_idx = np.where(basis_mask)[0]
|
| 96 |
+
if basis_idx.size != Phi.shape[1]:
|
| 97 |
+
basis_idx = _choose_full_rank_rows(Phi)
|
| 98 |
+
else:
|
| 99 |
+
basis_idx = _choose_full_rank_rows(Phi)
|
| 100 |
+
|
| 101 |
+
Phi_basis = Phi[basis_idx]
|
| 102 |
+
s_basis = s[basis_idx]
|
| 103 |
+
|
| 104 |
+
tau_sol = torch.linalg.lstsq(as_tensor(Phi_basis, dtype), s_basis).solution
|
| 105 |
+
tau = tau_sol
|
| 106 |
+
except Exception:
|
| 107 |
+
tau = torch.zeros((Phi.shape[1],), dtype=dtype)
|
| 108 |
+
|
| 109 |
+
return tau.reshape(-1, 1)
|
| 110 |
+
|
| 111 |
+
def torch_score_func_sample_level(features: List[np.ndarray], annotations: List[np.ndarray], beta: torch.Tensor) -> torch.Tensor:
|
| 112 |
+
"""sample-level score (max_false_score) calculation"""
|
| 113 |
+
scores = as_tensor(np.zeros((len(features),)), dtype=beta.dtype)
|
| 114 |
+
for i, (f, a) in enumerate(zip(features, annotations)):
|
| 115 |
+
cs = -as_tensor(f, dtype=beta.dtype) @ beta
|
| 116 |
+
at = as_tensor(a, dtype=torch.bool)
|
| 117 |
+
scores[i] = torch.sort(cs[~at], descending=True)[0][0] if torch.sum(~at) > 0 else torch.tensor(1e9, dtype=beta.dtype)
|
| 118 |
+
return scores
|
| 119 |
+
|
| 120 |
+
def cond_score_loss(beta: torch.Tensor, dataset: Tuple, z_processed: np.ndarray, random_seed: int, q: float) -> torch.Tensor:
|
| 121 |
+
"""Claim-level loss function for boosting"""
|
| 122 |
+
indices = np.arange(len(dataset[0]))
|
| 123 |
+
ind_train, ind_calib = train_test_split(indices, test_size=0.5, random_state=random_seed)
|
| 124 |
+
|
| 125 |
+
x_train, y_train = [dataset[0][i] for i in ind_train], [dataset[1][i] for i in ind_train]
|
| 126 |
+
x_calib, y_calib = [dataset[0][i] for i in ind_calib], [dataset[1][i] for i in ind_calib]
|
| 127 |
+
z_train, z_calib = z_processed[ind_train], z_processed[ind_calib]
|
| 128 |
+
|
| 129 |
+
scores_train_sample = torch_score_func_sample_level(x_train, y_train, beta)
|
| 130 |
+
tau = solve_qr_for_boosting(z_train, scores_train_sample, q, beta.dtype)
|
| 131 |
+
|
| 132 |
+
cutoffs = (as_tensor(z_calib, dtype=beta.dtype) @ tau).flatten()
|
| 133 |
+
|
| 134 |
+
total_loss = torch.tensor(0.0, dtype=beta.dtype, requires_grad=True)
|
| 135 |
+
count = 0
|
| 136 |
+
for i, (f_c, a_c) in enumerate(zip(x_calib, y_calib)):
|
| 137 |
+
claim_scores = -(as_tensor(f_c, dtype=beta.dtype) @ beta)
|
| 138 |
+
perc = torch.sigmoid(cutoffs[i] - claim_scores)
|
| 139 |
+
total_loss = total_loss + torch.mean(perc)
|
| 140 |
+
count += 1
|
| 141 |
+
|
| 142 |
+
total_loss = total_loss / count if count > 0 else total_loss
|
| 143 |
+
return -total_loss
|
| 144 |
+
|
| 145 |
+
class ConditionalConformalBoosting:
|
| 146 |
+
def __init__(self, random_state: int = 0):
|
| 147 |
+
self.rng = np.random.default_rng(random_state)
|
| 148 |
+
self.beta: Optional[np.ndarray] = None
|
| 149 |
+
self.z_projector: Optional[np.ndarray] = None
|
| 150 |
+
|
| 151 |
+
def _extract_features_for_boosting(self, data: List[dict]) -> Tuple[List[np.ndarray], np.ndarray, List[np.ndarray]]:
|
| 152 |
+
basic_features = [d['features_4d'] for d in data]
|
| 153 |
+
annotations = [d['annotations'] for d in data]
|
| 154 |
+
conditional_features = []
|
| 155 |
+
for d in data:
|
| 156 |
+
sample = d.get('sample', {})
|
| 157 |
+
scores_dict = d.get('scores', {})
|
| 158 |
+
base_features = d.get('prompt_features', [])
|
| 159 |
+
logprob_scores = scores_dict.get('logprob', np.array([]))
|
| 160 |
+
logprob_mean = np.mean(logprob_scores) if logprob_scores.size > 0 else 0.0
|
| 161 |
+
logprob_std = np.std(logprob_scores) if logprob_scores.size > 1 else 0.0
|
| 162 |
+
claim_count = len(sample.get('atomic_facts', []))
|
| 163 |
+
combined_features = np.concatenate([base_features, [logprob_mean, logprob_std, claim_count]])
|
| 164 |
+
conditional_features.append(combined_features)
|
| 165 |
+
z = np.array(conditional_features, dtype=float)
|
| 166 |
+
if not np.isfinite(z).all():
|
| 167 |
+
z = np.nan_to_num(z, nan=np.nanmean(z, axis=0))
|
| 168 |
+
|
| 169 |
+
return basic_features, z, annotations
|
| 170 |
+
|
| 171 |
+
def _preprocess_z(self, z: np.ndarray) -> np.ndarray:
|
| 172 |
+
intercept = np.ones((z.shape[0], 1))
|
| 173 |
+
z_aug = np.hstack([z, intercept])
|
| 174 |
+
try:
|
| 175 |
+
_, s, Vt = np.linalg.svd(z_aug, full_matrices=False)
|
| 176 |
+
rank = np.sum(s > 1e-10)
|
| 177 |
+
self.z_projector = Vt.T[:, :rank]
|
| 178 |
+
except np.linalg.LinAlgError:
|
| 179 |
+
self.z_projector = np.eye(z_aug.shape[1])
|
| 180 |
+
return z_aug @ self.z_projector
|
| 181 |
+
|
| 182 |
+
def fit(self, data: List[dict], alpha: float = 0.1, boosting_epochs: int = 1000, boosting_lr: float = 0.005) -> np.ndarray:
|
| 183 |
+
basic_features, z, annotations = self._extract_features_for_boosting(data)
|
| 184 |
+
dataset_boost = (basic_features, annotations)
|
| 185 |
+
z_processed = self._preprocess_z(z)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
feature_dim = basic_features[0].shape[1]
|
| 189 |
+
beta_tensor = torch.tensor([0.25] * feature_dim, dtype=torch.float, requires_grad=True)
|
| 190 |
+
optimizer = torch.optim.Adam([beta_tensor], lr=boosting_lr)
|
| 191 |
+
|
| 192 |
+
for epoch in range(boosting_epochs):
|
| 193 |
+
optimizer.zero_grad()
|
| 194 |
+
seed_epoch = self.rng.integers(1e7)
|
| 195 |
+
loss = cond_score_loss(beta_tensor, dataset_boost, z_processed, seed_epoch, q=1 - alpha)
|
| 196 |
+
if torch.isnan(loss) or torch.isinf(loss): break
|
| 197 |
+
loss.backward()
|
| 198 |
+
if beta_tensor.grad is not None and torch.isfinite(beta_tensor.grad).all():
|
| 199 |
+
optimizer.step()
|
| 200 |
+
|
| 201 |
+
self.beta = beta_tensor.detach().cpu().numpy()
|
| 202 |
+
#
|
| 203 |
+
return self.beta
|
| 204 |
+
|
| 205 |
+
# ==============================================================================
|
| 206 |
+
# === Step 2: Classes and Helper Functions for Calibration and Prediction ===
|
| 207 |
+
# ==============================================================================
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class ConditionalConformalInference:
|
| 211 |
+
def __init__(self, random_state: int = 0):
|
| 212 |
+
self.rng = np.random.default_rng(random_state)
|
| 213 |
+
self.alpha: Optional[float] = None
|
| 214 |
+
self.beta: Optional[np.ndarray] = None
|
| 215 |
+
self.model: Optional[CondConf] = None
|
| 216 |
+
# Adaptive alpha components
|
| 217 |
+
self.adaptive_enabled: bool = False
|
| 218 |
+
self.retention_target: Optional[float] = None
|
| 219 |
+
self.quantile_theta: Optional[np.ndarray] = None # parameters for linear quantile_fn
|
| 220 |
+
self._z_proj_for_quantile: Optional[np.ndarray] = None # projector used for z in quantile fit
|
| 221 |
+
|
| 222 |
+
def _make_z_only(self, data: List[dict]) -> np.ndarray:
|
| 223 |
+
"""z generation - same structure as boosting: [prompt_features..., logprob_mean, logprob_std, claim_count]"""
|
| 224 |
+
max_base_len = 0
|
| 225 |
+
for d in data:
|
| 226 |
+
base = d.get('prompt_features', np.array([]))
|
| 227 |
+
try:
|
| 228 |
+
base_len = int(np.asarray(base).size)
|
| 229 |
+
except Exception:
|
| 230 |
+
base_len = 0
|
| 231 |
+
if base_len > max_base_len:
|
| 232 |
+
max_base_len = base_len
|
| 233 |
+
|
| 234 |
+
cond_feats: List[np.ndarray] = []
|
| 235 |
+
for d in data:
|
| 236 |
+
sample = d.get('sample', {})
|
| 237 |
+
scores_dict = d.get('scores', {})
|
| 238 |
+
|
| 239 |
+
base = np.asarray(d.get('prompt_features', np.array([])), dtype=float).ravel()
|
| 240 |
+
if base.size < max_base_len:
|
| 241 |
+
pad = np.zeros(max_base_len - base.size, dtype=float)
|
| 242 |
+
base = np.concatenate([base, pad])
|
| 243 |
+
elif base.size > max_base_len and max_base_len > 0:
|
| 244 |
+
base = base[:max_base_len]
|
| 245 |
+
|
| 246 |
+
logprob_scores = np.asarray(scores_dict.get('logprob', np.array([])), dtype=float).ravel()
|
| 247 |
+
logprob_mean = float(np.mean(logprob_scores)) if logprob_scores.size > 0 else 0.0
|
| 248 |
+
logprob_std = float(np.std(logprob_scores)) if logprob_scores.size > 1 else 0.0
|
| 249 |
+
|
| 250 |
+
claim_count = float(len(sample.get('atomic_facts', [])))
|
| 251 |
+
|
| 252 |
+
combined = np.concatenate([base, np.array([logprob_mean, logprob_std, claim_count], dtype=float)])
|
| 253 |
+
cond_feats.append(combined)
|
| 254 |
+
|
| 255 |
+
result = np.asarray(cond_feats, dtype=float)
|
| 256 |
+
return result
|
| 257 |
+
|
| 258 |
+
def _make_yz_for_calib(self, data: List[dict], beta: np.ndarray, eps: float = 0.0):
|
| 259 |
+
z = self._make_z_only(data)
|
| 260 |
+
y_list = []
|
| 261 |
+
for d in data:
|
| 262 |
+
feats = d['features_4d']
|
| 263 |
+
ann = np.asarray(d['annotations'], dtype=bool)
|
| 264 |
+
s = -(feats @ beta)
|
| 265 |
+
false_s = s[~ann]
|
| 266 |
+
if false_s.size > 0:
|
| 267 |
+
y_list.append(np.min(false_s) - eps)
|
| 268 |
+
else:
|
| 269 |
+
y_list.append((np.max(s) if s.size > 0 else 0.0))
|
| 270 |
+
y = np.asarray(y_list, dtype=float)
|
| 271 |
+
mask = np.isfinite(y)
|
| 272 |
+
return y[mask], z[mask], mask
|
| 273 |
+
|
| 274 |
+
def fit(self, calib_data: List[dict], alpha: float, beta: np.ndarray,
|
| 275 |
+
adaptive_alpha: bool = False, retention_target: float = 0.7):
|
| 276 |
+
"""Set up and calibrate CondConf model"""
|
| 277 |
+
|
| 278 |
+
self.alpha = alpha
|
| 279 |
+
self.beta = beta
|
| 280 |
+
self.adaptive_enabled = bool(adaptive_alpha)
|
| 281 |
+
self.retention_target = float(retention_target) if adaptive_alpha else None
|
| 282 |
+
if not self.adaptive_enabled:
|
| 283 |
+
self.quantile_theta = None
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
y_calib, z_calib, mask = self._make_yz_for_calib(calib_data, beta)
|
| 287 |
+
self._last_calib_mask = mask
|
| 288 |
+
|
| 289 |
+
self.model = CondConf(score_fn=lambda x, y: y, Phi_fn=lambda x: x, seed=self.rng.integers(1e6))
|
| 290 |
+
self.model.setup_problem(x_calib=z_calib, y_calib=y_calib)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
if self.adaptive_enabled:
|
| 294 |
+
try:
|
| 295 |
+
self._fit_adaptive_quantile_fn(calib_data, z_calib, mask)
|
| 296 |
+
|
| 297 |
+
except Exception as e:
|
| 298 |
+
|
| 299 |
+
self.adaptive_enabled = False
|
| 300 |
+
return self
|
| 301 |
+
|
| 302 |
+
def predict(self, test_data: List[dict]) -> List[dict]:
|
| 303 |
+
if not self.model or self.beta is None:
|
| 304 |
+
raise RuntimeError("Model is not fitted. Call fit() first.")
|
| 305 |
+
z_test = self._make_z_only(test_data)
|
| 306 |
+
out = []
|
| 307 |
+
|
| 308 |
+
for i, d in enumerate(test_data):
|
| 309 |
+
sample = dict(d.get('sample', {}))
|
| 310 |
+
claims = sample.get('atomic_facts', [])
|
| 311 |
+
if not claims:
|
| 312 |
+
sample['filtered_claims'] = []
|
| 313 |
+
out.append(sample)
|
| 314 |
+
continue
|
| 315 |
+
|
| 316 |
+
feats = d['features_4d']
|
| 317 |
+
scores = -(feats @ self.beta)
|
| 318 |
+
z_i = z_test[i:i+1]
|
| 319 |
+
|
| 320 |
+
get_threshold_fn = lambda threshold, x: threshold
|
| 321 |
+
|
| 322 |
+
try:
|
| 323 |
+
if self.adaptive_enabled and self.quantile_theta is not None:
|
| 324 |
+
q_i = float(self._quantile_fn(z_i))
|
| 325 |
+
else:
|
| 326 |
+
q_i = float(self.alpha)
|
| 327 |
+
|
| 328 |
+
thr = self.model.predict(
|
| 329 |
+
quantile=q_i,
|
| 330 |
+
x_test=z_i,
|
| 331 |
+
score_inv_fn=get_threshold_fn,
|
| 332 |
+
randomize=True,
|
| 333 |
+
exact=True
|
| 334 |
+
)
|
| 335 |
+
thr = float(np.squeeze(thr))
|
| 336 |
+
s_min = float(np.min(scores)) if scores.size > 0 else -np.inf
|
| 337 |
+
s_max = float(np.max(scores)) if scores.size > 0 else np.inf
|
| 338 |
+
if not np.isfinite(thr):
|
| 339 |
+
thr = s_max
|
| 340 |
+
else:
|
| 341 |
+
thr = float(np.clip(thr, s_min, s_max))
|
| 342 |
+
sample['filtered_claims'] = [c for j, c in enumerate(claims) if scores[j] <= thr]
|
| 343 |
+
except Exception:
|
| 344 |
+
sample['filtered_claims'] = []
|
| 345 |
+
|
| 346 |
+
out.append(sample)
|
| 347 |
+
|
| 348 |
+
return out
|
| 349 |
+
|
| 350 |
+
# ------------------------------------------------------------------
|
| 351 |
+
# Adaptive alpha utilities
|
| 352 |
+
# ------------------------------------------------------------------
|
| 353 |
+
def _get_claim_scores_list(self, data: List[dict], beta: np.ndarray) -> List[np.ndarray]:
|
| 354 |
+
scores_list = []
|
| 355 |
+
for d in data:
|
| 356 |
+
feats = d['features_4d']
|
| 357 |
+
s = -(feats @ beta)
|
| 358 |
+
scores_list.append(s)
|
| 359 |
+
return scores_list
|
| 360 |
+
|
| 361 |
+
def _compute_retention_given_threshold(self, claim_scores: np.ndarray, threshold: float) -> float:
|
| 362 |
+
if claim_scores.size == 0:
|
| 363 |
+
return 0.0
|
| 364 |
+
return float(np.mean(claim_scores <= threshold))
|
| 365 |
+
|
| 366 |
+
def _fit_adaptive_quantile_fn(self, calib_data: List[dict], z_calib: np.ndarray, mask: np.ndarray):
|
| 367 |
+
assert self.model is not None and self.beta is not None and self.retention_target is not None
|
| 368 |
+
|
| 369 |
+
calib_data_masked = [calib_data[i] for i, m in enumerate(mask) if m]
|
| 370 |
+
claim_scores_list = self._get_claim_scores_list(calib_data_masked, self.beta)
|
| 371 |
+
quantile_grid = np.linspace(0.01, 0.99, 31)
|
| 372 |
+
q_star = np.zeros(len(z_calib), dtype=float)
|
| 373 |
+
for i in range(len(z_calib)):
|
| 374 |
+
z_i = z_calib[i:i+1]
|
| 375 |
+
best_q = None
|
| 376 |
+
best_r = -1.0
|
| 377 |
+
best_q_near = None
|
| 378 |
+
for q in quantile_grid:
|
| 379 |
+
try:
|
| 380 |
+
cutoff = self.model.predict(
|
| 381 |
+
quantile=float(q),
|
| 382 |
+
x_test=z_i,
|
| 383 |
+
score_inv_fn=lambda c, x: c,
|
| 384 |
+
randomize=True,
|
| 385 |
+
exact=True
|
| 386 |
+
)
|
| 387 |
+
T = float(np.asarray(cutoff).reshape(-1)[0])
|
| 388 |
+
except Exception:
|
| 389 |
+
continue
|
| 390 |
+
if not np.isfinite(T):
|
| 391 |
+
continue
|
| 392 |
+
r = self._compute_retention_given_threshold(claim_scores_list[i], T)
|
| 393 |
+
if r >= self.retention_target:
|
| 394 |
+
best_q = float(q)
|
| 395 |
+
break
|
| 396 |
+
if r > best_r:
|
| 397 |
+
best_r = r
|
| 398 |
+
best_q_near = float(q)
|
| 399 |
+
q_star[i] = float(best_q if best_q is not None else (best_q_near if best_q_near is not None else quantile_grid[-1]))
|
| 400 |
+
|
| 401 |
+
def phi_alpha(x: np.ndarray) -> np.ndarray:
|
| 402 |
+
x = np.asarray(x)
|
| 403 |
+
ones = np.ones((x.shape[0], 1))
|
| 404 |
+
return np.concatenate([ones, x, x**2], axis=1)
|
| 405 |
+
|
| 406 |
+
Phi = phi_alpha(z_calib)
|
| 407 |
+
ridge = 1e-6
|
| 408 |
+
theta = np.linalg.pinv(Phi.T @ Phi + ridge * np.eye(Phi.shape[1])) @ (Phi.T @ q_star)
|
| 409 |
+
self.quantile_theta = theta
|
| 410 |
+
self._z_proj_for_quantile = None
|
| 411 |
+
|
| 412 |
+
def _quantile_fn(self, z_row: np.ndarray) -> float:
|
| 413 |
+
"""Given single-row z (1 x d), return clipped quantile using phi_alpha (1, z, z^2)."""
|
| 414 |
+
assert self.quantile_theta is not None
|
| 415 |
+
z = np.asarray(z_row)
|
| 416 |
+
phi = np.concatenate([np.ones((z.shape[0], 1)), z, z**2], axis=1)
|
| 417 |
+
q = float(phi @ self.quantile_theta)
|
| 418 |
+
return float(np.clip(q, 0.01, 0.99))
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def evaluate_auroc(self, test_data: List[dict]) -> dict:
|
| 422 |
+
if not self.model or self.beta is None:
|
| 423 |
+
raise RuntimeError("Model is not fitted. Call fit() first.")
|
| 424 |
+
all_scores = []
|
| 425 |
+
all_labels = []
|
| 426 |
+
|
| 427 |
+
for sample_data in test_data:
|
| 428 |
+
features = sample_data['features_4d']
|
| 429 |
+
annotations = np.array(sample_data['annotations'])
|
| 430 |
+
|
| 431 |
+
nonconformity_scores = -features @ self.beta
|
| 432 |
+
|
| 433 |
+
all_scores.extend(nonconformity_scores)
|
| 434 |
+
all_labels.extend((~annotations.astype(bool)).astype(int))
|
| 435 |
+
|
| 436 |
+
all_scores = np.array(all_scores)
|
| 437 |
+
all_labels = np.array(all_labels)
|
| 438 |
+
|
| 439 |
+
try:
|
| 440 |
+
auroc = roc_auc_score(all_labels, all_scores)
|
| 441 |
+
fpr, tpr, thresholds = roc_curve(all_labels, all_scores)
|
| 442 |
+
|
| 443 |
+
results = {
|
| 444 |
+
'auroc': auroc,
|
| 445 |
+
'fpr': fpr,
|
| 446 |
+
'tpr': tpr,
|
| 447 |
+
'thresholds': thresholds,
|
| 448 |
+
'n_samples': len(all_scores),
|
| 449 |
+
'n_false_claims': np.sum(all_labels),
|
| 450 |
+
'n_true_claims': len(all_labels) - np.sum(all_labels)
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
return results
|
| 456 |
+
|
| 457 |
+
except ValueError as e:
|
| 458 |
+
|
| 459 |
+
return {
|
| 460 |
+
'auroc': np.nan,
|
| 461 |
+
'error': str(e),
|
| 462 |
+
'n_samples': len(all_scores),
|
| 463 |
+
'n_false_claims': np.sum(all_labels),
|
| 464 |
+
'n_true_claims': len(all_labels) - np.sum(all_labels)
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
def get_claim_scores(self, test_data: List[dict]) -> List[dict]:
|
| 468 |
+
"""Return claim-level scores for each sample"""
|
| 469 |
+
if not self.model or self.beta is None:
|
| 470 |
+
raise RuntimeError("Model is not fitted. Call fit() first.")
|
| 471 |
+
|
| 472 |
+
results = []
|
| 473 |
+
for sample_data in test_data:
|
| 474 |
+
features = sample_data['features_4d']
|
| 475 |
+
annotations = np.array(sample_data['annotations'])
|
| 476 |
+
claims = sample_data.get('sample', {}).get('atomic_facts', [])
|
| 477 |
+
|
| 478 |
+
nonconformity_scores = -features @ self.beta
|
| 479 |
+
|
| 480 |
+
sample_result = {
|
| 481 |
+
'sample_id': sample_data.get('sample_id', 'unknown'),
|
| 482 |
+
'claims': claims,
|
| 483 |
+
'nonconformity_scores': nonconformity_scores.tolist(),
|
| 484 |
+
'annotations': annotations.tolist(),
|
| 485 |
+
'is_false': (~annotations.astype(bool)).tolist()
|
| 486 |
+
}
|
| 487 |
+
results.append(sample_result)
|
| 488 |
+
|
| 489 |
+
return results
|
MACI-main/data/med_scores/medlfqa_frequencies.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:75c946c6772d18c650b4484bd743032f861a1af4819ded8014cbd5a3b7102857
|
| 3 |
+
size 2225374
|
MACI-main/data/med_scores/medlfqa_logprobs.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:032e80b7aa0c1c73343aec30aca91c128fa9a3fa076333a07a601ba3495b1bd7
|
| 3 |
+
size 2199362
|
MACI-main/data/med_scores/medlfqa_scores_deepseek_deepseek-chat-v3-0324.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9dee348e1265e4f01029224c070ab7b75d6fdbf51b8f4705828c378430c97e38
|
| 3 |
+
size 426183
|
MACI-main/data/med_scores/medlfqa_scores_meta-llama_llama-3.3-70b-instruct.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12eebea7f09ff94aaba4944f6d3ccf3e3ad10e33cd154a6cc274218f2709f1bd
|
| 3 |
+
size 426183
|
MACI-main/data/med_scores/medlfqa_scores_qwen_qwen-2.5-72b-instruct.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:67fe74653e69c421202cc4da4308f262657349a5bdc10bf65c43f083d92499e8
|
| 3 |
+
size 426183
|
MACI-main/data/med_scores/medlfqa_selfevals.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58bc690d686c367019bb016b68723747f5a311b3d56f35249c8ee36a61b77878
|
| 3 |
+
size 2226438
|
MACI-main/data/wiki_scores/wikibio_final.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
MACI-main/data/wiki_scores/wikibio_final_dataset.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:482d015cec319b80bf92c4adb8e6d65c20cc40808801561291c7d5bcf76ed551
|
| 3 |
+
size 20356478
|
MACI-main/data/wiki_scores/wikibio_final_frequencies.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f657a70d4e4307b91bdd80ee4b03daf5d3362e723ebae54ceefd5c7cc2330a37
|
| 3 |
+
size 3933826
|
MACI-main/data/wiki_scores/wikibio_final_logprobs.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:644865e9037b9139e277f5d4dc2da594314a4414eca3a9bad4dcad5f1c511319
|
| 3 |
+
size 4820424
|
MACI-main/data/wiki_scores/wikibio_final_self_evals.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:77f61a212dabb3fd85724e78c793887ccbe90e0285c4055411888c64dd5d44d4
|
| 3 |
+
size 4848638
|
MACI-main/data/wiki_scores/wikibio_scores_deepseek-chat-v3-0324.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b9726f79ffdd9e6c1f89b64c66960cdc3cce5c3b868e750cc99ee28e4a666c50
|
| 3 |
+
size 621202
|
MACI-main/data/wiki_scores/wikibio_scores_meta-llama_llama-3.3-70b-instruct.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c36a3e500d2c08be38fcf07260a0496f498849e8ea57c9cf074f5f10aca855a
|
| 3 |
+
size 621202
|
MACI-main/data/wiki_scores/wikibio_scores_qwen_qwen-2.5-72b-instruct.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c8047a2bcf802c082f6995fae735b634738f351cda58c37409a52376f667ac4a
|
| 3 |
+
size 621168
|
MACI-main/experiments/conditional_groupers.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Flexible conditional grouping utilities for subgroup analysis.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import re
|
| 9 |
+
import warnings
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
from typing import List, Dict, Any, Tuple
|
| 13 |
+
from abc import ABC, abstractmethod
|
| 14 |
+
|
| 15 |
+
warnings.filterwarnings('default')
|
| 16 |
+
np.seterr(all='warn')
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ConditionalGrouper(ABC):
|
| 20 |
+
|
| 21 |
+
def __init__(self, name: str, description: str):
|
| 22 |
+
self.name = name
|
| 23 |
+
self.description = description
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
def create_bins(self, values: np.ndarray, method: str = 'quartiles',
|
| 30 |
+
custom_bins: List[float] = None) -> List[Tuple[float, float]]:
|
| 31 |
+
finite_values = values[np.isfinite(values)]
|
| 32 |
+
|
| 33 |
+
if len(finite_values) == 0:
|
| 34 |
+
return [(float(np.min(values)), float(np.max(values)))]
|
| 35 |
+
|
| 36 |
+
if method == 'quartiles':
|
| 37 |
+
quantiles = [0.0, 0.25, 0.5, 0.75, 1.0]
|
| 38 |
+
elif method == 'quintiles':
|
| 39 |
+
quantiles = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
|
| 40 |
+
elif method == 'deciles':
|
| 41 |
+
quantiles = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
| 42 |
+
elif method == 'tertiles':
|
| 43 |
+
quantiles = [0.0, 0.33, 0.67, 1.0]
|
| 44 |
+
elif method == 'median_split':
|
| 45 |
+
quantiles = [0.0, 0.5, 1.0]
|
| 46 |
+
elif method == 'custom' and custom_bins:
|
| 47 |
+
qs = np.array(custom_bins)
|
| 48 |
+
else:
|
| 49 |
+
quantiles = [0.0, 0.25, 0.5, 0.75, 1.0]
|
| 50 |
+
|
| 51 |
+
if method != 'custom':
|
| 52 |
+
qs = np.quantile(finite_values, quantiles)
|
| 53 |
+
|
| 54 |
+
bins = [(float(qs[i]), float(qs[i+1])) for i in range(len(qs)-1)]
|
| 55 |
+
return bins
|
| 56 |
+
|
| 57 |
+
def get_group_info(self, dataset: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
|
| 58 |
+
values = self.compute_values(dataset, **kwargs)
|
| 59 |
+
finite_values = values[np.isfinite(values)]
|
| 60 |
+
|
| 61 |
+
return {
|
| 62 |
+
'name': self.name,
|
| 63 |
+
'description': self.description,
|
| 64 |
+
'total_samples': len(values),
|
| 65 |
+
'valid_samples': len(finite_values),
|
| 66 |
+
'min_value': float(np.min(finite_values)) if len(finite_values) > 0 else np.nan,
|
| 67 |
+
'max_value': float(np.max(finite_values)) if len(finite_values) > 0 else np.nan,
|
| 68 |
+
'mean_value': float(np.mean(finite_values)) if len(finite_values) > 0 else np.nan,
|
| 69 |
+
'std_value': float(np.std(finite_values)) if len(finite_values) > 0 else np.nan,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# View metadata configuration (globally overridable)
|
| 74 |
+
def _default_view_csv_path() -> str:
|
| 75 |
+
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 76 |
+
return os.path.join(repo_root, 'data', 'wiki_scores', 'wikibio_final.csv')
|
| 77 |
+
|
| 78 |
+
GLOBAL_VIEW_METADATA_CSV = _default_view_csv_path()
|
| 79 |
+
|
| 80 |
+
def set_view_metadata_csv(csv_path: str):
|
| 81 |
+
global GLOBAL_VIEW_METADATA_CSV
|
| 82 |
+
if isinstance(csv_path, str) and len(csv_path) > 0:
|
| 83 |
+
GLOBAL_VIEW_METADATA_CSV = csv_path
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class ViewCountGrouper(ConditionalGrouper):
|
| 87 |
+
|
| 88 |
+
def __init__(self):
|
| 89 |
+
super().__init__(
|
| 90 |
+
name="view_count",
|
| 91 |
+
description="Wikipedia view count (from wikibio_final.csv)"
|
| 92 |
+
)
|
| 93 |
+
self._loaded = False
|
| 94 |
+
self._csv_path = None
|
| 95 |
+
self._name_to_views = {}
|
| 96 |
+
self._global_min_count = 0.0
|
| 97 |
+
|
| 98 |
+
@staticmethod
|
| 99 |
+
def _parse_name_from_prompt(prompt: str) -> str:
|
| 100 |
+
if not isinstance(prompt, str):
|
| 101 |
+
try:
|
| 102 |
+
prompt = str(prompt)
|
| 103 |
+
except Exception:
|
| 104 |
+
return ""
|
| 105 |
+
txt = prompt.strip()
|
| 106 |
+
# Typical pattern: "Please write one biographical paragraph about {NAME}."
|
| 107 |
+
import re
|
| 108 |
+
m = re.search(r"about\s+(.+?)(?:[\.]|\n|$)", txt, flags=re.IGNORECASE)
|
| 109 |
+
if m:
|
| 110 |
+
return m.group(1).strip()
|
| 111 |
+
# Fallback: try after 'about '
|
| 112 |
+
if 'about ' in txt:
|
| 113 |
+
return txt.split('about ', 1)[-1].strip().rstrip('.').strip()
|
| 114 |
+
return txt
|
| 115 |
+
|
| 116 |
+
def _ensure_loaded(self):
|
| 117 |
+
# Lazy-load and refresh if global path changed
|
| 118 |
+
if (not self._loaded) or (self._csv_path != GLOBAL_VIEW_METADATA_CSV):
|
| 119 |
+
try:
|
| 120 |
+
df = pd.read_csv(GLOBAL_VIEW_METADATA_CSV)
|
| 121 |
+
name_col = 'Name' if 'Name' in df.columns else None
|
| 122 |
+
views_col = 'Views' if 'Views' in df.columns else None
|
| 123 |
+
maxc_col = 'max_counts' if 'max_counts' in df.columns else None
|
| 124 |
+
mapping = {}
|
| 125 |
+
values_for_min = []
|
| 126 |
+
if name_col and (views_col or maxc_col):
|
| 127 |
+
for _, row in df.iterrows():
|
| 128 |
+
name = str(row[name_col]).strip()
|
| 129 |
+
v = np.nan
|
| 130 |
+
# Per-row preference: Views if finite, else max_counts
|
| 131 |
+
if views_col is not None:
|
| 132 |
+
try:
|
| 133 |
+
vv = float(row[views_col])
|
| 134 |
+
if np.isfinite(vv):
|
| 135 |
+
v = vv
|
| 136 |
+
except Exception:
|
| 137 |
+
pass
|
| 138 |
+
if (not np.isfinite(v)) and maxc_col is not None:
|
| 139 |
+
try:
|
| 140 |
+
mv = float(row[maxc_col])
|
| 141 |
+
if np.isfinite(mv):
|
| 142 |
+
v = mv
|
| 143 |
+
except Exception:
|
| 144 |
+
pass
|
| 145 |
+
mapping[name] = v
|
| 146 |
+
if np.isfinite(v):
|
| 147 |
+
values_for_min.append(v)
|
| 148 |
+
self._name_to_views = mapping
|
| 149 |
+
self._csv_path = GLOBAL_VIEW_METADATA_CSV
|
| 150 |
+
# Global minimum over available finite counts; default to 0.0 if none
|
| 151 |
+
self._global_min_count = float(np.min(values_for_min)) if len(values_for_min) > 0 else 0.0
|
| 152 |
+
self._loaded = True
|
| 153 |
+
except Exception:
|
| 154 |
+
# If loading fails, mark as loaded with empty mapping
|
| 155 |
+
self._name_to_views = {}
|
| 156 |
+
self._csv_path = GLOBAL_VIEW_METADATA_CSV
|
| 157 |
+
self._global_min_count = 0.0
|
| 158 |
+
self._loaded = True
|
| 159 |
+
|
| 160 |
+
def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray:
|
| 161 |
+
self._ensure_loaded()
|
| 162 |
+
values = []
|
| 163 |
+
for sample in dataset:
|
| 164 |
+
prompt = sample.get('prompt', '')
|
| 165 |
+
name = self._parse_name_from_prompt(prompt)
|
| 166 |
+
# Direct match first
|
| 167 |
+
val = self._name_to_views.get(name)
|
| 168 |
+
if val is None:
|
| 169 |
+
# Try naive normalization: collapse spaces
|
| 170 |
+
key2 = " ".join(name.split())
|
| 171 |
+
val = self._name_to_views.get(key2, np.nan)
|
| 172 |
+
# Fallback: global min count if missing or NaN
|
| 173 |
+
if val is None or (isinstance(val, float) and not np.isfinite(val)):
|
| 174 |
+
val = self._global_min_count
|
| 175 |
+
values.append(float(val))
|
| 176 |
+
return np.array(values, dtype=float)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class FalseClaimRiskGrouper(ConditionalGrouper):
|
| 180 |
+
|
| 181 |
+
def __init__(self):
|
| 182 |
+
super().__init__(
|
| 183 |
+
name="false_claim_risk",
|
| 184 |
+
description="Text-based false-claim risk index (higher → more risk)"
|
| 185 |
+
)
|
| 186 |
+
self.abs_terms = [
|
| 187 |
+
'always', 'never', 'guarantee', 'guaranteed', 'cure', 'proven',
|
| 188 |
+
'will', 'must', 'definitely', 'certainly', 'undoubtedly', 'no doubt'
|
| 189 |
+
]
|
| 190 |
+
self.enum_keywords = [
|
| 191 |
+
'symptom', 'symptoms', 'signs', 'causes', 'cause', 'types', 'treatments',
|
| 192 |
+
'treatment', 'risk factors', 'complications', 'side effects', 'prevention'
|
| 193 |
+
]
|
| 194 |
+
self.citation_patterns = [
|
| 195 |
+
r'according\s+to', r'based\s+on', r'research\s+(?:shows?|indicates?|suggests?)',
|
| 196 |
+
r'studies?\s+(?:show|indicate|suggest|reveal|demonstrate)', r'\(\d{4}\)', r'\[[\d,\s-]+\]'
|
| 197 |
+
]
|
| 198 |
+
self.compiled_cite = [re.compile(p, re.IGNORECASE) for p in self.citation_patterns]
|
| 199 |
+
|
| 200 |
+
@staticmethod
|
| 201 |
+
def _num_sentences(text: str) -> int:
|
| 202 |
+
if not text:
|
| 203 |
+
return 0
|
| 204 |
+
return max(1, text.count('.') + text.count('!') + text.count('?') + text.count('\n'))
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def _listiness(text: str) -> int:
|
| 208 |
+
if not text:
|
| 209 |
+
return 0
|
| 210 |
+
markers = [',', ';', '\n', '-', '*', '•']
|
| 211 |
+
count = sum(text.count(m) for m in markers)
|
| 212 |
+
# Enumerations like "1.", "2)", "(3)"
|
| 213 |
+
count += len(re.findall(r'(?:(?<=\s)|^)(?:\d{1,2}[\.)\]])', text))
|
| 214 |
+
return count
|
| 215 |
+
|
| 216 |
+
def _citation_density(self, text: str) -> float:
|
| 217 |
+
if not text:
|
| 218 |
+
return 0.0
|
| 219 |
+
words = text.split()
|
| 220 |
+
if not words:
|
| 221 |
+
return 0.0
|
| 222 |
+
matches = 0
|
| 223 |
+
low = text.lower()
|
| 224 |
+
for pat in self.compiled_cite:
|
| 225 |
+
matches += len(pat.findall(low))
|
| 226 |
+
return matches / max(1, len(words))
|
| 227 |
+
|
| 228 |
+
def _absolute_density(self, text: str) -> float:
|
| 229 |
+
if not text:
|
| 230 |
+
return 0.0
|
| 231 |
+
words = re.findall(r"\b\w+\b", text.lower())
|
| 232 |
+
if not words:
|
| 233 |
+
return 0.0
|
| 234 |
+
abs_cnt = sum(1 for w in words if w in self.abs_terms)
|
| 235 |
+
return abs_cnt / max(1, len(words))
|
| 236 |
+
|
| 237 |
+
def _enum_keyword_score(self, prompt: str, response: str) -> float:
|
| 238 |
+
txt = f"{prompt} {response}".lower()
|
| 239 |
+
return float(sum(1 for k in self.enum_keywords if k in txt))
|
| 240 |
+
|
| 241 |
+
def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray:
|
| 242 |
+
vals = []
|
| 243 |
+
for sample in dataset:
|
| 244 |
+
prompt = sample.get('prompt', '') or ''
|
| 245 |
+
response = sample.get('response', '') or ''
|
| 246 |
+
resp = str(response)
|
| 247 |
+
|
| 248 |
+
# Features
|
| 249 |
+
num_words = len(resp.split())
|
| 250 |
+
len_norm = min(1.0, num_words / 400.0)
|
| 251 |
+
sent_norm = min(1.0, self._num_sentences(resp) / 12.0)
|
| 252 |
+
list_norm = min(1.0, self._listiness(resp) / 40.0)
|
| 253 |
+
num_density = (sum(ch.isdigit() for ch in resp) / max(1, len(resp)))
|
| 254 |
+
abs_density = self._absolute_density(resp)
|
| 255 |
+
cite_density = self._citation_density(resp)
|
| 256 |
+
enum_score = min(1.0, self._enum_keyword_score(str(prompt), resp) / 4.0)
|
| 257 |
+
|
| 258 |
+
# Composite risk (clipped to [0,1])
|
| 259 |
+
risk = (
|
| 260 |
+
0.30 * len_norm +
|
| 261 |
+
0.15 * sent_norm +
|
| 262 |
+
0.20 * list_norm +
|
| 263 |
+
0.10 * num_density +
|
| 264 |
+
0.15 * abs_density +
|
| 265 |
+
0.10 * enum_score -
|
| 266 |
+
0.10 * cite_density
|
| 267 |
+
)
|
| 268 |
+
vals.append(float(np.clip(risk, 0.0, 1.0)))
|
| 269 |
+
return np.array(vals, dtype=float)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class MedicalContentGrouper(ConditionalGrouper):
|
| 273 |
+
def __init__(self):
|
| 274 |
+
super().__init__(
|
| 275 |
+
name="medical_content",
|
| 276 |
+
description="Medical content (Information/Interpretation/Action)"
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
@staticmethod
|
| 280 |
+
def _normalize(text: str) -> str:
|
| 281 |
+
if not isinstance(text, str):
|
| 282 |
+
try:
|
| 283 |
+
text = str(text)
|
| 284 |
+
except Exception:
|
| 285 |
+
return ""
|
| 286 |
+
return " ".join(text.strip().lower().split())
|
| 287 |
+
|
| 288 |
+
def _classify(self, prompt: str) -> int:
|
| 289 |
+
p = self._normalize(prompt)
|
| 290 |
+
|
| 291 |
+
# Heuristic keyword sets
|
| 292 |
+
info_kw = [
|
| 293 |
+
"what is", "what are", "definition", "define", "symptom", "signs", "cause", "why",
|
| 294 |
+
"prognosis", "life expectancy", "effect", "does .* do", "means?", "treatment", "therapy",
|
| 295 |
+
"disease", "syndrome", "disorder", "cancer", "diabetes", "ards", "tay-sachs", "paget",
|
| 296 |
+
"thalassemia", "psp", "rosacea", "empyema"
|
| 297 |
+
]
|
| 298 |
+
drug_kw = [
|
| 299 |
+
"drug", "medication", "medicine", "dose", "dosage", "tablet", "pill", "mg", "patch",
|
| 300 |
+
"paxlovid", "zoloft", "lexapro", "meloxicam", "naproxen", "fentanyl", "celexa", "restoril",
|
| 301 |
+
"calcitonin", "latanoprost", "aldactazide", "nicoderm"
|
| 302 |
+
]
|
| 303 |
+
symptom_kw = [
|
| 304 |
+
"pain", "ache", "swelling", "lump", "dark urine", "dizziness", "lightheaded", "fatigue",
|
| 305 |
+
"muscle aches", "discharge", "sunburn", "hoarder", "smell"
|
| 306 |
+
]
|
| 307 |
+
interpret_kw = [
|
| 308 |
+
"what does it mean", "what does .* mean", "when should you worry", "should i worry",
|
| 309 |
+
]
|
| 310 |
+
action_kw = [
|
| 311 |
+
"should i", "do i need", "is it okay", "can i", "how to", "how do i", "stop", "start",
|
| 312 |
+
"continue", "switch", "swap", "get tested", "try", "take", "drink", "use"
|
| 313 |
+
]
|
| 314 |
+
|
| 315 |
+
def contains_any(keys: List[str]) -> bool:
|
| 316 |
+
for k in keys:
|
| 317 |
+
if " .* " in k or ".*" in k:
|
| 318 |
+
import re
|
| 319 |
+
if re.search(k, p):
|
| 320 |
+
return True
|
| 321 |
+
if k in p:
|
| 322 |
+
return True
|
| 323 |
+
return False
|
| 324 |
+
|
| 325 |
+
# Action-seeking first (high precision phrases)
|
| 326 |
+
if contains_any(action_kw):
|
| 327 |
+
return 2
|
| 328 |
+
|
| 329 |
+
# Information-seeking: has disease/drug entity cues and info-type query words
|
| 330 |
+
if (contains_any(info_kw) or contains_any(drug_kw)) and ("?" in prompt or contains_any(["what", "why", "signs", "symptom", "life expectancy", "treatment"])):
|
| 331 |
+
return 0
|
| 332 |
+
|
| 333 |
+
# Interpretation-seeking: general symptom phrases or interpret patterns
|
| 334 |
+
if contains_any(interpret_kw) or contains_any(symptom_kw):
|
| 335 |
+
return 1
|
| 336 |
+
|
| 337 |
+
# Fallback: map generic questions with what/why to information
|
| 338 |
+
if contains_any(["what", "why"]):
|
| 339 |
+
return 0
|
| 340 |
+
|
| 341 |
+
# Otherwise treat as action if imperative-like
|
| 342 |
+
if contains_any(["how to", "how do i"]):
|
| 343 |
+
return 2
|
| 344 |
+
|
| 345 |
+
# Default to interpretation
|
| 346 |
+
return 1
|
| 347 |
+
|
| 348 |
+
def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray:
|
| 349 |
+
values = []
|
| 350 |
+
for sample in dataset:
|
| 351 |
+
prompt = sample.get('prompt', '')
|
| 352 |
+
values.append(self._classify(prompt))
|
| 353 |
+
return np.array(values, dtype=float)
|
| 354 |
+
|
| 355 |
+
def create_bins(self, values: np.ndarray, method: str = 'ignored', custom_bins: List[float] = None) -> List[Tuple[float, float]]:
|
| 356 |
+
return [(-0.5, 0.5), (0.5, 1.5), (1.5, 2.5)]
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class ExpertQAFieldGrouper(ConditionalGrouper):
|
| 360 |
+
"""ExpertQA official metadata.field based 3-group classifier
|
| 361 |
+
|
| 362 |
+
- 0: Biology/Medicine (Biology, Chemistry, Psychology, Environmental Science, etc.)
|
| 363 |
+
- 1: Engineering/Technology (Engineering and Technology, Physics and Astronomy, Architecture, etc.)
|
| 364 |
+
- 2: Other (All other fields)
|
| 365 |
+
|
| 366 |
+
The mapping is loaded from '/expertqa_prompt_to_field.json' by default.
|
| 367 |
+
If the file does not exist, all samples are classified as Other(2).
|
| 368 |
+
The values are integer labels, and create_bins is fixed to discrete intervals.
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
def __init__(self, mapping_path: str = "/expertqa_prompt_to_field.json"):
|
| 372 |
+
super().__init__(
|
| 373 |
+
name="expertqa_field",
|
| 374 |
+
description="ExpertQA metadata.field → {Bio/Med, Eng/Tech, Other}"
|
| 375 |
+
)
|
| 376 |
+
self.mapping_path = mapping_path
|
| 377 |
+
self._loaded = False
|
| 378 |
+
self._prompt_to_field = {}
|
| 379 |
+
|
| 380 |
+
self.bio_med_fields = set([
|
| 381 |
+
"Healthcare / Medicine",
|
| 382 |
+
"Biology",
|
| 383 |
+
"Chemistry",
|
| 384 |
+
"Psychology",
|
| 385 |
+
"Environmental Science",
|
| 386 |
+
])
|
| 387 |
+
self.eng_tech_fields = set([
|
| 388 |
+
"Engineering and Technology",
|
| 389 |
+
"Physics and Astronomy",
|
| 390 |
+
"Architecture",
|
| 391 |
+
])
|
| 392 |
+
|
| 393 |
+
@staticmethod
|
| 394 |
+
def _normalize(text: str) -> str:
|
| 395 |
+
if not isinstance(text, str):
|
| 396 |
+
try:
|
| 397 |
+
text = str(text)
|
| 398 |
+
except Exception:
|
| 399 |
+
return ""
|
| 400 |
+
return " ".join(text.strip().split())
|
| 401 |
+
|
| 402 |
+
def _ensure_loaded(self):
|
| 403 |
+
if self._loaded:
|
| 404 |
+
return
|
| 405 |
+
try:
|
| 406 |
+
if os.path.exists(self.mapping_path):
|
| 407 |
+
with open(self.mapping_path, "r", encoding="utf-8") as f:
|
| 408 |
+
data = json.load(f)
|
| 409 |
+
self._prompt_to_field = {self._normalize(k): v for k, v in data.items()}
|
| 410 |
+
else:
|
| 411 |
+
self._prompt_to_field = {}
|
| 412 |
+
except Exception:
|
| 413 |
+
self._prompt_to_field = {}
|
| 414 |
+
finally:
|
| 415 |
+
self._loaded = True
|
| 416 |
+
|
| 417 |
+
def _field_to_group(self, field: str) -> int:
|
| 418 |
+
if not isinstance(field, str):
|
| 419 |
+
return 2
|
| 420 |
+
f = field.strip()
|
| 421 |
+
if f in self.bio_med_fields:
|
| 422 |
+
return 0
|
| 423 |
+
if f in self.eng_tech_fields:
|
| 424 |
+
return 1
|
| 425 |
+
return 2
|
| 426 |
+
|
| 427 |
+
def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray:
|
| 428 |
+
self._ensure_loaded()
|
| 429 |
+
labels = []
|
| 430 |
+
for sample in dataset:
|
| 431 |
+
prompt = sample.get('prompt', '')
|
| 432 |
+
p_key = self._normalize(prompt)
|
| 433 |
+
field = self._prompt_to_field.get(p_key)
|
| 434 |
+
if field is None:
|
| 435 |
+
q = sample.get('question', '')
|
| 436 |
+
q_key = self._normalize(q)
|
| 437 |
+
field = self._prompt_to_field.get(q_key)
|
| 438 |
+
group_id = self._field_to_group(field)
|
| 439 |
+
labels.append(float(group_id))
|
| 440 |
+
return np.array(labels, dtype=float)
|
| 441 |
+
|
| 442 |
+
def create_bins(self, values: np.ndarray, method: str = 'ignored', custom_bins: List[float] = None) -> List[Tuple[float, float]]:
|
| 443 |
+
return [(-0.5, 0.5), (0.5, 1.5), (1.5, 2.5)]
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def get_available_groupers() -> Dict[str, ConditionalGrouper]:
|
| 447 |
+
return {
|
| 448 |
+
'view_count': ViewCountGrouper(),
|
| 449 |
+
'medical_content': MedicalContentGrouper(),
|
| 450 |
+
'false_claim_risk': FalseClaimRiskGrouper(),
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def compute_conditional_coverage_by_grouper(
|
| 455 |
+
filtered_dataset: List[Dict[str, Any]],
|
| 456 |
+
grouping_values: np.ndarray,
|
| 457 |
+
bins: List[Tuple[float, float]]
|
| 458 |
+
) -> List[float]:
|
| 459 |
+
"""Calculate conditional coverage by a specific grouper"""
|
| 460 |
+
|
| 461 |
+
def compute_marginal_coverage(sub_dataset: List[Dict[str, Any]]) -> float:
|
| 462 |
+
"""Calculate marginal coverage from a given subset"""
|
| 463 |
+
indicators = []
|
| 464 |
+
for d in sub_dataset:
|
| 465 |
+
retained = d.get('filtered_claims', [])
|
| 466 |
+
has_false = any([not c.get('is_supported', False) for c in retained])
|
| 467 |
+
indicators.append(0.0 if has_false else 1.0)
|
| 468 |
+
return float(np.mean(indicators)) if indicators else 0.0
|
| 469 |
+
|
| 470 |
+
coverage_results = []
|
| 471 |
+
|
| 472 |
+
for bin_min, bin_max in bins:
|
| 473 |
+
mask = []
|
| 474 |
+
for i, value in enumerate(grouping_values):
|
| 475 |
+
if np.isfinite(value):
|
| 476 |
+
mask.append(bin_min <= value <= bin_max)
|
| 477 |
+
else:
|
| 478 |
+
mask.append(False)
|
| 479 |
+
|
| 480 |
+
indices = [i for i, m in enumerate(mask) if m]
|
| 481 |
+
|
| 482 |
+
if not indices:
|
| 483 |
+
coverage_results.append(np.nan)
|
| 484 |
+
continue
|
| 485 |
+
|
| 486 |
+
subset = [filtered_dataset[i] for i in indices]
|
| 487 |
+
coverage = compute_marginal_coverage(subset)
|
| 488 |
+
coverage_results.append(coverage)
|
| 489 |
+
|
| 490 |
+
return coverage_results
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def compute_retention_by_grouper(
|
| 494 |
+
filtered_dataset: List[Dict[str, Any]],
|
| 495 |
+
grouping_values: np.ndarray,
|
| 496 |
+
bins: List[Tuple[float, float]]
|
| 497 |
+
) -> List[Dict[str, Any]]:
|
| 498 |
+
"""Calculate retention rate by a specific grouper"""
|
| 499 |
+
|
| 500 |
+
retention_results = []
|
| 501 |
+
|
| 502 |
+
for bin_min, bin_max in bins:
|
| 503 |
+
mask = []
|
| 504 |
+
for i, value in enumerate(grouping_values):
|
| 505 |
+
if np.isfinite(value):
|
| 506 |
+
mask.append(bin_min <= value <= bin_max)
|
| 507 |
+
else:
|
| 508 |
+
mask.append(False)
|
| 509 |
+
|
| 510 |
+
indices = [i for i, m in enumerate(mask) if m]
|
| 511 |
+
|
| 512 |
+
if not indices:
|
| 513 |
+
retention_results.append({
|
| 514 |
+
'bin': (float(bin_min), float(bin_max)),
|
| 515 |
+
'samples': 0,
|
| 516 |
+
'retained': 0,
|
| 517 |
+
'total': 0,
|
| 518 |
+
'rate': np.nan,
|
| 519 |
+
})
|
| 520 |
+
continue
|
| 521 |
+
|
| 522 |
+
total_claims = 0
|
| 523 |
+
retained_claims = 0
|
| 524 |
+
sample_count = len(indices)
|
| 525 |
+
|
| 526 |
+
for idx in indices:
|
| 527 |
+
d = filtered_dataset[idx]
|
| 528 |
+
afs = d.get('atomic_facts', [])
|
| 529 |
+
total_claims += len(afs)
|
| 530 |
+
retained_claims += len(d.get('filtered_claims', []))
|
| 531 |
+
|
| 532 |
+
rate = (retained_claims / total_claims) if total_claims > 0 else np.nan
|
| 533 |
+
|
| 534 |
+
retention_results.append({
|
| 535 |
+
'bin': (float(bin_min), float(bin_max)),
|
| 536 |
+
'samples': sample_count,
|
| 537 |
+
'retained': int(retained_claims),
|
| 538 |
+
'total': int(total_claims),
|
| 539 |
+
'rate': float(rate) if not np.isnan(rate) else np.nan,
|
| 540 |
+
})
|
| 541 |
+
|
| 542 |
+
return retention_results
|
MACI-main/experiments/run_experiment.py
ADDED
|
@@ -0,0 +1,1127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pickle
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import json
|
| 6 |
+
import argparse
|
| 7 |
+
import time
|
| 8 |
+
import logging
|
| 9 |
+
import warnings
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from typing import Optional, Dict, Any, List
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
warnings.filterwarnings('default')
|
| 14 |
+
warnings.simplefilter('ignore', category=FutureWarning)
|
| 15 |
+
np.seterr(all='warn')
|
| 16 |
+
|
| 17 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
| 18 |
+
|
| 19 |
+
from conformal.basic_conformal import BasicConformal
|
| 20 |
+
from conformal.adaptive_conformal import MACIAdaptiveConformal, SubgroupOptimizedMACI
|
| 21 |
+
from conditional_groupers import get_available_groupers
|
| 22 |
+
from conditional_groupers import set_view_metadata_csv
|
| 23 |
+
|
| 24 |
+
MODEL_NAMES = ['qwen-2.5-72b-instruct', 'deepseek-chat-v3-0324', 'llama-3.3-70b-instruct']
|
| 25 |
+
|
| 26 |
+
def setup_logging(log_dir: str):
|
| 27 |
+
"""Sets up logging to both console and file."""
|
| 28 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 29 |
+
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 30 |
+
log_filename = os.path.join(log_dir, f"experiment_log_{timestamp}.log")
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger()
|
| 33 |
+
logger.setLevel(logging.INFO)
|
| 34 |
+
|
| 35 |
+
for handler in logger.handlers[:]:
|
| 36 |
+
logger.removeHandler(handler)
|
| 37 |
+
|
| 38 |
+
file_handler = logging.FileHandler(log_filename)
|
| 39 |
+
file_handler.setFormatter(logging.Formatter('%(message)s'))
|
| 40 |
+
logger.addHandler(file_handler)
|
| 41 |
+
|
| 42 |
+
console_handler = logging.StreamHandler()
|
| 43 |
+
console_handler.setFormatter(logging.Formatter('%(message)s'))
|
| 44 |
+
logger.addHandler(console_handler)
|
| 45 |
+
|
| 46 |
+
logging.info(f"📝 Logging to {log_filename}")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_1000_samples(data_dir: str, scores_dir: Optional[str] = None, dataset_type: str = "auto", limit_samples: int = 1000):
|
| 50 |
+
"""Load up to `limit_samples` samples and attach LLM scores."""
|
| 51 |
+
logging.info(f"📁 Loading up to {limit_samples} samples with provided scores...")
|
| 52 |
+
|
| 53 |
+
if dataset_type == "auto":
|
| 54 |
+
wikibio_path = os.path.join(data_dir, "wiki_scores", "wikibio_final_dataset.pkl")
|
| 55 |
+
medlfqa_path = os.path.join(data_dir, "med_scores", "medlfqa_dataset.pkl")
|
| 56 |
+
|
| 57 |
+
if os.path.exists(wikibio_path):
|
| 58 |
+
dataset_type = "wikibio"
|
| 59 |
+
logging.info(f" 🔍 Auto-detected dataset type: {dataset_type}")
|
| 60 |
+
elif os.path.exists(medlfqa_path):
|
| 61 |
+
dataset_type = "medlfqa"
|
| 62 |
+
logging.info(f" 🔍 Auto-detected dataset type: {dataset_type}")
|
| 63 |
+
else:
|
| 64 |
+
raise FileNotFoundError(f"Could not find dataset files in {data_dir}")
|
| 65 |
+
|
| 66 |
+
if dataset_type == "wikibio":
|
| 67 |
+
dataset_path = os.path.join(data_dir, "wiki_scores", "wikibio_final_dataset.pkl")
|
| 68 |
+
base_scores_dir = os.path.join(data_dir, "wiki_scores")
|
| 69 |
+
score_prefix = "wikibio_scores"
|
| 70 |
+
basic_scores = {
|
| 71 |
+
'frequencies': os.path.join(base_scores_dir, "wikibio_final_frequencies.npz"),
|
| 72 |
+
'logprobs': os.path.join(base_scores_dir, "wikibio_final_logprobs.npz"),
|
| 73 |
+
'selfevals': os.path.join(base_scores_dir, "wikibio_final_self_evals.npz")
|
| 74 |
+
}
|
| 75 |
+
elif dataset_type == "medlfqa":
|
| 76 |
+
dataset_path = os.path.join(data_dir, "med_scores", "medlfqa_dataset.pkl")
|
| 77 |
+
base_scores_dir = os.path.join(data_dir, "med_scores")
|
| 78 |
+
score_prefix = "medlfqa_scores"
|
| 79 |
+
basic_scores = {
|
| 80 |
+
'frequencies': os.path.join(base_scores_dir, "medlfqa_frequencies.npz"),
|
| 81 |
+
'logprobs': os.path.join(base_scores_dir, "medlfqa_logprobs.npz"),
|
| 82 |
+
'selfevals': os.path.join(base_scores_dir, "medlfqa_selfevals.npz")
|
| 83 |
+
}
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"Unknown dataset type: {dataset_type}")
|
| 86 |
+
|
| 87 |
+
logging.info(f" 📊 Dataset: {dataset_path}")
|
| 88 |
+
logging.info(f" 🎯 Score prefix: {score_prefix}")
|
| 89 |
+
|
| 90 |
+
with open(dataset_path, 'rb') as f:
|
| 91 |
+
dataset = pickle.load(f)
|
| 92 |
+
|
| 93 |
+
dataset_1000 = dataset[:limit_samples]
|
| 94 |
+
|
| 95 |
+
frequencies = {}
|
| 96 |
+
logprobs = {}
|
| 97 |
+
selfevals = {}
|
| 98 |
+
|
| 99 |
+
for score_type, score_path in basic_scores.items():
|
| 100 |
+
try:
|
| 101 |
+
if score_type == 'frequencies':
|
| 102 |
+
frequencies = np.load(score_path, allow_pickle=True)
|
| 103 |
+
elif score_type == 'logprobs':
|
| 104 |
+
logprobs = np.load(score_path, allow_pickle=True)
|
| 105 |
+
elif score_type == 'selfevals':
|
| 106 |
+
selfevals = np.load(score_path, allow_pickle=True)
|
| 107 |
+
logging.info(f" ✅ Loaded {score_type}: {score_path}")
|
| 108 |
+
except FileNotFoundError:
|
| 109 |
+
logging.warning(f" ⚠️ {score_type} not found: {score_path}")
|
| 110 |
+
|
| 111 |
+
if scores_dir is not None and os.path.isdir(scores_dir):
|
| 112 |
+
score_files_dir = scores_dir
|
| 113 |
+
else:
|
| 114 |
+
score_files_dir = base_scores_dir
|
| 115 |
+
|
| 116 |
+
logging.info(f" 🎯 Score files directory: {score_files_dir}")
|
| 117 |
+
|
| 118 |
+
import glob
|
| 119 |
+
all_npz_files = sorted(glob.glob(os.path.join(score_files_dir, f"{score_prefix}_*.npz")))
|
| 120 |
+
def find_by_tokens(token_options: List[List[str]]):
|
| 121 |
+
for tokens in token_options:
|
| 122 |
+
for fp in all_npz_files:
|
| 123 |
+
name = os.path.basename(fp).lower()
|
| 124 |
+
if all(t in name for t in tokens):
|
| 125 |
+
return fp
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
score_files = {
|
| 129 |
+
'qwen-2.5-72b-instruct': find_by_tokens([
|
| 130 |
+
['qwen-2.5-72b','instruct'], ['qwen','instruct'], ['qwen']
|
| 131 |
+
]),
|
| 132 |
+
'deepseek-chat-v3-0324': find_by_tokens([
|
| 133 |
+
['deepseek','chat','v3'], ['deepseek','chat'], ['deepseek']
|
| 134 |
+
]),
|
| 135 |
+
'llama-3.3-70b-instruct': find_by_tokens([
|
| 136 |
+
['llama-3.3-70b','instruct'], ['llama-3.3','instruct'], ['llama']
|
| 137 |
+
]),
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
llm_scores = {}
|
| 141 |
+
for model_name, filename in score_files.items():
|
| 142 |
+
try:
|
| 143 |
+
model_data = np.load(filename, allow_pickle=True)
|
| 144 |
+
model_prompts = model_data['prompts'].tolist()
|
| 145 |
+
model_scores_list = model_data['scores_list'].tolist()
|
| 146 |
+
llm_scores[model_name] = {p: s for p, s in zip(model_prompts, model_scores_list)}
|
| 147 |
+
logging.info(f" ✅ Loaded {model_name} scores")
|
| 148 |
+
except (FileNotFoundError, TypeError):
|
| 149 |
+
logging.warning(f" ⚠️ {model_name} scores not found or invalid: {filename}")
|
| 150 |
+
llm_scores[model_name] = {}
|
| 151 |
+
|
| 152 |
+
aligned_data = []
|
| 153 |
+
for i, sample in enumerate(dataset_1000):
|
| 154 |
+
prompt = sample['prompt']
|
| 155 |
+
atomic_facts = sample.get('atomic_facts', [])
|
| 156 |
+
n_claims = len(atomic_facts)
|
| 157 |
+
|
| 158 |
+
if n_claims == 0:
|
| 159 |
+
continue
|
| 160 |
+
|
| 161 |
+
if prompt in selfevals:
|
| 162 |
+
selfeval_vals = selfevals[prompt]
|
| 163 |
+
if hasattr(selfeval_vals, 'ndim') and selfeval_vals.ndim == 1:
|
| 164 |
+
if np.allclose(selfeval_vals, -1):
|
| 165 |
+
continue
|
| 166 |
+
elif np.allclose(selfeval_vals, -1):
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
annotations = np.array([af.get('is_supported', False) for af in atomic_facts])
|
| 170 |
+
|
| 171 |
+
freq_scores = np.zeros(n_claims)
|
| 172 |
+
if dataset_type == 'wikibio':
|
| 173 |
+
key = f'arr_{i}'
|
| 174 |
+
if key in frequencies:
|
| 175 |
+
freq_vals = frequencies[key]
|
| 176 |
+
if hasattr(freq_vals, 'ndim') and freq_vals.ndim == 1:
|
| 177 |
+
freq_scores = freq_vals[:n_claims]
|
| 178 |
+
else:
|
| 179 |
+
freq_scores = np.full(n_claims, freq_vals.item() if hasattr(freq_vals, 'item') else freq_vals)
|
| 180 |
+
freq_scores = np.nan_to_num(freq_scores, nan=0.0)
|
| 181 |
+
else:
|
| 182 |
+
if prompt in frequencies:
|
| 183 |
+
freq_vals = frequencies[prompt]
|
| 184 |
+
if hasattr(freq_vals, 'ndim') and freq_vals.ndim == 1:
|
| 185 |
+
freq_scores = freq_vals[:n_claims]
|
| 186 |
+
else:
|
| 187 |
+
freq_val = freq_vals.item() if hasattr(freq_vals, 'item') else freq_vals
|
| 188 |
+
freq_val = 0.0 if np.isnan(freq_val) else freq_val
|
| 189 |
+
freq_scores = np.full(n_claims, freq_val)
|
| 190 |
+
freq_scores = np.nan_to_num(freq_scores, nan=0.0)
|
| 191 |
+
|
| 192 |
+
if dataset_type == 'wikibio':
|
| 193 |
+
key = f'arr_{i}'
|
| 194 |
+
if key in logprobs:
|
| 195 |
+
lp_vals = logprobs[key]
|
| 196 |
+
if hasattr(lp_vals, 'ndim') and lp_vals.ndim == 1:
|
| 197 |
+
logprob_scores = np.nan_to_num(lp_vals[:n_claims], nan=0.0)
|
| 198 |
+
else:
|
| 199 |
+
v = lp_vals.item() if hasattr(lp_vals, 'item') else lp_vals
|
| 200 |
+
v = 0.0 if np.isnan(v) else v
|
| 201 |
+
logprob_scores = np.full(n_claims, v)
|
| 202 |
+
else:
|
| 203 |
+
logprob_scores = np.zeros(n_claims)
|
| 204 |
+
else:
|
| 205 |
+
if prompt in logprobs:
|
| 206 |
+
logprob_vals = logprobs[prompt]
|
| 207 |
+
if hasattr(logprob_vals, 'ndim') and logprob_vals.ndim == 1:
|
| 208 |
+
logprob_scores = logprob_vals[:n_claims]
|
| 209 |
+
logprob_scores = np.nan_to_num(logprob_scores, nan=0.0)
|
| 210 |
+
else:
|
| 211 |
+
logprob_val = logprob_vals.item() if hasattr(logprob_vals, 'item') else logprob_vals
|
| 212 |
+
logprob_val = 0.0 if np.isnan(logprob_val) else logprob_val
|
| 213 |
+
logprob_scores = np.full(n_claims, logprob_val)
|
| 214 |
+
else:
|
| 215 |
+
logprob_scores = np.zeros(n_claims)
|
| 216 |
+
|
| 217 |
+
if dataset_type == 'wikibio':
|
| 218 |
+
key = f'arr_{i}'
|
| 219 |
+
if key in selfevals:
|
| 220 |
+
se_vals = selfevals[key]
|
| 221 |
+
if hasattr(se_vals, 'ndim') and se_vals.ndim == 1:
|
| 222 |
+
selfeval_scores = np.nan_to_num(se_vals[:n_claims], nan=0.0)
|
| 223 |
+
else:
|
| 224 |
+
v = se_vals.item() if hasattr(se_vals, 'item') else se_vals
|
| 225 |
+
v = 0.0 if np.isnan(v) else v
|
| 226 |
+
selfeval_scores = np.full(n_claims, v)
|
| 227 |
+
else:
|
| 228 |
+
selfeval_scores = np.zeros(n_claims)
|
| 229 |
+
else:
|
| 230 |
+
if prompt in selfevals:
|
| 231 |
+
selfeval_vals = selfevals[prompt]
|
| 232 |
+
if hasattr(selfeval_vals, 'ndim') and selfeval_vals.ndim == 1:
|
| 233 |
+
selfeval_scores = selfeval_vals[:n_claims]
|
| 234 |
+
selfeval_scores = np.nan_to_num(selfeval_scores, nan=0.0)
|
| 235 |
+
else:
|
| 236 |
+
selfeval_val = selfeval_vals.item() if hasattr(selfeval_vals, 'item') else selfeval_vals
|
| 237 |
+
selfeval_val = 0.0 if np.isnan(selfeval_val) else selfeval_val
|
| 238 |
+
selfeval_scores = np.full(n_claims, selfeval_val)
|
| 239 |
+
else:
|
| 240 |
+
selfeval_scores = np.zeros(n_claims)
|
| 241 |
+
|
| 242 |
+
ordinal_scores = np.arange(n_claims)
|
| 243 |
+
if n_claims > 1:
|
| 244 |
+
ordinal_scores = ordinal_scores / (n_claims - 1)
|
| 245 |
+
else:
|
| 246 |
+
ordinal_scores = np.array([0.5])
|
| 247 |
+
|
| 248 |
+
scores_dict = {}
|
| 249 |
+
for model_name, model_data in llm_scores.items():
|
| 250 |
+
if prompt in model_data:
|
| 251 |
+
scores_dict[model_name] = np.array(model_data[prompt][:n_claims])
|
| 252 |
+
scores_dict[model_name] = np.clip(scores_dict[model_name], 0.0, 1.0)
|
| 253 |
+
else:
|
| 254 |
+
scores_dict[model_name] = np.full(n_claims, 0.5)
|
| 255 |
+
|
| 256 |
+
valid_llm_scores = []
|
| 257 |
+
for model_name in MODEL_NAMES:
|
| 258 |
+
if model_name in scores_dict:
|
| 259 |
+
valid_llm_scores.append(scores_dict[model_name])
|
| 260 |
+
|
| 261 |
+
if valid_llm_scores:
|
| 262 |
+
ensemble_mean = np.mean(valid_llm_scores, axis=0)
|
| 263 |
+
ensemble_std = np.std(valid_llm_scores, axis=0)
|
| 264 |
+
lambda_uncertainty = 0.0
|
| 265 |
+
ensemble_scores = ensemble_mean - lambda_uncertainty * ensemble_std
|
| 266 |
+
ensemble_scores = np.clip(ensemble_scores, 0.0, 1.0)
|
| 267 |
+
else:
|
| 268 |
+
ensemble_scores = np.full(n_claims, 0.5)
|
| 269 |
+
|
| 270 |
+
features_4d = np.concatenate((
|
| 271 |
+
freq_scores.reshape(-1, 1),
|
| 272 |
+
selfeval_scores.reshape(-1, 1),
|
| 273 |
+
(logprob_scores / (np.max(logprob_scores) + 1e-8)).reshape(-1, 1),
|
| 274 |
+
ordinal_scores.reshape(-1, 1)
|
| 275 |
+
), axis=1)
|
| 276 |
+
|
| 277 |
+
aligned_data.append({
|
| 278 |
+
'sample': sample,
|
| 279 |
+
'annotations': annotations,
|
| 280 |
+
'scores': {
|
| 281 |
+
'frequency': freq_scores,
|
| 282 |
+
'selfeval': selfeval_scores,
|
| 283 |
+
'logprob': logprob_scores,
|
| 284 |
+
'ensemble': ensemble_scores,
|
| 285 |
+
**scores_dict
|
| 286 |
+
},
|
| 287 |
+
'features_4d': features_4d,
|
| 288 |
+
'prompt_features': np.array([1.0, len(sample.get('response', '')), len(prompt)])
|
| 289 |
+
})
|
| 290 |
+
|
| 291 |
+
logging.info(f"✅ Loaded {len(aligned_data)} valid samples")
|
| 292 |
+
return aligned_data
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def create_splits(data, calib_ratio=0.7, test_ratio=0.3, random_seed=42):
|
| 296 |
+
"""Create calibration and test splits based on ratios with random shuffling"""
|
| 297 |
+
total_size = len(data)
|
| 298 |
+
calib_size = int(total_size * calib_ratio)
|
| 299 |
+
test_size = int(total_size * test_ratio)
|
| 300 |
+
|
| 301 |
+
if calib_size + test_size > total_size:
|
| 302 |
+
test_size = total_size - calib_size
|
| 303 |
+
|
| 304 |
+
logging.info(f"📊 Creating splits: {calib_size} calib ({calib_ratio*100:.0f}%), {test_size} test ({test_ratio*100:.0f}%)")
|
| 305 |
+
|
| 306 |
+
np.random.seed(random_seed)
|
| 307 |
+
indices = np.random.permutation(total_size)
|
| 308 |
+
|
| 309 |
+
calib_idx = indices[:calib_size]
|
| 310 |
+
test_idx = indices[calib_size:calib_size + test_size]
|
| 311 |
+
|
| 312 |
+
calib_data = [data[i] for i in calib_idx]
|
| 313 |
+
test_data = [data[i] for i in test_idx]
|
| 314 |
+
|
| 315 |
+
logging.info(f"🎲 Random split with seed {random_seed}: calib indices {calib_idx[:5]}..., test indices {test_idx[:5]}...")
|
| 316 |
+
|
| 317 |
+
return calib_data, test_data, calib_idx, test_idx
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def run_bcp_experiment(calib_data, test_data, score_type='frequency', alpha=0.1, **kwargs):
|
| 321 |
+
"""
|
| 322 |
+
Run BCP (Split Conformal) experiment.
|
| 323 |
+
[FIXED] Uses a unified score_function that relies on pre-aligned data.
|
| 324 |
+
"""
|
| 325 |
+
logging.info(f"📈 Running BCI (Split Conformal) with {score_type} scores...")
|
| 326 |
+
|
| 327 |
+
calib_samples = [item['sample'] for item in calib_data]
|
| 328 |
+
test_samples = [item['sample'] for item in test_data]
|
| 329 |
+
|
| 330 |
+
def score_function(samples):
|
| 331 |
+
result = []
|
| 332 |
+
sample_to_data = {item['sample']['prompt']: item for item in calib_data + test_data}
|
| 333 |
+
|
| 334 |
+
for sample in samples:
|
| 335 |
+
prompt = sample['prompt']
|
| 336 |
+
if prompt in sample_to_data:
|
| 337 |
+
scores = sample_to_data[prompt]['scores'].get(score_type)
|
| 338 |
+
if scores is not None:
|
| 339 |
+
if score_type in ['frequency', 'selfeval', 'logprob']:
|
| 340 |
+
non_conformity_scores = 1.0 - scores
|
| 341 |
+
else:
|
| 342 |
+
non_conformity_scores = 1.0 - scores
|
| 343 |
+
result.append(non_conformity_scores)
|
| 344 |
+
else:
|
| 345 |
+
|
| 346 |
+
n_claims = len(sample.get('atomic_facts', []))
|
| 347 |
+
result.append(np.full(n_claims, 0.5))
|
| 348 |
+
else:
|
| 349 |
+
n_claims = len(sample.get('atomic_facts', []))
|
| 350 |
+
result.append(np.full(n_claims, 0.5))
|
| 351 |
+
return result
|
| 352 |
+
|
| 353 |
+
basic_conformal = BasicConformal(score_function=score_function, random_state=0)
|
| 354 |
+
basic_conformal.fit_on_calib(calib_samples, alpha=alpha)
|
| 355 |
+
filtered_results, _ = basic_conformal.predict(test_samples)
|
| 356 |
+
|
| 357 |
+
coverage = compute_marginal_coverage(filtered_results)
|
| 358 |
+
retention = evaluate_retention(filtered_results, "BCP")
|
| 359 |
+
|
| 360 |
+
return {
|
| 361 |
+
'coverage': coverage,
|
| 362 |
+
'retention_rate': retention['overall_retention_rate'],
|
| 363 |
+
'retained_claims': retention['retained_claims'],
|
| 364 |
+
'total_claims': retention['total_claims'],
|
| 365 |
+
'filtered_results': filtered_results
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
def run_as_experiment(calib_data: List[Dict], test_data: List[Dict],
|
| 369 |
+
model_names: List[str],
|
| 370 |
+
alpha: float,
|
| 371 |
+
as_mode: str,
|
| 372 |
+
subgroup_name: str, **kwargs) -> Dict:
|
| 373 |
+
"""Run MACI (Adaptive Subclaims) experiment for a given subgroup."""
|
| 374 |
+
logging.info(f"📊 Running MACI experiment with mode: {as_mode} for subgroup: '{subgroup_name}'...")
|
| 375 |
+
|
| 376 |
+
timing: Dict[str, float] = {}
|
| 377 |
+
|
| 378 |
+
if as_mode == 'subgroup_optimized':
|
| 379 |
+
available_groupers = get_available_groupers()
|
| 380 |
+
if subgroup_name not in available_groupers:
|
| 381 |
+
raise ValueError(f"Unknown subgroup: {subgroup_name}")
|
| 382 |
+
grouper = available_groupers[subgroup_name]
|
| 383 |
+
|
| 384 |
+
as_model = SubgroupOptimizedMACI(
|
| 385 |
+
model_names=model_names,
|
| 386 |
+
grouper=grouper,
|
| 387 |
+
n_bins=3,
|
| 388 |
+
random_state=kwargs.get('random_state', 0),
|
| 389 |
+
solver='osqp',
|
| 390 |
+
)
|
| 391 |
+
t0 = time.perf_counter()
|
| 392 |
+
as_model.fit(calib_data, alpha=alpha, ensemble_train_ratio=0.5, target_tpr=kwargs.get('target_tpr', 0.95))
|
| 393 |
+
timing_details = as_model.get_timing()
|
| 394 |
+
timing['maci_weight_optimization_s'] = timing_details.get('weight_optimization_s', 0.0)
|
| 395 |
+
timing['maci_calibration_s'] = timing_details.get('calibration_s', 0.0)
|
| 396 |
+
|
| 397 |
+
t1 = time.perf_counter()
|
| 398 |
+
filtered_results, _ = as_model.predict(test_data)
|
| 399 |
+
timing['maci_inference_s'] = time.perf_counter() - t1
|
| 400 |
+
budgets = as_model.get_budgets()
|
| 401 |
+
weights = as_model.get_weights()
|
| 402 |
+
|
| 403 |
+
else:
|
| 404 |
+
score_type = kwargs.get("as_score_type", "ensemble")
|
| 405 |
+
def score_function(data_list: List[Dict]) -> List[np.ndarray]:
|
| 406 |
+
scores_list = []
|
| 407 |
+
for item in data_list:
|
| 408 |
+
valid_scores = [item['scores'][m] for m in model_names if m in item['scores']]
|
| 409 |
+
if valid_scores:
|
| 410 |
+
scores_list.append(np.mean(valid_scores, axis=0))
|
| 411 |
+
else:
|
| 412 |
+
scores_list.append(np.array([0.5] * len(item.get('sample', {}).get('atomic_facts', []))))
|
| 413 |
+
return scores_list
|
| 414 |
+
|
| 415 |
+
as_model = MACIAdaptiveConformal(score_function=score_function, random_state=kwargs.get('random_state', 0))
|
| 416 |
+
t0 = time.perf_counter()
|
| 417 |
+
as_model.fit_on_calib(calib_data, alpha=alpha)
|
| 418 |
+
timing['maci_calibration_s'] = time.perf_counter() - t0
|
| 419 |
+
t1 = time.perf_counter()
|
| 420 |
+
filtered_results, _ = as_model.predict(test_data)
|
| 421 |
+
timing['maci_inference_s'] = time.perf_counter() - t1
|
| 422 |
+
budgets = {'overall': as_model.tau_hat}
|
| 423 |
+
weights = None
|
| 424 |
+
coverage = compute_marginal_coverage(filtered_results)
|
| 425 |
+
retention = evaluate_retention(filtered_results, "MACI")
|
| 426 |
+
return {
|
| 427 |
+
'coverage': coverage,
|
| 428 |
+
'retention_rate': retention['overall_retention_rate'],
|
| 429 |
+
'retained_claims': retention['retained_claims'],
|
| 430 |
+
'total_claims': retention['total_claims'],
|
| 431 |
+
'budgets': budgets,
|
| 432 |
+
'weights': weights,
|
| 433 |
+
'filtered_results': filtered_results,
|
| 434 |
+
'timing': timing
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def run_cci_experiment(
|
| 439 |
+
calib_data,
|
| 440 |
+
test_data,
|
| 441 |
+
alpha=0.1,
|
| 442 |
+
boosting_epochs=1000,
|
| 443 |
+
boosting_lr=0.005,
|
| 444 |
+
calib_split_for_boost=0.3,
|
| 445 |
+
random_seed=0,
|
| 446 |
+
adaptive_alpha: bool = False,
|
| 447 |
+
retention_target: float = 0.7
|
| 448 |
+
):
|
| 449 |
+
"""
|
| 450 |
+
Two-stage CCI:
|
| 451 |
+
- Stage 1 (Boosting): learn beta on a subset of calib_data
|
| 452 |
+
- Stage 2 (CondConf): calibrate CondConf on the remaining calib_data using learned beta
|
| 453 |
+
- Predict on test_data
|
| 454 |
+
"""
|
| 455 |
+
logging.info("🎯 Running CCI (Boosting -> CondConf) with internal calib split...")
|
| 456 |
+
|
| 457 |
+
try:
|
| 458 |
+
from conformal.conditional_conformal import ConditionalConformalBoosting, ConditionalConformalInference
|
| 459 |
+
except Exception as e:
|
| 460 |
+
logging.error(f"CCI unavailable due to missing dependencies: {e}")
|
| 461 |
+
return {
|
| 462 |
+
"coverage": None,
|
| 463 |
+
"retention_rate": None,
|
| 464 |
+
"retained_claims": 0,
|
| 465 |
+
"total_claims": 0,
|
| 466 |
+
"filtered_results": [],
|
| 467 |
+
"timing": {"cci_skipped": True, "error": str(e)}
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
rng = np.random.default_rng(random_seed)
|
| 471 |
+
idx = np.arange(len(calib_data))
|
| 472 |
+
rng.shuffle(idx)
|
| 473 |
+
k = int(len(idx) * calib_split_for_boost)
|
| 474 |
+
idx_boost, idx_conf = idx[:k], idx[k:]
|
| 475 |
+
if len(idx_conf) == 0:
|
| 476 |
+
idx_boost, idx_conf = idx[:-1], idx[-1:] [1]
|
| 477 |
+
calib_boost = [calib_data[i] for i in idx_boost]
|
| 478 |
+
calib_conf = [calib_data[i] for i in idx_conf]
|
| 479 |
+
logging.info(f" 🔧 calib split -> boost:{len(calib_boost)} | conf:{len(calib_conf)} (seed={random_seed})")
|
| 480 |
+
|
| 481 |
+
booster = ConditionalConformalBoosting(random_state=random_seed)
|
| 482 |
+
t_boost_0 = time.perf_counter()
|
| 483 |
+
beta = booster.fit(
|
| 484 |
+
calib_boost,
|
| 485 |
+
boosting_epochs=boosting_epochs,
|
| 486 |
+
boosting_lr=boosting_lr
|
| 487 |
+
)
|
| 488 |
+
t_boost_1 = time.perf_counter()
|
| 489 |
+
|
| 490 |
+
infer = ConditionalConformalInference(random_state=random_seed)
|
| 491 |
+
t_fit_0 = time.perf_counter()
|
| 492 |
+
infer.fit(calib_conf, alpha=alpha, beta=beta, adaptive_alpha=adaptive_alpha, retention_target=retention_target)
|
| 493 |
+
t_fit_1 = time.perf_counter()
|
| 494 |
+
auroc_results = infer.evaluate_auroc(test_data)
|
| 495 |
+
t_pred_0 = time.perf_counter()
|
| 496 |
+
filtered_results = infer.predict(test_data)
|
| 497 |
+
t_pred_1 = time.perf_counter()
|
| 498 |
+
|
| 499 |
+
coverage = compute_marginal_coverage(filtered_results)
|
| 500 |
+
retention = evaluate_retention(filtered_results, "CCI")
|
| 501 |
+
|
| 502 |
+
return {
|
| 503 |
+
"coverage": coverage,
|
| 504 |
+
"retention_rate": retention["overall_retention_rate"],
|
| 505 |
+
"retained_claims": retention["retained_claims"],
|
| 506 |
+
"total_claims": retention["total_claims"],
|
| 507 |
+
"filtered_results": filtered_results,
|
| 508 |
+
"beta": beta,
|
| 509 |
+
"calib_sizes": {"boost": len(calib_boost), "conf": len(calib_conf)},
|
| 510 |
+
"split_seed": random_seed,
|
| 511 |
+
"timing": {
|
| 512 |
+
"cci_boost_fit_s": t_boost_1 - t_boost_0,
|
| 513 |
+
"cci_condconf_fit_s": t_fit_1 - t_fit_0,
|
| 514 |
+
"cci_inference_s": t_pred_1 - t_pred_0,
|
| 515 |
+
"cci_adaptive_alpha_enabled": bool(adaptive_alpha)
|
| 516 |
+
}
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
def evaluate_retention(filtered_dataset: List[Dict], method_name: str = "") -> Dict:
|
| 520 |
+
total_original_claims = 0
|
| 521 |
+
total_retained_claims = 0
|
| 522 |
+
|
| 523 |
+
if not filtered_dataset:
|
| 524 |
+
return {'overall_retention_rate': 0.0, 'retained_claims': 0, 'total_claims': 0}
|
| 525 |
+
|
| 526 |
+
for item in filtered_dataset:
|
| 527 |
+
sample_dict = item.get('sample', item)
|
| 528 |
+
if not isinstance(sample_dict, dict):
|
| 529 |
+
logging.warning(f"Skipping invalid item in retention evaluation: {type(sample_dict)}")
|
| 530 |
+
continue
|
| 531 |
+
|
| 532 |
+
original_claims = sample_dict.get('atomic_facts', [])
|
| 533 |
+
retained_claims = sample_dict.get('filtered_claims', [])
|
| 534 |
+
|
| 535 |
+
total_original_claims += len(original_claims)
|
| 536 |
+
total_retained_claims += len(retained_claims)
|
| 537 |
+
|
| 538 |
+
if total_original_claims > 0:
|
| 539 |
+
overall_retention_rate = total_retained_claims / total_original_claims
|
| 540 |
+
else:
|
| 541 |
+
overall_retention_rate = 0.0
|
| 542 |
+
|
| 543 |
+
return {
|
| 544 |
+
'overall_retention_rate': overall_retention_rate,
|
| 545 |
+
'retained_claims': total_retained_claims,
|
| 546 |
+
'total_claims': total_original_claims
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
def compute_marginal_coverage(filtered_dataset: List[Dict]):
|
| 550 |
+
indicators = []
|
| 551 |
+
for item in filtered_dataset:
|
| 552 |
+
sample_dict = item.get('sample', item)
|
| 553 |
+
if not isinstance(sample_dict, dict):
|
| 554 |
+
logging.warning(f"Skipping invalid item in coverage calculation: {type(sample_dict)}")
|
| 555 |
+
continue
|
| 556 |
+
|
| 557 |
+
retained = sample_dict.get('filtered_claims', [])
|
| 558 |
+
|
| 559 |
+
if len(retained) == 0:
|
| 560 |
+
indicators.append(1.0)
|
| 561 |
+
else:
|
| 562 |
+
has_false = any(not claim.get('is_supported', False) for claim in retained if isinstance(claim, dict))
|
| 563 |
+
indicators.append(0.0 if has_false else 1.0)
|
| 564 |
+
|
| 565 |
+
return np.mean(indicators) if indicators else 0.0
|
| 566 |
+
|
| 567 |
+
def compute_conditional_coverage(test_data, filtered_results, grouper, alpha=0.1, binning_method='quartiles'):
|
| 568 |
+
"""Compute conditional coverage for subgroups"""
|
| 569 |
+
|
| 570 |
+
combined_data = []
|
| 571 |
+
for orig_sample, filtered_sample in zip(test_data, filtered_results):
|
| 572 |
+
combined_sample = dict(orig_sample['sample'])
|
| 573 |
+
combined_sample['scores'] = orig_sample['scores']
|
| 574 |
+
combined_sample['filtered_claims'] = filtered_sample.get('filtered_claims', [])
|
| 575 |
+
combined_data.append(combined_sample)
|
| 576 |
+
|
| 577 |
+
method_mapping = {
|
| 578 |
+
'quantile': 'tertiles',
|
| 579 |
+
'equal_width': 'tertiles',
|
| 580 |
+
'quartiles': 'tertiles'
|
| 581 |
+
}
|
| 582 |
+
method = method_mapping.get(binning_method, 'tertiles')
|
| 583 |
+
|
| 584 |
+
values = grouper.compute_values(combined_data)
|
| 585 |
+
|
| 586 |
+
if len(values) == 0:
|
| 587 |
+
logging.warning(f" ⚠️ Warning: {grouper.__class__.__name__} returned no values")
|
| 588 |
+
return {}
|
| 589 |
+
|
| 590 |
+
if np.all(values == values[0]):
|
| 591 |
+
logging.warning(f" ⚠️ Warning: {grouper.__class__.__name__} all values identical ({values[0]:.4f})")
|
| 592 |
+
|
| 593 |
+
bins = grouper.create_bins(values, method=method)
|
| 594 |
+
|
| 595 |
+
groups = {}
|
| 596 |
+
group_names = ['low', 'medium', 'high'] if len(bins) == 3 else [f'bin_{i}' for i in range(len(bins))]
|
| 597 |
+
|
| 598 |
+
for i, (bin_min, bin_max) in enumerate(bins):
|
| 599 |
+
if i == len(bins) - 1:
|
| 600 |
+
mask = (values >= bin_min) & (values <= bin_max)
|
| 601 |
+
else:
|
| 602 |
+
mask = (values >= bin_min) & (values < bin_max)
|
| 603 |
+
|
| 604 |
+
indices = np.where(mask)[0].tolist()
|
| 605 |
+
bin_name = group_names[i] if i < len(group_names) else f'bin_{i}'
|
| 606 |
+
groups[bin_name] = indices
|
| 607 |
+
|
| 608 |
+
results = {}
|
| 609 |
+
for group_name, indices in groups.items():
|
| 610 |
+
if len(indices) == 0:
|
| 611 |
+
continue
|
| 612 |
+
|
| 613 |
+
group_indicators = []
|
| 614 |
+
group_total_claims = 0
|
| 615 |
+
group_retained_claims = 0
|
| 616 |
+
|
| 617 |
+
for idx in indices:
|
| 618 |
+
filtered_sample = filtered_results[idx]
|
| 619 |
+
retained = filtered_sample.get('filtered_claims', [])
|
| 620 |
+
original_claims = test_data[idx]['sample'].get('atomic_facts', [])
|
| 621 |
+
|
| 622 |
+
has_false = any(not claim.get('is_supported', False) for claim in retained)
|
| 623 |
+
group_indicators.append(0.0 if has_false else 1.0)
|
| 624 |
+
|
| 625 |
+
group_total_claims += len(original_claims)
|
| 626 |
+
group_retained_claims += len(retained)
|
| 627 |
+
|
| 628 |
+
coverage = np.mean(group_indicators) if group_indicators else 0.0
|
| 629 |
+
retention_rate = group_retained_claims / group_total_claims if group_total_claims > 0 else 0.0
|
| 630 |
+
results[group_name] = {
|
| 631 |
+
'size': len(indices),
|
| 632 |
+
'coverage': coverage,
|
| 633 |
+
'retention_rate': retention_rate,
|
| 634 |
+
'retained_claims': group_retained_claims,
|
| 635 |
+
'total_claims': group_total_claims,
|
| 636 |
+
'target_coverage': 1 - alpha,
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
return results
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def save_aggregated_results_to_json(results: Dict, args: argparse.Namespace):
|
| 643 |
+
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 644 |
+
default_output_dir = os.path.join(repo_root, 'analysis', 'experiment_results')
|
| 645 |
+
output_dir = getattr(args, 'time_out', None) or default_output_dir
|
| 646 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 647 |
+
|
| 648 |
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
| 649 |
+
groups_str = "_".join(sorted(args.conditional_groups))
|
| 650 |
+
filename = f"results_{args.dataset_type}_{args.model_set}_{groups_str}_{timestamp}.json"
|
| 651 |
+
filepath = os.path.join(output_dir, filename)
|
| 652 |
+
|
| 653 |
+
logging.info(f"\n💾 Saving aggregated results to {filepath}...")
|
| 654 |
+
|
| 655 |
+
def convert_to_native_types(obj):
|
| 656 |
+
if isinstance(obj, np.integer):
|
| 657 |
+
return int(obj)
|
| 658 |
+
elif isinstance(obj, np.floating):
|
| 659 |
+
return float(obj)
|
| 660 |
+
elif isinstance(obj, np.ndarray):
|
| 661 |
+
return obj.tolist()
|
| 662 |
+
elif isinstance(obj, defaultdict):
|
| 663 |
+
return dict(obj)
|
| 664 |
+
try:
|
| 665 |
+
json.dumps(obj)
|
| 666 |
+
return obj
|
| 667 |
+
except TypeError:
|
| 668 |
+
return str(obj)
|
| 669 |
+
|
| 670 |
+
keys_to_exclude = {'filtered_results', 'beta', 'weights', 'budgets', 'calib_sizes', 'split_seed'}
|
| 671 |
+
|
| 672 |
+
serializable_data = {}
|
| 673 |
+
for method_name, method_data in results.items():
|
| 674 |
+
serializable_data[method_name] = {}
|
| 675 |
+
for key, value in method_data.items():
|
| 676 |
+
if key in keys_to_exclude:
|
| 677 |
+
continue
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
try:
|
| 681 |
+
cleaned_value = json.loads(json.dumps(value, default=convert_to_native_types))
|
| 682 |
+
serializable_data[method_name][key] = cleaned_value
|
| 683 |
+
except Exception as e:
|
| 684 |
+
logging.warning(f"Could not serialize key '{key}' for method '{method_name}'. Skipping. Error: {e}")
|
| 685 |
+
|
| 686 |
+
try:
|
| 687 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 688 |
+
json.dump(serializable_data, f, indent=4, ensure_ascii=False)
|
| 689 |
+
logging.info(f"✅ Successfully saved results.")
|
| 690 |
+
except Exception as e:
|
| 691 |
+
logging.error(f"❌ Failed to save results to JSON: {e}")
|
| 692 |
+
|
| 693 |
+
def main():
|
| 694 |
+
parser = argparse.ArgumentParser(description="Experiment with three conformal methods")
|
| 695 |
+
parser.add_argument("--random-seed", type=int, default=123, help="Random seed")
|
| 696 |
+
parser.add_argument("--data-dir", type=str, default=None, help="Data directory (defaults to repo_root/data)")
|
| 697 |
+
parser.add_argument("--log-dir", type=str, default=None, help="Directory to save logs (defaults to repo_root/logs)")
|
| 698 |
+
parser.add_argument("--dataset-type", type=str, default="auto", choices=["auto", "wikibio", "medlfqa"],
|
| 699 |
+
help="Dataset type (auto-detected if not specified)")
|
| 700 |
+
parser.add_argument("--alpha", type=float, default=0.1, help="Significance level (fixed if --adaptive-alpha is false)")
|
| 701 |
+
parser.add_argument("--adaptive-alpha", action='store_true', help="Enable per-sample adaptive alpha (learn q*(z) for retention target)")
|
| 702 |
+
parser.add_argument("--retention-target", type=float, default=0.4, help="Target retention used to learn adaptive alpha")
|
| 703 |
+
parser.add_argument("--scores-dir", type=str, default=None, help="Directory containing final NPZ score files (optional)")
|
| 704 |
+
parser.add_argument("--calib-ratio", type=float, default=0.75, help="Calibration set ratio")
|
| 705 |
+
parser.add_argument("--test-ratio", type=float, default=0.25, help="Test set ratio")
|
| 706 |
+
parser.add_argument("--boosting-epochs", type=int, default=100, help="Boosting epochs")
|
| 707 |
+
parser.add_argument("--n-runs", type=int, default=10, help="Number of repeated runs with different random splits")
|
| 708 |
+
parser.add_argument("--model-set", type=str, default="fixed", choices=["fixed"], help="Model set (fixed 3 models)")
|
| 709 |
+
parser.add_argument("--bcp-score-type", type=str, default="frequency",
|
| 710 |
+
choices=['frequency', 'selfeval', 'logprob', 'ensemble'],
|
| 711 |
+
help="Score type for BCI")
|
| 712 |
+
# --as-score-type removed; MACI uses ensemble by default
|
| 713 |
+
parser.add_argument("--as-mode", type=str, default="subgroup_optimized", choices=["standard", "subgroup_optimized"], help="AS variant")
|
| 714 |
+
parser.add_argument("--conditional-groups", type=str, nargs='*',
|
| 715 |
+
default=['false_claim_risk','medicalcontent','view_count'],
|
| 716 |
+
choices=['false_claim_risk','medicalcontent','view_count'],
|
| 717 |
+
help="Conditional groups to analyze")
|
| 718 |
+
parser.add_argument("--view-metadata-csv", type=str, default=None,
|
| 719 |
+
help="Optional CSV for view_count grouper; defaults to repo-relative data path")
|
| 720 |
+
parser.add_argument("--binning-method", type=str, default="quantile",
|
| 721 |
+
choices=['quantile', 'equal_width'],
|
| 722 |
+
help="Binning method for conditional groups")
|
| 723 |
+
parser.add_argument("--limit-samples", type=int, default=2000, help="Max number of samples to load")
|
| 724 |
+
parser.add_argument("--target-tpr", type=float, default=0.8, help="Target TPR for subgroup-optimized AS")
|
| 725 |
+
|
| 726 |
+
parser.add_argument("--time-profile", action='store_true', help="Enable timing profile output")
|
| 727 |
+
parser.add_argument("--time-out", type=str, default=None, help="Directory to save timing JSON (defaults to repo_root/analysis/experiment_results)")
|
| 728 |
+
|
| 729 |
+
args = parser.parse_args()
|
| 730 |
+
|
| 731 |
+
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 732 |
+
if not args.data_dir:
|
| 733 |
+
args.data_dir = os.path.join(repo_root, 'data')
|
| 734 |
+
if not args.log_dir:
|
| 735 |
+
args.log_dir = os.path.join(repo_root, 'logs')
|
| 736 |
+
if not getattr(args, 'time_out', None):
|
| 737 |
+
args.time_out = os.path.join(repo_root, 'analysis', 'experiment_results')
|
| 738 |
+
|
| 739 |
+
setup_logging(args.log_dir)
|
| 740 |
+
if args.view_metadata_csv:
|
| 741 |
+
set_view_metadata_csv(args.view_metadata_csv)
|
| 742 |
+
|
| 743 |
+
logging.info("=" * 80)
|
| 744 |
+
logging.info(f"📊 Setup: {args.calib_ratio*100:.0f}% calibration + {args.test_ratio*100:.0f}% test, α={args.alpha}, adaptive={args.adaptive_alpha}")
|
| 745 |
+
logging.info(f"🔄 Number of runs: {args.n_runs}")
|
| 746 |
+
logging.info(f"🏷️ BCI Score: {args.bcp_score_type}")
|
| 747 |
+
logging.info(f"🧠 CCI: enabled")
|
| 748 |
+
logging.info(f"🎯 MACI: enabled")
|
| 749 |
+
logging.info(f"📊 Conditional groups: {args.conditional_groups}")
|
| 750 |
+
logging.info(f"🔧 Binning Method: {args.binning_method}")
|
| 751 |
+
logging.info(f"🆕 Using provided scores and enhanced features")
|
| 752 |
+
|
| 753 |
+
limit_samples = args.limit_samples
|
| 754 |
+
data = load_1000_samples(args.data_dir, scores_dir=args.scores_dir, dataset_type=args.dataset_type, limit_samples=limit_samples)
|
| 755 |
+
|
| 756 |
+
from collections import defaultdict
|
| 757 |
+
all_runs_results = defaultdict(lambda: defaultdict(list))
|
| 758 |
+
|
| 759 |
+
groupers = []
|
| 760 |
+
available_groupers = get_available_groupers()
|
| 761 |
+
for group_name in args.conditional_groups:
|
| 762 |
+
if group_name in available_groupers:
|
| 763 |
+
groupers.append(available_groupers[group_name])
|
| 764 |
+
else:
|
| 765 |
+
logging.warning(f"⚠️ Unknown grouper: {group_name}")
|
| 766 |
+
|
| 767 |
+
detected_dataset_type = args.dataset_type
|
| 768 |
+
if detected_dataset_type == "auto":
|
| 769 |
+
if data and 'scores' in data[0] and isinstance(data[0]['scores'].get('frequency'), np.ndarray):
|
| 770 |
+
detected_dataset_type = 'medlfqa'
|
| 771 |
+
else:
|
| 772 |
+
detected_dataset_type = 'wikibio'
|
| 773 |
+
logging.info(f"➡️ Using detected dataset type: {detected_dataset_type}")
|
| 774 |
+
|
| 775 |
+
factscore_npz_path = None
|
| 776 |
+
if detected_dataset_type == 'wikibio':
|
| 777 |
+
wikibio_npz_path = os.path.join(args.data_dir, "wiki_scores", "wikibio_final_frequencies.npz")
|
| 778 |
+
|
| 779 |
+
model_names_to_use = MODEL_NAMES
|
| 780 |
+
logging.info(f" MACI Models: {', '.join(model_names_to_use)}")
|
| 781 |
+
|
| 782 |
+
for run_idx in range(args.n_runs):
|
| 783 |
+
logging.info(f"\n🔄 Run {run_idx + 1}/{args.n_runs}")
|
| 784 |
+
logging.info("-" * 50)
|
| 785 |
+
|
| 786 |
+
random_seed = args.random_seed + run_idx
|
| 787 |
+
calib_data, test_data, calib_idx, test_idx = create_splits(
|
| 788 |
+
data, args.calib_ratio, args.test_ratio, random_seed=random_seed
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
logging.info(f"📊 Run {run_idx + 1} sizes: {len(calib_data)} calib, {len(test_data)} test (seed: {random_seed})")
|
| 792 |
+
|
| 793 |
+
results = {}
|
| 794 |
+
|
| 795 |
+
try:
|
| 796 |
+
results['BCI'] = run_bcp_experiment(calib_data, test_data, score_type=args.bcp_score_type, alpha=args.alpha)
|
| 797 |
+
except Exception as e:
|
| 798 |
+
logging.error(f"❌ BCP failed: {e}")
|
| 799 |
+
import traceback
|
| 800 |
+
logging.error(f"Traceback: {traceback.format_exc()}")
|
| 801 |
+
results['BCI'] = None
|
| 802 |
+
|
| 803 |
+
try:
|
| 804 |
+
results['CCI'] = run_cci_experiment(
|
| 805 |
+
calib_data, test_data,
|
| 806 |
+
alpha=args.alpha,
|
| 807 |
+
boosting_epochs=args.boosting_epochs,
|
| 808 |
+
adaptive_alpha=args.adaptive_alpha,
|
| 809 |
+
retention_target=args.retention_target
|
| 810 |
+
)
|
| 811 |
+
except Exception as e:
|
| 812 |
+
logging.error(f"❌ CCI failed: {e}")
|
| 813 |
+
import traceback
|
| 814 |
+
logging.error(f"Traceback: {traceback.format_exc()}")
|
| 815 |
+
results['CCI'] = None
|
| 816 |
+
|
| 817 |
+
results['MACI'] = {
|
| 818 |
+
'coverage': [],
|
| 819 |
+
'retention_rate': [],
|
| 820 |
+
'retained_claims': [],
|
| 821 |
+
'total_claims': [],
|
| 822 |
+
'subgroup_results': {}
|
| 823 |
+
}
|
| 824 |
+
|
| 825 |
+
logging.info("--- Starting MACI Experiments ---")
|
| 826 |
+
mace_marginal_results_set = False
|
| 827 |
+
|
| 828 |
+
for subgroup_name in args.conditional_groups:
|
| 829 |
+
try:
|
| 830 |
+
mace_subgroup_result = run_as_experiment(
|
| 831 |
+
calib_data, test_data,
|
| 832 |
+
model_names=model_names_to_use,
|
| 833 |
+
alpha=args.alpha,
|
| 834 |
+
as_mode='subgroup_optimized',
|
| 835 |
+
subgroup_name=subgroup_name,
|
| 836 |
+
random_state=random_seed,
|
| 837 |
+
target_tpr=args.target_tpr
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
if mace_subgroup_result and 'filtered_results' in mace_subgroup_result:
|
| 841 |
+
flat_filtered_results = []
|
| 842 |
+
for res in mace_subgroup_result['filtered_results']:
|
| 843 |
+
flat_item = dict(res.get('sample', {}))
|
| 844 |
+
flat_item['filtered_claims'] = res.get('sample', {}).get('filtered_claims', [])
|
| 845 |
+
flat_filtered_results.append(flat_item)
|
| 846 |
+
mace_subgroup_result['filtered_results'] = flat_filtered_results
|
| 847 |
+
|
| 848 |
+
if not mace_marginal_results_set:
|
| 849 |
+
results['MACI']['coverage'] = mace_subgroup_result['coverage']
|
| 850 |
+
results['MACI']['retention_rate'] = mace_subgroup_result['retention_rate']
|
| 851 |
+
results['MACI']['retained_claims'] = mace_subgroup_result['retained_claims']
|
| 852 |
+
results['MACI']['total_claims'] = mace_subgroup_result['total_claims']
|
| 853 |
+
results['MACI']['filtered_results'] = mace_subgroup_result.get('filtered_results', [])
|
| 854 |
+
results['MACI']['timing'] = mace_subgroup_result.get('timing', {})
|
| 855 |
+
|
| 856 |
+
mace_marginal_results_set = True
|
| 857 |
+
|
| 858 |
+
target_grouper = available_groupers.get(subgroup_name)
|
| 859 |
+
if target_grouper:
|
| 860 |
+
try:
|
| 861 |
+
conditional_results = compute_conditional_coverage(
|
| 862 |
+
test_data,
|
| 863 |
+
mace_subgroup_result['filtered_results'],
|
| 864 |
+
target_grouper,
|
| 865 |
+
args.alpha,
|
| 866 |
+
args.binning_method
|
| 867 |
+
)
|
| 868 |
+
results['MACI']['subgroup_results'][target_grouper.__class__.__name__] = conditional_results
|
| 869 |
+
except Exception as e:
|
| 870 |
+
logging.error(f" ❌ MACI subgroup analysis for {target_grouper.__class__.__name__} failed: {e}")
|
| 871 |
+
|
| 872 |
+
except Exception as e:
|
| 873 |
+
logging.error(f"❌ MACI ({subgroup_name}) failed: {e}")
|
| 874 |
+
import traceback
|
| 875 |
+
logging.error(f"Traceback: {traceback.format_exc()}")
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
for method_name, result in results.items():
|
| 879 |
+
if not result or result.get('coverage') is None:
|
| 880 |
+
continue
|
| 881 |
+
|
| 882 |
+
all_runs_results[method_name]['coverage'].append(result['coverage'])
|
| 883 |
+
all_runs_results[method_name]['retention_rate'].append(result['retention_rate'])
|
| 884 |
+
all_runs_results[method_name]['retained_claims'].append(result['retained_claims'])
|
| 885 |
+
all_runs_results[method_name]['total_claims'].append(result['total_claims'])
|
| 886 |
+
|
| 887 |
+
run_subgroup_results = {}
|
| 888 |
+
if method_name == 'MACI':
|
| 889 |
+
run_subgroup_results = result.get('subgroup_results', {})
|
| 890 |
+
else:
|
| 891 |
+
for grouper in groupers:
|
| 892 |
+
try:
|
| 893 |
+
conditional_results = compute_conditional_coverage(
|
| 894 |
+
test_data,
|
| 895 |
+
result['filtered_results'],
|
| 896 |
+
grouper,
|
| 897 |
+
args.alpha,
|
| 898 |
+
args.binning_method
|
| 899 |
+
)
|
| 900 |
+
run_subgroup_results[grouper.__class__.__name__] = conditional_results
|
| 901 |
+
except Exception as e:
|
| 902 |
+
logging.error(f" ❌ {grouper.__class__.__name__} failed for {method_name}: {e}")
|
| 903 |
+
|
| 904 |
+
all_runs_results[method_name]['subgroup_results'].append(run_subgroup_results)
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
logging.info(f"\n📊 Run {run_idx + 1} Results:")
|
| 908 |
+
for method_name, result in results.items():
|
| 909 |
+
if not result or result.get('coverage') is None:
|
| 910 |
+
logging.info(f" {method_name}: ❌ FAILED or SKIPPED")
|
| 911 |
+
continue
|
| 912 |
+
logging.info(f" {method_name}: Coverage={result['coverage']:.4f}, Retention={result['retention_rate']:.3f}, Claims={result['retained_claims']}/{result['total_claims']}")
|
| 913 |
+
|
| 914 |
+
if args.time_profile:
|
| 915 |
+
timing_payload = {
|
| 916 |
+
'dataset_type': detected_dataset_type,
|
| 917 |
+
'model_set': args.model_set,
|
| 918 |
+
'boosting_epochs': args.boosting_epochs,
|
| 919 |
+
'adaptive_alpha': args.adaptive_alpha,
|
| 920 |
+
'retention_target': args.retention_target,
|
| 921 |
+
'run_idx': run_idx,
|
| 922 |
+
'CCI': results.get('CCI', {}).get('timing', {}),
|
| 923 |
+
'MACI': {}
|
| 924 |
+
}
|
| 925 |
+
try:
|
| 926 |
+
first_subgroup = next(iter(results['MACI'].get('subgroup_results', {}).keys()), None)
|
| 927 |
+
if first_subgroup:
|
| 928 |
+
mace_timing = None
|
| 929 |
+
mace_timing = results['MACI'].get('timing')
|
| 930 |
+
timing_payload['MACI'] = mace_timing if mace_timing else {}
|
| 931 |
+
except Exception:
|
| 932 |
+
pass
|
| 933 |
+
|
| 934 |
+
if not timing_payload['MACI']:
|
| 935 |
+
try:
|
| 936 |
+
timing_payload['MACI'] = {}
|
| 937 |
+
except Exception:
|
| 938 |
+
timing_payload['MACI'] = {}
|
| 939 |
+
|
| 940 |
+
os.makedirs(args.time_out, exist_ok=True)
|
| 941 |
+
tstamp = datetime.now().strftime('%Y%m%d-%H%M%S')
|
| 942 |
+
timing_path = os.path.join(args.time_out, f"time_profile_{detected_dataset_type}_{args.model_set}_{tstamp}.json")
|
| 943 |
+
with open(timing_path, 'w', encoding='utf-8') as f:
|
| 944 |
+
json.dump(timing_payload, f, indent=2, ensure_ascii=False)
|
| 945 |
+
logging.info(f"⏱️ Saved timing profile to {timing_path}")
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
if run_idx == 0 and getattr(args, 'show_sample_idx', None) is not None and args.show_sample_idx >= 0:
|
| 949 |
+
idx = int(args.show_sample_idx)
|
| 950 |
+
if 0 <= idx < len(test_data):
|
| 951 |
+
def _get_claim_text(c: Dict[str, Any]) -> str:
|
| 952 |
+
if not isinstance(c, dict):
|
| 953 |
+
return str(c)
|
| 954 |
+
return c.get('atom') or c.get('text') or c.get('claim') or c.get('fact') or str(c)
|
| 955 |
+
def _get_claim_support(c: Dict[str, Any]) -> str:
|
| 956 |
+
if isinstance(c, dict):
|
| 957 |
+
v = c.get('is_supported')
|
| 958 |
+
if isinstance(v, (bool, np.bool_)):
|
| 959 |
+
return 'T' if bool(v) else 'F'
|
| 960 |
+
return '?'
|
| 961 |
+
|
| 962 |
+
sample = test_data[idx]['sample']
|
| 963 |
+
prompt = sample.get('prompt', '')
|
| 964 |
+
response = sample.get('response', '')
|
| 965 |
+
original_claims = sample.get('atomic_facts', [])
|
| 966 |
+
original_pairs = [(_get_claim_text(c), _get_claim_support(c)) for c in original_claims]
|
| 967 |
+
|
| 968 |
+
bci_item = results.get('BCI', {}).get('filtered_results', [None]*len(test_data))[idx]
|
| 969 |
+
cci_item = results.get('CCI', {}).get('filtered_results', [None]*len(test_data))[idx]
|
| 970 |
+
mace_item = results.get('MACE', {}).get('filtered_results', [None]*len(test_data))[idx]
|
| 971 |
+
|
| 972 |
+
def _filtered_claims(item):
|
| 973 |
+
if not item:
|
| 974 |
+
return []
|
| 975 |
+
claims = item.get('filtered_claims')
|
| 976 |
+
if claims is None and isinstance(item.get('sample'), dict):
|
| 977 |
+
claims = item['sample'].get('filtered_claims', [])
|
| 978 |
+
return [(_get_claim_text(c), _get_claim_support(c)) for c in (claims or [])]
|
| 979 |
+
|
| 980 |
+
logging.info("\n=== SAMPLE CLAIMS DUMP ===")
|
| 981 |
+
logging.info(f"[Test idx={idx}] Prompt: {prompt}")
|
| 982 |
+
logging.info(f"Original claims ({len(original_pairs)}):")
|
| 983 |
+
for i, (t, lab) in enumerate(original_pairs, 1):
|
| 984 |
+
logging.info(f" {i:2d}. [{lab}] {t}")
|
| 985 |
+
|
| 986 |
+
bci_pairs = _filtered_claims(bci_item)
|
| 987 |
+
cci_pairs = _filtered_claims(cci_item)
|
| 988 |
+
mace_pairs = _filtered_claims(mace_item)
|
| 989 |
+
|
| 990 |
+
logging.info(f"\n[BCI] filtered claims ({len(bci_pairs)}):")
|
| 991 |
+
for i, (t, lab) in enumerate(bci_pairs, 1):
|
| 992 |
+
logging.info(f" {i:2d}. [{lab}] {t}")
|
| 993 |
+
|
| 994 |
+
logging.info(f"\n[CCI] filtered claims ({len(cci_pairs)}):")
|
| 995 |
+
for i, (t, lab) in enumerate(cci_pairs, 1):
|
| 996 |
+
logging.info(f" {i:2d}. [{lab}] {t}")
|
| 997 |
+
|
| 998 |
+
logging.info(f"\n[MACI] filtered claims ({len(mace_pairs)}):")
|
| 999 |
+
for i, (t, lab) in enumerate(mace_pairs, 1):
|
| 1000 |
+
logging.info(f" {i:2d}. [{lab}] {t}")
|
| 1001 |
+
|
| 1002 |
+
if run_idx == 0 and getattr(args, 'show_sample_count', 0) > 0:
|
| 1003 |
+
dump_n = min(int(args.show_sample_count), len(test_data))
|
| 1004 |
+
def _get_claim_text(c: Dict[str, Any]) -> str:
|
| 1005 |
+
if not isinstance(c, dict):
|
| 1006 |
+
return str(c)
|
| 1007 |
+
return c.get('atom') or c.get('text') or c.get('claim') or c.get('fact') or str(c)
|
| 1008 |
+
def _get_claim_support(c: Dict[str, Any]) -> str:
|
| 1009 |
+
if isinstance(c, dict):
|
| 1010 |
+
v = c.get('is_supported')
|
| 1011 |
+
if isinstance(v, (bool, np.bool_)):
|
| 1012 |
+
return 'T' if bool(v) else 'F'
|
| 1013 |
+
return '?'
|
| 1014 |
+
def _filtered_pairs(item):
|
| 1015 |
+
if not item:
|
| 1016 |
+
return []
|
| 1017 |
+
claims = item.get('filtered_claims')
|
| 1018 |
+
if claims is None and isinstance(item.get('sample'), dict):
|
| 1019 |
+
claims = item['sample'].get('filtered_claims', [])
|
| 1020 |
+
return [(_get_claim_text(c), _get_claim_support(c)) for c in (claims or [])]
|
| 1021 |
+
for idx in range(dump_n):
|
| 1022 |
+
sample = test_data[idx]['sample']
|
| 1023 |
+
prompt = sample.get('prompt', '')
|
| 1024 |
+
original_pairs = [(_get_claim_text(c), _get_claim_support(c)) for c in sample.get('atomic_facts', [])]
|
| 1025 |
+
bci_item = results.get('BCI', {}).get('filtered_results', [None]*len(test_data))[idx]
|
| 1026 |
+
cci_item = results.get('CCI', {}).get('filtered_results', [None]*len(test_data))[idx]
|
| 1027 |
+
mace_item = results.get('MACE', {}).get('filtered_results', [None]*len(test_data))[idx]
|
| 1028 |
+
logging.info("\n=== SAMPLE CLAIMS DUMP ===")
|
| 1029 |
+
logging.info(f"[Test idx={idx}] Prompt: {prompt}")
|
| 1030 |
+
logging.info(f"Original claims ({len(original_pairs)}):")
|
| 1031 |
+
for i, (t, lab) in enumerate(original_pairs, 1):
|
| 1032 |
+
logging.info(f" {i:2d}. [{lab}] {t}")
|
| 1033 |
+
bci_pairs = _filtered_pairs(bci_item)
|
| 1034 |
+
cci_pairs = _filtered_pairs(cci_item)
|
| 1035 |
+
mace_pairs = _filtered_pairs(mace_item)
|
| 1036 |
+
logging.info(f"\n[BCI] filtered claims ({len(bci_pairs)}):")
|
| 1037 |
+
for i, (t, lab) in enumerate(bci_pairs, 1):
|
| 1038 |
+
logging.info(f" {i:2d}. [{lab}] {t}")
|
| 1039 |
+
logging.info(f"\n[CCI] filtered claims ({len(cci_pairs)}):")
|
| 1040 |
+
for i, (t, lab) in enumerate(cci_pairs, 1):
|
| 1041 |
+
logging.info(f" {i:2d}. [{lab}] {t}")
|
| 1042 |
+
logging.info(f"\n[MACI] filtered claims ({len(mace_pairs)}):")
|
| 1043 |
+
for i, (t, lab) in enumerate(mace_pairs, 1):
|
| 1044 |
+
logging.info(f" {i:2d}. [{lab}] {t}")
|
| 1045 |
+
|
| 1046 |
+
logging.info("\n" + "=" * 100)
|
| 1047 |
+
logging.info("📊 AGGREGATED RESULTS (All Runs)")
|
| 1048 |
+
logging.info("=" * 100)
|
| 1049 |
+
for method_name in sorted(all_runs_results.keys()):
|
| 1050 |
+
method_results = all_runs_results[method_name]
|
| 1051 |
+
|
| 1052 |
+
if not method_results['coverage']:
|
| 1053 |
+
logging.info(f"\n{method_name}: ❌ NO SUCCESSFUL RUNS")
|
| 1054 |
+
continue
|
| 1055 |
+
|
| 1056 |
+
n_runs = len(method_results['coverage'])
|
| 1057 |
+
coverage_mean = np.mean(method_results['coverage'])
|
| 1058 |
+
coverage_std = np.std(method_results['coverage'])
|
| 1059 |
+
retention_mean = np.mean(method_results['retention_rate'])
|
| 1060 |
+
retention_std = np.std(method_results['retention_rate'])
|
| 1061 |
+
retained_claims_mean = np.mean(method_results['retained_claims'])
|
| 1062 |
+
retained_claims_std = np.std(method_results['retained_claims'])
|
| 1063 |
+
total_claims_mean = np.mean(method_results['total_claims'])
|
| 1064 |
+
|
| 1065 |
+
logging.info(f"\n{'='*20} {method_name} ({n_runs} runs) {'='*20}")
|
| 1066 |
+
logging.info(f"📈 MARGINAL RESULTS:")
|
| 1067 |
+
logging.info(f" Coverage: {coverage_mean:.4f} ± {coverage_std:.4f}")
|
| 1068 |
+
logging.info(f" Retention Rate: {retention_mean:.3f} ± {retention_std:.3f}")
|
| 1069 |
+
logging.info(f" Claims: {retained_claims_mean:.1f} ± {retained_claims_std:.1f}/{total_claims_mean:.1f}")
|
| 1070 |
+
|
| 1071 |
+
if method_results['subgroup_results']:
|
| 1072 |
+
logging.info(f"\n📊 SUBGROUP RESULTS:")
|
| 1073 |
+
|
| 1074 |
+
subgroup_data = {}
|
| 1075 |
+
for run_results in method_results['subgroup_results']:
|
| 1076 |
+
for grouper_name, grouper_results in run_results.items():
|
| 1077 |
+
if grouper_name not in subgroup_data:
|
| 1078 |
+
subgroup_data[grouper_name] = {}
|
| 1079 |
+
|
| 1080 |
+
for group_name, group_result in grouper_results.items():
|
| 1081 |
+
if group_name not in subgroup_data[grouper_name]:
|
| 1082 |
+
subgroup_data[grouper_name][group_name] = {
|
| 1083 |
+
'coverage': [], 'retention_rate': [], 'retained_claims': [],
|
| 1084 |
+
'total_claims': [], 'size': []
|
| 1085 |
+
}
|
| 1086 |
+
|
| 1087 |
+
subgroup_data[grouper_name][group_name]['coverage'].append(group_result['coverage'])
|
| 1088 |
+
subgroup_data[grouper_name][group_name]['retention_rate'].append(group_result['retention_rate'])
|
| 1089 |
+
subgroup_data[grouper_name][group_name]['retained_claims'].append(group_result['retained_claims'])
|
| 1090 |
+
subgroup_data[grouper_name][group_name]['total_claims'].append(group_result['total_claims'])
|
| 1091 |
+
subgroup_data[grouper_name][group_name]['size'].append(group_result['size'])
|
| 1092 |
+
|
| 1093 |
+
for grouper_name, groups in subgroup_data.items():
|
| 1094 |
+
logging.info(f"\n 🔍 {grouper_name}:")
|
| 1095 |
+
|
| 1096 |
+
for group_name, group_data in groups.items():
|
| 1097 |
+
if not group_data['coverage']:
|
| 1098 |
+
continue
|
| 1099 |
+
|
| 1100 |
+
group_coverage_mean = np.mean(group_data['coverage'])
|
| 1101 |
+
group_coverage_std = np.std(group_data['coverage'])
|
| 1102 |
+
group_retention_mean = np.mean(group_data['retention_rate'])
|
| 1103 |
+
group_retention_std = np.std(group_data['retention_rate'])
|
| 1104 |
+
group_retained_claims_mean = np.mean(group_data['retained_claims'])
|
| 1105 |
+
group_retained_claims_std = np.std(group_data['retained_claims'])
|
| 1106 |
+
group_total_claims_mean = np.mean(group_data['total_claims'])
|
| 1107 |
+
group_size_mean = np.mean(group_data['size'])
|
| 1108 |
+
|
| 1109 |
+
target_coverage = 1 - args.alpha
|
| 1110 |
+
violation_marker = "⚠️ " if abs(group_coverage_mean - target_coverage) > 0.014 else "✅ "
|
| 1111 |
+
|
| 1112 |
+
logging.info(f" {violation_marker}{group_name}:")
|
| 1113 |
+
logging.info(f" Coverage: {group_coverage_mean:.3f} ± {group_coverage_std:.3f} (target: {target_coverage:.1f})")
|
| 1114 |
+
logging.info(f" Retention: {group_retention_mean:.3f} ± {group_retention_std:.3f}")
|
| 1115 |
+
logging.info(f" Claims: {group_retained_claims_mean:.1f} ± {group_retained_claims_std:.1f}/{group_total_claims_mean:.1f}")
|
| 1116 |
+
logging.info(f" Group size: {group_size_mean:.1f} samples")
|
| 1117 |
+
logging.info(f" Coverage gap: {group_coverage_mean - target_coverage:+.3f}")
|
| 1118 |
+
|
| 1119 |
+
|
| 1120 |
+
|
| 1121 |
+
logging.info("\n" + "=" * 100)
|
| 1122 |
+
|
| 1123 |
+
save_aggregated_results_to_json(all_runs_results, args)
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
if __name__ == "__main__":
|
| 1127 |
+
main()
|
MACI-main/requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==2.0.2
|
| 2 |
+
scipy==1.13.1
|
| 3 |
+
scikit-learn==1.6.1
|
| 4 |
+
pandas==2.3.1
|
| 5 |
+
matplotlib==3.9.4
|
| 6 |
+
seaborn==0.13.2
|
| 7 |
+
tqdm==4.67.1
|
| 8 |
+
cvxpy==1.7.1
|
| 9 |
+
conditionalconformal==0.0.5
|
| 10 |
+
torch==2.8.0
|
| 11 |
+
torchvision==0.23.0
|
| 12 |
+
torchaudio==2.8.0
|