KangjunNoh commited on
Commit
906e061
·
verified ·
1 Parent(s): d238913

Upload 47 files

Browse files
Files changed (47) hide show
  1. MACI-main/LICENSE +21 -0
  2. MACI-main/README.md +46 -0
  3. MACI-main/conditional-conformal/conditionalconformal/__init__.py +1 -0
  4. MACI-main/conditional-conformal/conditionalconformal/condconf.py +877 -0
  5. MACI-main/conditional-conformal/conditionalconformal/experiment_utils.py +182 -0
  6. MACI-main/conditional-conformal/conditionalconformal/synthetic_data.py +55 -0
  7. MACI-main/conditional-conformal/src/atomizer.py +347 -0
  8. MACI-main/conditional-conformal/src/aws_utils.py +15 -0
  9. MACI-main/conditional-conformal/src/client.py +89 -0
  10. MACI-main/conditional-conformal/src/config.py +7 -0
  11. MACI-main/conditional-conformal/src/conformal.py +68 -0
  12. MACI-main/conditional-conformal/src/data_utils/sample_names.py +86 -0
  13. MACI-main/conditional-conformal/src/dataset.py +279 -0
  14. MACI-main/conditional-conformal/src/featurizer.py +352 -0
  15. MACI-main/conditional-conformal/src/gpt.py +58 -0
  16. MACI-main/conditional-conformal/src/llm_utils.py +111 -0
  17. MACI-main/conditional-conformal/src/postprocess_factscore.py +34 -0
  18. MACI-main/conditional-conformal/src/prob_model.py +101 -0
  19. MACI-main/conditional-conformal/src/query.py +112 -0
  20. MACI-main/conditional-conformal/src/ray_data.py +192 -0
  21. MACI-main/conditional-conformal/src/retrieval.py +268 -0
  22. MACI-main/conditional-conformal/src/retrieve_data.py +86 -0
  23. MACI-main/conditional-conformal/src/run.py +119 -0
  24. MACI-main/conditional-conformal/src/scorer.py +202 -0
  25. MACI-main/conformal/__pycache__/adaptive_conformal.cpython-39.pyc +0 -0
  26. MACI-main/conformal/__pycache__/basic_conformal.cpython-39.pyc +0 -0
  27. MACI-main/conformal/__pycache__/conditional_conformal.cpython-39.pyc +0 -0
  28. MACI-main/conformal/adaptive_conformal.py +403 -0
  29. MACI-main/conformal/basic_conformal.py +189 -0
  30. MACI-main/conformal/conditional_conformal.py +489 -0
  31. MACI-main/data/med_scores/medlfqa_frequencies.npz +3 -0
  32. MACI-main/data/med_scores/medlfqa_logprobs.npz +3 -0
  33. MACI-main/data/med_scores/medlfqa_scores_deepseek_deepseek-chat-v3-0324.npz +3 -0
  34. MACI-main/data/med_scores/medlfqa_scores_meta-llama_llama-3.3-70b-instruct.npz +3 -0
  35. MACI-main/data/med_scores/medlfqa_scores_qwen_qwen-2.5-72b-instruct.npz +3 -0
  36. MACI-main/data/med_scores/medlfqa_selfevals.npz +3 -0
  37. MACI-main/data/wiki_scores/wikibio_final.csv +0 -0
  38. MACI-main/data/wiki_scores/wikibio_final_dataset.pkl +3 -0
  39. MACI-main/data/wiki_scores/wikibio_final_frequencies.npz +3 -0
  40. MACI-main/data/wiki_scores/wikibio_final_logprobs.npz +3 -0
  41. MACI-main/data/wiki_scores/wikibio_final_self_evals.npz +3 -0
  42. MACI-main/data/wiki_scores/wikibio_scores_deepseek-chat-v3-0324.npz +3 -0
  43. MACI-main/data/wiki_scores/wikibio_scores_meta-llama_llama-3.3-70b-instruct.npz +3 -0
  44. MACI-main/data/wiki_scores/wikibio_scores_qwen_qwen-2.5-72b-instruct.npz +3 -0
  45. MACI-main/experiments/conditional_groupers.py +542 -0
  46. MACI-main/experiments/run_experiment.py +1127 -0
  47. MACI-main/requirements.txt +12 -0
MACI-main/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Anonymous2026conf
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MACI-main/README.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MACI
2
+ This repository contains an anonymized version of our Multi-LLM Adaptive Conformal Inference experiments. The entry point is `experiments/run_experiment.py`.
3
+
4
+ ## Abstract
5
+
6
+ Ensuring factuality is essential for the safe use of Large Language Models (LLMs) in high-stakes domains such as medicine and law. Conformal inference provides distribution-free guarantees, but existing approaches are either overly conservative, discarding many true-claims, or rely on adaptive error rates and simple linear models that fail to capture complex group structures. To address these challenges, we reformulate conformal inference in a multiplicative filtering setting, modeling factuality as a product of claim-level scores. Our method, Multi-LLM Adaptive Conformal Inference (MACI), leverages ensembles to produce more accurate factuality-scores, which in our experiments led to higher retention, while validity is preserved through group-conditional calibration. Experiments show that MACI consistently achieves user-specified coverage with substantially higher retention and lower time cost than baselines.
7
+
8
+ ## Running
9
+
10
+ Step 1) Create a fresh Conda environment (Python 3.9)
11
+
12
+ ```bash
13
+ conda create -y -n maci python=3.9
14
+ ```
15
+
16
+ Step 2) Install dependencies from requirements.txt
17
+
18
+ ```bash
19
+ conda run -n maci \
20
+ python -m pip install -r requirements.txt --no-input
21
+ ```
22
+
23
+ Step 3) Prepare data layout (repo-relative defaults)
24
+
25
+ - Place data under `data/` in the repository root (or pass `--data-dir`).
26
+ - For MedLFQA: put files under `data/med_scores/`.
27
+ - For WikiBio: put files under `data/wiki_scores/`.
28
+
29
+ Step 4) Run a quick experiment (MedLFQA example)
30
+
31
+ ```bash
32
+ conda run -n maci \
33
+ python experiments/run_experiment.py \
34
+ --dataset-type medlfqa \
35
+ --conditional-groups false_claim_risk \
36
+ ```
37
+
38
+ Step 5) Where outputs go
39
+
40
+ - Logs: `logs/` (repo-root-relative by default)
41
+ - Results JSON: `analysis/experiment_results/`
42
+
43
+
44
+
45
+ ## CCI Attribution
46
+ Our implementation of the Conditional Conformal Inference (CCI) baseline is a direct adoption of the work from the [conformal-safety](https://github.com/jjcherian/conformal-safety.git) repository. To ensure full reproducibility, we have included a local copy of the necessary modules in the conditional-conformal/ directory. We explicitly state that the code within this directory is not the work of the MACI project. For all details, please refer to the original repository: [conformal-safety](https://github.com/jjcherian/conformal-safety.git)
MACI-main/conditional-conformal/conditionalconformal/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .condconf import CondConf
MACI-main/conditional-conformal/conditionalconformal/condconf.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cvxpy as cp
2
+ import numpy as np
3
+
4
+ from functools import partial, lru_cache
5
+ from scipy.optimize import linprog
6
+ from sklearn.metrics.pairwise import pairwise_kernels
7
+ from typing import Callable
8
+
9
+ FUNCTION_DEFAULTS = {"kernel": None, "gamma" : 1, "lambda": 1}
10
+
11
+ class CondConf:
12
+ def __init__(
13
+ self,
14
+ score_fn : Callable,
15
+ Phi_fn : Callable,
16
+ quantile_fn : Callable = None,
17
+ infinite_params : dict = {},
18
+ seed : int = 0
19
+ ):
20
+ """
21
+ Constructs the CondConf object that caches relevant information for
22
+ generating conditionally valid prediction sets.
23
+
24
+ We define the score function and set of conditional guarantees
25
+ that we care about in this function.
26
+
27
+ Parameters
28
+ ---------
29
+ score_fn : Callable[np.ndarray, np.ndarray] -> np.ndarray
30
+ Fixed (vectorized) conformity score function that takes in
31
+ X and Y as inputs and returns S as output
32
+
33
+ Phi_fn : Callable[np.ndarray] -> np.ndarray
34
+ Function that defines finite basis set that we provide
35
+ exact conditional guarantees over
36
+
37
+ infinite_params : dict = {}
38
+ Dictionary containing parameters for the RKHS component of the fit
39
+ Valid keys are ('kernel', 'gamma', 'lambda')
40
+ 'kernel' should be a valid kernel name for sklearn.metrics.pairwise_kernels
41
+ 'gamma' is a hyperparameter for certain kernels
42
+ 'lambda' is the regularization penalty applied to the RKHS component
43
+ """
44
+ self.score_fn = score_fn
45
+ self.Phi_fn = Phi_fn
46
+ self.quantile_fn = quantile_fn
47
+ self.infinite_params = infinite_params
48
+ self.rng = np.random.default_rng(seed=seed)
49
+
50
+ def setup_problem(
51
+ self,
52
+ x_calib : np.ndarray,
53
+ y_calib : np.ndarray
54
+ ):
55
+ """
56
+ setup_problem sets up the final fitting problem for a
57
+ particular calibration set
58
+
59
+ The resulting cvxpy Problem object is stored inside the CondConf parent.
60
+
61
+ Arguments
62
+ ---------
63
+ x_calib : np.ndarray
64
+ Covariate data for the calibration set
65
+
66
+ y_calib : np.ndarray
67
+ Labels for the calibration set
68
+ """
69
+ self.x_calib = x_calib
70
+ self.y_calib = y_calib
71
+ phi_calib = self.Phi_fn(x_calib)
72
+
73
+ _, s, Vt = np.linalg.svd(phi_calib, full_matrices=False)
74
+
75
+ # Set a tolerance to decide which singular values are nonzero
76
+ tol = 1e-10
77
+ r = np.sum(s > tol)
78
+
79
+ if r < len(s):
80
+ self.Phi_fn_orig = self.Phi_fn
81
+ T = Vt.T[:, :r]
82
+ self.Phi_fn = lambda x: (self.Phi_fn_orig(x) @ T)
83
+ phi_calib = self.Phi_fn(x_calib)
84
+
85
+ self.phi_calib = phi_calib
86
+ self.scores_calib = self.score_fn(x_calib, y_calib)
87
+
88
+ if self.quantile_fn is not None:
89
+ self.quantile_calib = self.quantile_fn(x_calib).reshape(-1,1)
90
+
91
+ self.cvx_problem = setup_cvx_problem(
92
+ self.x_calib,
93
+ self.scores_calib,
94
+ self.phi_calib,
95
+ self.infinite_params
96
+ )
97
+
98
+
99
+ @lru_cache()
100
+ def _get_calibration_solution(
101
+ self,
102
+ quantile : float
103
+ ):
104
+ S = self.scores_calib.reshape(-1,1)
105
+ Phi = self.phi_calib.astype(float)
106
+ zeros = np.zeros((Phi.shape[1],))
107
+
108
+ if quantile is None:
109
+ bounds = np.concatenate((self.quantile_calib - 1, self.quantile_calib), axis=1)
110
+ else:
111
+ bounds = np.asarray([quantile - 1, quantile])
112
+ bounds = np.tile(bounds.reshape(1,-1), (len(S), 1))
113
+
114
+ res = linprog(-1 * S, A_eq=Phi.T, b_eq=zeros, bounds=bounds, method='highs')
115
+ primal_vars = -1 * res.eqlin.marginals.reshape(-1,1)
116
+ dual_vars = res.x.reshape(-1,1)
117
+
118
+ residuals = S - (Phi @ primal_vars)
119
+ interpolated_pts = np.isclose(residuals, 0)
120
+
121
+ # if I didn't converge to a solution that interpolates at least Phi.shape[1] pts,
122
+ # I need to manually find one via a modified simplex iteration
123
+ if interpolated_pts.sum() < Phi.shape[1]:
124
+ num_to_add = Phi.shape[1] - interpolated_pts.sum()
125
+ for _ in range(num_to_add):
126
+ candidate_pts = interpolated_pts.copy().flatten()
127
+
128
+ # find candidate idx for interpolation, e.g., new covariate that is
129
+ # linearly independent of the previously interpolated points
130
+ Q, _ = np.linalg.qr(Phi[candidate_pts].T)
131
+ projections = Phi @ Q @ Q.T
132
+ norms = np.linalg.norm(Phi - projections, axis=1)
133
+ candidate_idx = np.where(norms > 1e-5)[0][0]
134
+ candidate_pts[candidate_idx] = True
135
+
136
+ # find direction to solution that would interpolate the new point
137
+ gamma, _, _, _ = np.linalg.lstsq(Phi[candidate_pts], S[candidate_pts], rcond=None)
138
+ direction = gamma.reshape(-1,1) - primal_vars
139
+ step_sizes = residuals / (Phi @ direction)
140
+
141
+ # check the non-basic indices for which a step in this direction could have led to interpolation
142
+ # e.g., those for which the step size is positive and the point is not already interpolated
143
+ positive_indices = np.where((step_sizes > 0) & ~interpolated_pts)[0]
144
+
145
+ # take smallest possible step that would lead to interpolation
146
+ primal_vars += np.min(step_sizes[positive_indices]) * direction
147
+
148
+ residuals = S - (Phi @ primal_vars)
149
+ interpolated_pts = np.isclose(residuals, 0)
150
+
151
+ return dual_vars, primal_vars
152
+
153
+ def _compute_exact_cutoff(
154
+ self,
155
+ quantiles,
156
+ primals,
157
+ duals,
158
+ phi_test,
159
+ dual_threshold
160
+ ):
161
+ def get_current_basis(primals, duals, Phi, S, quantiles):
162
+ interp_bools = np.logical_and(~np.isclose(duals, quantiles - 1), ~np.isclose(duals, quantiles))
163
+ if np.sum(interp_bools) == Phi.shape[1]:
164
+ return interp_bools
165
+ preds = (Phi @ primals).flatten()
166
+ active_indices = np.where(interp_bools)[0]
167
+ interp_indices = np.where(np.isclose(np.abs(S - preds), 0))[0]
168
+ diff_indices = np.setdiff1d(interp_indices, active_indices)
169
+ num_missing = Phi.shape[1] - np.sum(interp_bools)
170
+ if num_missing < len(diff_indices):
171
+ from itertools import combinations
172
+ for cand_indices in combinations(diff_indices, num_missing):
173
+ cand_phi = Phi[np.concatenate((active_indices, cand_indices))]
174
+ if np.isfinite(np.linalg.cond(cand_phi)):
175
+ interp_bools[np.asarray(cand_indices)] = True
176
+ break
177
+ else:
178
+ interp_bools[diff_indices] = True
179
+ if np.sum(interp_bools) != Phi.shape[1]:
180
+ raise ValueError("Initial basis could not be found - retry with exact=False.")
181
+ return interp_bools
182
+
183
+ if np.allclose(phi_test, 0):
184
+ return np.inf if quantiles[-1] >= 0.5 else -np.inf
185
+
186
+ basis = get_current_basis(primals, duals, self.phi_calib, self.scores_calib, quantiles[:-1])
187
+ S_test = phi_test @ primals
188
+
189
+ duals = np.concatenate((duals.flatten(), [0]))
190
+ basis = np.concatenate((basis.flatten(), [False]))
191
+ phi = np.concatenate((self.phi_calib, phi_test.reshape(1,-1)), axis=0)
192
+ S = np.concatenate((self.scores_calib.reshape(-1,1), S_test.reshape(-1,1)), axis=0)
193
+
194
+ candidate_idx = phi.shape[0] - 1
195
+ num_iters = 0
196
+ while True:
197
+ # get direction vector for dual variable step
198
+ direction = -1 * np.linalg.solve(phi[basis].T, phi[candidate_idx].reshape(-1,1)).flatten()
199
+
200
+ # only consider non-zero entries of the direction vector
201
+ active_indices = ~np.isclose(direction, 0)
202
+ active_direction = direction[active_indices]
203
+ active_basis = basis.copy()
204
+ active_basis[np.where(basis)[0][~active_indices]] = False
205
+
206
+ positive_step = True if duals[candidate_idx] <= 0 else False
207
+ if candidate_idx == phi.shape[0] - 1:
208
+ positive_step = True if dual_threshold >= 0 else False
209
+
210
+ if positive_step:
211
+ gap_to_bounds = np.maximum(
212
+ (quantiles[active_basis].flatten() - duals[active_basis]) / active_direction,
213
+ ((quantiles[active_basis].flatten() - 1) - duals[active_basis]) / active_direction
214
+ )
215
+ step_size = np.min(gap_to_bounds)
216
+ departing_idx = np.where(active_basis)[0][np.argmin(gap_to_bounds)]
217
+ else:
218
+ gap_to_bounds = np.minimum(
219
+ (quantiles[active_basis].flatten() - duals[active_basis]) / active_direction,
220
+ ((quantiles[active_basis].flatten() - 1) - duals[active_basis]) / active_direction
221
+ )
222
+ step_size = np.max(gap_to_bounds)
223
+ departing_idx = np.where(active_basis)[0][np.argmax(gap_to_bounds)]
224
+ step_size_clip = np.clip(
225
+ step_size,
226
+ a_max=quantiles[candidate_idx] - duals[candidate_idx],
227
+ a_min=(quantiles[candidate_idx] - 1) - duals[candidate_idx]
228
+ )
229
+
230
+ duals[basis] += step_size_clip * direction
231
+ duals[candidate_idx] += step_size_clip
232
+ # print("Current value of final dual", duals[-1], "target threshold", dual_threshold)
233
+
234
+ if dual_threshold > 0 and duals[-1] > dual_threshold:
235
+ break
236
+
237
+ if dual_threshold < 0 and duals[-1] < dual_threshold:
238
+ break
239
+
240
+ if step_size_clip == step_size:
241
+ basis[departing_idx] = False
242
+ basis[candidate_idx] = True
243
+
244
+ if np.isclose(duals[-1], dual_threshold):
245
+ break
246
+
247
+ # TODO: make this a SMW update and reuse in the direction vector calc...
248
+ reduced_A = np.linalg.solve(phi[basis].T, phi[~basis].T)
249
+ reduced_costs = (S[~basis].T - S[basis].T @ reduced_A).flatten()
250
+ bottom = reduced_A[-1]
251
+ bottom[np.isclose(bottom, 0)] = np.inf
252
+ req_change = reduced_costs / bottom
253
+ if dual_threshold >= 0:
254
+ ignore_entries = (np.isclose(bottom, 0) | np.asarray(req_change <= 1e-5))
255
+ else:
256
+ ignore_entries = (np.isclose(bottom, 0) | np.asarray(req_change >= -1e-5))
257
+ if np.sum(~ignore_entries) == 0:
258
+ S[-1] = np.inf if quantiles[-1] >= 0.5 else -np.inf
259
+ break
260
+ if dual_threshold >= 0:
261
+ candidate_idx = np.where(~basis)[0][np.where(~ignore_entries, req_change, np.inf).argmin()]
262
+ S[-1] += np.min(req_change[~ignore_entries])
263
+ else:
264
+ candidate_idx = np.where(~basis)[0][np.where(~ignore_entries, req_change, -np.inf).argmax()]
265
+ S[-1] += np.max(req_change[~ignore_entries])
266
+ num_iters += 1
267
+ if num_iters > 10000:
268
+ S[-1] = np.inf if dual_threshold > 0 else -1 * np.inf
269
+ return S[-1]
270
+
271
+ def predict(
272
+ self,
273
+ quantile : float,
274
+ x_test : np.ndarray,
275
+ score_inv_fn : Callable,
276
+ S_min : float = None,
277
+ S_max : float = None,
278
+ randomize : bool = False,
279
+ exact : bool = True,
280
+ threshold : float = None
281
+ ):
282
+ """
283
+ Returns the (conditionally valid) prediction set for a given
284
+ test point
285
+
286
+ Arguments
287
+ ---------
288
+ quantile : float
289
+ Nominal quantile level
290
+ x_test : np.ndarray
291
+ Single test point
292
+ score_inv_fn : Callable[float, np.ndarray] -> .
293
+ Function that takes in a score threshold S^* and test point x and
294
+ outputs all values of y such that S(x, y) <= S^*
295
+ S_min : float = None
296
+ Lower bound (if available) on the conformity scores
297
+ S_max : float = None
298
+ Upper bound (if available) on the conformity scores
299
+ randomize : bool = False
300
+ Randomize prediction set for exact coverage
301
+ exact : bool = True
302
+ Avoid binary search and compute threshold exactly
303
+
304
+ Returns
305
+ -------
306
+ prediction_set
307
+ """
308
+ if quantile is None:
309
+ quantile_test = self.quantile_fn(x_test).reshape(-1,1)
310
+ quantiles = np.concatenate((self.quantile_calib, quantile_test), axis=0)
311
+ else:
312
+ quantile_test = quantile
313
+ quantiles = np.ones((len(self.scores_calib) + 1,1)) * quantile
314
+ if threshold is None:
315
+ if randomize:
316
+ threshold = self.rng.uniform(low=quantile_test - 1, high=quantile_test)
317
+ else:
318
+ if quantile_test < 0.5:
319
+ threshold = quantile_test - 1
320
+ else:
321
+ threshold = quantile_test
322
+
323
+ if exact:
324
+ if self.infinite_params.get('kernel', FUNCTION_DEFAULTS['kernel']):
325
+ raise ValueError("Exact computation doesn't support RKHS quantile regression for now.")
326
+ if np.allclose(quantiles[0], quantiles):
327
+ naive_duals, naive_primals = self._get_calibration_solution(
328
+ quantiles.flatten()[0]
329
+ )
330
+ else:
331
+ naive_duals, naive_primals = self._get_calibration_solution(
332
+ None
333
+ )
334
+ score_cutoff = self._compute_exact_cutoff(
335
+ quantiles,
336
+ naive_primals,
337
+ naive_duals,
338
+ self.Phi_fn(x_test),
339
+ threshold
340
+ )
341
+ else:
342
+ _solve = partial(_solve_dual, gcc=self, x_test=x_test, quantiles=quantiles, threshold=threshold)
343
+
344
+ if S_min is None:
345
+ S_min = np.min(self.scores_calib)
346
+ if S_max is None:
347
+ S_max = np.max(self.scores_calib)
348
+ lower, upper = binary_search(_solve, S_min, S_max * 2)
349
+
350
+ if quantile < 0.5:
351
+ score_cutoff = self._get_threshold(lower, x_test, quantiles)
352
+ else:
353
+ score_cutoff = self._get_threshold(upper, x_test, quantiles)
354
+ return score_inv_fn(score_cutoff, x_test.reshape(-1,1))
355
+
356
+ def estimate_coverage(
357
+ self,
358
+ quantile : float,
359
+ weights : np.ndarray,
360
+ x : np.ndarray = None
361
+ ):
362
+ """
363
+ estimate_coverage estimates the true percentile of the issued estimate of the
364
+ conditional quantile under the covariate shift induced by 'weights'
365
+
366
+ If we are ostensibly estimating the 0.95-quantile using an RKHS fit, we may
367
+ determine using our theory that the true percentile of this estimate is only 0.93
368
+
369
+ Arguments
370
+ ---------
371
+ quantile : float
372
+ Nominal quantile level
373
+ weights : np.ndarray
374
+ RKHS weights for tilt under which the coverage is estimated
375
+ x : np.ndarray = None
376
+ Points for which the RKHS weights are defined. If None, we assume
377
+ that weights corresponds to x_calib
378
+
379
+ Returns
380
+ -------
381
+ estimated_alpha : float
382
+ Our estimate for the realized quantile level
383
+ """
384
+ weights = weights.reshape(-1,1)
385
+ prob = setup_cvx_problem_calib(
386
+ quantile,
387
+ self.x_calib,
388
+ self.scores_calib,
389
+ self.phi_calib,
390
+ self.infinite_params
391
+ )
392
+ if "MOSEK" in cp.installed_solvers():
393
+ prob.solve(solver="MOSEK")
394
+ else:
395
+ prob.solve()
396
+
397
+ fitted_weights = prob.var_dict['weights'].value
398
+ if x is not None:
399
+ K = pairwise_kernels(
400
+ X=x,
401
+ Y=self.x_calib,
402
+ metric=self.infinite_params.get("kernel", FUNCTION_DEFAULTS["kernel"]),
403
+ gamma=self.infinite_params.get("gamma", FUNCTION_DEFAULTS["gamma"])
404
+ )
405
+ else:
406
+ K = pairwise_kernels(
407
+ X=self.x_calib,
408
+ metric=self.infinite_params.get("kernel", FUNCTION_DEFAULTS["kernel"]),
409
+ gamma=self.infinite_params.get("gamma", FUNCTION_DEFAULTS["gamma"])
410
+ )
411
+ inner_prod = weights.T @ K @ fitted_weights
412
+ expectation = np.mean(weights.T @ K)
413
+ #penalty = self.infinite_params['lambda'] * (inner_prod / expectation)
414
+ penalty = (1/(len(self.x_calib) + 1))*(inner_prod / expectation)
415
+ return quantile - penalty
416
+
417
+ def predict_naive(
418
+ self,
419
+ quantile : float,
420
+ x : np.ndarray,
421
+ score_inv_fn : Callable
422
+ ):
423
+ """
424
+ If we do not wish to include the imputed data point, we can sanity check that
425
+ the regression is appropriately adaptive to the conditional variability in the data
426
+ by running a quantile regression on the calibration set without any imputation.
427
+ When n_calib is large and the fit is stable, we expect these two sets to nearly coincide.
428
+
429
+ Arguments
430
+ ---------
431
+ quantile : float
432
+ Nominal quantile level
433
+ x : np.ndarray
434
+ Set of points for which we are issuing prediction sets
435
+ score_inv_fn : Callable[np.ndarray, np.ndarray] -> np.ndarray
436
+ Vectorized function that takes in a score threshold S^* and test point x and
437
+ outputs all values of y such that S(x, y) <= S^*
438
+
439
+ Returns
440
+ -------
441
+ prediction_sets
442
+
443
+ """
444
+ if len(x.shape) < 2:
445
+ raise ValueError("x needs to have shape (m, n), not {x_test.shape}.")
446
+
447
+ if self.infinite_params.get('kernel', FUNCTION_DEFAULTS['kernel']):
448
+ prob = setup_cvx_problem_calib(
449
+ quantile,
450
+ self.x_calib,
451
+ self.scores_calib,
452
+ self.phi_calib,
453
+ self.infinite_params
454
+ )
455
+ if "MOSEK" in cp.installed_solvers():
456
+ prob.solve(solver="MOSEK", verbose=False)
457
+ else:
458
+ prob.solve()
459
+
460
+ weights = prob.var_dict['weights'].value
461
+ beta = prob.constraints[-1].dual_value
462
+ K = pairwise_kernels(
463
+ X=x,
464
+ Y=self.x_calib,
465
+ metric=self.infinite_params.get("kernel", FUNCTION_DEFAULTS["kernel"]),
466
+ gamma=self.infinite_params.get("gamma", FUNCTION_DEFAULTS["gamma"])
467
+ )
468
+ threshold = K @ weights + self.Phi_fn(x) @ beta
469
+ else:
470
+ S = np.concatenate([self.scores_calib, [S]], dtype=float)
471
+ Phi = self.phi_calib.astype(float)
472
+ zeros = np.zeros((Phi.shape[1],))
473
+
474
+ if quantile is None:
475
+ bounds = np.concatenate((self.quantile_calib - 1, self.quantile_calib), axis=1)
476
+ else:
477
+ bounds = [(quantile - 1, quantile)] * (len(self.scores_calib) + 1)
478
+ res = linprog(-1 * S, A_eq=Phi.T, b_eq=zeros, bounds=bounds, method='highs')
479
+ beta = -1 * res.eqlin.marginals
480
+ threshold = self.Phi_fn(x) @ beta
481
+
482
+ return score_inv_fn(threshold, x)
483
+
484
+ def verify_coverage(
485
+ self,
486
+ x : np.ndarray,
487
+ y : np.ndarray,
488
+ quantile : float,
489
+ randomize : bool = False,
490
+ resolve : bool = False,
491
+ return_dual : bool = False,
492
+ eps : float = 0.001
493
+ ):
494
+ """
495
+ In some experiments, we may simply be interested in verifying the coverage of our method.
496
+ In this case, we do not need to binary search for the threshold S^*, but only need to verify that
497
+ S <= f_S(x) for the true value of S. This function implements this check for test points
498
+ denoted by x and y
499
+
500
+ Arguments
501
+ ---------
502
+ x : np.ndarray
503
+ A vector of test covariates
504
+ y : np.ndarray
505
+ A vector of test labels
506
+ quantile : float
507
+ Nominal quantile level
508
+ resolve : bool
509
+ Resolve LP/QP with posited value to determine coverage
510
+
511
+ Returns
512
+ -------
513
+ coverage_booleans : np.ndarray
514
+ """
515
+ covers = []
516
+ duals = []
517
+
518
+ if quantile is None:
519
+ quantiles = np.concatenate((self.quantile_calib, [[0.]]), axis=0).flatten()
520
+ else:
521
+ quantiles = quantile * np.ones((len(self.scores_calib) + 1, 1))
522
+
523
+ if self.infinite_params.get('kernel', FUNCTION_DEFAULTS['kernel']):
524
+ for x_val, y_val in zip(x, y):
525
+ S_true = self.score_fn(x_val.reshape(1,-1), y_val)
526
+ eta = self._get_dual_solution(S_true[0], x_val.reshape(1,-1), quantiles) # no need to recompute quantiles
527
+ if randomize:
528
+ threshold = self.rng.uniform(low=quantile - 1, high=quantile)
529
+ elif quantile > 0.5:
530
+ threshold = quantile - eps
531
+ else:
532
+ threshold = quantile - 1 + eps
533
+ if quantile > 0.5:
534
+ covers.append(eta[-1] < threshold)
535
+ else:
536
+ covers.append(eta[-1] > threshold)
537
+ duals.append(eta[-1])
538
+
539
+ else:
540
+ for x_val, y_val in zip(x, y):
541
+ if randomize:
542
+ threshold = self.rng.uniform(low=quantiles[-1] - 1, high=quantiles[-1])
543
+ elif quantiles[-1] > 0.5:
544
+ threshold = quantiles[-1]
545
+ else:
546
+ threshold = quantiles[-1] - 1
547
+
548
+ S_true = self.score_fn(x_val.reshape(1,-1), y_val)
549
+ if resolve:
550
+ eta = self._get_dual_solution(S_true[0], x_val.reshape(1,-1), quantile)
551
+ if quantile > 0.5:
552
+ covers.append(eta[-1] < threshold)
553
+ else:
554
+ covers.append(eta[-1] > threshold)
555
+ duals.append(eta[-1])
556
+ else:
557
+ naive_duals, naive_primals = self._get_calibration_solution(
558
+ quantile
559
+ )
560
+ score_cutoff = self._compute_exact_cutoff(
561
+ quantiles,
562
+ naive_primals,
563
+ naive_duals,
564
+ self.Phi_fn(x_val),
565
+ threshold
566
+ )
567
+ if quantile > 0.5:
568
+ covers.append(S_true < score_cutoff)
569
+ else:
570
+ covers.append(S_true > score_cutoff)
571
+ duals.append(np.nan)
572
+ if return_dual:
573
+ return np.asarray(covers), np.asarray(duals)
574
+ return np.asarray(covers)
575
+
576
+ def _get_dual_solution(
577
+ self,
578
+ S : float,
579
+ x : np.ndarray,
580
+ quantiles : np.ndarray
581
+ ):
582
+ if self.infinite_params.get("kernel", FUNCTION_DEFAULTS['kernel']):
583
+ prob = finish_dual_setup(
584
+ self.cvx_problem,
585
+ S,
586
+ x,
587
+ quantiles[-1][0],
588
+ self.Phi_fn(x),
589
+ self.x_calib,
590
+ self.infinite_params
591
+ )
592
+ if "MOSEK" in cp.installed_solvers():
593
+ prob.solve(solver="MOSEK")
594
+ else:
595
+ prob.solve()
596
+ # TODO: THIS IS WRONG
597
+ #raise ValueError("need to get variable out of problem and return its value")
598
+ return prob.var_dict['weights'].value
599
+ else:
600
+ S = np.concatenate([self.scores_calib, [S]])
601
+ Phi = np.concatenate([self.phi_calib, self.Phi_fn(x)], axis=0)
602
+ zeros = np.zeros((Phi.shape[1],))
603
+ bounds = np.concatenate((quantiles - 1, quantiles), axis=1)
604
+ res = linprog(-1 * S, A_eq=Phi.T, b_eq=zeros, bounds=bounds,
605
+ method='highs-ds', options={'presolve': False})
606
+ eta = res.x
607
+ return eta
608
+
609
+
610
+ def _get_primal_solution(
611
+ self,
612
+ S : float,
613
+ x : np.ndarray,
614
+ quantiles : np.ndarray
615
+ ):
616
+ if self.infinite_params.get("kernel", FUNCTION_DEFAULTS['kernel']):
617
+ prob = finish_dual_setup(
618
+ self.cvx_problem,
619
+ S,
620
+ x,
621
+ quantiles[-1][0],
622
+ self.Phi_fn(x),
623
+ self.x_calib,
624
+ self.infinite_params
625
+ )
626
+ if "MOSEK" in cp.installed_solvers():
627
+ prob.solve(solver="MOSEK")
628
+ else:
629
+ prob.solve()
630
+
631
+ weights = prob.var_dict['weights'].value
632
+ beta = prob.constraints[-1].dual_value
633
+ else:
634
+ S = np.concatenate([self.scores_calib, [S]])
635
+ Phi = np.concatenate([self.phi_calib, self.Phi_fn(x)], axis=0)
636
+ zeros = np.zeros((Phi.shape[1],))
637
+ bounds = np.concatenate((quantiles - 1, quantiles), axis=1)
638
+ res = linprog(-1 * S, A_eq=Phi.T, b_eq=zeros, bounds=bounds,
639
+ method='highs-ds', options={'presolve': False})
640
+ beta = -1 * res.eqlin.marginals
641
+ weights = None
642
+ return beta, weights
643
+
644
+ def _get_threshold(
645
+ self,
646
+ S : float,
647
+ x : np.ndarray,
648
+ quantiles : np.ndarray
649
+ ):
650
+ beta, weights = self._get_primal_solution(S, x, quantiles)
651
+
652
+ threshold = self.Phi_fn(x) @ beta
653
+ if self.infinite_params.get('kernel', FUNCTION_DEFAULTS['kernel']):
654
+ K = pairwise_kernels(
655
+ X=np.concatenate([self.x_calib, x.reshape(1,-1)], axis=0),
656
+ Y=np.concatenate([self.x_calib, x.reshape(1,-1)], axis=0),
657
+ metric=self.infinite_params.get("kernel", FUNCTION_DEFAULTS["kernel"]),
658
+ gamma=self.infinite_params.get("gamma", FUNCTION_DEFAULTS["gamma"])
659
+ )
660
+ threshold = (K @ weights)[-1] + threshold
661
+ return threshold
662
+
663
+
664
+ def binary_search(func, min, max, tol=1e-3):
665
+ min, max = float(min), float(max)
666
+ assert (max + tol) > max
667
+ while (max - min) > tol:
668
+ mid = (min + max) / 2
669
+ if func(mid) > 0:
670
+ max = mid
671
+ else:
672
+ min = mid
673
+ return min, max
674
+
675
+
676
+ def _solve_dual(S, gcc, x_test, quantiles, threshold=None):
677
+ if gcc.infinite_params.get('kernel', None):
678
+ prob = finish_dual_setup(
679
+ gcc.cvx_problem,
680
+ S,
681
+ x_test,
682
+ quantiles[-1][0],
683
+ gcc.Phi_fn(x_test),
684
+ gcc.x_calib,
685
+ gcc.infinite_params
686
+ )
687
+ if "MOSEK" in cp.installed_solvers():
688
+ prob.solve(solver="MOSEK")
689
+ else:
690
+ prob.solve(solver="OSQP")
691
+ weights = prob.var_dict['weights'].value
692
+ else:
693
+ S = np.concatenate([gcc.scores_calib, [S]], dtype=float)
694
+ Phi = np.concatenate([gcc.phi_calib, gcc.Phi_fn(x_test)], axis=0, dtype=float)
695
+ zeros = np.zeros((Phi.shape[1],))
696
+
697
+ bounds = np.concatenate((quantiles - 1, quantiles), axis=1)
698
+ res = linprog(-1 * S, A_eq=Phi.T, b_eq=zeros, bounds=bounds,
699
+ method='highs', options={'presolve': False})
700
+ weights = res.x
701
+
702
+ if threshold is None:
703
+ if quantiles[-1] < 0.5:
704
+ threshold = quantiles[-1] - 1
705
+ else:
706
+ threshold = quantiles[-1]
707
+ # if quantile < 0.5:
708
+ # return weights[-1] + (1 - quantile)
709
+ return weights[-1] - threshold
710
+
711
+
712
+ def setup_cvx_problem(
713
+ x_calib,
714
+ scores_calib,
715
+ phi_calib,
716
+ infinite_params = {}
717
+ ):
718
+ n_calib = len(scores_calib)
719
+ if phi_calib is None:
720
+ phi_calib = np.ones((n_calib,1))
721
+
722
+ eta = cp.Variable(name="weights", shape=n_calib + 1)
723
+
724
+ quantile = cp.Parameter(name="quantile")
725
+
726
+ scores_const = cp.Constant(scores_calib.reshape(-1,1))
727
+ scores_param = cp.Parameter(name="S_test", shape=(1,1))
728
+ scores = cp.vstack([scores_const, scores_param])
729
+
730
+ Phi_calibration = cp.Constant(phi_calib)
731
+ Phi_test = cp.Parameter(name="Phi_test", shape=(1, phi_calib.shape[1]))
732
+ Phi = cp.vstack([Phi_calibration, Phi_test])
733
+
734
+ kernel = infinite_params.get("kernel", FUNCTION_DEFAULTS["kernel"])
735
+ gamma = infinite_params.get("gamma", FUNCTION_DEFAULTS["gamma"])
736
+
737
+ if kernel is None: # no RKHS fitting
738
+ constraints = [
739
+ (quantile - 1) <= eta,
740
+ quantile >= eta,
741
+ eta.T @ Phi == 0
742
+ ]
743
+ prob = cp.Problem(
744
+ cp.Minimize(-1 * cp.sum(cp.multiply(eta, cp.vec(scores)))),
745
+ constraints
746
+ )
747
+ else: # RKHS fitting
748
+ radius = cp.Parameter(name="radius", nonneg=True)
749
+
750
+ _, L_11 = _get_kernel_matrix(x_calib, kernel, gamma)
751
+
752
+ L_11_const = cp.Constant(
753
+ np.hstack([L_11, np.zeros((L_11.shape[0], 1))])
754
+ )
755
+ L_21_22_param = cp.Parameter(name="L_21_22", shape=(1, n_calib + 1))
756
+ L = cp.vstack([L_11_const, L_21_22_param])
757
+
758
+ C = radius / (n_calib + 1)
759
+
760
+ # this is really C * (quantile - 1) and C * quantile
761
+ constraints = [
762
+ (quantile - 1) <= eta,
763
+ quantile >= eta,
764
+ eta.T @ Phi == 0]
765
+ prob = cp.Problem(
766
+ cp.Minimize(0.5 * C * cp.sum_squares(L.T @ eta) - cp.sum(cp.multiply(eta, cp.vec(scores)))),
767
+ constraints
768
+ )
769
+ return prob
770
+
771
+
772
+ def _get_kernel_matrix(x_calib, kernel, gamma):
773
+ K = pairwise_kernels(
774
+ X=x_calib,
775
+ metric=kernel,
776
+ gamma=gamma
777
+ ) + 1e-5 * np.eye(len(x_calib))
778
+
779
+ K_chol = np.linalg.cholesky(K)
780
+ return K, K_chol
781
+
782
+
783
+ def finish_dual_setup(
784
+ prob : cp.Problem,
785
+ S : np.ndarray,
786
+ X : np.ndarray,
787
+ quantile : float,
788
+ Phi : np.ndarray,
789
+ x_calib : np.ndarray,
790
+ infinite_params = {}
791
+ ):
792
+ prob.param_dict['S_test'].value = np.asarray([[S]])
793
+ prob.param_dict['Phi_test'].value = Phi.reshape(1,-1)
794
+ prob.param_dict['quantile'].value = quantile
795
+
796
+ kernel = infinite_params.get('kernel', FUNCTION_DEFAULTS['kernel'])
797
+ gamma = infinite_params.get('gamma', FUNCTION_DEFAULTS['gamma'])
798
+ radius = 1 / infinite_params.get('lambda', FUNCTION_DEFAULTS['lambda'])
799
+
800
+ if kernel is not None:
801
+ K_12 = pairwise_kernels(
802
+ X=np.concatenate([x_calib, X.reshape(1,-1)], axis=0),
803
+ Y=X.reshape(1,-1),
804
+ metric=kernel,
805
+ gamma=gamma
806
+ )
807
+
808
+ if 'K_12' in prob.param_dict:
809
+ prob.param_dict['K_12'].value = K_12[:-1]
810
+ prob.param_dict['K_21'].value = K_12.T
811
+
812
+ _, L_11 = _get_kernel_matrix(x_calib, kernel, gamma)
813
+ K_22 = pairwise_kernels(
814
+ X=X.reshape(1,-1),
815
+ metric=kernel,
816
+ gamma=gamma
817
+ )
818
+ L_21 = np.linalg.solve(L_11, K_12[:-1]).T
819
+ L_22 = K_22 - L_21 @ L_21.T
820
+ L_22[L_22 < 0] = 0
821
+ L_22 = np.sqrt(L_22)
822
+ prob.param_dict['L_21_22'].value = np.hstack([L_21, L_22])
823
+
824
+ prob.param_dict['radius'].value = radius
825
+
826
+ # update quantile definition for silly cvxpy reasons
827
+ prob.param_dict['quantile'].value = quantile
828
+ #prob.param_dict['quantile'].value *= radius / (len(x_calib) + 1)
829
+
830
+ return prob
831
+
832
+ def setup_cvx_problem_calib(
833
+ quantile,
834
+ x_calib,
835
+ scores_calib,
836
+ phi_calib,
837
+ infinite_params = {}
838
+ ):
839
+ n_calib = len(scores_calib)
840
+ if phi_calib is None:
841
+ phi_calib = np.ones((n_calib,1))
842
+
843
+ eta = cp.Variable(name="weights", shape=n_calib)
844
+
845
+ scores = cp.Constant(scores_calib.reshape(-1,1))
846
+
847
+ Phi = cp.Constant(phi_calib)
848
+
849
+ kernel = infinite_params.get("kernel", FUNCTION_DEFAULTS["kernel"])
850
+ gamma = infinite_params.get("gamma", FUNCTION_DEFAULTS["gamma"])
851
+
852
+ if kernel is None: # no RKHS fitting
853
+ constraints = [
854
+ (quantile - 1) <= eta,
855
+ quantile >= eta,
856
+ eta.T @ Phi == 0
857
+ ]
858
+ prob = cp.Problem(
859
+ cp.Minimize(-1 * cp.sum(cp.multiply(eta, cp.vec(scores)))),
860
+ constraints
861
+ )
862
+ else: # RKHS fitting
863
+ radius = 1 / infinite_params.get('lambda', FUNCTION_DEFAULTS['lambda'])
864
+
865
+ _, L = _get_kernel_matrix(x_calib, kernel, gamma)
866
+
867
+ C = radius / (n_calib + 1)
868
+
869
+ constraints = [
870
+ (quantile - 1) <= eta,
871
+ quantile >= eta,
872
+ eta.T @ Phi == 0]
873
+ prob = cp.Problem(
874
+ cp.Minimize(0.5 * C * cp.sum_squares(L.T @ eta) - cp.sum(cp.multiply(eta, cp.vec(scores)))),
875
+ constraints
876
+ )
877
+ return prob
MACI-main/conditional-conformal/conditionalconformal/experiment_utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from sklearn.linear_model import LinearRegression
5
+ from quantile_forest import RandomForestQuantileRegressor
6
+
7
+ from conditionalconformal import CondConf
8
+
9
+ ## get base model for constructing scores
10
+ def fit_model(data_train, base_model):
11
+ x_train, y_train = data_train
12
+ if base_model == "ols":
13
+ reg = LinearRegression().fit(x_train, y_train)
14
+ elif base_model == "qrf":
15
+ reg = RandomForestQuantileRegressor()
16
+ reg.fit(x_train, y_train)
17
+ elif base_model == "qr":
18
+ reg = CondConf(score_fn = lambda x, y: y, Phi_fn = lambda x: x)
19
+ reg.setup_problem(x_train, y_train)
20
+ # overwrite prediction function so it looks like a regression object
21
+ reg.predict = lambda x, q: (x @ reg._get_calibration_solution(q)[1]).flatten() # expects x to be of form (n_points, n_feats)
22
+ return reg
23
+
24
+ # helper function for splitting dataset
25
+ def split_dataset(dataset, n_test, n_calib, rng):
26
+ X, Y = dataset
27
+ data_indices = np.arange(len(X))
28
+ rng.shuffle(data_indices)
29
+ test_indices, calib_indices, train_indices = np.array_split(
30
+ data_indices,
31
+ np.cumsum([n_test, n_calib])
32
+ )
33
+
34
+ X_test = X[test_indices]
35
+ Y_test = Y[test_indices]
36
+
37
+ X_calib = X[calib_indices]
38
+ Y_calib = Y[calib_indices]
39
+
40
+ X_train = X[train_indices]
41
+ Y_train = Y[train_indices]
42
+ return (X_train, Y_train), (X_calib, Y_calib), (X_test, Y_test)
43
+
44
+ # get coverages for each method type...
45
+ def get_coverage(dataset_calib, dataset_test, score_fn, method, quantile):
46
+ if method == "split":
47
+ scores_calib = score_fn(*dataset_calib)
48
+ scores_test = score_fn(*dataset_test)
49
+
50
+ score_cutoff = np.quantile(
51
+ scores_calib,
52
+ [quantile * (1 + 1/len(scores_calib))]
53
+ )
54
+ if quantile >= 0.5:
55
+ covs = scores_test <= score_cutoff
56
+ else:
57
+ covs = scores_test >= score_cutoff
58
+ elif "rand" in method:
59
+ condcalib = CondConf(score_fn, lambda x: x)
60
+ condcalib.setup_problem(*dataset_calib)
61
+ X_test, Y_test = dataset_test
62
+ covs = condcalib.verify_coverage(X_test, Y_test, quantile, resolve=True, randomize=True)
63
+ else:
64
+ condcalib = CondConf(score_fn, lambda x: x)
65
+ condcalib.setup_problem(*dataset_calib)
66
+ X_test, Y_test = dataset_test
67
+ covs = condcalib.verify_coverage(X_test, Y_test, quantile, resolve=True, randomize=False)
68
+ return covs
69
+
70
+ # get coverages for each method type...
71
+ def get_cutoff(dataset_calib, dataset_test, score_fn, method, quantile):
72
+ print(method, quantile)
73
+ scores_test = score_fn(*dataset_test)
74
+ if method == "split":
75
+ scores_calib = score_fn(*dataset_calib)
76
+ score_cutoff = np.quantile(
77
+ scores_calib,
78
+ [quantile * (1 + 1/len(scores_calib))]
79
+ )
80
+ cutoffs = np.ones((len(scores_test,))) * score_cutoff
81
+ elif "rand" in method:
82
+ condcalib = CondConf(score_fn, lambda x: x)
83
+ condcalib.setup_problem(*dataset_calib)
84
+ cutoffs = []
85
+ for x in dataset_test[0]:
86
+ cutoff = condcalib.predict(quantile, x, lambda c, x: c, randomize=True)
87
+ cutoffs.append(cutoff)
88
+ cutoffs = np.asarray(cutoffs)
89
+ else:
90
+ condcalib = CondConf(score_fn, lambda x: x)
91
+ condcalib.setup_problem(*dataset_calib)
92
+ cutoffs = []
93
+ for x in dataset_test[0]:
94
+ cutoff = condcalib.predict(quantile, x, lambda c, x: c, randomize=False)
95
+ cutoffs.append(cutoff)
96
+ cutoffs = np.asarray(cutoffs)
97
+ if quantile > 0.5:
98
+ coverages = scores_test <= cutoffs.flatten()
99
+ else:
100
+ coverages = scores_test >= cutoffs.flatten()
101
+ return cutoffs, coverages
102
+
103
+
104
+ def run_coverage_experiment(dataset, n_test, n_calib, alpha, methods = [], seed = 0):
105
+ rng = np.random.default_rng(seed=seed)
106
+
107
+ dataset_train, dataset_calib, dataset_test = split_dataset(
108
+ dataset,
109
+ n_test,
110
+ n_calib,
111
+ rng
112
+ )
113
+
114
+ ### Compute conformity scores
115
+ base_methods = set([m.split('-')[0] for m in methods])
116
+ base_model = {base : fit_model(dataset_train, base) for base in base_methods}
117
+
118
+ coverages = []
119
+ # example methods: (BASE_METHOD)-(CONFORMAL_METHOD)
120
+ # BASE_METHOD valid choices: "ols", "qr", "qrf"
121
+ # CONFORMAL_METHOD valid choices: "split", "cc", "ccrand", "lcp", "rlcp" (todo on last two)
122
+ for method in methods:
123
+ base_method, conformal_method = method.split('-')
124
+ reg = base_model[base_method]
125
+ if "q" in base_method: # if a quantile regression score needs to specify quantile
126
+ score_fn_upper = lambda x, y: y - reg.predict(x, 1 - alpha/2)
127
+ score_fn_lower = lambda x, y: y - reg.predict(x, alpha/2)
128
+ else:
129
+ score_fn_upper = lambda x, y: y - reg.predict(x)
130
+ score_fn_lower = lambda x, y: y - reg.predict(x)
131
+ covers_upper = get_coverage(dataset_calib, dataset_test, score_fn_upper, conformal_method, 1 - alpha/2)
132
+ covers_lower = get_coverage(dataset_calib, dataset_test, score_fn_lower, conformal_method, alpha/2)
133
+ covers = np.logical_and(covers_upper, covers_lower)
134
+ coverages.append(covers)
135
+
136
+ return dataset_test[0], coverages
137
+
138
+
139
+ def run_experiment(dataset, n_test, n_calib, alpha, methods = [], seed = 0):
140
+ rng = np.random.default_rng(seed=seed)
141
+
142
+ dataset_train, dataset_calib, dataset_test = split_dataset(
143
+ dataset,
144
+ n_test,
145
+ n_calib,
146
+ rng
147
+ )
148
+
149
+ ### Compute conformity scores
150
+ base_model = {base : fit_model(dataset_train, base) for base in ["ols", "qrf", "qr"]}
151
+
152
+ all_lengths = []
153
+ all_coverages = []
154
+ # example methods: (BASE_METHOD)-(CONFORMAL_METHOD)
155
+ # BASE_METHOD valid choices: "ols", "qr", "qrf"
156
+ # CONFORMAL_METHOD valid choices: "split", "cc", "ccrand", "lcp", "ccqp"
157
+ for method in methods:
158
+ base_method, conformal_method = method.split('-')
159
+ reg = base_model[base_method]
160
+ if "qrf" in base_method: # if a quantile regression score needs to specify quantile
161
+ score_fn_upper = lambda x, y: y - reg.predict(x, 1 - alpha/2) + rng.uniform(0, 1e-5, size=len(x))
162
+ score_fn_lower = lambda x, y: y - reg.predict(x, alpha/2) + rng.uniform(0, 1e-5, size=len(x))
163
+ elif "q" in base_method:
164
+ score_fn_upper = lambda x, y: y - reg.predict(x, 1 - alpha/2)
165
+ score_fn_lower = lambda x, y: y - reg.predict(x, alpha/2)
166
+ else:
167
+ score_fn_upper = lambda x, y: y - reg.predict(x)
168
+ score_fn_lower = lambda x, y: y - reg.predict(x)
169
+ cutoffs_upper, cov_upper = get_cutoff(dataset_calib, dataset_test, score_fn_upper, conformal_method, 1 - alpha/2)
170
+ cutoffs_lower, cov_lower = get_cutoff(dataset_calib, dataset_test, score_fn_lower, conformal_method, alpha/2)
171
+ if "q" in base_method:
172
+ pred_upper = reg.predict(dataset_test[0], 1 - alpha/2)
173
+ pred_lower = reg.predict(dataset_test[0], alpha/2)
174
+ pred_gap = pred_upper - pred_lower
175
+ else:
176
+ pred_gap = 0
177
+ lengths = cutoffs_upper - cutoffs_lower + pred_gap
178
+ coverage = np.logical_and(cov_upper, cov_lower)
179
+ all_lengths.append(lengths)
180
+ all_coverages.append(coverage)
181
+
182
+ return dataset_test[0], (all_lengths, all_coverages)
MACI-main/conditional-conformal/conditionalconformal/synthetic_data.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def generate_cqr_data(seed,n_train=2000,n_calib=1000,n_test=500):
4
+ np.random.seed(seed)
5
+
6
+ n_train = n_train + n_calib
7
+
8
+ def f(x):
9
+ ''' Construct data (1D example)
10
+ '''
11
+ ax = 0*x
12
+ for i in range(len(x)):
13
+ ax[i] = np.random.poisson(np.sin(x[i])**2+0.1) + 0.03*x[i]*np.random.randn(1)
14
+ ax[i] += 25*(np.random.uniform(0,1,1)<0.01)*np.random.randn(1)
15
+ return ax.astype(np.float32)
16
+
17
+ # training features
18
+ x_train = np.random.uniform(0, 5.0, size=n_train).astype(np.float32)
19
+
20
+ # test features
21
+ x_test = np.random.uniform(0, 5.0, size=n_test).astype(np.float32)
22
+
23
+ # generate labels
24
+ y_train = f(x_train)
25
+ y_test = f(x_test)
26
+
27
+ # reshape the features
28
+ x_train = np.reshape(x_train,(n_train,1))
29
+ x_test = np.reshape(x_test,(n_test,1))
30
+
31
+ train_set_size = len(y_train) - n_calib
32
+ x_train_final = x_train[ : train_set_size]
33
+ x_calib = x_train[train_set_size : ]
34
+ y_train_final = y_train[ : train_set_size]
35
+ y_calib = y_train[train_set_size : ]
36
+
37
+ return x_train_final, y_train_final, x_calib, y_calib, x_test, y_test
38
+
39
+
40
+ def indicator_matrix(scalar_values, disc):
41
+ scalar_values = np.array(scalar_values)
42
+
43
+ # Create all possible intervals
44
+ intervals = [(disc[i], disc[i + 1]) for i in range(len(disc) - 1)]
45
+
46
+ # Initialize the indicator matrix
47
+ matrix = np.zeros((len(scalar_values), len(intervals)))
48
+
49
+ # Fill in the indicator matrix
50
+ for i, value in enumerate(scalar_values):
51
+ for j, (a, b) in enumerate(intervals):
52
+ if a <= value < b:
53
+ matrix[i, j] = 1
54
+
55
+ return matrix
MACI-main/conditional-conformal/src/atomizer.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import re
4
+ import string
5
+ import spacy
6
+ import nltk
7
+ from rank_bm25 import BM25Okapi
8
+ import os
9
+
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from nltk.tokenize import sent_tokenize
12
+
13
+ nltk.download("punkt")
14
+
15
+
16
+ class Atomizer(object):
17
+ def __init__(self, client, demo_dir):
18
+ self.nlp = spacy.load("en_core_web_sm")
19
+ self.is_bio = True
20
+ self.demo_path = os.path.join(demo_dir, "demos.json" if self.is_bio else "demos_complex.json")
21
+
22
+ self.client = client
23
+
24
+ # get the demos
25
+ with open(self.demo_path, 'r') as f:
26
+ self.demos = json.load(f)
27
+
28
+ tokenized_corpus = [doc.split(" ") for doc in self.demos.keys()]
29
+ self.bm25 = BM25Okapi(tokenized_corpus)
30
+
31
+ def save_cache(self):
32
+ self.client.save_cache()
33
+
34
+ def run(self, generation, cost_estimate=None):
35
+ """Convert the generation into a set of atomic facts. Return a total words cost if cost_estimate != None."""
36
+ assert isinstance(generation, str), "generation must be a string"
37
+ paragraphs = [para.strip() for para in generation.split("\n") if len(para.strip()) > 0]
38
+ return self.get_atomic_facts_from_paragraph(paragraphs, cost_estimate=cost_estimate)
39
+
40
+ def get_atomic_facts_from_paragraph(self, paragraphs, cost_estimate=None):
41
+ sentences = []
42
+ para_breaks = []
43
+ for para_idx, paragraph in enumerate(paragraphs):
44
+ if para_idx > 0 :
45
+ para_breaks.append(len(sentences))
46
+
47
+ initials = detect_initials(paragraph)
48
+
49
+ curr_sentences = sent_tokenize(paragraph)
50
+ curr_sentences_2 = sent_tokenize(paragraph)
51
+
52
+ curr_sentences = fix_sentence_splitter(curr_sentences, initials)
53
+ curr_sentences_2 = fix_sentence_splitter(curr_sentences_2, initials)
54
+
55
+ # checking this, just to ensure the crediability of the sentence splitter fixing algorithm
56
+ assert curr_sentences == curr_sentences_2, (paragraph, curr_sentences, curr_sentences_2)
57
+
58
+ sentences += curr_sentences
59
+
60
+ atoms_or_estimate = self.get_init_atomic_facts_from_sentence([sent for i, sent in enumerate(sentences) if not (not self.is_bio and ( \
61
+ (i==0 and (sent.startswith("Sure") or sent.startswith("Here are"))) or \
62
+ (i==len(sentences)-1 and (sent.startswith("Please") or sent.startswith("I hope") or sent.startswith("Here are")))))], cost_estimate=cost_estimate)
63
+
64
+ if cost_estimate:
65
+ return atoms_or_estimate
66
+ else:
67
+ atoms = atoms_or_estimate
68
+ atomic_facts_pairs = []
69
+ for i, sent in enumerate(sentences):
70
+ if not self.is_bio and ( \
71
+ (i==0 and (sent.startswith("Sure") or sent.startswith("Here are"))) or \
72
+ (i==len(sentences)-1 and (sent.startswith("Please") or sent.startswith("I hope") or sent.startswith("Here are")))):
73
+ atomic_facts_pairs.append((sent, []))
74
+ elif self.is_bio and sent.startswith("This sentence does not contain any facts"):
75
+ atomic_facts_pairs.append((sent, []))
76
+ elif sent.startswith("Sure") or sent.startswith("Please") or (i==0 and sent.startswith("Here are")):
77
+ atomic_facts_pairs.append((sent, []))
78
+ else:
79
+ atomic_facts_pairs.append((sent, atoms[sent]))
80
+
81
+ # postprocess_atomic_facts will fix minor issues from InstructGPT
82
+ # it is supposed to handle sentence splitter issue too, but since here
83
+ # we fixed sentence splitter issue already,
84
+ # the new para_breaks should be identical to the original para_breaks
85
+ if self.is_bio:
86
+ atomic_facts_pairs, para_breaks = postprocess_atomic_facts(atomic_facts_pairs, list(para_breaks), self.nlp)
87
+
88
+ return atomic_facts_pairs, para_breaks
89
+
90
+
91
+ def get_init_atomic_facts_from_sentence(self, sentences, cost_estimate=None):
92
+ """Get the initial atomic facts from the sentences. Return a total words cost if cost_estimate != None."""
93
+
94
+ is_bio = self.is_bio
95
+ demos = self.demos
96
+
97
+ k = 1 if is_bio else 0
98
+ n = 7 if is_bio else 8
99
+
100
+ prompts = []
101
+ prompt_to_sent = {}
102
+ atoms = {}
103
+ for sentence in sentences:
104
+ if sentence in atoms:
105
+ continue
106
+ top_matchings = best_demos(sentence, self.bm25, list(demos.keys()), k)
107
+ prompt = ""
108
+
109
+ for i in range(n):
110
+ prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(list(demos.keys())[i])
111
+ for fact in demos[list(demos.keys())[i]]:
112
+ prompt = prompt + "- {}\n".format(fact)
113
+ prompt = prompt + "\n"
114
+
115
+ for match in top_matchings:
116
+ prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(match)
117
+ for fact in demos[match]:
118
+ prompt = prompt + "- {}\n".format(fact)
119
+ prompt = prompt + "\n"
120
+ prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(sentence)
121
+ prompts.append(prompt)
122
+ prompt_to_sent[prompt] = sentence
123
+
124
+ if cost_estimate:
125
+ total_words_estimate = 0
126
+ for prompt in prompts:
127
+ if cost_estimate == "consider_cache" and (prompt.strip() + "_0") in self.client.cache_dict:
128
+ continue
129
+ total_words_estimate += len(prompt.split())
130
+ return total_words_estimate
131
+ else:
132
+ outputs = []
133
+
134
+ with ThreadPoolExecutor(max_workers=len(prompts)) as executor:
135
+ outputs = list(
136
+ executor.map(
137
+ lambda x : self.client.query(x),
138
+ prompts
139
+ )
140
+ )
141
+ for prompt, output in zip(prompts, outputs):
142
+ atoms[prompt_to_sent[prompt]] = text_to_sentences(output[0]['message'])
143
+ # for prompt in prompts:
144
+ # output = self.client.query(prompt)
145
+ # outputs.append(output)
146
+ # atoms[prompt_to_sent[prompt]] = text_to_sentences(output[0]['message'])
147
+
148
+ self.client.cache_outputs(
149
+ prompts=prompts,
150
+ sample_indices=np.zeros((len(prompts),), dtype=int),
151
+ outputs=outputs
152
+ )
153
+
154
+ for key, value in demos.items():
155
+ if key not in atoms:
156
+ atoms[key] = value
157
+
158
+ return atoms
159
+
160
+
161
+ def best_demos(query, bm25, demos_sents, k):
162
+ tokenized_query = query.split(" ")
163
+ top_matchings = bm25.get_top_n(tokenized_query, demos_sents, k)
164
+ return top_matchings
165
+
166
+
167
+ # transform InstructGPT output into sentences
168
+ def text_to_sentences(text):
169
+ sentences = text.split("- ")[1:]
170
+ sentences = [sent.strip()[:-1] if sent.strip()[-1] == '\n' else sent.strip() for sent in sentences]
171
+ if len(sentences) > 0:
172
+ if sentences[-1][-1] != '.':
173
+ sentences[-1] = sentences[-1] + '.'
174
+ else:
175
+ sentences = []
176
+ return sentences
177
+
178
+
179
+ def normalize_answer(s):
180
+ """Lower text and remove punctuation, articles and extra whitespace."""
181
+ def remove_articles(text):
182
+ regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
183
+ return re.sub(regex, ' ', text)
184
+ def white_space_fix(text):
185
+ return ' '.join(text.split())
186
+ def remove_punc(text):
187
+ exclude = set(string.punctuation)
188
+ return ''.join(ch for ch in text if ch not in exclude)
189
+ def lower(text):
190
+ return text.lower()
191
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
192
+
193
+ MONTHS = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"]
194
+ MONTHS = [m.lower() for m in MONTHS]
195
+
196
+ def is_num(text):
197
+ try:
198
+ text = int(text)
199
+ return True
200
+ except Exception:
201
+ return False
202
+
203
+ def is_date(text):
204
+ text = normalize_answer(text)
205
+ for token in text.split(" "):
206
+ if (not is_num(token)) and token not in MONTHS:
207
+ return False
208
+ return True
209
+
210
+ def extract_numeric_values(text):
211
+ pattern = r'\b\d+\b' # regular expression pattern for integers
212
+ numeric_values = re.findall(pattern, text) # find all numeric values in the text
213
+ return set([value for value in numeric_values]) # convert the values to float and return as a list
214
+
215
+
216
+ def detect_entities(text, nlp):
217
+ doc = nlp(text)
218
+ entities = set()
219
+
220
+ def _add_to_entities(text):
221
+ if "-" in text:
222
+ for _text in text.split("-"):
223
+ entities.add(_text.strip())
224
+ else:
225
+ entities.add(text)
226
+
227
+
228
+ for ent in doc.ents:
229
+ # spacy often has errors with other types of entities
230
+ if ent.label_ in ["DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"]:
231
+
232
+ if is_date(ent.text):
233
+ _add_to_entities(ent.text)
234
+ else:
235
+ for token in ent.text.split():
236
+ if is_date(token):
237
+ _add_to_entities(token)
238
+
239
+ for new_ent in extract_numeric_values(text):
240
+ if not np.any([new_ent in ent for ent in entities]):
241
+ entities.add(new_ent)
242
+
243
+ return entities
244
+
245
+ def postprocess_atomic_facts(_atomic_facts, para_breaks, nlp):
246
+
247
+ verbs = ["born.", " appointed.", " characterized.", " described.", " known.", " member.", " advocate.", "served.", "elected."]
248
+ permitted_verbs = ["founding member."]
249
+
250
+ atomic_facts = []
251
+ new_atomic_facts = []
252
+ new_para_breaks = []
253
+
254
+ for i, (sent, facts) in enumerate(_atomic_facts):
255
+ sent = sent.strip()
256
+ if len(sent.split())==1 and i not in para_breaks and i > 0:
257
+ assert i not in para_breaks
258
+ atomic_facts[-1][0] += " " + sent
259
+ atomic_facts[-1][1] += facts
260
+ else:
261
+ if i in para_breaks:
262
+ new_para_breaks.append(len(atomic_facts))
263
+ atomic_facts.append([sent, facts])
264
+
265
+ for i, (sent, facts) in enumerate(atomic_facts):
266
+ entities = detect_entities(sent, nlp)
267
+ covered_entities = set()
268
+ # print (entities)
269
+ new_facts = []
270
+ for i, fact in enumerate(facts):
271
+ if any([fact.endswith(verb) for verb in verbs]) and not any([fact.endswith(verb) for verb in permitted_verbs]):
272
+ if any([fact[:-1] in other_fact for j, other_fact in enumerate(facts) if j != i]):
273
+ continue
274
+ sent_entities = detect_entities(fact, nlp)
275
+ covered_entities |= set([e for e in sent_entities if e in entities])
276
+ new_entities = sent_entities - entities
277
+ if len(new_entities) > 0:
278
+ do_pass = False
279
+ for new_ent in new_entities:
280
+ pre_ent = None
281
+ for ent in entities:
282
+ if ent.startswith(new_ent):
283
+ pre_ent = ent
284
+ break
285
+ if pre_ent is None:
286
+ do_pass = True
287
+ break
288
+ fact = fact.replace(new_ent, pre_ent)
289
+ covered_entities.add(pre_ent)
290
+ if do_pass:
291
+ continue
292
+ if fact in new_facts:
293
+ continue
294
+ new_facts.append(fact)
295
+ try:
296
+ assert entities==covered_entities
297
+ except Exception:
298
+ new_facts = facts # there is a bug in spacy entity linker, so just go with the previous facts
299
+
300
+ new_atomic_facts.append((sent, new_facts))
301
+
302
+ return new_atomic_facts, new_para_breaks
303
+
304
+ def is_integer(s):
305
+ try:
306
+ s = int(s)
307
+ return True
308
+ except Exception:
309
+ return False
310
+
311
+ def detect_initials(text):
312
+ pattern = r"[A-Z]\. ?[A-Z]\."
313
+ match = re.findall(pattern, text)
314
+ return [m for m in match]
315
+
316
+ def fix_sentence_splitter(curr_sentences, initials):
317
+ for initial in initials:
318
+ if not np.any([initial in sent for sent in curr_sentences]):
319
+ alpha1, alpha2 = [t.strip() for t in initial.split(".") if len(t.strip())>0]
320
+ for i, (sent1, sent2) in enumerate(zip(curr_sentences, curr_sentences[1:])):
321
+ if sent1.endswith(alpha1 + ".") and sent2.startswith(alpha2 + "."):
322
+ # merge sentence i and i+1
323
+ curr_sentences = curr_sentences[:i] + [curr_sentences[i] + " " + curr_sentences[i+1]] + curr_sentences[i+2:]
324
+ break
325
+ sentences = []
326
+ combine_with_previous = None
327
+ for sent_idx, sent in enumerate(curr_sentences):
328
+ if len(sent.split())<=1 and sent_idx==0:
329
+ assert not combine_with_previous
330
+ combine_with_previous = True
331
+ sentences.append(sent)
332
+ elif len(sent.split())<=1:
333
+ assert sent_idx > 0
334
+ sentences[-1] += " " + sent
335
+ combined_with_previous = False
336
+ elif sent[0].isalpha() and not sent[0].isupper() and sent_idx > 0:
337
+ assert sent_idx > 0, curr_sentences
338
+ sentences[-1] += " " + sent
339
+ combine_with_previous = False
340
+ elif combine_with_previous:
341
+ assert sent_idx > 0
342
+ sentences[-1] += " " + sent
343
+ combine_with_previous = False
344
+ else:
345
+ assert not combine_with_previous
346
+ sentences.append(sent)
347
+ return sentences
MACI-main/conditional-conformal/src/aws_utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import boto3
2
+ import io
3
+
4
+ def s3_open(bucket_name, key):
5
+ # Create a session using your AWS credentials
6
+ session = boto3.Session()
7
+ # Create an S3 client
8
+ s3 = session.client('s3')
9
+
10
+ # Download the file object
11
+ response = s3.get_object(Bucket=bucket_name, Key=key)
12
+ file_content = response['Body'].read()
13
+
14
+ # Return a BytesIO object to mimic a file object
15
+ return io.BytesIO(file_content)
MACI-main/conditional-conformal/src/client.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import time
4
+
5
+ from typing import Any, List
6
+
7
+ class Client:
8
+ """
9
+ Wrapper class for language models that we query. It keeps a cache of prompts and
10
+ responses so that we don't have to requery things in experiments.
11
+ """
12
+
13
+ def __init__(self, cache_file, model : str = 'gpt-3.5-turbo'):
14
+ self.cache_file = cache_file
15
+ self.cache_dict = self.load_cache()
16
+ self.model = model
17
+ self.modified_cache = False
18
+
19
+ def load_model(self):
20
+ # load the model and put it as self.model
21
+ raise NotImplementedError()
22
+
23
+ def query(
24
+ self,
25
+ prompt : str,
26
+ sample_idx : int = 0,
27
+ **kwargs
28
+ ):
29
+ prompt = prompt.strip() # it's important not to end with a whitespace
30
+ cache_key = f"{prompt}_{sample_idx}"
31
+
32
+ if cache_key in self.cache_dict:
33
+ return self.cache_dict[cache_key]
34
+
35
+ if self.model is None:
36
+ self.load_model()
37
+ # print("I didn't find a cached copy!")
38
+ output = self._query(prompt, **kwargs)
39
+
40
+ return output
41
+
42
+ def cache_outputs(
43
+ self,
44
+ prompts : List[str],
45
+ sample_indices : List[int],
46
+ outputs : List[Any]
47
+ ):
48
+ for prompt, sample_idx, output in zip(prompts, sample_indices, outputs):
49
+ prompt = prompt.strip()
50
+ cache_key = f"{prompt}_{sample_idx}"
51
+ self.cache_dict[cache_key] = output
52
+ self.modified_cache = True
53
+
54
+ def save_cache(self):
55
+ if self.modified_cache == False:
56
+ return
57
+
58
+ # load the latest cache first, since if there were other processes running in parallel, cache might have been updated
59
+ for k, v in self.load_cache().items():
60
+ self.cache_dict[k] = v
61
+
62
+ with open(self.cache_file, "wb") as f:
63
+ pickle.dump(self.cache_dict, f)
64
+
65
+ def load_cache(self, allow_retry=True):
66
+ if os.path.exists(self.cache_file):
67
+ while True:
68
+ try:
69
+ with open(self.cache_file, "rb") as f:
70
+ cache = pickle.load(f)
71
+ break
72
+ except Exception: # if there are concurent processes, things can fail
73
+ if not allow_retry:
74
+ assert False
75
+ print ("Pickle Error: Retry in 5sec...")
76
+ time.sleep(5)
77
+ elif 's3' in self.cache_file:
78
+ from aws_utils import s3_open
79
+ s3_path = self.cache_file.removeprefix('s3://')
80
+ bucket_name = s3_path.split('/')[0]
81
+ path_to_file = '/'.join(s3_path.split('/')[1:])
82
+ with s3_open(bucket_name, path_to_file) as fp:
83
+ cache = pickle.load(fp)
84
+ else:
85
+ cache = {}
86
+ return cache
87
+
88
+
89
+
MACI-main/conditional-conformal/src/config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import munch
2
+ import toml
3
+
4
+ def get_config(filepath: str = 'configs/default.toml'):
5
+ return munch.munchify(
6
+ toml.load(filepath)
7
+ )
MACI-main/conditional-conformal/src/conformal.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from typing import Callable, List
4
+
5
+ def compute_conformity_scores(
6
+ dataset : List,
7
+ scores_list : List,
8
+ ):
9
+ annotations_list = [
10
+ np.asarray([c['is_supported'] for c in unit['atomic_facts']])
11
+ for unit in dataset
12
+ ]
13
+ conf_scores = [np.max(scores[~annotes], initial=0) for scores, annotes in zip(scores_list, annotations_list)]
14
+ return conf_scores
15
+
16
+ def calibrate_thresholds(
17
+ feats_test : List,
18
+ feats_valid : List,
19
+ scores_valid : List,
20
+ alpha_fn : Callable
21
+ ) -> List[float]:
22
+ alpha_valid = alpha_fn(feats_valid)
23
+ quantile = np.ceil((1 - alpha_valid[0]) * (len(feats_valid) + 1)) / len(feats_valid)
24
+ return [np.quantile(
25
+ scores_valid,
26
+ q=quantile
27
+ )] * len(feats_test)
28
+
29
+ def conformal_filter(
30
+ dataset : List,
31
+ scores_list : List,
32
+ thresholds : List
33
+ ) -> List:
34
+ for unit, scores, t in zip(dataset, scores_list, thresholds):
35
+ filtered_claims = [
36
+ c for c, s in zip(unit['atomic_facts'], scores) if s >= t
37
+ ]
38
+ unit['filtered_claims'] = filtered_claims
39
+ return dataset
40
+
41
+
42
+ def assess_factscore_coverage(
43
+ dataset : List,
44
+ nominal_alpha : float
45
+ ) -> None:
46
+ nonfactual_list = []
47
+ nonfactual_grps = {}
48
+ for d in dataset:
49
+ nonfactual = 'F' in [c['is_supported'] for c in d['filtered_claims']]
50
+ nonfactual_list.append(nonfactual)
51
+
52
+ # right now metadata is only *two* strings...TODO this needs to be more flexible
53
+ # if tuple(d['metadata']) not in nonfactual_grps:
54
+ # nonfactual_grps[tuple(d['metadata'])] = [nonfactual]
55
+ # else:
56
+ # nonfactual_grps[tuple(d['metadata'])].append(nonfactual)
57
+ # if d['metadata'][0] not in nonfactual_grps:
58
+ # nonfactual_grps[d['metadata'][0]] = [nonfactual]
59
+ # else:
60
+ # nonfactual_grps[d['metadata'][0]].append(nonfactual)
61
+ # if d['metadata'][1] not in nonfactual_grps:
62
+ # nonfactual_grps[d['metadata'][1]] = [nonfactual]
63
+ # else:
64
+ # nonfactual_grps[d['metadata'][1]].append(nonfactual)
65
+ print(f"Nominal coverage: {nominal_alpha}")
66
+ print(f"Realized marginal coverage: {np.mean(nonfactual_list)}")
67
+ # for grp, nonfactuals in nonfactual_grps.items():
68
+ # print(f"Realized {grp} coverage: {np.mean(nonfactuals)}")
MACI-main/conditional-conformal/src/data_utils/sample_names.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ import requests
4
+
5
+ from typing import Dict
6
+ from tqdm import tqdm
7
+
8
+ from concurrent.futures import ThreadPoolExecutor
9
+
10
+ ENTITY_PATH = '/data/jcherian/wikipedia_entity_map.npz'
11
+ WIKIDATA_URL = "https://www.wikidata.org/w/api.php"
12
+ logger = logging.getLogger(__name__)
13
+ logging.basicConfig(filename='human.log', level=logging.INFO)
14
+
15
+
16
+ def get_id(response : Dict) -> str:
17
+ if response.get("entities", None) is None:
18
+ return None
19
+ wikidata_codes = list(response['entities'].keys())
20
+ assert len(wikidata_codes) == 1
21
+ return wikidata_codes[0]
22
+
23
+
24
+ def is_human(response : Dict, id: str) -> bool:
25
+ instances = response['entities'][id]['claims'].get('P31', [])
26
+ for inst in instances:
27
+ if inst['mainsnak']['datavalue']['value']['id'] == 'Q5':
28
+ return True
29
+ return False
30
+
31
+ def validate_entity(k):
32
+ name = k.split('/')[-1]
33
+ adapter = requests.adapters.HTTPAdapter(max_retries=10)
34
+ with requests.session() as s:
35
+ s.mount("https://", adapter)
36
+ response = s.get(url=WIKIDATA_URL, params={"action" : "wbgetentities",
37
+ "sites" : "enwiki",
38
+ "titles" : name,
39
+ "normalize": "1",
40
+ "languages": "en",
41
+ "format": "json",
42
+ "props": "claims"})
43
+
44
+ try:
45
+ response = response.json()
46
+ except:
47
+ print(response.text)
48
+
49
+ wiki_id = get_id(response)
50
+
51
+ if wiki_id is None:
52
+ return name, False
53
+
54
+ try:
55
+ human = is_human(response, wiki_id)
56
+ except:
57
+ return name, False
58
+ logger.info(f"{name}, {human}")
59
+ return name, human
60
+
61
+
62
+ if __name__ == "__main__":
63
+ wiki_entities = np.load(ENTITY_PATH)
64
+ entity_names = list(wiki_entities.keys())
65
+ try:
66
+ with ThreadPoolExecutor(max_workers=5) as executor:
67
+ res = list(
68
+ tqdm(
69
+ executor.map(
70
+ lambda k : validate_entity(k),
71
+ entity_names
72
+ ),
73
+ total=len(entity_names)
74
+ )
75
+ )
76
+ except:
77
+ import pickle
78
+ with open('human.pkl', 'wb') as fp:
79
+ pickle.dump(res, fp)
80
+
81
+
82
+ import pickle
83
+ with open('human.pkl', 'wb') as fp:
84
+ pickle.dump(res, fp)
85
+
86
+ import IPython; IPython.embed()
MACI-main/conditional-conformal/src/dataset.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from tqdm import tqdm
4
+ from typing import List, Tuple
5
+
6
+ import json
7
+ import pandas as pd
8
+ import numpy as np
9
+ import os
10
+
11
+ from atomizer import Atomizer, text_to_sentences
12
+ from gpt import GPTClient
13
+ from scorer import Scorer
14
+
15
+ def get_prompts(
16
+ dataset : str,
17
+ data_path : str = None
18
+ ) -> List:
19
+ if dataset.lower() == "factscore":
20
+ with open('data/factscore_names.txt', 'r') as fp:
21
+ names = fp.readlines()
22
+ names = [name.strip() for name in names]
23
+ prompts = [
24
+ f"Please write one biographical paragraph about {name.strip()}."
25
+ for name in names
26
+ ]
27
+ return names, prompts
28
+ if dataset.lower() == "factscore_v2":
29
+ with open('data/factscore_v2_names.txt', 'r') as fp:
30
+ names = fp.readlines()
31
+ names = [name.strip() for name in names]
32
+ prompts = [
33
+ f"Please write one biographical paragraph about {name.strip()}."
34
+ for name in names
35
+ ]
36
+ return names, prompts
37
+ if dataset.lower() == "factscore_v3":
38
+ with open('data/factscore_v3_names.txt', 'r') as fp:
39
+ names = fp.readlines()
40
+ names = [name.strip() for name in names]
41
+ prompts = [
42
+ f"Please write one biographical paragraph about {name.strip()}."
43
+ for name in names
44
+ ]
45
+ return names, prompts
46
+
47
+ if dataset.lower() == "factscore_final":
48
+ df = pd.read_csv(data_path, index_col=0)
49
+ names = set([n.strip() for n in df['Name']])
50
+ prompts = [
51
+ f"Please write one biographical paragraph about {name.strip()}."
52
+ for name in names
53
+ ]
54
+ return names, prompts
55
+
56
+ if dataset.lower() == "medlfqa":
57
+ datasets = {}
58
+
59
+ suffix = "_test_MedLFQA.jsonl"
60
+
61
+ dataset_dir = "/Users/cherian/Projects/OLAPH/MedLFQA"
62
+ for path in os.listdir(dataset_dir):
63
+ if "MedLFQA" not in path:
64
+ continue
65
+ dataset_name = path[:-len(suffix)]
66
+ with open(os.path.join(dataset_dir, path), 'r') as fp:
67
+ datasets[dataset_name] = [json.loads(line) for line in fp.readlines()]
68
+
69
+ prompts = []
70
+ for _, dataset in datasets.items():
71
+ prompts += [pt['Question'] for pt in dataset]
72
+ prompts = list(set(prompts))
73
+ return prompts, prompts
74
+
75
+ if dataset.lower() == "medlfqav2":
76
+ datasets = {}
77
+
78
+ suffix = ".jsonl"
79
+
80
+ for filename in os.listdir(data_path):
81
+ dataset_name = filename[:-len(suffix)]
82
+ with open(os.path.join(data_path, filename), 'r') as fp:
83
+ datasets[dataset_name] = [json.loads(line) for line in fp.readlines()]
84
+
85
+ prompts = []
86
+ for _, dataset in datasets.items():
87
+ prompts += [pt['Question'] for pt in dataset]
88
+
89
+ return prompts, prompts
90
+
91
+ else:
92
+ raise ValueError("Unsupported data set.")
93
+
94
+ def find_unique_element(lst, condition, approx_index):
95
+ # Check the approximate index first
96
+ if condition(lst[approx_index]):
97
+ return approx_index
98
+
99
+ # Initialize left and right pointers
100
+ left = approx_index - 1
101
+ right = approx_index + 1
102
+
103
+ # Expand outwards from the approximate index
104
+ while left >= 0 or right < len(lst):
105
+ if left >= 0 and condition(lst[left]):
106
+ return left
107
+ if right < len(lst) and condition(lst[right]):
108
+ return right
109
+ left -= 1
110
+ right += 1
111
+
112
+ # If no element satisfies the condition, return None or raise an exception
113
+ return None
114
+
115
+ def load_dataset(
116
+ config : dict
117
+ ) -> List:
118
+
119
+ print("Loading responder.")
120
+ responder = GPTClient(config.model.responder.cache_path)
121
+
122
+ topics, prompts = get_prompts(config.dataset.name, config.dataset.path)
123
+
124
+ with ThreadPoolExecutor(max_workers=25) as executor:
125
+ responses = list(
126
+ tqdm(
127
+ executor.map(
128
+ lambda x : responder.query(x),
129
+ prompts
130
+ ),
131
+ total=len(prompts)
132
+ )
133
+ )
134
+
135
+ # TODO: Uncomment me if I want to run fresh dataset...
136
+
137
+ responder.cache_outputs(
138
+ prompts,
139
+ np.zeros((len(responses),), dtype=int),
140
+ responses
141
+ )
142
+
143
+ responder.save_cache()
144
+
145
+ responses = [r[0] for r in responses]
146
+
147
+
148
+ outputs = [{'prompt': p, 'response': o['message']}
149
+ for p, o in zip(prompts, responses)] # first output is the response we will filter
150
+
151
+ import IPython; IPython.embed()
152
+ print("Loading atomizer.")
153
+ atomizer_client = GPTClient(config.model.parser.cache_path, model=config.model.parser.name)
154
+
155
+ atomizer = Atomizer(atomizer_client, demo_dir='data/demos')
156
+
157
+ CACHE_EXISTS = True
158
+
159
+ if CACHE_EXISTS: # TODO: dumb hard-coded variable to side step the slow retrieval
160
+ ordered_messages = [r['message'] for r in responses]
161
+
162
+ responder_cache = responder.cache_dict
163
+ messages = []
164
+ for val in responder_cache.values():
165
+ messages.append(val[0]['message'])
166
+
167
+ atomizer_cache = atomizer_client.cache_dict
168
+ idx_guess = 0
169
+ atomic_facts = [[] for _ in range(len(messages))]
170
+ atomic_facts_ph = [[] for _ in range(len(messages))]
171
+
172
+ sentences = defaultdict(int)
173
+ for k in tqdm(atomizer_cache.keys()):
174
+ atomized_msg = atomizer_cache[k][0]['message']
175
+ atomized_facts = text_to_sentences(atomized_msg)
176
+ sentence = k.split('\n')[-1].split('facts:')[-1].strip()[:-2]
177
+ cur_idx = -1
178
+ sentences[sentence] += 1
179
+ # if the sentence has appeared more than once we need to find the appropriate match...
180
+ for i in range(sentences[sentence]):
181
+ cur_idx = find_unique_element(messages[cur_idx + 1:], lambda x: sentence in x, approx_index=idx_guess)
182
+ if cur_idx is None: # TODO: TERRIBLE SPECIAL CASING that I looked at by hand...
183
+ raise ValueError()
184
+ if idx_guess in (4148, 4149, 4150):
185
+ cur_idx = 4149
186
+ elif cur_idx == 993:
187
+ cur_idx = 993
188
+ else:
189
+ continue
190
+ idx_guess = cur_idx
191
+ atomic_facts[cur_idx].extend(atomized_facts)
192
+
193
+ for af, msg in zip(atomic_facts, messages):
194
+ if len(af) == 0:
195
+ continue
196
+ new_idx = ordered_messages.index(msg)
197
+ atomic_facts_ph[new_idx] = af
198
+ atomic_facts = atomic_facts_ph
199
+
200
+ else:
201
+ with ThreadPoolExecutor(max_workers=10) as executor:
202
+ atoms = list(
203
+ tqdm(
204
+ executor.map(
205
+ lambda x : atomizer.run(*x),
206
+ [(o['response'],) for o in outputs]
207
+ ),
208
+ total=len(outputs)
209
+ )
210
+ )
211
+
212
+ atomizer.save_cache()
213
+ atomic_facts = [[fact for _, facts in atom[0] for fact in facts] for atom in atoms]
214
+
215
+
216
+ dataset = []
217
+
218
+ for p, r, af in zip(prompts, responses, atomic_facts):
219
+ atoms = [{'atom': fact} for fact in af]
220
+ data_pt = {'prompt': p, 'response': r, 'atomic_facts': atoms}
221
+ dataset.append(data_pt)
222
+
223
+ # time to annotate responses using factscore code
224
+ print("Loading annotator.")
225
+ scorer_client = GPTClient(config.model.annotator.cache_path, model=config.model.annotator.name)
226
+ scorer = Scorer(scorer_client, config, model_name="retrieval")
227
+
228
+ scorer_inputs = [(topic, output['response'], fact) for topic, output, fact in zip(topics, outputs, atomic_facts)]
229
+ with ThreadPoolExecutor(max_workers=4) as executor:
230
+ scores = list(
231
+ tqdm(
232
+ executor.map(
233
+ lambda x : scorer.get_score(*x, knowledge_source='medlfqa'),
234
+ scorer_inputs
235
+ ),
236
+ total=len(scorer_inputs)
237
+ )
238
+ )
239
+ # scorer.save_cache()
240
+
241
+ dataset = []
242
+
243
+ for p, r, s in zip(prompts, responses, scores):
244
+ data_pt = {
245
+ 'prompt': p,
246
+ 'response': r,
247
+ 'atomic_facts': s['decisions'][0]
248
+ }
249
+ dataset.append(data_pt)
250
+
251
+ import IPython; IPython.embed()
252
+
253
+ return dataset
254
+
255
+ def split_dataset(
256
+ dataset : List,
257
+ train_perc : float = 0.33,
258
+ valid_perc : float = 0.33,
259
+ rng : np.random.Generator = None
260
+ ) -> Tuple[List, List, List]:
261
+ """
262
+ Splits dataset into three parts. Split into training and validation is specified here.
263
+ """
264
+ total_length = len(dataset)
265
+
266
+ # Calculate lengths of each part based on percentages
267
+ len1 = int(total_length * train_perc)
268
+ len2 = int(total_length * valid_perc)
269
+
270
+ # if rng passed in, shuffle the dataset
271
+ if rng is not None:
272
+ rng.shuffle(dataset)
273
+
274
+ # Split the list using slicing
275
+ train_data = dataset[:len1]
276
+ valid_data = dataset[len1:len1+len2]
277
+ test_data = dataset[len1+len2:]
278
+
279
+ return train_data, valid_data, test_data
MACI-main/conditional-conformal/src/featurizer.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+
4
+ from functools import partial
5
+ from tqdm.contrib.concurrent import process_map
6
+ from concurrent.futures import ThreadPoolExecutor
7
+
8
+ from tqdm import tqdm
9
+
10
+ import client
11
+
12
+ from typing import Dict, List, Tuple
13
+
14
+
15
+ SELF_ASSESS_PROMPT = 'You will get a list of claims and the original prompt that motivated these claims. For each claim, assess the probability of correctness. Directly return a jsonl, where each line is {"id":[CLAIM_ID], "gpt-score":[SCORE]}. Directly return the jsonl with NO explanation or ANY other formatting. For the [SCORE], return the esimated probability of correctness to three significant figures.\n'
16
+
17
+ SELF_BOOL_PROMPT = 'You will get a list of claims and the original prompt that motivated these claims. For each claim, assess the correctness. Directly return a jsonl, where each line is {"id":[CLAIM_ID], "gpt-bool":[BOOL]}. Directly return the jsonl with NO explanation or ANY other formatting. For the [BOOL], return "T" or "F" in quotes so that it is valid json.\n'
18
+
19
+ MAX_WORKERS = 20
20
+
21
+ def get_features(
22
+ dataset: List[Dict],
23
+ config : Dict
24
+ ) -> np.ndarray:
25
+ from gpt import GPTClient
26
+ feature_names = config.model.prob.features
27
+ all_features = []
28
+ if 'frequency' in feature_names:
29
+ client = GPTClient(f'.cache/{config.dataset.name}_frequency.pkl')
30
+
31
+
32
+ with ThreadPoolExecutor(max_workers=5) as executor:
33
+ frequencies = list(
34
+ tqdm(
35
+ executor.map(
36
+ lambda x: get_frequency(client, [af['atom'] for af in x['atomic_facts']], x['prompt'], config.model.prob.frequency.model),
37
+ dataset
38
+ ),
39
+ total=len(dataset)
40
+ )
41
+ )
42
+ client.save_cache()
43
+ all_features.append(np.concatenate(frequencies).reshape(-1,1))
44
+
45
+ if 'selfeval' in feature_names:
46
+
47
+ eval_client = GPTClient(f'.cache/{config.dataset.name}_self_evals.pkl')
48
+
49
+ with ThreadPoolExecutor(max_workers=25) as executor:
50
+ self_evals = list(
51
+ tqdm(
52
+ executor.map(
53
+ lambda x: get_self_eval(x['prompt'], [af['atom'] for af in x['atomic_facts']], eval_client),
54
+ dataset
55
+ ),
56
+ total=len(dataset)
57
+ )
58
+ )
59
+ eval_client.save_cache()
60
+ all_features.append(np.concatenate(self_evals).reshape(-1,1))
61
+
62
+ features = np.concatenate(
63
+ all_features,
64
+ axis=1
65
+ )
66
+ return features
67
+
68
+ # def get_features(
69
+ # dataset : List[Dict],
70
+ # config : Dict
71
+ # ) -> np.ndarray:
72
+ # feature_names = config.features
73
+ # num_claims = np.sum([len(dat['claims']) for dat in dataset])
74
+ # all_features = []
75
+ # for feat in feature_names:
76
+ # if feat == "embedding":
77
+ # embeds = np.zeros((num_claims, int(config.embedding.n_dimensions)))
78
+ # print("Fetching embeddings.")
79
+ # embedding_func = partial(get_embedding, model=config.embedding.model, n_dim=config.embedding.n_dimensions)
80
+ # res = process_map(embedding_func, [dat['claims'] for dat in dataset], max_workers=MAX_WORKERS)
81
+ # i = 0
82
+ # for dat in tqdm(dataset):
83
+ # len_dat = len(dat['claims'])
84
+ # embeds[i:(i + len_dat)] = get_embedding(dat['claims'], config.embedding.model, config.embedding.n_dimensions)
85
+ # i += len_dat
86
+ # all_features.append(embeds)
87
+
88
+ # elif feat == "selfeval":
89
+ # print("Fetching selfevals.")
90
+ # evals = np.zeros((num_claims, 1))
91
+ # selfeval_func = partial(get_self_eval, model=config.selfeval.model.name)
92
+ # res = process_map(selfeval_func, dataset, max_workers=MAX_WORKERS)
93
+ # i = 0
94
+ # for dat in tqdm(dataset):
95
+ # len_dat = len(dat['claims'])
96
+ # evals[i:(i + len_dat)] = get_self_eval(dat['claims'], dat['prompt'], config.selfeval.model.name)
97
+ # i += len_dat
98
+ # all_features.append(evals)
99
+ # elif feat == "frequency":
100
+ # print("Fetching frequency.")
101
+ # freqs = np.zeros(((num_claims), 1))
102
+ # i = 0
103
+ # for dat in tqdm(dataset):
104
+ # len_dat = len(dat['claims'])
105
+ # freqs[i:(i + len_dat)] = get_frequency(dat['claims'], dat['prompt'], config.frequency.model.n_samples, config.frequency.model.name)
106
+ # i += len_dat
107
+ # all_features.append(freqs)
108
+ # else:
109
+ # raise ValueError(f"{feat} not supported.")
110
+ # return np.concatenate(all_features, axis=1)
111
+
112
+
113
+ def get_embedding(
114
+ subclaims : List[str],
115
+ client : client.Client, # needs to be embedding client not *GPT* client
116
+ n_dim : int = 8
117
+ ) -> np.ndarray:
118
+ raise ValueError("not supported yet")
119
+ embeddings = []
120
+ for claim in subclaims:
121
+ msg = claim['message'].replace('\n', ' ')
122
+ embed = client.query(msg)
123
+ embeddings.append(embed[:n_dim])
124
+ return np.asarray(embeddings)
125
+
126
+
127
+ def _eval_self(
128
+ prompt : str,
129
+ subclaims : List,
130
+ client : client.Client,
131
+ err_msg : str = None
132
+ ) -> Tuple[Tuple[str, List], np.ndarray]:
133
+ claim_string = "\n".join(
134
+ [str(i) + ": " + fact for i, fact in enumerate(subclaims)]
135
+ )
136
+ self_eval_prompt = SELF_ASSESS_PROMPT
137
+ self_eval_prompt += f"The original prompt is: {prompt}.\n"
138
+ self_eval_prompt += f"The claims are: {claim_string}.\n"
139
+
140
+ if err_msg is not None:
141
+ self_eval_prompt += "\n" + err_msg
142
+
143
+ self_evals = client.query(self_eval_prompt)
144
+ parsed_evals = self_evals[0]['message']
145
+ parsed_evals = parsed_evals.replace("```jsonl\n", "")
146
+ parsed_evals = parsed_evals.replace("```", "")
147
+ final_evals = np.zeros((len(parsed_evals.splitlines()),))
148
+ try:
149
+ assert len(final_evals) == len(subclaims)
150
+ except AssertionError:
151
+ if err_msg is not None and 'exactly' in err_msg:
152
+ print(f"I'm giving up on {claim_string} and {parsed_evals}, since I already retried this.")
153
+ return (None, None), None
154
+ err_msg = f"IMPORTANT: This is a retry. Make sure you return exactly {len(subclaims)} lines of JSON."
155
+ print(err_msg)
156
+ return _eval_self(prompt, subclaims, client, err_msg=err_msg)
157
+ try:
158
+ for line in parsed_evals.splitlines():
159
+ eval = json.loads(line)
160
+ idx = int(eval["id"])
161
+ final_evals[idx] += float(eval["gpt-score"])
162
+ except Exception as ex:
163
+ if err_msg is not None and 'requested' in err_msg:
164
+ print(f"I'm giving up on {claim_string} and {parsed_evals}, since I already retried this.")
165
+ return (None, None), None
166
+ err_msg = f"IMPORTANT: This is a retry. Make sure you return the lines in the requested JSON format with NO additional formatting."
167
+ print(err_msg)
168
+ return _eval_self(prompt, subclaims, client, err_msg=err_msg)
169
+ return (self_eval_prompt, self_evals), final_evals
170
+
171
+
172
+ def get_self_eval(
173
+ prompt : str,
174
+ subclaims : List[str],
175
+ client : client.Client
176
+ ) -> np.ndarray:
177
+ all_evals = _eval_self(
178
+ prompt,
179
+ subclaims,
180
+ client
181
+ )
182
+
183
+ to_cache = all_evals[0]
184
+
185
+ if to_cache[0] is None:
186
+ return -1 * np.ones((len(subclaims),)) # -1 prob is error
187
+
188
+ client.cache_outputs(
189
+ [to_cache[0]],
190
+ np.zeros((1,), dtype=int),
191
+ [to_cache[1]]
192
+ )
193
+
194
+ return all_evals[1]
195
+
196
+ def _bool_self(
197
+ prompt : str,
198
+ subclaims : List,
199
+ client : client.Client,
200
+ err_msg : str = None
201
+ ) -> Tuple[Tuple[str, List], np.ndarray]:
202
+ claim_string = "\n".join(
203
+ [str(i) + ": " + fact for i, fact in enumerate(subclaims)]
204
+ )
205
+ self_eval_prompt = SELF_BOOL_PROMPT
206
+ self_eval_prompt += f"The original prompt is: {prompt}.\n"
207
+ self_eval_prompt += f"The claims are: {claim_string}.\n"
208
+
209
+ if err_msg is not None:
210
+ self_eval_prompt += "\n" + err_msg
211
+
212
+ self_evals = client.query(self_eval_prompt)
213
+ parsed_evals = self_evals[0]['message']
214
+ parsed_evals = parsed_evals.replace("```jsonl\n", "")
215
+ parsed_evals = parsed_evals.replace("```", "")
216
+ final_evals = ['T' for i in range(len(parsed_evals.splitlines()))]
217
+ try:
218
+ assert len(final_evals) == len(subclaims)
219
+ except AssertionError:
220
+ if err_msg is not None and 'exactly' in err_msg:
221
+ print(f"I'm giving up on {claim_string} and {parsed_evals}, since I already retried this.")
222
+ return (None, None), None
223
+ err_msg = f"IMPORTANT: This is a retry. Make sure you return exactly {len(subclaims)} lines of JSON."
224
+ print(err_msg)
225
+ return _bool_self(prompt, subclaims, client, err_msg=err_msg)
226
+ try:
227
+ for line in parsed_evals.splitlines():
228
+ eval = json.loads(line)
229
+ idx = int(eval["id"])
230
+ final_evals[idx] = eval["gpt-bool"]
231
+ except Exception as ex:
232
+ if err_msg is not None and 'requested' in err_msg:
233
+ print(f"I'm giving up on {claim_string} and {parsed_evals}, since I already retried this.")
234
+ return (None, None), None
235
+ err_msg = f"IMPORTANT: This is a retry. Make sure you return the lines in the requested JSON format with NO additional formatting."
236
+ print(err_msg)
237
+ return _bool_self(prompt, subclaims, client, err_msg=err_msg)
238
+ return (self_eval_prompt, self_evals), final_evals
239
+
240
+
241
+ def get_bool_eval(
242
+ prompt : str,
243
+ subclaims : List[str],
244
+ client : client.Client
245
+ ) -> np.ndarray:
246
+ all_evals = _bool_self(
247
+ prompt,
248
+ subclaims,
249
+ client
250
+ )
251
+
252
+ to_cache = all_evals[0]
253
+
254
+ if to_cache[0] is None:
255
+ return -1 * np.ones((len(subclaims),)) # -1 prob is error
256
+ client.cache_outputs(
257
+ [to_cache[0]],
258
+ np.zeros((1,), dtype=int),
259
+ [to_cache[1]]
260
+ )
261
+
262
+ return all_evals[1]
263
+
264
+
265
+ def _eval_support(
266
+ output : str,
267
+ subclaims : List,
268
+ client : client.Client,
269
+ err_msg : str = None
270
+ ) -> Tuple[Tuple[str, List], np.ndarray]:
271
+ claim_string = "\n".join(
272
+ [str(i) + ": " + fact for i, fact in enumerate(subclaims)]
273
+ )
274
+ counting_prompt = (
275
+ 'You will get a list of claims and piece of text. For each claim, score whether the text supports, contradicts, or is unrelated to the claim. Directly return a jsonl, where each line is {"id":[CLAIM_ID], "score":[SCORE]}. Directly return the jsonl with NO explanation or ANY other formatting. For the [SCORE], return 1 for supports, -1 for contradicts, and 0 for unrelated. The claims are:\n'
276
+ + claim_string
277
+ + "\n\nThe text is:\n"
278
+ + output
279
+ )
280
+ if err_msg is not None:
281
+ counting_prompt += "\n" + err_msg
282
+
283
+ support_scores = client.query(counting_prompt)
284
+ parsed_scores = support_scores[0]['message']
285
+ parsed_scores = parsed_scores.replace("```jsonl\n", "")
286
+ parsed_scores = parsed_scores.replace("```", "")
287
+ final_scores = np.zeros((len(parsed_scores.splitlines()),))
288
+ try:
289
+ assert len(final_scores) == len(subclaims)
290
+ except AssertionError:
291
+ if err_msg is not None and 'exactly' in err_msg:
292
+ print(f"I'm giving up on {claim_string} and {parsed_scores}, since I already retried this.")
293
+ return (None, None), None
294
+ err_msg = f"IMPORTANT: This is a retry. Make sure you return exactly {len(subclaims)} lines of JSON."
295
+ print(err_msg)
296
+ return _eval_support(output, subclaims, client, err_msg=err_msg)
297
+ try:
298
+ for line in parsed_scores.splitlines():
299
+ score = json.loads(line)
300
+ idx = int(score["id"])
301
+ final_scores[idx] += float(score["score"])
302
+ except Exception as ex:
303
+ if err_msg is not None and 'requested' in err_msg:
304
+ print(f"I'm giving up on {claim_string} and {parsed_scores}, since I already retried this.")
305
+ return (None, None), None
306
+ err_msg = f"IMPORTANT: This is a retry. Make sure you return the lines in the requested JSON format with NO additional formatting."
307
+ print(err_msg)
308
+ return _eval_support(output, subclaims, client, err_msg=err_msg)
309
+ return (counting_prompt, support_scores), final_scores
310
+
311
+
312
+ def get_frequency(
313
+ client : client.Client,
314
+ subclaims : List,
315
+ prompt : str,
316
+ config : dict
317
+ ) -> np.ndarray:
318
+ """
319
+ Returns a vector of (frequency) scores corresponding to each entry of the subclaims list.
320
+ """
321
+ # Generate n_samples alternate outputs with temperature 1.0.
322
+ alternate_outputs = client.query(
323
+ prompt, 1, n_samples=config.n_samples, temperature=config.temperature
324
+ )
325
+ client.cache_outputs(
326
+ [prompt],
327
+ [int(1)],
328
+ [alternate_outputs]
329
+ )
330
+
331
+ alternate_outputs = [o['message'] for o in alternate_outputs]
332
+
333
+ with ThreadPoolExecutor(max_workers=config.n_samples) as executor:
334
+ all_scores = list(
335
+ executor.map(
336
+ lambda x : _eval_support(x, subclaims, client),
337
+ alternate_outputs
338
+ )
339
+ )
340
+
341
+ # to_cache = [s[0] for s in all_scores if s[0][0] is not None]
342
+
343
+ # client.cache_outputs(
344
+ # [c[0] for c in to_cache],
345
+ # np.zeros((len(to_cache),), dtype=int),
346
+ # [c[1] for c in to_cache]
347
+ # )
348
+
349
+ # TODO: error handling if this is all empty?
350
+ parsed_scores = np.mean([s[1] for s in all_scores if s[1] is not None], axis=0)
351
+
352
+ return parsed_scores
MACI-main/conditional-conformal/src/gpt.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+
3
+ from typing import List
4
+
5
+ from client import Client
6
+
7
+ from tenacity import (
8
+ retry,
9
+ stop_after_attempt,
10
+ wait_random_exponential,
11
+ ) # for exponential backoff
12
+
13
+ class GPTClient(Client):
14
+ def __init__(
15
+ self,
16
+ cache_file : str,
17
+ model : str = 'gpt-3.5-turbo'
18
+ ):
19
+ super(GPTClient, self).__init__(cache_file, model)
20
+ self.client = openai.Client()
21
+ self.tokens_used = 0
22
+ self.requests_made = 0
23
+
24
+ @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
25
+ def _query(
26
+ self,
27
+ prompt : List[str],
28
+ role : List[str] = None,
29
+ max_tokens : int = 1000,
30
+ temperature: float = 0,
31
+ response_format : str = None,
32
+ n_samples: int = 1
33
+ ):
34
+ if role is None:
35
+ messages = [{"role": "user", "content": prompt}]
36
+ else:
37
+ messages = [{"role": role, "content": prompt}]
38
+
39
+ completion = self.client.chat.completions.create(
40
+ model=self.model,
41
+ messages=messages,
42
+ response_format=response_format,
43
+ max_tokens=max_tokens,
44
+ temperature=temperature,
45
+ n=n_samples,
46
+ logprobs=True
47
+ )
48
+ self.tokens_used += completion.usage.total_tokens
49
+ self.requests_made += 1
50
+ # print(self.tokens_used, self.requests_made)
51
+ outputs = []
52
+ for choice in completion.choices:
53
+ output_dict = {
54
+ 'logprobs': choice.logprobs.content,
55
+ 'message': choice.message.content
56
+ }
57
+ outputs.append(output_dict)
58
+ return outputs
MACI-main/conditional-conformal/src/llm_utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from tqdm import tqdm
6
+ from typing import Dict, List
7
+
8
+ from query import (
9
+ generate_subclaim_prompt, generate_annotation_prompt,
10
+ generate_merge_prompt, query_llm
11
+ )
12
+
13
+ import client
14
+
15
+ MERGE_PROMPT = "You will get an instruction and a set of facts that are true. Construct an answer using ONLY the facts provided, and use ALL of the facts provided. If no facts are given, reply and say that you don't know enough to respond.\n"
16
+
17
+
18
+ def parse_responses(
19
+ outputs : List[Dict],
20
+ parser_config : str,
21
+ annotate : bool = False,
22
+ annotator_config : str = None
23
+ ):
24
+ for output in tqdm(outputs):
25
+ prompt, response = output["prompt"], output["response"]
26
+ subclaims = get_subclaims(prompt, response, parser_config)
27
+ if annotate:
28
+ subclaims = add_annotations(prompt, subclaims, annotator_config)
29
+ output["claims"] = subclaims
30
+ return outputs
31
+
32
+ def get_subclaims(
33
+ prompt : str,
34
+ response : str,
35
+ parser_config : str
36
+ ) -> List[Dict]:
37
+ subclaim_prompt = generate_subclaim_prompt(prompt, response)
38
+ subclaims = query_llm([subclaim_prompt], parser_config)[0] # get the first output
39
+ subclaims = [{'message': c} for c in subclaims['message'].splitlines()]
40
+ return subclaims
41
+
42
+ def add_annotations(
43
+ prompt : str,
44
+ subclaims : List[Dict],
45
+ annotator_config : str
46
+ ) -> List[Dict]:
47
+ annotation_prompt = generate_annotation_prompt(prompt, subclaims)
48
+ annotations = query_llm([annotation_prompt], annotator_config)[0]
49
+ annotations = annotations['message'].splitlines()
50
+ num_retries = 0
51
+ while len(annotations) != len(subclaims):
52
+ print(f"Annotation length does not match subclaims for {prompt}. Retrying query.")
53
+ annotations = query_llm([annotation_prompt], annotator_config)[0]
54
+ annotations = annotations['message'].splitlines()
55
+ num_retries += 1
56
+ if num_retries > 5:
57
+ print("Giving up and assigning False to all subclaims.")
58
+ annotations = ['F' for _ in subclaims]
59
+ for a, subclaim in zip(annotations, subclaims):
60
+ try:
61
+ subclaim['annotation'] = json.loads(a)['value']
62
+ except:
63
+ import IPython; IPython.embed()
64
+ return subclaims
65
+
66
+
67
+ def _concat_claims(
68
+ subclaims : List[str]
69
+ ) -> str:
70
+ return "\n".join(
71
+ f"{i}: {subclaim}" for i, subclaim in enumerate(subclaims)
72
+ )
73
+
74
+ def _get_merged_output(
75
+ prompt : str,
76
+ subclaims : List[str],
77
+ client : client.Client
78
+ ) -> str:
79
+ final_prompt = MERGE_PROMPT + f"The original instruction was: {prompt}\n"
80
+
81
+ final_prompt += f"The facts are: {_concat_claims(subclaims)}"
82
+
83
+ output = client.query(final_prompt)
84
+
85
+ return (final_prompt, output), output[0]['message']
86
+
87
+
88
+ def merge_claims(
89
+ dataset : List,
90
+ client : client.Client
91
+ ) -> List:
92
+ with ThreadPoolExecutor(max_workers=25) as executor:
93
+ responses = list(
94
+ tqdm(
95
+ executor.map(
96
+ lambda x : _get_merged_output(x['prompt'], x['filtered_claims'], client),
97
+ dataset
98
+ ),
99
+ total=len(dataset)
100
+ )
101
+ )
102
+
103
+ to_cache = [r[0] for r in responses]
104
+
105
+ client.cache_outputs(
106
+ [c[0] for c in to_cache],
107
+ np.zeros((len(to_cache),), dtype=int),
108
+ [c[1] for c in to_cache]
109
+ )
110
+
111
+ return [r[1] for r in responses]
MACI-main/conditional-conformal/src/postprocess_factscore.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ output = []
4
+ prompt_to_idx = {}
5
+ idx = 0
6
+ with open("/Users/cherian/Downloads/factscore-unlabeled-predictions/ChatGPT.jsonl") as fp:
7
+ for line in fp:
8
+ res = json.loads(line)
9
+ new_res = {}
10
+ new_res['prompt'] = res['prompt']
11
+ new_res['claims'] = []
12
+ annotator = 'ChatGPT_Labels' if 'ChatGPT_Labels' in res else 'LLAMA+NP_Labels'
13
+ for fact, annotation in zip(res['facts'], res[annotator]):
14
+ a = 'T' if annotation == 'S' else 'F'
15
+ new_res['claims'].append(
16
+ {'message': fact, 'annotation': a}
17
+ )
18
+ output.append(new_res)
19
+ prompt_to_idx[res['prompt']] = idx
20
+ idx += 1
21
+
22
+ with open("/Users/cherian/Projects/FActScore/factscore/data/unlabeled/ChatGPT.jsonl", 'r') as fp:
23
+ for line in fp:
24
+ res = json.loads(line)
25
+ idx = prompt_to_idx.get(res['input'], None)
26
+ if idx is None:
27
+ continue
28
+ else:
29
+ output[idx]['response'] = res['output']
30
+ output[idx]['topic'] = res['topic']
31
+ output[idx]['metadata'] = res['cat']
32
+
33
+ with open("data/factscore_processed.json", 'w') as fp:
34
+ fp.write(json.dumps(output) + "\n")
MACI-main/conditional-conformal/src/prob_model.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from sklearn.linear_model import LogisticRegressionCV
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.optim as optim
9
+
10
+ from typing import List
11
+
12
+ from conformal import compute_conformity_scores
13
+
14
+ def fit_model(
15
+ features : np.ndarray,
16
+ labels : np.ndarray,
17
+ config : dict,
18
+ dataset_train : List = None,
19
+ eval_dict : dict = None
20
+ ):
21
+ name = config.model.prob.name
22
+ if name == "logistic":
23
+ model = LogisticRegressionCV()
24
+ model.fit(X=features, y=labels)
25
+ return model
26
+ elif name == "XGBoost":
27
+ raise ValueError("not implemented yet")
28
+ elif name == "torch":
29
+ # no data splitting for now when constructing conformal loss
30
+ model = LogisticRegression(features.shape[1])
31
+
32
+ optimizer = optim.Adam(model.parameters(), lr=1)
33
+ x = torch.tensor(features, requires_grad=True, dtype=torch.float32)
34
+
35
+ for i in range(500):
36
+ optimizer.zero_grad()
37
+ probs = model.forward(x)
38
+
39
+ loss, avg_train = get_conformal_loss(probs, labels, dataset_train, config.conformal.alpha)
40
+ if i % 100 == 0:
41
+ probs_valid = model.forward(torch.tensor(eval_dict['X_valid'], dtype=torch.float32)).detach().numpy()
42
+ probs_split = np.array_split(probs_valid, eval_dict['splits_valid'])
43
+ threshold = np.quantile(compute_conformity_scores(eval_dict['dataset_valid'], probs_split), 1 - config.conformal.alpha)
44
+ probs_test = model.forward(torch.tensor(eval_dict['X_test'], dtype=torch.float32)).detach().numpy()
45
+ probs_split = np.array_split(probs_test, eval_dict['splits_test'])
46
+ avg = 0
47
+ for prob in probs_split:
48
+ avg_retain = np.mean(prob > threshold.item())
49
+ avg += avg_retain
50
+ print(f"Average % of train claims retained: {avg_train}")
51
+ print(f"Average % of test claims retained: {avg / len(probs_split)}")
52
+ print(f"Loss at iteration {i}: {loss.item()}")
53
+
54
+ loss.backward()
55
+ optimizer.step()
56
+ return model
57
+
58
+ else:
59
+ return ValueError(f"{name} not available.")
60
+
61
+
62
+ def get_conformal_loss(probs, labels, dataset_train, alpha):
63
+ claim_splits = torch.tensor(
64
+ np.cumsum([len(dat['atomic_facts']) for dat in dataset_train])[:-1]
65
+ )
66
+
67
+ claim_probs = torch.tensor_split(probs, claim_splits)
68
+ claim_labels = np.array_split(1 - labels, claim_splits.numpy())
69
+ scores = []
70
+ for c_prob, c_label in zip(claim_probs, claim_labels):
71
+ scores.append(c_prob[c_label].max()) # could replace this with element-wise multiplication and make annotations softer?
72
+
73
+ # use random set of scores to calibrate
74
+ random_indices = np.random.permutation(len(scores))
75
+ threshold_indices = random_indices[:25]
76
+ loss_indices = random_indices[25:]
77
+
78
+ threshold_scores = [scores[i] for i in range(len(scores)) if i in threshold_indices]
79
+
80
+ threshold = torch.quantile(torch.stack(threshold_scores), 1 - alpha)
81
+ loss = 0
82
+ avg = 0
83
+ for idx, c_prob in enumerate(claim_probs):
84
+ if idx in loss_indices:
85
+ loss += torch.sigmoid((threshold - c_prob)).mean()
86
+ avg_retain = (c_prob > threshold).float().mean()
87
+ avg += avg_retain
88
+ if np.isnan(loss.item()):
89
+ raise ValueError(claim_probs[0])
90
+ return loss, avg / len(loss_indices)
91
+
92
+ class LogisticRegression(nn.Module):
93
+
94
+ def __init__(self, n_features):
95
+ super(LogisticRegression, self).__init__()
96
+ self.linear = nn.Linear(n_features, 1)
97
+
98
+ def forward(self, x):
99
+ return F.sigmoid(self.linear(x))
100
+
101
+
MACI-main/conditional-conformal/src/query.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import openai
3
+
4
+ SUBCLAIM_PROMPT = 'Please breakdown the following response to a prompt into a set of small, independent claims. Return each subclaim (with no other characters) on a new line. \n'
5
+
6
+ MERGE_PROMPT = "You will get an instruction and a set of facts that are true. Construct an answer using ONLY the facts provided, and use ALL of the facts provided. If no facts are given, reply and say that you don't know enough to respond.\n"
7
+
8
+ ANNOTATION_PROMPT = 'You will get an instruction and a set of claims made in response to that instruction. Determine whether each claim is true, subjective, or false. Each returned determination should be {"claim_id": ID, "value": TRUTH_VALUE} and be on its own line with NO other characters. The truth value should be in quotes and it should be T for Factual, S for Subjective, and F for False.\n'
9
+
10
+ FREQUENCY_PROMPT = 'You will get a list of claims and piece of text. For each claim, score whether the text supports, contradicts, or is unrelated to the claim. Directly return a jsonl, where each line is {"id":[CLAIM_ID], "score":[SCORE]}. Directly return the jsonl with no explanation or other formatting. For the [SCORE], return 1 for supports, -1 for contradicts, and 0 for unrelated.\n'
11
+
12
+ def _concat_claims(
13
+ subclaims : List[str]
14
+ ) -> str:
15
+ return "\n".join(
16
+ f"{i}: {subclaim['message']}" for i, subclaim in enumerate(subclaims)
17
+ )
18
+
19
+ def generate_subclaim_prompt(
20
+ prompt : str,
21
+ response : str
22
+ ) -> str:
23
+ final_output = SUBCLAIM_PROMPT + f"The original instruction was: {prompt}\n"
24
+ final_output += f"The response to be broken down into subclaims is: {response}"
25
+
26
+ return final_output
27
+
28
+ def generate_merge_prompt(
29
+ prompt : str,
30
+ subclaims : List[str]
31
+ ) -> str:
32
+ final_output = MERGE_PROMPT + f"The original instruction was: {prompt}\n"
33
+
34
+ final_output += f"The facts are: {_concat_claims(subclaims)}"
35
+
36
+ return final_output
37
+
38
+ def generate_annotation_prompt(
39
+ prompt : str,
40
+ subclaims : List[str]
41
+ ) -> str:
42
+ final_output = ANNOTATION_PROMPT + f"The original instruction was: {prompt}\n"
43
+ final_output += f"The claims are: \n{_concat_claims(subclaims)}"
44
+
45
+ return final_output
46
+
47
+ def generate_frequency_prompt(
48
+ subclaims : List[str],
49
+ output : str,
50
+ ) -> str:
51
+ final_output = FREQUENCY_PROMPT + f"The claims are: {_concat_claims(subclaims)}\n"
52
+ final_output += f"The text is: {output}"
53
+ return final_output
54
+
55
+ def query_gpt(
56
+ client : openai.Client,
57
+ prompts : List[str],
58
+ model : str = "gpt-3.5-turbo",
59
+ roles : List[str] = None,
60
+ max_tokens : int = 1000,
61
+ temperature: float = 0,
62
+ response_format : str = None,
63
+ n_samples: int = 1
64
+ ):
65
+ if roles is None:
66
+ messages = [{"role": "user", "content": prompt} for prompt in prompts]
67
+ else:
68
+ messages = [{"role": role, "content": prompt} for role, prompt in zip(roles, prompts)]
69
+
70
+ completion = client.chat.completions.create(
71
+ model=model,
72
+ messages=messages,
73
+ response_format=response_format,
74
+ max_tokens=max_tokens,
75
+ temperature=temperature,
76
+ n=n_samples,
77
+ logprobs=True
78
+ )
79
+ return completion
80
+
81
+ def query_embedding(
82
+ client : openai.Client,
83
+ prompts : List[str],
84
+ model : str = "text-embedding-3-small",
85
+ **kwargs
86
+ ):
87
+ embed = client.embeddings.create(input = prompts, model = model, **kwargs).data[0].embedding
88
+ return embed
89
+
90
+ def query_llm(
91
+ prompts : List[str],
92
+ model : str,
93
+ **kwargs
94
+ ) -> Dict:
95
+ if 'gpt' in model:
96
+ client = openai.Client() # OPENAI_API_KEY should be set as an environment variable
97
+ completion = query_gpt(client, prompts, model, **kwargs)
98
+ outputs = []
99
+ for choice in completion.choices:
100
+ output_dict = {
101
+ 'logprobs': choice.logprobs.content,
102
+ 'message': choice.message.content
103
+ }
104
+ outputs.append(output_dict)
105
+ return outputs
106
+ elif 'embedding' in model:
107
+ client = openai.Client()
108
+ output = query_embedding(client, prompts, model, **kwargs)
109
+ return output
110
+
111
+ else:
112
+ raise ValueError(f"Model {model} is not supported in query.")
MACI-main/conditional-conformal/src/ray_data.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from tqdm import tqdm
6
+
7
+ from config import get_config
8
+
9
+ from featurizer import get_frequency, get_self_eval
10
+ from gpt import GPTClient
11
+
12
+ from atomizer import text_to_sentences
13
+ from dataset import get_prompts
14
+ from scorer import Scorer
15
+
16
+ import ray
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser(
20
+ prog="conformal-safety",
21
+ description="Auto-filter claims from LLM to meet accuracy and safety guarantees.",
22
+ )
23
+ parser.add_argument('-config_path', '-c', default='configs/default.toml', help="Config for construction.")
24
+ args = parser.parse_args()
25
+ return args
26
+
27
+ def find_unique_element(lst, condition, approx_index):
28
+ # Check the approximate index first
29
+ if condition(lst[approx_index]):
30
+ return approx_index
31
+
32
+ # Initialize left and right pointers
33
+ left = approx_index - 1
34
+ right = approx_index + 1
35
+
36
+ # Expand outwards from the approximate index
37
+ while left >= 0 or right < len(lst):
38
+ if left >= 0 and condition(lst[left]):
39
+ return left
40
+ if right < len(lst) and condition(lst[right]):
41
+ return right
42
+ left -= 1
43
+ right += 1
44
+
45
+ # If no element satisfies the condition, return None or raise an exception
46
+ return None
47
+
48
+
49
+ @ray.remote
50
+ def parallel_scorer(*args, **kwargs):
51
+ return None
52
+ return run_experiment(*args, **kwargs)
53
+
54
+
55
+ if __name__ == "__main__":
56
+ args = parse_args()
57
+ config = get_config(args.config_path)
58
+
59
+ import IPython; IPython.embed()
60
+ responder = GPTClient(config.model.responder.cache_path)
61
+
62
+ topics, prompts = get_prompts(config.dataset.name)
63
+
64
+ with ThreadPoolExecutor(max_workers=25) as executor:
65
+ responses = list(
66
+ tqdm(
67
+ executor.map(
68
+ lambda x : responder.query(x),
69
+ prompts
70
+ ),
71
+ total=len(prompts)
72
+ )
73
+ )
74
+
75
+ responses = [r[0] for r in responses]
76
+
77
+
78
+ outputs = [{'prompt': p, 'response': o['message']}
79
+ for p, o in zip(prompts, responses)] # first output is the response we will filter
80
+
81
+ print("Loading atomizer.")
82
+ atomizer_client = GPTClient(config.model.parser.cache_path, model=config.model.parser.name)
83
+
84
+ responder_cache = responder.cache_dict
85
+ messages = []
86
+ for val in responder_cache.values():
87
+ messages.append(val[0]['message'])
88
+
89
+ atomizer_cache = atomizer_client.cache_dict
90
+ idx_guess = 0
91
+ atomic_facts = [[] for _ in range(len(messages))]
92
+ for k in tqdm(atomizer_cache.keys()):
93
+ atomized_msg = atomizer_cache[k][0]['message']
94
+ atomized_facts = text_to_sentences(atomized_msg)
95
+ sentence = k.split('\n')[-1].split('facts:')[-1].strip()[:-2]
96
+ cur_idx = find_unique_element(messages, lambda x: sentence in x, approx_index=idx_guess)
97
+ if cur_idx is None: # TODO: TERRIBLE SPECIAL CASING that I looked at by hand...
98
+ if idx_guess == 4151:
99
+ cur_idx = 4152
100
+ else:
101
+ cur_idx = idx_guess
102
+ idx_guess = cur_idx
103
+ atomic_facts[cur_idx].extend(atomized_facts)
104
+
105
+ # time to annotate responses using factscore code
106
+ print("Loading annotator.")
107
+ scorer_client = GPTClient(config.model.annotator.cache_path, model=config.model.annotator.name)
108
+ scorer = Scorer(scorer_client, config, model_name="retrieval")
109
+
110
+ scorer_inputs = [(topic, output['response'], fact) for topic, output, fact in zip(topics, outputs, atomic_facts)]
111
+
112
+ import IPython; IPython.embed()
113
+
114
+
115
+ # connect to cluster
116
+ ray.init(address="auto")
117
+
118
+ results = []
119
+
120
+ for seed in range(args.seed, args.seed + args.n_trials):
121
+ if args.type == 'coverage':
122
+ result = parallel_coverage_experiment.remote(
123
+ (X, Y), n_test, n_calib, alpha, methods=args.methods, seed=seed
124
+ )
125
+ else:
126
+ result = parallel_experiment.remote(
127
+ (X, Y), n_test, n_calib, alpha, methods=args.methods, seed=seed
128
+ )
129
+ results.append(result)
130
+
131
+ trial_results = ray.get(results)
132
+
133
+ with ThreadPoolExecutor(max_workers=1) as executor:
134
+ scores = list(
135
+ tqdm(
136
+ executor.map(
137
+ lambda x : scorer.get_score(*x),
138
+ scorer_inputs
139
+ ),
140
+ total=len(scorer_inputs)
141
+ )
142
+ )
143
+ scorer.save_cache()
144
+
145
+ dataset = []
146
+
147
+ for p, r, s in zip(prompts, responses, scores):
148
+ data_pt = {
149
+ 'prompt': p,
150
+ 'response': r,
151
+ 'atomic_facts': s['decisions'][0]
152
+ }
153
+ dataset.append(data_pt)
154
+
155
+ import IPython
156
+ IPython.embed()
157
+
158
+ # client = GPTClient(f'.cache/{config.dataset.name}_frequency.pkl')
159
+
160
+ # with ThreadPoolExecutor(max_workers=5) as executor:
161
+ # frequencies = list(
162
+ # tqdm(
163
+ # executor.map(
164
+ # lambda x: get_frequency(client, [af['atom'] for af in x['atomic_facts']], x['prompt'], config.model.prob.frequency.model),
165
+ # dataset
166
+ # ),
167
+ # total=len(dataset)
168
+ # )
169
+ # )
170
+ # client.save_cache()
171
+
172
+ # eval_client = GPTClient(f'.cache/{config.dataset.name}_self_evals.pkl')
173
+
174
+ # with ThreadPoolExecutor(max_workers=25) as executor:
175
+ # self_evals = list(
176
+ # tqdm(
177
+ # executor.map(
178
+ # lambda x: get_self_eval(x['prompt'], [af['atom'] for af in x['atomic_facts']], eval_client),
179
+ # dataset
180
+ # ),
181
+ # total=len(dataset)
182
+ # )
183
+ # )
184
+ # eval_client.save_cache()
185
+
186
+ # features = np.concatenate(
187
+ # [
188
+ # np.concatenate(frequencies).reshape(-1,1),
189
+ # np.concatenate(self_evals).reshape(-1,1)
190
+ # ],
191
+ # axis=1
192
+ # )
MACI-main/conditional-conformal/src/retrieval.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import os
4
+
5
+ import sqlite3
6
+ import numpy as np
7
+ import pickle as pkl
8
+
9
+ from rank_bm25 import BM25Okapi
10
+
11
+ SPECIAL_SEPARATOR = "####SPECIAL####SEPARATOR####"
12
+ MAX_LENGTH = 256
13
+
14
+ class DocDB(object):
15
+ """Sqlite backed document storage.
16
+
17
+ Implements get_doc_text(doc_id).
18
+ """
19
+
20
+ def __init__(self, db_path=None, data_path=None, cache_path=None):
21
+ self.db_path = db_path
22
+ self.cache_file = cache_path
23
+ self.connection = sqlite3.connect(self.db_path, check_same_thread=False)
24
+
25
+ self.cache_dict = self.load_cache()
26
+
27
+ cursor = self.connection.cursor()
28
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
29
+
30
+ if len(cursor.fetchall())==0:
31
+ assert data_path is not None, f"{self.db_path} is empty. Specify `data_path` in order to create a DB."
32
+ print (f"{self.db_path} is empty. start building DB from {data_path}...")
33
+ self.build_db(self.db_path, data_path)
34
+
35
+ def load_cache(self, allow_retry=True):
36
+ if os.path.exists(self.cache_file):
37
+ while True:
38
+ try:
39
+ with open(self.cache_file, "rb") as f:
40
+ cache = pkl.load(f)
41
+ break
42
+ except Exception: # if there are concurent processes, things can fail
43
+ if not allow_retry:
44
+ assert False
45
+ print ("Pickle Error: Retry in 5sec...")
46
+ time.sleep(5)
47
+ elif 's3' in self.cache_file:
48
+ from aws_utils import s3_open
49
+ s3_path = self.cache_file.removeprefix('s3://')
50
+ bucket_name = s3_path.split('/')[0]
51
+ path_to_file = '/'.join(s3_path.split('/')[1:])
52
+ with s3_open(bucket_name, path_to_file) as fp:
53
+ cache = pkl.load(fp)
54
+ else:
55
+ cache = {}
56
+ return cache
57
+
58
+ def save_cache(self):
59
+ # load the latest cache first, since if there were other processes running in parallel, cache might have been updated
60
+ for k, v in self.load_cache().items():
61
+ self.cache_dict[k] = v
62
+
63
+ with open(self.cache_file, "wb") as f:
64
+ pkl.dump(self.cache_dict, f)
65
+
66
+ def __enter__(self):
67
+ return self
68
+
69
+ def __exit__(self, *args):
70
+ self.close()
71
+
72
+ def path(self):
73
+ """Return the path to the file that backs this database."""
74
+ return self.path
75
+
76
+ def close(self):
77
+ """Close the connection to the database."""
78
+ self.connection.close()
79
+
80
+ def build_db(self, db_path, data_path):
81
+ from transformers import RobertaTokenizer
82
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
83
+
84
+ titles = set()
85
+ output_lines = []
86
+ tot = 0
87
+ start_time = time.time()
88
+ c = self.connection.cursor()
89
+ c.execute("CREATE TABLE documents (title PRIMARY KEY, text);")
90
+
91
+ with open(data_path, "r") as f:
92
+ for line in f:
93
+ dp = json.loads(line)
94
+ title = dp["title"]
95
+ text = dp["text"]
96
+ if title in titles:
97
+ continue
98
+ titles.add(title)
99
+ if type(text)==str:
100
+ text = [text]
101
+ passages = [[]]
102
+ for sent_idx, sent in enumerate(text):
103
+ assert len(sent.strip())>0
104
+ tokens = tokenizer(sent)["input_ids"]
105
+ max_length = MAX_LENGTH - len(passages[-1])
106
+ if len(tokens) <= max_length:
107
+ passages[-1].extend(tokens)
108
+ else:
109
+ passages[-1].extend(tokens[:max_length])
110
+ offset = max_length
111
+ while offset < len(tokens):
112
+ passages.append(tokens[offset:offset+MAX_LENGTH])
113
+ offset += MAX_LENGTH
114
+
115
+ psgs = [tokenizer.decode(tokens) for tokens in passages if np.sum([t not in [0, 2] for t in tokens])>0]
116
+ text = SPECIAL_SEPARATOR.join(psgs)
117
+ output_lines.append((title, text))
118
+ tot += 1
119
+
120
+ if len(output_lines) == 1000000:
121
+ c.executemany("INSERT INTO documents VALUES (?,?)", output_lines)
122
+ output_lines = []
123
+ print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60))
124
+
125
+ if len(output_lines) > 0:
126
+ c.executemany("INSERT INTO documents VALUES (?,?)", output_lines)
127
+ print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60))
128
+
129
+ self.connection.commit()
130
+ self.connection.close()
131
+
132
+ def get_text_from_title(self, title):
133
+ """Fetch the raw text of the doc for 'doc_id'."""
134
+ with open('data/wiki_corrections.txt') as fp:
135
+ all_names = fp.readlines()
136
+ all_names = [n.strip() for n in all_names]
137
+ name_converter = {names.split('=')[0]:names.split('=')[1] for names in all_names}
138
+ if title in name_converter:
139
+ title = name_converter[title]
140
+
141
+ if title in self.cache_dict:
142
+ results = self.cache_dict[title]
143
+ else:
144
+ print("I SHOULD NOT BE HERE.")
145
+ cursor = self.connection.cursor()
146
+ cursor.execute("SELECT text FROM documents WHERE title = ?", (title,))
147
+ results = cursor.fetchall()
148
+ results = [r for r in results]
149
+ cursor.close()
150
+ try:
151
+ assert results is not None and len(results)==1, f"`topic` in your data ({title}) is likely to be not a valid title in the DB."
152
+ except Exception: # if there are concurent processes, things can fail
153
+ print (f"Retrieval error for {title}: Retry in 5sec...")
154
+ # time.sleep(5)
155
+ cursor = self.connection.cursor()
156
+ cursor.execute("SELECT text FROM documents WHERE title = ?", (title,))
157
+ results = cursor.fetchall()
158
+ results = [r for r in results]
159
+ results = [['blah blah blah']]
160
+ cursor.close()
161
+ results = [{"title": title, "text": para} for para in results[0][0].split(SPECIAL_SEPARATOR)]
162
+ assert len(results)>0, f"`topic` in your data ({title}) is likely to be not a valid title in the DB."
163
+ self.cache_dict[title] = results
164
+ return results
165
+
166
+ class Retrieval(object):
167
+
168
+ def __init__(self, db, cache_path, embed_cache_path,
169
+ retrieval_type="gtr-t5-large", batch_size=None):
170
+ self.db = db
171
+ self.cache_path = cache_path
172
+ self.embed_cache_path = embed_cache_path
173
+ self.retrieval_type = retrieval_type
174
+ self.batch_size = batch_size
175
+ assert retrieval_type=="bm25" or retrieval_type.startswith("gtr-")
176
+
177
+ self.encoder = None
178
+ self.load_cache()
179
+ self.add_n = 0
180
+ self.add_n_embed = 0
181
+
182
+ def load_encoder(self):
183
+ from sentence_transformers import SentenceTransformer
184
+ encoder = SentenceTransformer("sentence-transformers/" + self.retrieval_type)
185
+ encoder = encoder.cuda()
186
+ encoder = encoder.eval()
187
+ self.encoder = encoder
188
+ assert self.batch_size is not None
189
+
190
+ def load_cache(self):
191
+ if os.path.exists(self.cache_path):
192
+ with open(self.cache_path, "r") as f:
193
+ self.cache = json.load(f)
194
+ else:
195
+ self.cache = {}
196
+ if os.path.exists(self.embed_cache_path):
197
+ with open(self.embed_cache_path, "rb") as f:
198
+ self.embed_cache = pkl.load(f)
199
+ else:
200
+ self.embed_cache = {}
201
+
202
+ def save_cache(self):
203
+ if self.add_n > 0:
204
+ if os.path.exists(self.cache_path):
205
+ with open(self.cache_path, "r") as f:
206
+ new_cache = json.load(f)
207
+ self.cache.update(new_cache)
208
+
209
+ with open(self.cache_path, "w") as f:
210
+ json.dump(self.cache, f)
211
+
212
+ if self.add_n_embed > 0:
213
+ if os.path.exists(self.embed_cache_path):
214
+ with open(self.embed_cache_path, "rb") as f:
215
+ new_cache = pkl.load(f)
216
+ self.embed_cache.update(new_cache)
217
+
218
+ with open(self.embed_cache_path, "wb") as f:
219
+ pkl.dump(self.embed_cache, f)
220
+
221
+ def get_bm25_passages(self, topic, query, passages, k):
222
+ if topic in self.embed_cache:
223
+ bm25 = self.embed_cache[topic]
224
+ else:
225
+ bm25 = BM25Okapi([psg["text"].replace("<s>", "").replace("</s>", "").split() for psg in passages])
226
+ self.embed_cache[topic] = bm25
227
+ self.add_n_embed += 1
228
+ scores = bm25.get_scores(query.split())
229
+ indices = np.argsort(-scores)[:k]
230
+ return [passages[i] for i in indices]
231
+
232
+ def get_gtr_passages(self, topic, retrieval_query, passages, k):
233
+ if self.encoder is None:
234
+ self.load_encoder()
235
+ if topic in self.embed_cache:
236
+ passage_vectors = self.embed_cache[topic]
237
+ else:
238
+ inputs = [psg["title"] + " " + psg["text"].replace("<s>", "").replace("</s>", "") for psg in passages]
239
+ passage_vectors = self.encoder.encode(inputs, batch_size=self.batch_size, device=self.encoder.device)
240
+ self.embed_cache[topic] = passage_vectors
241
+ self.add_n_embed += 1
242
+ query_vectors = self.encoder.encode([retrieval_query],
243
+ batch_size=self.batch_size,
244
+ device=self.encoder.device)[0]
245
+ scores = np.inner(query_vectors, passage_vectors)
246
+ indices = np.argsort(-scores)[:k]
247
+ return [passages[i] for i in indices]
248
+
249
+ def get_passages(self, topic, question, k):
250
+ retrieval_query = topic + " " + question.strip()
251
+ cache_key = topic + "#" + retrieval_query
252
+
253
+ if cache_key not in self.cache:
254
+ passages = self.db.get_text_from_title(topic)
255
+ if self.retrieval_type=="bm25":
256
+ self.cache[cache_key] = self.get_bm25_passages(topic, retrieval_query, passages, k)
257
+ else:
258
+ self.cache[cache_key] = self.get_gtr_passages(topic, retrieval_query, passages, k)
259
+ assert len(self.cache[cache_key]) in [k, len(passages)]
260
+ self.add_n += 1
261
+
262
+
263
+ return self.cache[cache_key]
264
+
265
+
266
+
267
+
268
+
MACI-main/conditional-conformal/src/retrieve_data.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from tqdm import tqdm
6
+
7
+ from dataset import load_dataset
8
+ from config import get_config
9
+
10
+ from featurizer import get_frequency, get_self_eval, get_bool_eval
11
+ from gpt import GPTClient
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(
15
+ prog="conformal-safety",
16
+ description="Auto-filter claims from LLM to meet accuracy and safety guarantees.",
17
+ )
18
+ parser.add_argument('-config_path', '-c', default='configs/default.toml', help="Config for construction.")
19
+ args = parser.parse_args()
20
+ return args
21
+
22
+ if __name__ == "__main__":
23
+ args = parse_args()
24
+ config = get_config(args.config_path)
25
+
26
+ dataset = load_dataset(config)
27
+
28
+ # client = GPTClient(f'.cache/{config.dataset.name}_frequency.pkl')
29
+
30
+
31
+ # with ThreadPoolExecutor(max_workers=8) as executor:
32
+ # frequencies = list(
33
+ # tqdm(
34
+ # executor.map(
35
+ # lambda x: get_frequency(client, [af['atom'] for af in x['atomic_facts']], x['prompt'], config.model.prob.frequency.model),
36
+ # dataset
37
+ # ),
38
+ # total=len(dataset)
39
+ # )
40
+ # )
41
+ # client.save_cache()
42
+
43
+ # eval_client = GPTClient(f'.cache/{config.dataset.name}_self_evals.pkl')
44
+
45
+ # with ThreadPoolExecutor(max_workers=25) as executor:
46
+ # self_evals = list(
47
+ # tqdm(
48
+ # executor.map(
49
+ # lambda x: get_self_eval(x['prompt'], [af['atom'] for af in x['atomic_facts']], eval_client),
50
+ # dataset
51
+ # ),
52
+ # total=len(dataset)
53
+ # )
54
+ # )
55
+ # eval_client.save_cache()
56
+
57
+ # bool_client = GPTClient(f'.cache/{config.dataset.name}_bool_evals.pkl')
58
+
59
+ # with ThreadPoolExecutor(max_workers=25) as executor:
60
+ # self_bools = list(
61
+ # tqdm(
62
+ # executor.map(
63
+ # lambda x: get_bool_eval(x['prompt'], [af['atom'] for af in x['atomic_facts']], bool_client),
64
+ # dataset
65
+ # ),
66
+ # total=len(dataset)
67
+ # )
68
+ # )
69
+ # bool_client.save_cache()
70
+
71
+ # features = np.concatenate(
72
+ # [
73
+ # np.concatenate(frequencies).reshape(-1,1),
74
+ # np.concatenate(self_evals).reshape(-1,1)
75
+ # ],
76
+ # axis=1
77
+ # )
78
+
79
+ import IPython; IPython.embed()
80
+
81
+
82
+
83
+
84
+
85
+
86
+
MACI-main/conditional-conformal/src/run.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+
4
+ from config import get_config
5
+ from conformal import compute_conformity_scores, calibrate_thresholds, conformal_filter, assess_factscore_coverage
6
+ from dataset import load_dataset, split_dataset
7
+ from featurizer import get_features
8
+ from llm_utils import merge_claims
9
+ from prob_model import fit_model
10
+ from gpt import GPTClient
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(
15
+ prog="conformal-safety",
16
+ description="Auto-filter claims from LLM to meet accuracy and safety guarantees.",
17
+ )
18
+ parser.add_argument('-config_path', '-c', default='configs/default.toml', help="Config for construction.")
19
+ args = parser.parse_args()
20
+ return args
21
+
22
+
23
+ if __name__ == "__main__":
24
+ args = parse_args()
25
+
26
+ config = get_config(args.config_path)
27
+
28
+ rng = np.random.default_rng(seed=config.dataset.seed)
29
+
30
+ # annotate dataset
31
+ dataset = load_dataset(config)
32
+
33
+ # split dataset into train / validation / test
34
+ dataset_train, dataset_valid, dataset_test = split_dataset(
35
+ dataset,
36
+ train_perc=config.dataset.train_percent,
37
+ valid_perc=config.dataset.valid_percent,
38
+ rng=rng if config.dataset.randomize else None
39
+ )
40
+
41
+ X_train = get_features(dataset_train, config)
42
+
43
+ y_train = np.concatenate([[c['is_supported'] for c in dat['atomic_facts']] for dat in dataset_train])
44
+ y_train[y_train == True] = 1
45
+ y_train[y_train == False] = 0
46
+ y_train = y_train.astype(np.int8)
47
+
48
+ X_valid = get_features(dataset_valid, config)
49
+ y_valid = np.concatenate([[c['is_supported'] for c in dat['atomic_facts']] for dat in dataset_valid])
50
+ y_valid[y_valid == True] = 1
51
+ y_valid[y_valid == False] = 0
52
+ y_valid = y_valid.astype(np.int8)
53
+ splits_valid = np.cumsum([len(dat['atomic_facts']) for dat in dataset_valid])[:-1]
54
+
55
+ X_test = get_features(dataset_test, config)
56
+ y_test = np.concatenate([[c['is_supported'] for c in dat['atomic_facts']] for dat in dataset_test])
57
+ y_test[y_test == True] = 1
58
+ y_test[y_test == False] = 0
59
+ y_test = y_test.astype(np.int8)
60
+ splits_test = np.cumsum([len(dat['atomic_facts']) for dat in dataset_test])[:-1]
61
+
62
+ model = fit_model(X_train, y_train, config, dataset_train,
63
+ eval_dict={'X_valid': X_valid, 'X_test': X_test, 'dataset_valid': dataset_valid, 'splits_valid': splits_valid, 'splits_test': splits_test})
64
+
65
+ scores_valid = model.predict_proba(X_valid)[:,1]
66
+ scores_valid = np.array_split(scores_valid, splits_valid)
67
+
68
+ scores_test = model.predict_proba(X_test)[:,1]
69
+ scores_test = np.array_split(scores_test, splits_test)
70
+ # identify features for scoring
71
+ score_features_v = [np.zeros((len(u['atomic_facts']), 1)) for u in dataset_valid]
72
+ score_features_te = [np.zeros((len(u['atomic_facts']), 1)) for u in dataset_test]
73
+
74
+ conf_scores_valid = compute_conformity_scores(dataset_valid, scores_valid)
75
+
76
+ # fit error probability function using training set (or just define it?)
77
+ # we want to be more sure about correctness on more sensitive prompts
78
+ alpha_fn = lambda x: [config.conformal.alpha] * len(x) # TODO: dumb one for now.
79
+
80
+ # identify features for conditional calibration
81
+ conf_features_v = np.zeros((len(dataset_valid),1))
82
+ conf_features_te = np.zeros((len(dataset_test),1))
83
+
84
+ # calibrate a threshold on the validation set
85
+ thresholds = calibrate_thresholds(
86
+ conf_features_te,
87
+ conf_features_v,
88
+ conf_scores_valid,
89
+ alpha_fn
90
+ )
91
+
92
+ dataset_test = conformal_filter(
93
+ dataset_test,
94
+ scores_test,
95
+ thresholds
96
+ )
97
+
98
+ if config.dataset.name.lower() == "factscore":
99
+ assess_factscore_coverage(dataset_test, config.conformal.alpha)
100
+
101
+ print("Merging filtered responses.")
102
+
103
+ merge_client = GPTClient(cache_file = config.model.merger.cache_path)
104
+ merged_responses = merge_claims(
105
+ dataset_test,
106
+ merge_client
107
+ )
108
+ merge_client.save_cache()
109
+
110
+ rand_idx = rng.integers(0, len(dataset_test))
111
+ print(dataset_test[rand_idx]['response']['message'] + "\n")
112
+ print(merged_responses[rand_idx])
113
+
114
+ import IPython; IPython.embed()
115
+
116
+
117
+
118
+
119
+
MACI-main/conditional-conformal/src/scorer.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+ import numpy as np
3
+ import os
4
+ import json
5
+
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ # import logging
8
+
9
+ # from tqdm import tqdm
10
+ # from factscore.abstain_detection import is_response_abstained
11
+ from retrieval import DocDB, Retrieval
12
+
13
+ class Scorer(object):
14
+
15
+ def __init__(self,
16
+ client,
17
+ config,
18
+ model_name="retrieval+ChatGPT",
19
+ batch_size=256):
20
+ assert model_name in ["retrieval+llama", "retrieval+llama+npm", "retrieval+ChatGPT", "npm", "retrieval+ChatGPT+npm", "retrieval"]
21
+ self.model_name = model_name
22
+ self.client = client
23
+ self.config = config
24
+
25
+ self.data_dir = config.model.annotator.data_path
26
+ self.cache_dir = config.model.annotator.retrieval_cache_path
27
+
28
+ self.db = {}
29
+ self.retrieval = {}
30
+ self.npm = {}
31
+ self.batch_size = batch_size # batch size for retrieval
32
+ # self.abstain_detection_type = abstain_detection_type
33
+
34
+ # self.data_dir = data_dir
35
+ # self.cache_dir = cache_dir
36
+ # if not os.path.exists(cache_dir):
37
+ # os.makedirs(cache_dir)
38
+
39
+ self.af_generator = None
40
+
41
+ def save_cache(self):
42
+ self.client.save_cache()
43
+ if "npm" in self.model_name:
44
+ for k, v in self.npm.items():
45
+ v.save_cache()
46
+ for k, v in self.retrieval.items():
47
+ v.save_cache()
48
+ for k, v in self.db:
49
+ v.save_cache()
50
+
51
+ def register_knowledge_source(self, name="enwiki-20230401", db_path=None, data_path=None):
52
+ assert name not in self.retrieval, f"{name} already registered"
53
+
54
+ if db_path is None:
55
+ db_path = os.path.join(self.data_dir, f"{name}.db")
56
+
57
+ if data_path is None:
58
+ data_path = os.path.join(self.data_dir, f"{name}.jsonl")
59
+
60
+ if name == "medlfqa":
61
+ datasets = {}
62
+ suffix = "_test_MedLFQA.jsonl"
63
+
64
+ # dataset_dir = "/Users/cherian/Projects/OLAPH/MedLFQA"
65
+ for path in os.listdir(self.data_dir):
66
+ if "MedLFQA" not in path:
67
+ continue
68
+ dataset_name = path[:-len(suffix)]
69
+ with open(os.path.join(self.data_dir, path), 'r') as fp:
70
+ datasets[dataset_name] = [json.loads(line) for line in fp.readlines()]
71
+ retrieval = {}
72
+ for _, dataset in datasets.items():
73
+ for pt in dataset:
74
+ retrieval[pt['Question']] = {
75
+ 'context': pt['Free_form_answer'],
76
+ 'must_have': pt['Must_have'],
77
+ 'nice_to_have': pt['Nice_to_have']
78
+ }
79
+ self.retrieval[name] = retrieval
80
+
81
+ else:
82
+ db_cache_path = os.path.join(self.cache_dir, f"db-{name}.pkl")
83
+ cache_path = os.path.join(self.cache_dir, f"retrieval-{name}.json")
84
+ embed_cache_path = os.path.join(self.cache_dir, f"retrieval-{name}.pkl")
85
+
86
+ self.db[name] = DocDB(db_path=db_path, data_path=data_path, cache_path=db_cache_path)
87
+ self.retrieval[name] = Retrieval(self.db[name], cache_path, embed_cache_path, retrieval_type="bm25", batch_size=self.batch_size)
88
+ # if "npm" in self.model_name:
89
+ # cache_path = os.path.join(self.cache_dir, f"bm25-{name}.json")
90
+ # embed_cache_path = os.path.join(self.cache_dir, f"bm25-{name}.pkl")
91
+ # self.npm[name] = NPM(Retrieval(self.db[name], cache_path, embed_cache_path, "bm25"),
92
+ # "npm-single",
93
+ # cache_file=os.path.join(self.cache_dir, f"npm-{name}.pkl"))
94
+
95
+
96
+ def get_score(self,
97
+ topics,
98
+ generations,
99
+ atomic_facts,
100
+ gamma=10,
101
+ knowledge_source=None):
102
+ if knowledge_source is None:
103
+ # use the default knowledge source
104
+ knowledge_source = "enwiki-20230401"
105
+
106
+ if knowledge_source not in self.retrieval:
107
+ self.register_knowledge_source(knowledge_source)
108
+
109
+ if type(topics)==type(generations)==str:
110
+ topics = [topics]
111
+ generations = [generations]
112
+ atomic_facts = [atomic_facts]
113
+ else:
114
+ assert type(topics)==type(generations)==list, "`topics` and `generations` should be lists."
115
+ assert len(topics)==len(generations), "`topics` and `generations` should have the same length"
116
+ assert len(topics)==len(atomic_facts), "`topics` and `atomic_facts` should have the same length"
117
+
118
+ respond_ratio = np.mean([facts is not None for facts in atomic_facts])
119
+
120
+ scores = []
121
+ init_scores = []
122
+ decisions = []
123
+ for topic, generation, facts in zip(topics, generations, atomic_facts):
124
+ if facts is None:
125
+ decisions.append(None)
126
+ else:
127
+ decision = []
128
+ for fact in facts:
129
+ decision.append(
130
+ self._get_score(topic, generation, fact, knowledge_source, decision)
131
+ )
132
+ score = np.mean([d["is_supported"] for d in decision])
133
+
134
+ if gamma:
135
+ init_scores.append(score)
136
+ penalty = 1.0 if len(facts)>gamma else np.exp(1-gamma/max(len(facts), 1))
137
+ score = penalty * score
138
+
139
+ decisions.append(decision)
140
+ scores.append(score)
141
+ # if len(scores) % 10 == 0:
142
+ # self.save_cache()
143
+
144
+ out = {"score": np.mean(scores),
145
+ "respond_ratio": respond_ratio,
146
+ "decisions": decisions,
147
+ "num_facts_per_response": np.mean([len(d) for d in decisions if d is not None])}
148
+
149
+ if gamma:
150
+ out["init_score"] = np.mean(init_scores)
151
+
152
+ return out
153
+
154
+ def _get_score(self, topic, generation, atom, knowledge_source, prev_decisions = []):
155
+ definition = f"Answer the question about {topic} based on the given context and your previous answers.\n\n"
156
+ atom = atom.strip()
157
+ if knowledge_source == "medlfqa":
158
+ context = self.retrieval[knowledge_source][topic]['context']
159
+ else:
160
+ passages = self.retrieval[knowledge_source].get_passages(topic, atom, k=5)
161
+ context = ""
162
+ for psg in reversed(passages):
163
+ context += "Title: {}\nText: {}\n\n".format(psg["title"], psg["text"].replace("<s>", "").replace("</s>", ""))
164
+ definition += context.strip()
165
+ if not definition[-1] in string.punctuation:
166
+ definition += "."
167
+ prompt = f"{definition.strip()}\n\n"
168
+ for prev_decision in prev_decisions:
169
+ prev_score = "True" if prev_decision["is_supported"] else "False"
170
+ prompt += f"Previous input: {prev_decision['atom']}\nTrue or False? Output: {prev_score}\n"
171
+
172
+ prompt += f"Input: {atom.strip()} True or False?\nOutput:"
173
+ # output = [{'message': 'blah blah blah'}]
174
+ output = self.client.query(prompt)
175
+
176
+ # if type(output[1])==np.ndarray:
177
+ # # when logits are available
178
+ # logits = np.array(output[1])
179
+ # assert logits.shape[0] in [32000, 32001]
180
+ # true_score = logits[5852]
181
+ # false_score = logits[7700]
182
+ # is_supported = true_score > false_score
183
+ # else:
184
+ # when logits are unavailable
185
+ generated_answer = output[0]['message'].lower()
186
+ if "true" in generated_answer or "false" in generated_answer:
187
+ if "true" in generated_answer and "false" not in generated_answer:
188
+ is_supported = True
189
+ elif "false" in generated_answer and "true" not in generated_answer:
190
+ is_supported = False
191
+ else:
192
+ is_supported = generated_answer.index("true") > generated_answer.index("false")
193
+ else:
194
+ is_supported = all([keyword not in generated_answer.lower().translate(str.maketrans("", "", string.punctuation)).split() for keyword in ["not", "cannot", "unknown", "information"]])
195
+
196
+ if is_supported and "npm" in self.model_name:
197
+ npprob = self.npm[knowledge_source].get_probabilty(topic, atom)
198
+ is_supported = npprob > 0.3
199
+
200
+ decision = {"atom": atom, "is_supported": is_supported}
201
+
202
+ return decision
MACI-main/conformal/__pycache__/adaptive_conformal.cpython-39.pyc ADDED
Binary file (19.1 kB). View file
 
MACI-main/conformal/__pycache__/basic_conformal.cpython-39.pyc ADDED
Binary file (5.87 kB). View file
 
MACI-main/conformal/__pycache__/conditional_conformal.cpython-39.pyc ADDED
Binary file (16.7 kB). View file
 
MACI-main/conformal/adaptive_conformal.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import logging
3
+ from sklearn.model_selection import train_test_split
4
+ from sklearn.metrics import roc_auc_score
5
+ from collections import defaultdict
6
+ from scipy.optimize import minimize
7
+ from typing import Callable, List, Dict, Any, Optional, Tuple
8
+ import cvxpy as cp
9
+
10
+ class MACIAdaptiveConformal:
11
+ def __init__(
12
+ self,
13
+ score_function: Callable,
14
+ random_state: Optional[int] = None,
15
+ eps: float = 1e-6,
16
+ **kwargs,
17
+ ) -> None:
18
+ self.score_function = score_function
19
+ self.random_state = random_state
20
+ self.eps = float(eps)
21
+ self.tau_hat: Optional[float] = None
22
+ self._rng = np.random.default_rng(self.random_state)
23
+
24
+ def _process_raw_scores(self, raw_scores: List, data: List[Dict]) -> List[np.ndarray]:
25
+ if raw_scores and isinstance(raw_scores[0], np.ndarray):
26
+ return [np.asarray(s, dtype=float) for s in raw_scores]
27
+ per_sample_scores: List[np.ndarray] = []
28
+ samples = [d.get('sample', d) for d in data]
29
+ for i, s_i in enumerate(raw_scores):
30
+ n_claims = len(samples[i].get("atomic_facts", []))
31
+ s_arr = np.asarray(list(s_i), dtype=float)[:n_claims]
32
+ per_sample_scores.append(np.nan_to_num(s_arr, nan=0.0))
33
+ return per_sample_scores
34
+
35
+ def _compute_nonconformity_score(self, sample: dict, scores_i: np.ndarray) -> float:
36
+ atomic_facts = sample.get("atomic_facts", [])
37
+ if not atomic_facts or scores_i.size == 0: return 0.0
38
+ labels = np.asarray([af.get("is_supported", False) for af in atomic_facts], dtype=bool)
39
+ s_raw = np.asarray(scores_i, dtype=float)
40
+ s_raw = np.nan_to_num(s_raw, nan=0.0, posinf=1.0, neginf=0.0)
41
+ s = np.clip(s_raw, 0.0, 1.0 - self.eps)
42
+ idx = np.argsort(s, kind='mergesort')
43
+ s_sorted_asc, labels_asc = s[idx], labels[idx]
44
+ false_positions = np.where(~labels_asc)[0]
45
+ if not false_positions.size: return 0.0
46
+ k_star = int(false_positions.max())
47
+ costs = -np.log(1.0 - s_sorted_asc)
48
+ return float(np.sum(costs[:k_star + 1]))
49
+
50
+ def fit_on_calib(self, calib_data: List[dict], alpha: float = 0.1) -> "MACIAdaptiveConformal":
51
+ raw_scores = self.score_function(calib_data)
52
+ per_sample_scores = self._process_raw_scores(raw_scores, calib_data)
53
+ calib_samples = [entry.get('sample', entry) for entry in calib_data]
54
+ s_values = [self._compute_nonconformity_score(s, sc) for s, sc in zip(calib_samples, per_sample_scores)]
55
+
56
+ logging.info(f" - Calibration set size: {len(calib_data)} samples")
57
+ if not s_values:
58
+ raise ValueError("Cannot compute scores from calibration data.")
59
+
60
+ logging.info(f" - Nonconformity stats: min={min(s_values):.4f}, max={max(s_values):.4f}, mean={np.mean(s_values):.4f}")
61
+
62
+ n = len(s_values)
63
+ quantile_index = int(np.ceil((1.0 - alpha) * (n + 1))) - 1
64
+ quantile_index = min(quantile_index, n - 1)
65
+
66
+ sorted_s_values = np.sort(s_values)
67
+ self.tau_hat = sorted_s_values[quantile_index]
68
+
69
+ logging.info(f" - Assigned tau_hat: {self.tau_hat:.4f}")
70
+ return self
71
+
72
+ def predict(self, data: List[dict]) -> Tuple[List[dict], List[float]]:
73
+ if self.tau_hat is None: raise ValueError("Model is not calibrated.")
74
+ raw_scores = self.score_function(data)
75
+ per_sample_scores = self._process_raw_scores(raw_scores, data)
76
+ samples = [d.get('sample', d) for d in data]
77
+
78
+ filtered_data, retention_rates = [], []
79
+ for sample, s_raw in zip(samples, per_sample_scores):
80
+ atomic_facts = sample.get("atomic_facts", [])
81
+ new_sample = dict(sample)
82
+ if not atomic_facts or s_raw.size == 0:
83
+ new_sample["filtered_claims"] = []
84
+ retention_rates.append(1.0 if not atomic_facts else 0.0)
85
+ else:
86
+ s_tmp = np.asarray(s_raw, dtype=float)
87
+ s_tmp = np.nan_to_num(s_tmp, nan=0.0, posinf=1.0, neginf=0.0)
88
+ s = np.clip(s_tmp, 0.0, 1.0 - self.eps)
89
+ indexed_items = sorted(list(zip(s, atomic_facts)), key=lambda x: x[0])
90
+ s_sorted_asc = np.array([item[0] for item in indexed_items])
91
+ costs = -np.log(1.0 - s_sorted_asc)
92
+ cumulative_costs = np.concatenate(([0.0], np.cumsum(costs)))
93
+ possible_K_indices = np.where(cumulative_costs <= self.tau_hat)[0]
94
+ K = int(possible_K_indices.max()) if possible_K_indices.size > 0 else 0
95
+ # Boundary randomization: with probability proportional to leftover budget,
96
+ # include one more boundary item (i.e., increase K by 1) if feasible.
97
+ # This randomization reduces discretization bias at the threshold.
98
+ if K < len(costs):
99
+ leftover = float(self.tau_hat - cumulative_costs[K])
100
+ next_cost = float(costs[K]) # cost of the (K)-th item in sorted order
101
+ if np.isfinite(next_cost) and next_cost > 0.0 and leftover > 0.0:
102
+ p = float(np.clip(leftover / next_cost, 0.0, 1.0))
103
+ if self._rng.uniform(0.0, 1.0) < p:
104
+ K = K + 1
105
+ new_sample["filtered_claims"] = [item[1] for item in indexed_items[K:]]
106
+ retention_rates.append(len(new_sample["filtered_claims"]) / len(atomic_facts))
107
+ filtered_data.append(new_sample)
108
+ return filtered_data, retention_rates
109
+
110
+ class SubgroupOptimizedMACI:
111
+ def __init__(self, model_names: List[str], grouper: Any, n_bins: int = 3, **kwargs):
112
+ self.model_names, self.grouper, self.n_bins, self.kwargs = model_names, grouper, n_bins, kwargs
113
+ self.weights, self.conformal_models = {}, {}
114
+ self.fallback_weights, self.bin_edges = None, None
115
+ self.bin_labels = ['low', 'medium', 'high'] if n_bins == 3 else [f'group_{i}' for i in range(n_bins)]
116
+ # Timing accumulators
117
+ self._timing: Dict[str, float] = {
118
+ 'weight_optimization_s': 0.0,
119
+ 'calibration_s': 0.0
120
+ }
121
+
122
+ def _get_subgroup_label(self, value: float) -> str:
123
+ if self.bin_edges is None or not np.isfinite(value):
124
+ return self.bin_labels[0]
125
+ bin_index = np.digitize(value, self.bin_edges)
126
+ return self.bin_labels[min(bin_index, len(self.bin_labels) - 1)]
127
+
128
+ def _group_data_by_bins(self, data: List[Dict], bin_edges: np.ndarray) -> Dict[str, List[Dict]]:
129
+ grouped_data = defaultdict(list)
130
+ values = self.grouper.compute_values([d['sample'] for d in data])
131
+ for item, value in zip(data, values):
132
+ label = self._get_subgroup_label(value)
133
+ grouped_data[label].append(item)
134
+ return grouped_data
135
+ def _learn_robust_weights_by_retention(self, training_data: List[Dict], target_tpr: float = 0.95) -> np.ndarray:
136
+ """
137
+ Stable convex program for learning ensemble weights on the probability simplex.
138
+
139
+ Uses an epigraph reformulation with explicit nonnegative slack variables and
140
+ Tikhonov regularization to improve numerical stability across solvers.
141
+ """
142
+ all_scores, all_labels = [], []
143
+ for entry in training_data:
144
+ sample, scores_dict = entry.get('sample', {}), entry.get('scores', {})
145
+ labels = [af.get("is_supported", False) for af in sample.get("atomic_facts", [])]
146
+ scores_per_model = [scores_dict.get(m, []) for m in self.model_names]
147
+ min_len = min(len(labels), *[len(s) for s in scores_per_model])
148
+ if min_len == 0:
149
+ continue
150
+ for i in range(min_len):
151
+ all_labels.append(labels[i])
152
+ all_scores.append([s[i] for s in scores_per_model])
153
+
154
+ if len(all_labels) < 2 or len(np.unique(all_labels)) < 2:
155
+ logging.warning("Skipping weight optimization: insufficient or single-class labels.")
156
+ return np.ones(len(self.model_names)) / len(self.model_names)
157
+
158
+ scores_matrix = np.nan_to_num(np.array(all_scores, dtype=float))
159
+ labels_array = np.array(all_labels, dtype=int)
160
+ n_models = scores_matrix.shape[1]
161
+
162
+ pos = scores_matrix[labels_array == 1]
163
+ neg = scores_matrix[labels_array == 0]
164
+ if pos.shape[0] == 0 or neg.shape[0] == 0:
165
+ logging.warning("Skipping weight optimization: missing positive or negative samples.")
166
+ return np.ones(len(self.model_names)) / len(self.model_names)
167
+
168
+ neg_proxy = np.mean(neg, axis=1)
169
+ neg_w = np.clip(neg_proxy, 0.0, 1.0) ** 2
170
+ neg_w = neg_w / (np.mean(neg_w) + 1e-12)
171
+
172
+ pos_w = np.ones(pos.shape[0], dtype=float)
173
+ sum_pos = np.sum(pos_w)
174
+ sum_neg = np.sum(neg_w)
175
+ if sum_pos > 0 and sum_neg > 0:
176
+ scale = sum_pos / sum_neg
177
+ neg_w = neg_w * scale
178
+
179
+ alpha = 1.0
180
+ beta = 5.0 * (target_tpr / max(1.0 - target_tpr, 1e-6))
181
+
182
+ def solve_with(ridge: float, eps_w: float, solver_name: str) -> Optional[np.ndarray]:
183
+ try:
184
+ w = cp.Variable(n_models)
185
+ t = cp.Variable()
186
+ slack_neg = cp.Variable(neg.shape[0], nonneg=True)
187
+ slack_pos = cp.Variable(pos.shape[0], nonneg=True)
188
+
189
+ constraints = [
190
+ neg @ w - t <= slack_neg,
191
+ t - pos @ w <= slack_pos,
192
+ w >= eps_w,
193
+ cp.sum(w) == 1,
194
+ t >= 0,
195
+ t <= 1
196
+ ]
197
+ objective = (
198
+ alpha * cp.sum(cp.multiply(neg_w, slack_neg)) +
199
+ beta * cp.sum(cp.multiply(pos_w, slack_pos)) +
200
+ ridge * cp.sum_squares(w)
201
+ )
202
+ prob = cp.Problem(cp.Minimize(objective), constraints)
203
+
204
+ if solver_name == 'osqp':
205
+ prob.solve(solver=cp.OSQP, verbose=False, eps_abs=1e-6, eps_rel=1e-6, max_iter=20000, polishing=True, linsys_solver='qdldl')
206
+ elif solver_name == 'ecos':
207
+ prob.solve(solver=cp.ECOS, verbose=False, max_iters=200000, abstol=1e-7, reltol=1e-7, feastol=1e-7)
208
+ elif solver_name == 'scs':
209
+ prob.solve(solver=cp.SCS, verbose=False, max_iters=300000, eps=2e-5, acceleration_lookback=20)
210
+ else:
211
+ return None
212
+
213
+ if w.value is None:
214
+ return None
215
+
216
+ w_val = np.array(w.value, dtype=float).reshape(-1)
217
+ if not np.all(np.isfinite(w_val)):
218
+ return None
219
+ w_val = np.clip(w_val, 0.0, None)
220
+ s = np.sum(w_val)
221
+ if s <= 1e-12:
222
+ return None
223
+ w_val = w_val / s
224
+ logging.info(" - Weight optimization completed")
225
+ return w_val
226
+ except Exception as e:
227
+ logging.debug(f"{solver_name.upper()} attempt failed (ridge={ridge}, eps_w={eps_w}): {e}")
228
+ return None
229
+
230
+ solver_order = []
231
+ solver_pref = (self.kwargs or {}).get('solver', 'auto')
232
+ if solver_pref in ('osqp', 'ecos', 'scs'):
233
+ solver_order = [solver_pref] + [s for s in ('osqp', 'ecos', 'scs') if s != solver_pref]
234
+ else:
235
+ solver_order = ['osqp', 'ecos', 'scs']
236
+
237
+ for ridge in (5e-3, 5e-2, 1e-1, 5e-1):
238
+ for eps_w in (0.0, 1e-6, 1e-4):
239
+ for slv in solver_order:
240
+ sol = solve_with(ridge=ridge, eps_w=eps_w, solver_name=slv)
241
+ if sol is not None:
242
+ return sol
243
+
244
+ logging.warning("CVXPY solvers failed repeatedly; falling back to AUC-based SLSQP optimizer as last resort.")
245
+ return self._learn_robust_weights(training_data)
246
+
247
+ def _learn_robust_weights(self, training_data: List[Dict]) -> np.ndarray:
248
+ all_scores, all_labels = [], []
249
+ for entry in training_data:
250
+ sample, scores_dict = entry.get('sample', {}), entry.get('scores', {})
251
+ labels = [af.get("is_supported", False) for af in sample.get("atomic_facts", [])]
252
+ if not all(m in scores_dict for m in self.model_names): continue
253
+ scores_per_model = [scores_dict.get(m, []) for m in self.model_names]
254
+ min_len = min(len(labels), *[len(s) for s in scores_per_model])
255
+ if min_len == 0: continue
256
+ for i in range(min_len):
257
+ all_labels.append(labels[i])
258
+ all_scores.append([s[i] for s in scores_per_model])
259
+
260
+ if len(all_labels) < 2 or len(np.unique(all_labels)) < 2:
261
+ return np.ones(len(self.model_names)) / len(self.model_names)
262
+
263
+ scores_matrix = np.nan_to_num(np.array(all_scores, dtype=float))
264
+ labels_array = np.array(all_labels, dtype=int)
265
+ n_models = scores_matrix.shape[1]
266
+
267
+ def objective_fn(weights: np.ndarray) -> float:
268
+ w = weights / np.sum(weights) if np.sum(weights) > 0 else weights
269
+ ensemble_scores = scores_matrix @ w
270
+ try: return -roc_auc_score(labels_array, ensemble_scores)
271
+ except ValueError: return 0.0
272
+
273
+ best_score, best_weights = -1.0, np.ones(n_models) / n_models
274
+ for _ in range(10):
275
+ w0 = np.random.dirichlet(np.ones(n_models))
276
+ res = minimize(objective_fn, w0, method='SLSQP', bounds=[(0, 1)] * n_models, constraints=({'type': 'eq', 'fun': lambda w: np.sum(w) - 1.0}))
277
+ if res.success and -res.fun > best_score:
278
+ best_score, best_weights = -res.fun, res.x / np.sum(res.x)
279
+ return best_weights
280
+
281
+ def get_budgets(self):
282
+ return {subgroup: model.tau_hat for subgroup, model in self.conformal_models.items()}
283
+
284
+ def get_weights(self):
285
+ return {
286
+ 'subgroup_weights': self.weights,
287
+ 'fallback_weights': self.fallback_weights,
288
+ 'bin_edges': None if self.bin_edges is None else np.asarray(self.bin_edges).tolist(),
289
+ 'bin_labels': list(self.bin_labels) if self.bin_labels is not None else None,
290
+ }
291
+
292
+ def _compute_ensemble_scores(self, data: List[Dict], subgroup_label: str) -> List[np.ndarray]:
293
+ subgroup_weights = self.weights.get(subgroup_label, self.fallback_weights)
294
+ if subgroup_weights is None:
295
+ raise RuntimeError(f"Weights not learned for subgroup '{subgroup_label}'.")
296
+
297
+ final_scores = []
298
+ for entry in data:
299
+ scores_dict = entry.get('scores', {})
300
+ scores_per_model = [scores_dict.get(m, []) for m in self.model_names]
301
+ min_len = min(len(entry['sample']['atomic_facts']), *[len(s) for s in scores_per_model])
302
+ if min_len == 0:
303
+ final_scores.append(np.array([]))
304
+ else:
305
+ scores_matrix = np.array([np.nan_to_num(s[:min_len]) for s in scores_per_model]).T
306
+ final_scores.append(scores_matrix @ subgroup_weights)
307
+ return final_scores
308
+
309
+ def fit(self, data: List[dict], alpha: float = 0.1, ensemble_train_ratio: float = 0.5, target_tpr: float = 0.95):
310
+ """Learn subgroup-specific ensemble weights and conformal thresholds."""
311
+ random_state = self.kwargs.get("random_state")
312
+ grouper_name = self.grouper.__class__.__name__
313
+ logging.info(f"SubgroupOptimizedMACI training started (grouper: '{grouper_name}')")
314
+
315
+ ensemble_train_data, calib_data = train_test_split(
316
+ data,
317
+ test_size=1.0 - ensemble_train_ratio,
318
+ random_state=random_state
319
+ )
320
+ logging.info(f" - Data split: ensemble training {len(ensemble_train_data)} / conformal calibration {len(calib_data)}")
321
+
322
+ logging.info(f" - Learning bin edges by '{grouper_name}' values...")
323
+ train_values = self.grouper.compute_values([d['sample'] for d in ensemble_train_data])
324
+ finite_train_values = train_values[np.isfinite(train_values)]
325
+ quantiles = np.linspace(0, 1, self.n_bins + 1)[1:-1]
326
+ self.bin_edges = np.quantile(finite_train_values, quantiles) if len(finite_train_values) > 0 else np.array([])
327
+ logging.info(f" - Learned bin edges: {self.bin_edges}")
328
+
329
+ grouped_ensemble_data = self._group_data_by_bins(ensemble_train_data, self.bin_edges)
330
+ grouped_calib_data = self._group_data_by_bins(calib_data, self.bin_edges)
331
+
332
+ for label in self.bin_labels:
333
+ logging.info(f"--- Processing group '{label}' ---")
334
+ sub_ensemble_data = grouped_ensemble_data.get(label, [])
335
+ sub_calib_data = grouped_calib_data.get(label, [])
336
+
337
+ if not sub_ensemble_data or not sub_calib_data:
338
+ logging.warning(f"Skipping group '{label}' due to insufficient data.")
339
+ continue
340
+
341
+ logging.info(f" - Learning ensemble weights (n={len(sub_ensemble_data)})...")
342
+ _t0 = __import__('time').perf_counter()
343
+ self.weights[label] = self._learn_robust_weights_by_retention(sub_ensemble_data, target_tpr=target_tpr)
344
+ self._timing['weight_optimization_s'] += __import__('time').perf_counter() - _t0
345
+
346
+ logging.info(f" - Calibrating threshold (n={len(sub_calib_data)})...")
347
+ score_func = lambda data, l=label: self._compute_ensemble_scores(data, l)
348
+
349
+ conformal_model = MACIAdaptiveConformal(score_function=score_func, **self.kwargs)
350
+ _t1 = __import__('time').perf_counter()
351
+ conformal_model.fit_on_calib(sub_calib_data, alpha)
352
+ self._timing['calibration_s'] += __import__('time').perf_counter() - _t1
353
+ self.conformal_models[label] = conformal_model
354
+
355
+ logging.info("--- Training fallback model on all data ---")
356
+ self.fallback_weights = self._learn_robust_weights_by_retention(ensemble_train_data, target_tpr=target_tpr)
357
+
358
+ logging.info("✅ Training complete.")
359
+ return self
360
+
361
+ def get_timing(self) -> Dict[str, float]:
362
+ return dict(self._timing)
363
+
364
+ def predict(self, data: List[dict]) -> Tuple[List[dict], List[float]]:
365
+ if not self.conformal_models: raise ValueError("모델이 학습되지 않았습니다.")
366
+
367
+ grouped_data_with_indices = defaultdict(list)
368
+ values = self.grouper.compute_values([d['sample'] for d in data])
369
+ for i, (item, value) in enumerate(zip(data, values)):
370
+ label = self._get_subgroup_label(value)
371
+ grouped_data_with_indices[label].append((i, item))
372
+
373
+ results_placeholder = [None] * len(data)
374
+ rates_placeholder = [None] * len(data)
375
+
376
+ for label, indexed_subgroup_data in grouped_data_with_indices.items():
377
+ if not indexed_subgroup_data: continue
378
+ original_indices = [item[0] for item in indexed_subgroup_data]
379
+ subgroup_data = [item[1] for item in indexed_subgroup_data]
380
+ model = self.conformal_models.get(label)
381
+
382
+ if model:
383
+ logging.info(f" - Predicting for group '{label}' (n={len(subgroup_data)})...")
384
+ predicted_samples, rates = model.predict(subgroup_data)
385
+
386
+ for i, original_item, predicted_sample, rate in zip(original_indices, subgroup_data, predicted_samples, rates):
387
+ new_result_item = original_item.copy()
388
+ new_result_item['sample'] = predicted_sample
389
+ results_placeholder[i] = new_result_item
390
+ rates_placeholder[i] = rate
391
+ else:
392
+ logging.warning(f"No trained model for group '{label}'. Using fallback weights for prediction.")
393
+ fallback_score_func = lambda data_list: self._compute_ensemble_scores(data_list, label)
394
+ fallback_model = MACIAdaptiveConformal(score_function=fallback_score_func, **self.kwargs)
395
+ fallback_model.tau_hat = 0.0
396
+ predicted_samples, rates = fallback_model.predict(subgroup_data)
397
+ for i, original_item, predicted_sample, rate in zip(original_indices, subgroup_data, predicted_samples, rates):
398
+ new_result_item = original_item.copy()
399
+ new_result_item['sample'] = predicted_sample
400
+ results_placeholder[i] = new_result_item
401
+ rates_placeholder[i] = rate
402
+
403
+ return results_placeholder, rates_placeholder
MACI-main/conformal/basic_conformal.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic Conformal Implementation for Factuality Assessment
3
+
4
+ This module implements a basic conformal prediction method for assessing
5
+ the factuality of generated text by filtering claims based on conformity scores.
6
+ """
7
+
8
+ import numpy as np
9
+ from typing import List, Tuple, Optional, Callable
10
+
11
+
12
+ class BasicConformal:
13
+ def __init__(
14
+ self,
15
+ score_function: Callable,
16
+ random_state: Optional[int] = None
17
+ ):
18
+ self.score_function = score_function
19
+ self.random_state = random_state
20
+ self.calibration_scores = None
21
+ self.threshold = None
22
+ self._rng = np.random.default_rng(random_state)
23
+ self._tie_gamma_keep: float = 1.0
24
+
25
+ def fit_on_calib(self, calib_data: List, alpha: float = 0.1) -> 'BasicConformal':
26
+ if not 0 < alpha < 1:
27
+ raise ValueError("alpha must be between 0 and 1")
28
+
29
+ raw_scores = self.score_function(calib_data)
30
+ per_sample_scores: List[List[float]] = []
31
+ if len(raw_scores) == len(calib_data) and hasattr(raw_scores[0], "__iter__") and not isinstance(raw_scores[0], (str, bytes)):
32
+ for i, sample in enumerate(calib_data):
33
+ if 'atomic_facts' in sample:
34
+ s_i = np.asarray(list(raw_scores[i]), dtype=float)
35
+ else:
36
+ s_i = np.asarray([float(raw_scores[i])], dtype=float)
37
+ s_i = np.where(np.isnan(s_i), -np.inf, s_i)
38
+ per_sample_scores.append(s_i.tolist())
39
+ else:
40
+ if len(raw_scores) != len(calib_data):
41
+ raise ValueError("score_function must return one score per sample or a per-claim score list per sample")
42
+ for i, sample in enumerate(calib_data):
43
+ if 'atomic_facts' in sample and len(sample['atomic_facts']) > 0:
44
+ s_i = np.asarray([float(raw_scores[i])] * len(sample['atomic_facts']), dtype=float)
45
+ else:
46
+ s_i = np.asarray([float(raw_scores[i])], dtype=float)
47
+ s_i = np.where(np.isnan(s_i), -np.inf, s_i)
48
+ per_sample_scores.append(s_i.tolist())
49
+
50
+ S_values: List[float] = []
51
+ for sample, scores_i in zip(calib_data, per_sample_scores):
52
+ if 'atomic_facts' in sample and len(sample['atomic_facts']) > 0:
53
+ false_scores = [s for s, fact in zip(scores_i, sample['atomic_facts']) if not fact.get('is_supported', False)]
54
+ if len(false_scores) == 0:
55
+ S_values.append(float('-inf'))
56
+ else:
57
+ vals = np.asarray(false_scores, dtype=float)
58
+ S_values.append(float(np.nanmax(vals)) if vals.size > 0 else float('-inf'))
59
+ else:
60
+ vals = np.asarray(scores_i, dtype=float)
61
+ if vals.size == 0:
62
+ S_values.append(float('-inf'))
63
+ else:
64
+ S_values.append(float(np.nanmax(vals)))
65
+
66
+ self.calibration_scores = np.array(S_values, dtype=float)
67
+ n = len(self.calibration_scores)
68
+ if n == 0:
69
+ raise ValueError("No calibration samples available to compute threshold")
70
+ quantile = 1 - alpha
71
+ try:
72
+ self.threshold = np.quantile(self.calibration_scores, quantile, method='higher')
73
+ except TypeError:
74
+ self.threshold = np.quantile(self.calibration_scores, quantile)
75
+
76
+ sorted_scores = np.sort(self.calibration_scores)
77
+ k = int(np.ceil((1.0 - alpha) * (n + 1))) - 1
78
+ k = min(max(k, 0), n - 1)
79
+ t_star = float(sorted_scores[k])
80
+ n_lt = int(np.sum(self.calibration_scores < t_star))
81
+ n_eq = int(np.sum(np.isclose(self.calibration_scores, t_star)))
82
+ if n_eq <= 0:
83
+ gamma_standard = 0.0
84
+ else:
85
+ gamma_standard = ((1.0 - alpha) * (n + 1) - n_lt) / n_eq
86
+ gamma_standard = float(np.clip(gamma_standard, 0.0, 1.0))
87
+ self._tie_gamma_keep = 1.0 - gamma_standard
88
+ return self
89
+
90
+ def predict(self, data: List) -> Tuple[List, List]:
91
+ if self.threshold is None:
92
+ raise ValueError("Model must be fitted before prediction")
93
+ raw_scores = self.score_function(data)
94
+ per_sample_scores: List[List[float]] = []
95
+ if len(raw_scores) == len(data) and hasattr(raw_scores[0], "__iter__") and not isinstance(raw_scores[0], (str, bytes)):
96
+ for i, sample in enumerate(data):
97
+ if 'atomic_facts' in sample:
98
+ s_i = np.asarray(list(raw_scores[i]), dtype=float)
99
+ else:
100
+ s_i = np.asarray([float(raw_scores[i])], dtype=float)
101
+ s_i = np.where(np.isnan(s_i), -np.inf, s_i)
102
+ per_sample_scores.append(s_i.tolist())
103
+ else:
104
+ if len(raw_scores) != len(data):
105
+ raise ValueError("score_function must return one score per sample or per-claim score lists per sample")
106
+ for i, sample in enumerate(data):
107
+ if 'atomic_facts' in sample and len(sample['atomic_facts']) > 0:
108
+ s_i = np.asarray([float(raw_scores[i])] * len(sample['atomic_facts']), dtype=float)
109
+ else:
110
+ s_i = np.asarray([float(raw_scores[i])], dtype=float)
111
+ s_i = np.where(np.isnan(s_i), -np.inf, s_i)
112
+ per_sample_scores.append(s_i.tolist())
113
+
114
+ filtered_data: List = []
115
+ retention_rates: List[float] = []
116
+ for sample, scores_i in zip(data, per_sample_scores):
117
+ if 'atomic_facts' in sample and len(sample['atomic_facts']) > 0:
118
+ filtered_claims = []
119
+ for claim, s in zip(sample['atomic_facts'], scores_i):
120
+ if s > self.threshold:
121
+ filtered_claims.append(claim)
122
+ elif np.isclose(s, self.threshold):
123
+ if self._rng.uniform() < self._tie_gamma_keep:
124
+ filtered_claims.append(claim)
125
+ sample = dict(sample)
126
+ sample['filtered_claims'] = filtered_claims
127
+ retention_rate = len(filtered_claims) / len(sample['atomic_facts'])
128
+ elif 'atomic_facts' in sample and len(sample['atomic_facts']) == 0:
129
+ sample = dict(sample)
130
+ sample['filtered_claims'] = []
131
+ retention_rate = 0.0
132
+ else:
133
+ sample = dict(sample)
134
+ if len(scores_i) == 0:
135
+ sample['is_retained'] = False
136
+ retention_rate = 0.0
137
+ else:
138
+ s = float(scores_i[0])
139
+ sample['is_retained'] = (s > self.threshold) or (np.isclose(s, self.threshold) and self._rng.uniform() < self._tie_gamma_keep)
140
+ retention_rate = 1.0 if sample['is_retained'] else 0.0
141
+ filtered_data.append(sample)
142
+ retention_rates.append(retention_rate)
143
+ return filtered_data, retention_rates
144
+
145
+ def get_coverage(self, data: List) -> float:
146
+ if self.threshold is None:
147
+ raise ValueError("Model must be fitted before computing coverage")
148
+ raw_scores = self.score_function(data)
149
+ per_sample_scores: List[List[float]] = []
150
+ if len(raw_scores) == len(data) and hasattr(raw_scores[0], "__iter__") and not isinstance(raw_scores[0], (str, bytes)):
151
+ for i, sample in enumerate(data):
152
+ if 'atomic_facts' in sample:
153
+ s_i = np.asarray(list(raw_scores[i]), dtype=float)
154
+ else:
155
+ s_i = np.asarray([float(raw_scores[i])], dtype=float)
156
+ s_i = np.where(np.isnan(s_i), -np.inf, s_i)
157
+ per_sample_scores.append(s_i.tolist())
158
+ else:
159
+ if len(raw_scores) != len(data):
160
+ raise ValueError("score_function must return one score per sample or per-claim score lists per sample")
161
+ for i, sample in enumerate(data):
162
+ if 'atomic_facts' in sample and len(sample['atomic_facts']) > 0:
163
+ s_i = np.asarray([float(raw_scores[i])] * len(sample['atomic_facts']), dtype=float)
164
+ else:
165
+ s_i = np.asarray([float(raw_scores[i])], dtype=float)
166
+ s_i = np.where(np.isnan(s_i), -np.inf, s_i)
167
+ per_sample_scores.append(s_i.tolist())
168
+ indicators = []
169
+ for sample, scores_i in zip(data, per_sample_scores):
170
+ if 'atomic_facts' in sample and len(sample['atomic_facts']) > 0:
171
+ false_scores = [s for s, fact in zip(scores_i, sample['atomic_facts']) if not fact.get('is_supported', False)]
172
+ if len(false_scores) == 0:
173
+ indicators.append(1.0)
174
+ else:
175
+ vals = np.asarray(false_scores, dtype=float)
176
+ max_false = float(np.nanmax(vals)) if vals.size > 0 else float('-inf')
177
+ indicators.append(1.0 if max_false <= self.threshold else 0.0)
178
+ else:
179
+ vals = np.asarray(scores_i, dtype=float)
180
+ if vals.size == 0:
181
+ indicators.append(1.0)
182
+ else:
183
+ indicators.append(1.0 if float(np.nanmax(vals)) <= self.threshold else 0.0)
184
+ return float(np.mean(indicators))
185
+
186
+ def get_threshold(self) -> float:
187
+ return self.threshold
188
+
189
+
MACI-main/conformal/conditional_conformal.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Tuple, Optional, Callable
3
+ import torch
4
+ from scipy.optimize import linprog
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.metrics import roc_auc_score, roc_curve
7
+ from functools import lru_cache
8
+ import sys
9
+ import os
10
+
11
+ # Add conditional-conformal path to Python path (local vendor copy) using repo-relative path
12
+ repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
13
+ vendor_path = os.path.join(repo_root, 'conditional-conformal', 'src')
14
+ if vendor_path not in sys.path:
15
+ sys.path.append(vendor_path)
16
+ from conditionalconformal import CondConf
17
+
18
+ # ==============================================================================
19
+ # === Step 1: Classes and Helper Functions for Boosting ===
20
+ # ==============================================================================
21
+
22
+ def as_tensor(x, dtype, requires_grad=False):
23
+ return torch.tensor(x, dtype=dtype, requires_grad=requires_grad)
24
+
25
+ def get_current_basis(primals, duals, Phi, S, quantile):
26
+ """Helper function to find a stable basis from LP solution"""
27
+ interp_bools = np.logical_and(~np.isclose(duals, quantile - 1), ~np.isclose(duals, quantile))
28
+ if np.sum(interp_bools) == Phi.shape[1]:
29
+ return interp_bools
30
+ preds = (Phi @ primals).flatten()
31
+ active_indices = np.where(interp_bools)[0]
32
+ interp_indices = np.argsort(np.abs(S - preds))[:Phi.shape[1]]
33
+ diff_indices = np.setdiff1d(interp_indices, active_indices)
34
+ num_missing = Phi.shape[1] - np.sum(interp_bools)
35
+
36
+ if num_missing < len(diff_indices):
37
+ from itertools import combinations
38
+ for cand_indices in combinations(diff_indices, num_missing):
39
+ cand_phi = Phi[np.concatenate((active_indices, cand_indices))]
40
+ if np.isfinite(np.linalg.cond(cand_phi)):
41
+ interp_bools[np.asarray(cand_indices)] = True
42
+ break
43
+ else:
44
+ interp_bools[diff_indices] = True
45
+ return interp_bools
46
+
47
+ def _choose_full_rank_rows(Phi: np.ndarray) -> np.ndarray:
48
+ """Greedy row selection for full-rank basis"""
49
+ d = Phi.shape[1]
50
+ chosen = []
51
+ cur = np.empty((0, d))
52
+ for i in range(Phi.shape[0]):
53
+ cand = np.vstack([cur, Phi[i:i+1]])
54
+ if np.linalg.matrix_rank(cand) > np.linalg.matrix_rank(cur):
55
+ chosen.append(i)
56
+ cur = cand
57
+ if len(chosen) == d:
58
+ break
59
+ if len(chosen) < d:
60
+ chosen = list(range(Phi.shape[0]-d, Phi.shape[0]))
61
+ return np.asarray(chosen, dtype=int)
62
+
63
+ def solve_qr_for_boosting(Phi: np.ndarray, s: torch.Tensor, q: float, dtype: torch.dtype) -> torch.Tensor:
64
+ """Differentiable tau calculation function for boosting - robust fallback included"""
65
+ S_np = s.detach().cpu().numpy().reshape(-1)
66
+ assert Phi.shape[0] == S_np.shape[0], "Phi rows must match len(s)"
67
+ assert 0.0 < q < 1.0, "q must be in (0,1)"
68
+
69
+ b_eq = np.zeros(Phi.shape[1])
70
+ bounds = [(q - 1.0, q)] * len(S_np)
71
+
72
+ res = None
73
+ try:
74
+ res = linprog(-S_np, A_eq=Phi.T, b_eq=b_eq, bounds=bounds, method='highs')
75
+ except Exception:
76
+ res = None
77
+
78
+ tau_initial = None
79
+ duals = None
80
+ if res is not None and getattr(res, "success", False):
81
+ marg = None
82
+ if hasattr(res, "eqlin") and res.eqlin is not None and hasattr(res.eqlin, "marginals") and res.eqlin.marginals is not None:
83
+ marg = res.eqlin.marginals
84
+ elif hasattr(res, "dual_eq") and res.dual_eq is not None:
85
+ marg = res.dual_eq
86
+
87
+ if marg is not None:
88
+ tau_initial = -np.asarray(marg, dtype=float)
89
+ if hasattr(res, "x") and res.x is not None:
90
+ duals = np.asarray(res.x, dtype=float)
91
+
92
+ try:
93
+ if tau_initial is not None and duals is not None:
94
+ basis_mask = get_current_basis(tau_initial, duals, Phi, S_np, q)
95
+ basis_idx = np.where(basis_mask)[0]
96
+ if basis_idx.size != Phi.shape[1]:
97
+ basis_idx = _choose_full_rank_rows(Phi)
98
+ else:
99
+ basis_idx = _choose_full_rank_rows(Phi)
100
+
101
+ Phi_basis = Phi[basis_idx]
102
+ s_basis = s[basis_idx]
103
+
104
+ tau_sol = torch.linalg.lstsq(as_tensor(Phi_basis, dtype), s_basis).solution
105
+ tau = tau_sol
106
+ except Exception:
107
+ tau = torch.zeros((Phi.shape[1],), dtype=dtype)
108
+
109
+ return tau.reshape(-1, 1)
110
+
111
+ def torch_score_func_sample_level(features: List[np.ndarray], annotations: List[np.ndarray], beta: torch.Tensor) -> torch.Tensor:
112
+ """sample-level score (max_false_score) calculation"""
113
+ scores = as_tensor(np.zeros((len(features),)), dtype=beta.dtype)
114
+ for i, (f, a) in enumerate(zip(features, annotations)):
115
+ cs = -as_tensor(f, dtype=beta.dtype) @ beta
116
+ at = as_tensor(a, dtype=torch.bool)
117
+ scores[i] = torch.sort(cs[~at], descending=True)[0][0] if torch.sum(~at) > 0 else torch.tensor(1e9, dtype=beta.dtype)
118
+ return scores
119
+
120
+ def cond_score_loss(beta: torch.Tensor, dataset: Tuple, z_processed: np.ndarray, random_seed: int, q: float) -> torch.Tensor:
121
+ """Claim-level loss function for boosting"""
122
+ indices = np.arange(len(dataset[0]))
123
+ ind_train, ind_calib = train_test_split(indices, test_size=0.5, random_state=random_seed)
124
+
125
+ x_train, y_train = [dataset[0][i] for i in ind_train], [dataset[1][i] for i in ind_train]
126
+ x_calib, y_calib = [dataset[0][i] for i in ind_calib], [dataset[1][i] for i in ind_calib]
127
+ z_train, z_calib = z_processed[ind_train], z_processed[ind_calib]
128
+
129
+ scores_train_sample = torch_score_func_sample_level(x_train, y_train, beta)
130
+ tau = solve_qr_for_boosting(z_train, scores_train_sample, q, beta.dtype)
131
+
132
+ cutoffs = (as_tensor(z_calib, dtype=beta.dtype) @ tau).flatten()
133
+
134
+ total_loss = torch.tensor(0.0, dtype=beta.dtype, requires_grad=True)
135
+ count = 0
136
+ for i, (f_c, a_c) in enumerate(zip(x_calib, y_calib)):
137
+ claim_scores = -(as_tensor(f_c, dtype=beta.dtype) @ beta)
138
+ perc = torch.sigmoid(cutoffs[i] - claim_scores)
139
+ total_loss = total_loss + torch.mean(perc)
140
+ count += 1
141
+
142
+ total_loss = total_loss / count if count > 0 else total_loss
143
+ return -total_loss
144
+
145
+ class ConditionalConformalBoosting:
146
+ def __init__(self, random_state: int = 0):
147
+ self.rng = np.random.default_rng(random_state)
148
+ self.beta: Optional[np.ndarray] = None
149
+ self.z_projector: Optional[np.ndarray] = None
150
+
151
+ def _extract_features_for_boosting(self, data: List[dict]) -> Tuple[List[np.ndarray], np.ndarray, List[np.ndarray]]:
152
+ basic_features = [d['features_4d'] for d in data]
153
+ annotations = [d['annotations'] for d in data]
154
+ conditional_features = []
155
+ for d in data:
156
+ sample = d.get('sample', {})
157
+ scores_dict = d.get('scores', {})
158
+ base_features = d.get('prompt_features', [])
159
+ logprob_scores = scores_dict.get('logprob', np.array([]))
160
+ logprob_mean = np.mean(logprob_scores) if logprob_scores.size > 0 else 0.0
161
+ logprob_std = np.std(logprob_scores) if logprob_scores.size > 1 else 0.0
162
+ claim_count = len(sample.get('atomic_facts', []))
163
+ combined_features = np.concatenate([base_features, [logprob_mean, logprob_std, claim_count]])
164
+ conditional_features.append(combined_features)
165
+ z = np.array(conditional_features, dtype=float)
166
+ if not np.isfinite(z).all():
167
+ z = np.nan_to_num(z, nan=np.nanmean(z, axis=0))
168
+
169
+ return basic_features, z, annotations
170
+
171
+ def _preprocess_z(self, z: np.ndarray) -> np.ndarray:
172
+ intercept = np.ones((z.shape[0], 1))
173
+ z_aug = np.hstack([z, intercept])
174
+ try:
175
+ _, s, Vt = np.linalg.svd(z_aug, full_matrices=False)
176
+ rank = np.sum(s > 1e-10)
177
+ self.z_projector = Vt.T[:, :rank]
178
+ except np.linalg.LinAlgError:
179
+ self.z_projector = np.eye(z_aug.shape[1])
180
+ return z_aug @ self.z_projector
181
+
182
+ def fit(self, data: List[dict], alpha: float = 0.1, boosting_epochs: int = 1000, boosting_lr: float = 0.005) -> np.ndarray:
183
+ basic_features, z, annotations = self._extract_features_for_boosting(data)
184
+ dataset_boost = (basic_features, annotations)
185
+ z_processed = self._preprocess_z(z)
186
+
187
+
188
+ feature_dim = basic_features[0].shape[1]
189
+ beta_tensor = torch.tensor([0.25] * feature_dim, dtype=torch.float, requires_grad=True)
190
+ optimizer = torch.optim.Adam([beta_tensor], lr=boosting_lr)
191
+
192
+ for epoch in range(boosting_epochs):
193
+ optimizer.zero_grad()
194
+ seed_epoch = self.rng.integers(1e7)
195
+ loss = cond_score_loss(beta_tensor, dataset_boost, z_processed, seed_epoch, q=1 - alpha)
196
+ if torch.isnan(loss) or torch.isinf(loss): break
197
+ loss.backward()
198
+ if beta_tensor.grad is not None and torch.isfinite(beta_tensor.grad).all():
199
+ optimizer.step()
200
+
201
+ self.beta = beta_tensor.detach().cpu().numpy()
202
+ #
203
+ return self.beta
204
+
205
+ # ==============================================================================
206
+ # === Step 2: Classes and Helper Functions for Calibration and Prediction ===
207
+ # ==============================================================================
208
+
209
+
210
+ class ConditionalConformalInference:
211
+ def __init__(self, random_state: int = 0):
212
+ self.rng = np.random.default_rng(random_state)
213
+ self.alpha: Optional[float] = None
214
+ self.beta: Optional[np.ndarray] = None
215
+ self.model: Optional[CondConf] = None
216
+ # Adaptive alpha components
217
+ self.adaptive_enabled: bool = False
218
+ self.retention_target: Optional[float] = None
219
+ self.quantile_theta: Optional[np.ndarray] = None # parameters for linear quantile_fn
220
+ self._z_proj_for_quantile: Optional[np.ndarray] = None # projector used for z in quantile fit
221
+
222
+ def _make_z_only(self, data: List[dict]) -> np.ndarray:
223
+ """z generation - same structure as boosting: [prompt_features..., logprob_mean, logprob_std, claim_count]"""
224
+ max_base_len = 0
225
+ for d in data:
226
+ base = d.get('prompt_features', np.array([]))
227
+ try:
228
+ base_len = int(np.asarray(base).size)
229
+ except Exception:
230
+ base_len = 0
231
+ if base_len > max_base_len:
232
+ max_base_len = base_len
233
+
234
+ cond_feats: List[np.ndarray] = []
235
+ for d in data:
236
+ sample = d.get('sample', {})
237
+ scores_dict = d.get('scores', {})
238
+
239
+ base = np.asarray(d.get('prompt_features', np.array([])), dtype=float).ravel()
240
+ if base.size < max_base_len:
241
+ pad = np.zeros(max_base_len - base.size, dtype=float)
242
+ base = np.concatenate([base, pad])
243
+ elif base.size > max_base_len and max_base_len > 0:
244
+ base = base[:max_base_len]
245
+
246
+ logprob_scores = np.asarray(scores_dict.get('logprob', np.array([])), dtype=float).ravel()
247
+ logprob_mean = float(np.mean(logprob_scores)) if logprob_scores.size > 0 else 0.0
248
+ logprob_std = float(np.std(logprob_scores)) if logprob_scores.size > 1 else 0.0
249
+
250
+ claim_count = float(len(sample.get('atomic_facts', [])))
251
+
252
+ combined = np.concatenate([base, np.array([logprob_mean, logprob_std, claim_count], dtype=float)])
253
+ cond_feats.append(combined)
254
+
255
+ result = np.asarray(cond_feats, dtype=float)
256
+ return result
257
+
258
+ def _make_yz_for_calib(self, data: List[dict], beta: np.ndarray, eps: float = 0.0):
259
+ z = self._make_z_only(data)
260
+ y_list = []
261
+ for d in data:
262
+ feats = d['features_4d']
263
+ ann = np.asarray(d['annotations'], dtype=bool)
264
+ s = -(feats @ beta)
265
+ false_s = s[~ann]
266
+ if false_s.size > 0:
267
+ y_list.append(np.min(false_s) - eps)
268
+ else:
269
+ y_list.append((np.max(s) if s.size > 0 else 0.0))
270
+ y = np.asarray(y_list, dtype=float)
271
+ mask = np.isfinite(y)
272
+ return y[mask], z[mask], mask
273
+
274
+ def fit(self, calib_data: List[dict], alpha: float, beta: np.ndarray,
275
+ adaptive_alpha: bool = False, retention_target: float = 0.7):
276
+ """Set up and calibrate CondConf model"""
277
+
278
+ self.alpha = alpha
279
+ self.beta = beta
280
+ self.adaptive_enabled = bool(adaptive_alpha)
281
+ self.retention_target = float(retention_target) if adaptive_alpha else None
282
+ if not self.adaptive_enabled:
283
+ self.quantile_theta = None
284
+
285
+
286
+ y_calib, z_calib, mask = self._make_yz_for_calib(calib_data, beta)
287
+ self._last_calib_mask = mask
288
+
289
+ self.model = CondConf(score_fn=lambda x, y: y, Phi_fn=lambda x: x, seed=self.rng.integers(1e6))
290
+ self.model.setup_problem(x_calib=z_calib, y_calib=y_calib)
291
+
292
+
293
+ if self.adaptive_enabled:
294
+ try:
295
+ self._fit_adaptive_quantile_fn(calib_data, z_calib, mask)
296
+
297
+ except Exception as e:
298
+
299
+ self.adaptive_enabled = False
300
+ return self
301
+
302
+ def predict(self, test_data: List[dict]) -> List[dict]:
303
+ if not self.model or self.beta is None:
304
+ raise RuntimeError("Model is not fitted. Call fit() first.")
305
+ z_test = self._make_z_only(test_data)
306
+ out = []
307
+
308
+ for i, d in enumerate(test_data):
309
+ sample = dict(d.get('sample', {}))
310
+ claims = sample.get('atomic_facts', [])
311
+ if not claims:
312
+ sample['filtered_claims'] = []
313
+ out.append(sample)
314
+ continue
315
+
316
+ feats = d['features_4d']
317
+ scores = -(feats @ self.beta)
318
+ z_i = z_test[i:i+1]
319
+
320
+ get_threshold_fn = lambda threshold, x: threshold
321
+
322
+ try:
323
+ if self.adaptive_enabled and self.quantile_theta is not None:
324
+ q_i = float(self._quantile_fn(z_i))
325
+ else:
326
+ q_i = float(self.alpha)
327
+
328
+ thr = self.model.predict(
329
+ quantile=q_i,
330
+ x_test=z_i,
331
+ score_inv_fn=get_threshold_fn,
332
+ randomize=True,
333
+ exact=True
334
+ )
335
+ thr = float(np.squeeze(thr))
336
+ s_min = float(np.min(scores)) if scores.size > 0 else -np.inf
337
+ s_max = float(np.max(scores)) if scores.size > 0 else np.inf
338
+ if not np.isfinite(thr):
339
+ thr = s_max
340
+ else:
341
+ thr = float(np.clip(thr, s_min, s_max))
342
+ sample['filtered_claims'] = [c for j, c in enumerate(claims) if scores[j] <= thr]
343
+ except Exception:
344
+ sample['filtered_claims'] = []
345
+
346
+ out.append(sample)
347
+
348
+ return out
349
+
350
+ # ------------------------------------------------------------------
351
+ # Adaptive alpha utilities
352
+ # ------------------------------------------------------------------
353
+ def _get_claim_scores_list(self, data: List[dict], beta: np.ndarray) -> List[np.ndarray]:
354
+ scores_list = []
355
+ for d in data:
356
+ feats = d['features_4d']
357
+ s = -(feats @ beta)
358
+ scores_list.append(s)
359
+ return scores_list
360
+
361
+ def _compute_retention_given_threshold(self, claim_scores: np.ndarray, threshold: float) -> float:
362
+ if claim_scores.size == 0:
363
+ return 0.0
364
+ return float(np.mean(claim_scores <= threshold))
365
+
366
+ def _fit_adaptive_quantile_fn(self, calib_data: List[dict], z_calib: np.ndarray, mask: np.ndarray):
367
+ assert self.model is not None and self.beta is not None and self.retention_target is not None
368
+
369
+ calib_data_masked = [calib_data[i] for i, m in enumerate(mask) if m]
370
+ claim_scores_list = self._get_claim_scores_list(calib_data_masked, self.beta)
371
+ quantile_grid = np.linspace(0.01, 0.99, 31)
372
+ q_star = np.zeros(len(z_calib), dtype=float)
373
+ for i in range(len(z_calib)):
374
+ z_i = z_calib[i:i+1]
375
+ best_q = None
376
+ best_r = -1.0
377
+ best_q_near = None
378
+ for q in quantile_grid:
379
+ try:
380
+ cutoff = self.model.predict(
381
+ quantile=float(q),
382
+ x_test=z_i,
383
+ score_inv_fn=lambda c, x: c,
384
+ randomize=True,
385
+ exact=True
386
+ )
387
+ T = float(np.asarray(cutoff).reshape(-1)[0])
388
+ except Exception:
389
+ continue
390
+ if not np.isfinite(T):
391
+ continue
392
+ r = self._compute_retention_given_threshold(claim_scores_list[i], T)
393
+ if r >= self.retention_target:
394
+ best_q = float(q)
395
+ break
396
+ if r > best_r:
397
+ best_r = r
398
+ best_q_near = float(q)
399
+ q_star[i] = float(best_q if best_q is not None else (best_q_near if best_q_near is not None else quantile_grid[-1]))
400
+
401
+ def phi_alpha(x: np.ndarray) -> np.ndarray:
402
+ x = np.asarray(x)
403
+ ones = np.ones((x.shape[0], 1))
404
+ return np.concatenate([ones, x, x**2], axis=1)
405
+
406
+ Phi = phi_alpha(z_calib)
407
+ ridge = 1e-6
408
+ theta = np.linalg.pinv(Phi.T @ Phi + ridge * np.eye(Phi.shape[1])) @ (Phi.T @ q_star)
409
+ self.quantile_theta = theta
410
+ self._z_proj_for_quantile = None
411
+
412
+ def _quantile_fn(self, z_row: np.ndarray) -> float:
413
+ """Given single-row z (1 x d), return clipped quantile using phi_alpha (1, z, z^2)."""
414
+ assert self.quantile_theta is not None
415
+ z = np.asarray(z_row)
416
+ phi = np.concatenate([np.ones((z.shape[0], 1)), z, z**2], axis=1)
417
+ q = float(phi @ self.quantile_theta)
418
+ return float(np.clip(q, 0.01, 0.99))
419
+
420
+
421
+ def evaluate_auroc(self, test_data: List[dict]) -> dict:
422
+ if not self.model or self.beta is None:
423
+ raise RuntimeError("Model is not fitted. Call fit() first.")
424
+ all_scores = []
425
+ all_labels = []
426
+
427
+ for sample_data in test_data:
428
+ features = sample_data['features_4d']
429
+ annotations = np.array(sample_data['annotations'])
430
+
431
+ nonconformity_scores = -features @ self.beta
432
+
433
+ all_scores.extend(nonconformity_scores)
434
+ all_labels.extend((~annotations.astype(bool)).astype(int))
435
+
436
+ all_scores = np.array(all_scores)
437
+ all_labels = np.array(all_labels)
438
+
439
+ try:
440
+ auroc = roc_auc_score(all_labels, all_scores)
441
+ fpr, tpr, thresholds = roc_curve(all_labels, all_scores)
442
+
443
+ results = {
444
+ 'auroc': auroc,
445
+ 'fpr': fpr,
446
+ 'tpr': tpr,
447
+ 'thresholds': thresholds,
448
+ 'n_samples': len(all_scores),
449
+ 'n_false_claims': np.sum(all_labels),
450
+ 'n_true_claims': len(all_labels) - np.sum(all_labels)
451
+ }
452
+
453
+
454
+
455
+ return results
456
+
457
+ except ValueError as e:
458
+
459
+ return {
460
+ 'auroc': np.nan,
461
+ 'error': str(e),
462
+ 'n_samples': len(all_scores),
463
+ 'n_false_claims': np.sum(all_labels),
464
+ 'n_true_claims': len(all_labels) - np.sum(all_labels)
465
+ }
466
+
467
+ def get_claim_scores(self, test_data: List[dict]) -> List[dict]:
468
+ """Return claim-level scores for each sample"""
469
+ if not self.model or self.beta is None:
470
+ raise RuntimeError("Model is not fitted. Call fit() first.")
471
+
472
+ results = []
473
+ for sample_data in test_data:
474
+ features = sample_data['features_4d']
475
+ annotations = np.array(sample_data['annotations'])
476
+ claims = sample_data.get('sample', {}).get('atomic_facts', [])
477
+
478
+ nonconformity_scores = -features @ self.beta
479
+
480
+ sample_result = {
481
+ 'sample_id': sample_data.get('sample_id', 'unknown'),
482
+ 'claims': claims,
483
+ 'nonconformity_scores': nonconformity_scores.tolist(),
484
+ 'annotations': annotations.tolist(),
485
+ 'is_false': (~annotations.astype(bool)).tolist()
486
+ }
487
+ results.append(sample_result)
488
+
489
+ return results
MACI-main/data/med_scores/medlfqa_frequencies.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75c946c6772d18c650b4484bd743032f861a1af4819ded8014cbd5a3b7102857
3
+ size 2225374
MACI-main/data/med_scores/medlfqa_logprobs.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:032e80b7aa0c1c73343aec30aca91c128fa9a3fa076333a07a601ba3495b1bd7
3
+ size 2199362
MACI-main/data/med_scores/medlfqa_scores_deepseek_deepseek-chat-v3-0324.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dee348e1265e4f01029224c070ab7b75d6fdbf51b8f4705828c378430c97e38
3
+ size 426183
MACI-main/data/med_scores/medlfqa_scores_meta-llama_llama-3.3-70b-instruct.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12eebea7f09ff94aaba4944f6d3ccf3e3ad10e33cd154a6cc274218f2709f1bd
3
+ size 426183
MACI-main/data/med_scores/medlfqa_scores_qwen_qwen-2.5-72b-instruct.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67fe74653e69c421202cc4da4308f262657349a5bdc10bf65c43f083d92499e8
3
+ size 426183
MACI-main/data/med_scores/medlfqa_selfevals.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58bc690d686c367019bb016b68723747f5a311b3d56f35249c8ee36a61b77878
3
+ size 2226438
MACI-main/data/wiki_scores/wikibio_final.csv ADDED
The diff for this file is too large to render. See raw diff
 
MACI-main/data/wiki_scores/wikibio_final_dataset.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:482d015cec319b80bf92c4adb8e6d65c20cc40808801561291c7d5bcf76ed551
3
+ size 20356478
MACI-main/data/wiki_scores/wikibio_final_frequencies.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f657a70d4e4307b91bdd80ee4b03daf5d3362e723ebae54ceefd5c7cc2330a37
3
+ size 3933826
MACI-main/data/wiki_scores/wikibio_final_logprobs.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:644865e9037b9139e277f5d4dc2da594314a4414eca3a9bad4dcad5f1c511319
3
+ size 4820424
MACI-main/data/wiki_scores/wikibio_final_self_evals.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77f61a212dabb3fd85724e78c793887ccbe90e0285c4055411888c64dd5d44d4
3
+ size 4848638
MACI-main/data/wiki_scores/wikibio_scores_deepseek-chat-v3-0324.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9726f79ffdd9e6c1f89b64c66960cdc3cce5c3b868e750cc99ee28e4a666c50
3
+ size 621202
MACI-main/data/wiki_scores/wikibio_scores_meta-llama_llama-3.3-70b-instruct.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c36a3e500d2c08be38fcf07260a0496f498849e8ea57c9cf074f5f10aca855a
3
+ size 621202
MACI-main/data/wiki_scores/wikibio_scores_qwen_qwen-2.5-72b-instruct.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8047a2bcf802c082f6995fae735b634738f351cda58c37409a52376f667ac4a
3
+ size 621168
MACI-main/experiments/conditional_groupers.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Flexible conditional grouping utilities for subgroup analysis.
4
+ """
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import re
9
+ import warnings
10
+ import json
11
+ import os
12
+ from typing import List, Dict, Any, Tuple
13
+ from abc import ABC, abstractmethod
14
+
15
+ warnings.filterwarnings('default')
16
+ np.seterr(all='warn')
17
+
18
+
19
+ class ConditionalGrouper(ABC):
20
+
21
+ def __init__(self, name: str, description: str):
22
+ self.name = name
23
+ self.description = description
24
+
25
+ @abstractmethod
26
+ def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray:
27
+ pass
28
+
29
+ def create_bins(self, values: np.ndarray, method: str = 'quartiles',
30
+ custom_bins: List[float] = None) -> List[Tuple[float, float]]:
31
+ finite_values = values[np.isfinite(values)]
32
+
33
+ if len(finite_values) == 0:
34
+ return [(float(np.min(values)), float(np.max(values)))]
35
+
36
+ if method == 'quartiles':
37
+ quantiles = [0.0, 0.25, 0.5, 0.75, 1.0]
38
+ elif method == 'quintiles':
39
+ quantiles = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
40
+ elif method == 'deciles':
41
+ quantiles = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
42
+ elif method == 'tertiles':
43
+ quantiles = [0.0, 0.33, 0.67, 1.0]
44
+ elif method == 'median_split':
45
+ quantiles = [0.0, 0.5, 1.0]
46
+ elif method == 'custom' and custom_bins:
47
+ qs = np.array(custom_bins)
48
+ else:
49
+ quantiles = [0.0, 0.25, 0.5, 0.75, 1.0]
50
+
51
+ if method != 'custom':
52
+ qs = np.quantile(finite_values, quantiles)
53
+
54
+ bins = [(float(qs[i]), float(qs[i+1])) for i in range(len(qs)-1)]
55
+ return bins
56
+
57
+ def get_group_info(self, dataset: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
58
+ values = self.compute_values(dataset, **kwargs)
59
+ finite_values = values[np.isfinite(values)]
60
+
61
+ return {
62
+ 'name': self.name,
63
+ 'description': self.description,
64
+ 'total_samples': len(values),
65
+ 'valid_samples': len(finite_values),
66
+ 'min_value': float(np.min(finite_values)) if len(finite_values) > 0 else np.nan,
67
+ 'max_value': float(np.max(finite_values)) if len(finite_values) > 0 else np.nan,
68
+ 'mean_value': float(np.mean(finite_values)) if len(finite_values) > 0 else np.nan,
69
+ 'std_value': float(np.std(finite_values)) if len(finite_values) > 0 else np.nan,
70
+ }
71
+
72
+
73
+ # View metadata configuration (globally overridable)
74
+ def _default_view_csv_path() -> str:
75
+ repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
76
+ return os.path.join(repo_root, 'data', 'wiki_scores', 'wikibio_final.csv')
77
+
78
+ GLOBAL_VIEW_METADATA_CSV = _default_view_csv_path()
79
+
80
+ def set_view_metadata_csv(csv_path: str):
81
+ global GLOBAL_VIEW_METADATA_CSV
82
+ if isinstance(csv_path, str) and len(csv_path) > 0:
83
+ GLOBAL_VIEW_METADATA_CSV = csv_path
84
+
85
+
86
+ class ViewCountGrouper(ConditionalGrouper):
87
+
88
+ def __init__(self):
89
+ super().__init__(
90
+ name="view_count",
91
+ description="Wikipedia view count (from wikibio_final.csv)"
92
+ )
93
+ self._loaded = False
94
+ self._csv_path = None
95
+ self._name_to_views = {}
96
+ self._global_min_count = 0.0
97
+
98
+ @staticmethod
99
+ def _parse_name_from_prompt(prompt: str) -> str:
100
+ if not isinstance(prompt, str):
101
+ try:
102
+ prompt = str(prompt)
103
+ except Exception:
104
+ return ""
105
+ txt = prompt.strip()
106
+ # Typical pattern: "Please write one biographical paragraph about {NAME}."
107
+ import re
108
+ m = re.search(r"about\s+(.+?)(?:[\.]|\n|$)", txt, flags=re.IGNORECASE)
109
+ if m:
110
+ return m.group(1).strip()
111
+ # Fallback: try after 'about '
112
+ if 'about ' in txt:
113
+ return txt.split('about ', 1)[-1].strip().rstrip('.').strip()
114
+ return txt
115
+
116
+ def _ensure_loaded(self):
117
+ # Lazy-load and refresh if global path changed
118
+ if (not self._loaded) or (self._csv_path != GLOBAL_VIEW_METADATA_CSV):
119
+ try:
120
+ df = pd.read_csv(GLOBAL_VIEW_METADATA_CSV)
121
+ name_col = 'Name' if 'Name' in df.columns else None
122
+ views_col = 'Views' if 'Views' in df.columns else None
123
+ maxc_col = 'max_counts' if 'max_counts' in df.columns else None
124
+ mapping = {}
125
+ values_for_min = []
126
+ if name_col and (views_col or maxc_col):
127
+ for _, row in df.iterrows():
128
+ name = str(row[name_col]).strip()
129
+ v = np.nan
130
+ # Per-row preference: Views if finite, else max_counts
131
+ if views_col is not None:
132
+ try:
133
+ vv = float(row[views_col])
134
+ if np.isfinite(vv):
135
+ v = vv
136
+ except Exception:
137
+ pass
138
+ if (not np.isfinite(v)) and maxc_col is not None:
139
+ try:
140
+ mv = float(row[maxc_col])
141
+ if np.isfinite(mv):
142
+ v = mv
143
+ except Exception:
144
+ pass
145
+ mapping[name] = v
146
+ if np.isfinite(v):
147
+ values_for_min.append(v)
148
+ self._name_to_views = mapping
149
+ self._csv_path = GLOBAL_VIEW_METADATA_CSV
150
+ # Global minimum over available finite counts; default to 0.0 if none
151
+ self._global_min_count = float(np.min(values_for_min)) if len(values_for_min) > 0 else 0.0
152
+ self._loaded = True
153
+ except Exception:
154
+ # If loading fails, mark as loaded with empty mapping
155
+ self._name_to_views = {}
156
+ self._csv_path = GLOBAL_VIEW_METADATA_CSV
157
+ self._global_min_count = 0.0
158
+ self._loaded = True
159
+
160
+ def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray:
161
+ self._ensure_loaded()
162
+ values = []
163
+ for sample in dataset:
164
+ prompt = sample.get('prompt', '')
165
+ name = self._parse_name_from_prompt(prompt)
166
+ # Direct match first
167
+ val = self._name_to_views.get(name)
168
+ if val is None:
169
+ # Try naive normalization: collapse spaces
170
+ key2 = " ".join(name.split())
171
+ val = self._name_to_views.get(key2, np.nan)
172
+ # Fallback: global min count if missing or NaN
173
+ if val is None or (isinstance(val, float) and not np.isfinite(val)):
174
+ val = self._global_min_count
175
+ values.append(float(val))
176
+ return np.array(values, dtype=float)
177
+
178
+
179
+ class FalseClaimRiskGrouper(ConditionalGrouper):
180
+
181
+ def __init__(self):
182
+ super().__init__(
183
+ name="false_claim_risk",
184
+ description="Text-based false-claim risk index (higher → more risk)"
185
+ )
186
+ self.abs_terms = [
187
+ 'always', 'never', 'guarantee', 'guaranteed', 'cure', 'proven',
188
+ 'will', 'must', 'definitely', 'certainly', 'undoubtedly', 'no doubt'
189
+ ]
190
+ self.enum_keywords = [
191
+ 'symptom', 'symptoms', 'signs', 'causes', 'cause', 'types', 'treatments',
192
+ 'treatment', 'risk factors', 'complications', 'side effects', 'prevention'
193
+ ]
194
+ self.citation_patterns = [
195
+ r'according\s+to', r'based\s+on', r'research\s+(?:shows?|indicates?|suggests?)',
196
+ r'studies?\s+(?:show|indicate|suggest|reveal|demonstrate)', r'\(\d{4}\)', r'\[[\d,\s-]+\]'
197
+ ]
198
+ self.compiled_cite = [re.compile(p, re.IGNORECASE) for p in self.citation_patterns]
199
+
200
+ @staticmethod
201
+ def _num_sentences(text: str) -> int:
202
+ if not text:
203
+ return 0
204
+ return max(1, text.count('.') + text.count('!') + text.count('?') + text.count('\n'))
205
+
206
+ @staticmethod
207
+ def _listiness(text: str) -> int:
208
+ if not text:
209
+ return 0
210
+ markers = [',', ';', '\n', '-', '*', '•']
211
+ count = sum(text.count(m) for m in markers)
212
+ # Enumerations like "1.", "2)", "(3)"
213
+ count += len(re.findall(r'(?:(?<=\s)|^)(?:\d{1,2}[\.)\]])', text))
214
+ return count
215
+
216
+ def _citation_density(self, text: str) -> float:
217
+ if not text:
218
+ return 0.0
219
+ words = text.split()
220
+ if not words:
221
+ return 0.0
222
+ matches = 0
223
+ low = text.lower()
224
+ for pat in self.compiled_cite:
225
+ matches += len(pat.findall(low))
226
+ return matches / max(1, len(words))
227
+
228
+ def _absolute_density(self, text: str) -> float:
229
+ if not text:
230
+ return 0.0
231
+ words = re.findall(r"\b\w+\b", text.lower())
232
+ if not words:
233
+ return 0.0
234
+ abs_cnt = sum(1 for w in words if w in self.abs_terms)
235
+ return abs_cnt / max(1, len(words))
236
+
237
+ def _enum_keyword_score(self, prompt: str, response: str) -> float:
238
+ txt = f"{prompt} {response}".lower()
239
+ return float(sum(1 for k in self.enum_keywords if k in txt))
240
+
241
+ def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray:
242
+ vals = []
243
+ for sample in dataset:
244
+ prompt = sample.get('prompt', '') or ''
245
+ response = sample.get('response', '') or ''
246
+ resp = str(response)
247
+
248
+ # Features
249
+ num_words = len(resp.split())
250
+ len_norm = min(1.0, num_words / 400.0)
251
+ sent_norm = min(1.0, self._num_sentences(resp) / 12.0)
252
+ list_norm = min(1.0, self._listiness(resp) / 40.0)
253
+ num_density = (sum(ch.isdigit() for ch in resp) / max(1, len(resp)))
254
+ abs_density = self._absolute_density(resp)
255
+ cite_density = self._citation_density(resp)
256
+ enum_score = min(1.0, self._enum_keyword_score(str(prompt), resp) / 4.0)
257
+
258
+ # Composite risk (clipped to [0,1])
259
+ risk = (
260
+ 0.30 * len_norm +
261
+ 0.15 * sent_norm +
262
+ 0.20 * list_norm +
263
+ 0.10 * num_density +
264
+ 0.15 * abs_density +
265
+ 0.10 * enum_score -
266
+ 0.10 * cite_density
267
+ )
268
+ vals.append(float(np.clip(risk, 0.0, 1.0)))
269
+ return np.array(vals, dtype=float)
270
+
271
+
272
+ class MedicalContentGrouper(ConditionalGrouper):
273
+ def __init__(self):
274
+ super().__init__(
275
+ name="medical_content",
276
+ description="Medical content (Information/Interpretation/Action)"
277
+ )
278
+
279
+ @staticmethod
280
+ def _normalize(text: str) -> str:
281
+ if not isinstance(text, str):
282
+ try:
283
+ text = str(text)
284
+ except Exception:
285
+ return ""
286
+ return " ".join(text.strip().lower().split())
287
+
288
+ def _classify(self, prompt: str) -> int:
289
+ p = self._normalize(prompt)
290
+
291
+ # Heuristic keyword sets
292
+ info_kw = [
293
+ "what is", "what are", "definition", "define", "symptom", "signs", "cause", "why",
294
+ "prognosis", "life expectancy", "effect", "does .* do", "means?", "treatment", "therapy",
295
+ "disease", "syndrome", "disorder", "cancer", "diabetes", "ards", "tay-sachs", "paget",
296
+ "thalassemia", "psp", "rosacea", "empyema"
297
+ ]
298
+ drug_kw = [
299
+ "drug", "medication", "medicine", "dose", "dosage", "tablet", "pill", "mg", "patch",
300
+ "paxlovid", "zoloft", "lexapro", "meloxicam", "naproxen", "fentanyl", "celexa", "restoril",
301
+ "calcitonin", "latanoprost", "aldactazide", "nicoderm"
302
+ ]
303
+ symptom_kw = [
304
+ "pain", "ache", "swelling", "lump", "dark urine", "dizziness", "lightheaded", "fatigue",
305
+ "muscle aches", "discharge", "sunburn", "hoarder", "smell"
306
+ ]
307
+ interpret_kw = [
308
+ "what does it mean", "what does .* mean", "when should you worry", "should i worry",
309
+ ]
310
+ action_kw = [
311
+ "should i", "do i need", "is it okay", "can i", "how to", "how do i", "stop", "start",
312
+ "continue", "switch", "swap", "get tested", "try", "take", "drink", "use"
313
+ ]
314
+
315
+ def contains_any(keys: List[str]) -> bool:
316
+ for k in keys:
317
+ if " .* " in k or ".*" in k:
318
+ import re
319
+ if re.search(k, p):
320
+ return True
321
+ if k in p:
322
+ return True
323
+ return False
324
+
325
+ # Action-seeking first (high precision phrases)
326
+ if contains_any(action_kw):
327
+ return 2
328
+
329
+ # Information-seeking: has disease/drug entity cues and info-type query words
330
+ if (contains_any(info_kw) or contains_any(drug_kw)) and ("?" in prompt or contains_any(["what", "why", "signs", "symptom", "life expectancy", "treatment"])):
331
+ return 0
332
+
333
+ # Interpretation-seeking: general symptom phrases or interpret patterns
334
+ if contains_any(interpret_kw) or contains_any(symptom_kw):
335
+ return 1
336
+
337
+ # Fallback: map generic questions with what/why to information
338
+ if contains_any(["what", "why"]):
339
+ return 0
340
+
341
+ # Otherwise treat as action if imperative-like
342
+ if contains_any(["how to", "how do i"]):
343
+ return 2
344
+
345
+ # Default to interpretation
346
+ return 1
347
+
348
+ def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray:
349
+ values = []
350
+ for sample in dataset:
351
+ prompt = sample.get('prompt', '')
352
+ values.append(self._classify(prompt))
353
+ return np.array(values, dtype=float)
354
+
355
+ def create_bins(self, values: np.ndarray, method: str = 'ignored', custom_bins: List[float] = None) -> List[Tuple[float, float]]:
356
+ return [(-0.5, 0.5), (0.5, 1.5), (1.5, 2.5)]
357
+
358
+
359
+ class ExpertQAFieldGrouper(ConditionalGrouper):
360
+ """ExpertQA official metadata.field based 3-group classifier
361
+
362
+ - 0: Biology/Medicine (Biology, Chemistry, Psychology, Environmental Science, etc.)
363
+ - 1: Engineering/Technology (Engineering and Technology, Physics and Astronomy, Architecture, etc.)
364
+ - 2: Other (All other fields)
365
+
366
+ The mapping is loaded from '/expertqa_prompt_to_field.json' by default.
367
+ If the file does not exist, all samples are classified as Other(2).
368
+ The values are integer labels, and create_bins is fixed to discrete intervals.
369
+ """
370
+
371
+ def __init__(self, mapping_path: str = "/expertqa_prompt_to_field.json"):
372
+ super().__init__(
373
+ name="expertqa_field",
374
+ description="ExpertQA metadata.field → {Bio/Med, Eng/Tech, Other}"
375
+ )
376
+ self.mapping_path = mapping_path
377
+ self._loaded = False
378
+ self._prompt_to_field = {}
379
+
380
+ self.bio_med_fields = set([
381
+ "Healthcare / Medicine",
382
+ "Biology",
383
+ "Chemistry",
384
+ "Psychology",
385
+ "Environmental Science",
386
+ ])
387
+ self.eng_tech_fields = set([
388
+ "Engineering and Technology",
389
+ "Physics and Astronomy",
390
+ "Architecture",
391
+ ])
392
+
393
+ @staticmethod
394
+ def _normalize(text: str) -> str:
395
+ if not isinstance(text, str):
396
+ try:
397
+ text = str(text)
398
+ except Exception:
399
+ return ""
400
+ return " ".join(text.strip().split())
401
+
402
+ def _ensure_loaded(self):
403
+ if self._loaded:
404
+ return
405
+ try:
406
+ if os.path.exists(self.mapping_path):
407
+ with open(self.mapping_path, "r", encoding="utf-8") as f:
408
+ data = json.load(f)
409
+ self._prompt_to_field = {self._normalize(k): v for k, v in data.items()}
410
+ else:
411
+ self._prompt_to_field = {}
412
+ except Exception:
413
+ self._prompt_to_field = {}
414
+ finally:
415
+ self._loaded = True
416
+
417
+ def _field_to_group(self, field: str) -> int:
418
+ if not isinstance(field, str):
419
+ return 2
420
+ f = field.strip()
421
+ if f in self.bio_med_fields:
422
+ return 0
423
+ if f in self.eng_tech_fields:
424
+ return 1
425
+ return 2
426
+
427
+ def compute_values(self, dataset: List[Dict[str, Any]], **kwargs) -> np.ndarray:
428
+ self._ensure_loaded()
429
+ labels = []
430
+ for sample in dataset:
431
+ prompt = sample.get('prompt', '')
432
+ p_key = self._normalize(prompt)
433
+ field = self._prompt_to_field.get(p_key)
434
+ if field is None:
435
+ q = sample.get('question', '')
436
+ q_key = self._normalize(q)
437
+ field = self._prompt_to_field.get(q_key)
438
+ group_id = self._field_to_group(field)
439
+ labels.append(float(group_id))
440
+ return np.array(labels, dtype=float)
441
+
442
+ def create_bins(self, values: np.ndarray, method: str = 'ignored', custom_bins: List[float] = None) -> List[Tuple[float, float]]:
443
+ return [(-0.5, 0.5), (0.5, 1.5), (1.5, 2.5)]
444
+
445
+
446
+ def get_available_groupers() -> Dict[str, ConditionalGrouper]:
447
+ return {
448
+ 'view_count': ViewCountGrouper(),
449
+ 'medical_content': MedicalContentGrouper(),
450
+ 'false_claim_risk': FalseClaimRiskGrouper(),
451
+ }
452
+
453
+
454
+ def compute_conditional_coverage_by_grouper(
455
+ filtered_dataset: List[Dict[str, Any]],
456
+ grouping_values: np.ndarray,
457
+ bins: List[Tuple[float, float]]
458
+ ) -> List[float]:
459
+ """Calculate conditional coverage by a specific grouper"""
460
+
461
+ def compute_marginal_coverage(sub_dataset: List[Dict[str, Any]]) -> float:
462
+ """Calculate marginal coverage from a given subset"""
463
+ indicators = []
464
+ for d in sub_dataset:
465
+ retained = d.get('filtered_claims', [])
466
+ has_false = any([not c.get('is_supported', False) for c in retained])
467
+ indicators.append(0.0 if has_false else 1.0)
468
+ return float(np.mean(indicators)) if indicators else 0.0
469
+
470
+ coverage_results = []
471
+
472
+ for bin_min, bin_max in bins:
473
+ mask = []
474
+ for i, value in enumerate(grouping_values):
475
+ if np.isfinite(value):
476
+ mask.append(bin_min <= value <= bin_max)
477
+ else:
478
+ mask.append(False)
479
+
480
+ indices = [i for i, m in enumerate(mask) if m]
481
+
482
+ if not indices:
483
+ coverage_results.append(np.nan)
484
+ continue
485
+
486
+ subset = [filtered_dataset[i] for i in indices]
487
+ coverage = compute_marginal_coverage(subset)
488
+ coverage_results.append(coverage)
489
+
490
+ return coverage_results
491
+
492
+
493
+ def compute_retention_by_grouper(
494
+ filtered_dataset: List[Dict[str, Any]],
495
+ grouping_values: np.ndarray,
496
+ bins: List[Tuple[float, float]]
497
+ ) -> List[Dict[str, Any]]:
498
+ """Calculate retention rate by a specific grouper"""
499
+
500
+ retention_results = []
501
+
502
+ for bin_min, bin_max in bins:
503
+ mask = []
504
+ for i, value in enumerate(grouping_values):
505
+ if np.isfinite(value):
506
+ mask.append(bin_min <= value <= bin_max)
507
+ else:
508
+ mask.append(False)
509
+
510
+ indices = [i for i, m in enumerate(mask) if m]
511
+
512
+ if not indices:
513
+ retention_results.append({
514
+ 'bin': (float(bin_min), float(bin_max)),
515
+ 'samples': 0,
516
+ 'retained': 0,
517
+ 'total': 0,
518
+ 'rate': np.nan,
519
+ })
520
+ continue
521
+
522
+ total_claims = 0
523
+ retained_claims = 0
524
+ sample_count = len(indices)
525
+
526
+ for idx in indices:
527
+ d = filtered_dataset[idx]
528
+ afs = d.get('atomic_facts', [])
529
+ total_claims += len(afs)
530
+ retained_claims += len(d.get('filtered_claims', []))
531
+
532
+ rate = (retained_claims / total_claims) if total_claims > 0 else np.nan
533
+
534
+ retention_results.append({
535
+ 'bin': (float(bin_min), float(bin_max)),
536
+ 'samples': sample_count,
537
+ 'retained': int(retained_claims),
538
+ 'total': int(total_claims),
539
+ 'rate': float(rate) if not np.isnan(rate) else np.nan,
540
+ })
541
+
542
+ return retention_results
MACI-main/experiments/run_experiment.py ADDED
@@ -0,0 +1,1127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ import os
4
+ import sys
5
+ import json
6
+ import argparse
7
+ import time
8
+ import logging
9
+ import warnings
10
+ from datetime import datetime
11
+ from typing import Optional, Dict, Any, List
12
+ from collections import defaultdict
13
+ warnings.filterwarnings('default')
14
+ warnings.simplefilter('ignore', category=FutureWarning)
15
+ np.seterr(all='warn')
16
+
17
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
18
+
19
+ from conformal.basic_conformal import BasicConformal
20
+ from conformal.adaptive_conformal import MACIAdaptiveConformal, SubgroupOptimizedMACI
21
+ from conditional_groupers import get_available_groupers
22
+ from conditional_groupers import set_view_metadata_csv
23
+
24
+ MODEL_NAMES = ['qwen-2.5-72b-instruct', 'deepseek-chat-v3-0324', 'llama-3.3-70b-instruct']
25
+
26
+ def setup_logging(log_dir: str):
27
+ """Sets up logging to both console and file."""
28
+ os.makedirs(log_dir, exist_ok=True)
29
+ timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
30
+ log_filename = os.path.join(log_dir, f"experiment_log_{timestamp}.log")
31
+
32
+ logger = logging.getLogger()
33
+ logger.setLevel(logging.INFO)
34
+
35
+ for handler in logger.handlers[:]:
36
+ logger.removeHandler(handler)
37
+
38
+ file_handler = logging.FileHandler(log_filename)
39
+ file_handler.setFormatter(logging.Formatter('%(message)s'))
40
+ logger.addHandler(file_handler)
41
+
42
+ console_handler = logging.StreamHandler()
43
+ console_handler.setFormatter(logging.Formatter('%(message)s'))
44
+ logger.addHandler(console_handler)
45
+
46
+ logging.info(f"📝 Logging to {log_filename}")
47
+
48
+
49
+ def load_1000_samples(data_dir: str, scores_dir: Optional[str] = None, dataset_type: str = "auto", limit_samples: int = 1000):
50
+ """Load up to `limit_samples` samples and attach LLM scores."""
51
+ logging.info(f"📁 Loading up to {limit_samples} samples with provided scores...")
52
+
53
+ if dataset_type == "auto":
54
+ wikibio_path = os.path.join(data_dir, "wiki_scores", "wikibio_final_dataset.pkl")
55
+ medlfqa_path = os.path.join(data_dir, "med_scores", "medlfqa_dataset.pkl")
56
+
57
+ if os.path.exists(wikibio_path):
58
+ dataset_type = "wikibio"
59
+ logging.info(f" 🔍 Auto-detected dataset type: {dataset_type}")
60
+ elif os.path.exists(medlfqa_path):
61
+ dataset_type = "medlfqa"
62
+ logging.info(f" 🔍 Auto-detected dataset type: {dataset_type}")
63
+ else:
64
+ raise FileNotFoundError(f"Could not find dataset files in {data_dir}")
65
+
66
+ if dataset_type == "wikibio":
67
+ dataset_path = os.path.join(data_dir, "wiki_scores", "wikibio_final_dataset.pkl")
68
+ base_scores_dir = os.path.join(data_dir, "wiki_scores")
69
+ score_prefix = "wikibio_scores"
70
+ basic_scores = {
71
+ 'frequencies': os.path.join(base_scores_dir, "wikibio_final_frequencies.npz"),
72
+ 'logprobs': os.path.join(base_scores_dir, "wikibio_final_logprobs.npz"),
73
+ 'selfevals': os.path.join(base_scores_dir, "wikibio_final_self_evals.npz")
74
+ }
75
+ elif dataset_type == "medlfqa":
76
+ dataset_path = os.path.join(data_dir, "med_scores", "medlfqa_dataset.pkl")
77
+ base_scores_dir = os.path.join(data_dir, "med_scores")
78
+ score_prefix = "medlfqa_scores"
79
+ basic_scores = {
80
+ 'frequencies': os.path.join(base_scores_dir, "medlfqa_frequencies.npz"),
81
+ 'logprobs': os.path.join(base_scores_dir, "medlfqa_logprobs.npz"),
82
+ 'selfevals': os.path.join(base_scores_dir, "medlfqa_selfevals.npz")
83
+ }
84
+ else:
85
+ raise ValueError(f"Unknown dataset type: {dataset_type}")
86
+
87
+ logging.info(f" 📊 Dataset: {dataset_path}")
88
+ logging.info(f" 🎯 Score prefix: {score_prefix}")
89
+
90
+ with open(dataset_path, 'rb') as f:
91
+ dataset = pickle.load(f)
92
+
93
+ dataset_1000 = dataset[:limit_samples]
94
+
95
+ frequencies = {}
96
+ logprobs = {}
97
+ selfevals = {}
98
+
99
+ for score_type, score_path in basic_scores.items():
100
+ try:
101
+ if score_type == 'frequencies':
102
+ frequencies = np.load(score_path, allow_pickle=True)
103
+ elif score_type == 'logprobs':
104
+ logprobs = np.load(score_path, allow_pickle=True)
105
+ elif score_type == 'selfevals':
106
+ selfevals = np.load(score_path, allow_pickle=True)
107
+ logging.info(f" ✅ Loaded {score_type}: {score_path}")
108
+ except FileNotFoundError:
109
+ logging.warning(f" ⚠️ {score_type} not found: {score_path}")
110
+
111
+ if scores_dir is not None and os.path.isdir(scores_dir):
112
+ score_files_dir = scores_dir
113
+ else:
114
+ score_files_dir = base_scores_dir
115
+
116
+ logging.info(f" 🎯 Score files directory: {score_files_dir}")
117
+
118
+ import glob
119
+ all_npz_files = sorted(glob.glob(os.path.join(score_files_dir, f"{score_prefix}_*.npz")))
120
+ def find_by_tokens(token_options: List[List[str]]):
121
+ for tokens in token_options:
122
+ for fp in all_npz_files:
123
+ name = os.path.basename(fp).lower()
124
+ if all(t in name for t in tokens):
125
+ return fp
126
+ return None
127
+
128
+ score_files = {
129
+ 'qwen-2.5-72b-instruct': find_by_tokens([
130
+ ['qwen-2.5-72b','instruct'], ['qwen','instruct'], ['qwen']
131
+ ]),
132
+ 'deepseek-chat-v3-0324': find_by_tokens([
133
+ ['deepseek','chat','v3'], ['deepseek','chat'], ['deepseek']
134
+ ]),
135
+ 'llama-3.3-70b-instruct': find_by_tokens([
136
+ ['llama-3.3-70b','instruct'], ['llama-3.3','instruct'], ['llama']
137
+ ]),
138
+ }
139
+
140
+ llm_scores = {}
141
+ for model_name, filename in score_files.items():
142
+ try:
143
+ model_data = np.load(filename, allow_pickle=True)
144
+ model_prompts = model_data['prompts'].tolist()
145
+ model_scores_list = model_data['scores_list'].tolist()
146
+ llm_scores[model_name] = {p: s for p, s in zip(model_prompts, model_scores_list)}
147
+ logging.info(f" ✅ Loaded {model_name} scores")
148
+ except (FileNotFoundError, TypeError):
149
+ logging.warning(f" ⚠️ {model_name} scores not found or invalid: {filename}")
150
+ llm_scores[model_name] = {}
151
+
152
+ aligned_data = []
153
+ for i, sample in enumerate(dataset_1000):
154
+ prompt = sample['prompt']
155
+ atomic_facts = sample.get('atomic_facts', [])
156
+ n_claims = len(atomic_facts)
157
+
158
+ if n_claims == 0:
159
+ continue
160
+
161
+ if prompt in selfevals:
162
+ selfeval_vals = selfevals[prompt]
163
+ if hasattr(selfeval_vals, 'ndim') and selfeval_vals.ndim == 1:
164
+ if np.allclose(selfeval_vals, -1):
165
+ continue
166
+ elif np.allclose(selfeval_vals, -1):
167
+ continue
168
+
169
+ annotations = np.array([af.get('is_supported', False) for af in atomic_facts])
170
+
171
+ freq_scores = np.zeros(n_claims)
172
+ if dataset_type == 'wikibio':
173
+ key = f'arr_{i}'
174
+ if key in frequencies:
175
+ freq_vals = frequencies[key]
176
+ if hasattr(freq_vals, 'ndim') and freq_vals.ndim == 1:
177
+ freq_scores = freq_vals[:n_claims]
178
+ else:
179
+ freq_scores = np.full(n_claims, freq_vals.item() if hasattr(freq_vals, 'item') else freq_vals)
180
+ freq_scores = np.nan_to_num(freq_scores, nan=0.0)
181
+ else:
182
+ if prompt in frequencies:
183
+ freq_vals = frequencies[prompt]
184
+ if hasattr(freq_vals, 'ndim') and freq_vals.ndim == 1:
185
+ freq_scores = freq_vals[:n_claims]
186
+ else:
187
+ freq_val = freq_vals.item() if hasattr(freq_vals, 'item') else freq_vals
188
+ freq_val = 0.0 if np.isnan(freq_val) else freq_val
189
+ freq_scores = np.full(n_claims, freq_val)
190
+ freq_scores = np.nan_to_num(freq_scores, nan=0.0)
191
+
192
+ if dataset_type == 'wikibio':
193
+ key = f'arr_{i}'
194
+ if key in logprobs:
195
+ lp_vals = logprobs[key]
196
+ if hasattr(lp_vals, 'ndim') and lp_vals.ndim == 1:
197
+ logprob_scores = np.nan_to_num(lp_vals[:n_claims], nan=0.0)
198
+ else:
199
+ v = lp_vals.item() if hasattr(lp_vals, 'item') else lp_vals
200
+ v = 0.0 if np.isnan(v) else v
201
+ logprob_scores = np.full(n_claims, v)
202
+ else:
203
+ logprob_scores = np.zeros(n_claims)
204
+ else:
205
+ if prompt in logprobs:
206
+ logprob_vals = logprobs[prompt]
207
+ if hasattr(logprob_vals, 'ndim') and logprob_vals.ndim == 1:
208
+ logprob_scores = logprob_vals[:n_claims]
209
+ logprob_scores = np.nan_to_num(logprob_scores, nan=0.0)
210
+ else:
211
+ logprob_val = logprob_vals.item() if hasattr(logprob_vals, 'item') else logprob_vals
212
+ logprob_val = 0.0 if np.isnan(logprob_val) else logprob_val
213
+ logprob_scores = np.full(n_claims, logprob_val)
214
+ else:
215
+ logprob_scores = np.zeros(n_claims)
216
+
217
+ if dataset_type == 'wikibio':
218
+ key = f'arr_{i}'
219
+ if key in selfevals:
220
+ se_vals = selfevals[key]
221
+ if hasattr(se_vals, 'ndim') and se_vals.ndim == 1:
222
+ selfeval_scores = np.nan_to_num(se_vals[:n_claims], nan=0.0)
223
+ else:
224
+ v = se_vals.item() if hasattr(se_vals, 'item') else se_vals
225
+ v = 0.0 if np.isnan(v) else v
226
+ selfeval_scores = np.full(n_claims, v)
227
+ else:
228
+ selfeval_scores = np.zeros(n_claims)
229
+ else:
230
+ if prompt in selfevals:
231
+ selfeval_vals = selfevals[prompt]
232
+ if hasattr(selfeval_vals, 'ndim') and selfeval_vals.ndim == 1:
233
+ selfeval_scores = selfeval_vals[:n_claims]
234
+ selfeval_scores = np.nan_to_num(selfeval_scores, nan=0.0)
235
+ else:
236
+ selfeval_val = selfeval_vals.item() if hasattr(selfeval_vals, 'item') else selfeval_vals
237
+ selfeval_val = 0.0 if np.isnan(selfeval_val) else selfeval_val
238
+ selfeval_scores = np.full(n_claims, selfeval_val)
239
+ else:
240
+ selfeval_scores = np.zeros(n_claims)
241
+
242
+ ordinal_scores = np.arange(n_claims)
243
+ if n_claims > 1:
244
+ ordinal_scores = ordinal_scores / (n_claims - 1)
245
+ else:
246
+ ordinal_scores = np.array([0.5])
247
+
248
+ scores_dict = {}
249
+ for model_name, model_data in llm_scores.items():
250
+ if prompt in model_data:
251
+ scores_dict[model_name] = np.array(model_data[prompt][:n_claims])
252
+ scores_dict[model_name] = np.clip(scores_dict[model_name], 0.0, 1.0)
253
+ else:
254
+ scores_dict[model_name] = np.full(n_claims, 0.5)
255
+
256
+ valid_llm_scores = []
257
+ for model_name in MODEL_NAMES:
258
+ if model_name in scores_dict:
259
+ valid_llm_scores.append(scores_dict[model_name])
260
+
261
+ if valid_llm_scores:
262
+ ensemble_mean = np.mean(valid_llm_scores, axis=0)
263
+ ensemble_std = np.std(valid_llm_scores, axis=0)
264
+ lambda_uncertainty = 0.0
265
+ ensemble_scores = ensemble_mean - lambda_uncertainty * ensemble_std
266
+ ensemble_scores = np.clip(ensemble_scores, 0.0, 1.0)
267
+ else:
268
+ ensemble_scores = np.full(n_claims, 0.5)
269
+
270
+ features_4d = np.concatenate((
271
+ freq_scores.reshape(-1, 1),
272
+ selfeval_scores.reshape(-1, 1),
273
+ (logprob_scores / (np.max(logprob_scores) + 1e-8)).reshape(-1, 1),
274
+ ordinal_scores.reshape(-1, 1)
275
+ ), axis=1)
276
+
277
+ aligned_data.append({
278
+ 'sample': sample,
279
+ 'annotations': annotations,
280
+ 'scores': {
281
+ 'frequency': freq_scores,
282
+ 'selfeval': selfeval_scores,
283
+ 'logprob': logprob_scores,
284
+ 'ensemble': ensemble_scores,
285
+ **scores_dict
286
+ },
287
+ 'features_4d': features_4d,
288
+ 'prompt_features': np.array([1.0, len(sample.get('response', '')), len(prompt)])
289
+ })
290
+
291
+ logging.info(f"✅ Loaded {len(aligned_data)} valid samples")
292
+ return aligned_data
293
+
294
+
295
+ def create_splits(data, calib_ratio=0.7, test_ratio=0.3, random_seed=42):
296
+ """Create calibration and test splits based on ratios with random shuffling"""
297
+ total_size = len(data)
298
+ calib_size = int(total_size * calib_ratio)
299
+ test_size = int(total_size * test_ratio)
300
+
301
+ if calib_size + test_size > total_size:
302
+ test_size = total_size - calib_size
303
+
304
+ logging.info(f"📊 Creating splits: {calib_size} calib ({calib_ratio*100:.0f}%), {test_size} test ({test_ratio*100:.0f}%)")
305
+
306
+ np.random.seed(random_seed)
307
+ indices = np.random.permutation(total_size)
308
+
309
+ calib_idx = indices[:calib_size]
310
+ test_idx = indices[calib_size:calib_size + test_size]
311
+
312
+ calib_data = [data[i] for i in calib_idx]
313
+ test_data = [data[i] for i in test_idx]
314
+
315
+ logging.info(f"🎲 Random split with seed {random_seed}: calib indices {calib_idx[:5]}..., test indices {test_idx[:5]}...")
316
+
317
+ return calib_data, test_data, calib_idx, test_idx
318
+
319
+
320
+ def run_bcp_experiment(calib_data, test_data, score_type='frequency', alpha=0.1, **kwargs):
321
+ """
322
+ Run BCP (Split Conformal) experiment.
323
+ [FIXED] Uses a unified score_function that relies on pre-aligned data.
324
+ """
325
+ logging.info(f"📈 Running BCI (Split Conformal) with {score_type} scores...")
326
+
327
+ calib_samples = [item['sample'] for item in calib_data]
328
+ test_samples = [item['sample'] for item in test_data]
329
+
330
+ def score_function(samples):
331
+ result = []
332
+ sample_to_data = {item['sample']['prompt']: item for item in calib_data + test_data}
333
+
334
+ for sample in samples:
335
+ prompt = sample['prompt']
336
+ if prompt in sample_to_data:
337
+ scores = sample_to_data[prompt]['scores'].get(score_type)
338
+ if scores is not None:
339
+ if score_type in ['frequency', 'selfeval', 'logprob']:
340
+ non_conformity_scores = 1.0 - scores
341
+ else:
342
+ non_conformity_scores = 1.0 - scores
343
+ result.append(non_conformity_scores)
344
+ else:
345
+
346
+ n_claims = len(sample.get('atomic_facts', []))
347
+ result.append(np.full(n_claims, 0.5))
348
+ else:
349
+ n_claims = len(sample.get('atomic_facts', []))
350
+ result.append(np.full(n_claims, 0.5))
351
+ return result
352
+
353
+ basic_conformal = BasicConformal(score_function=score_function, random_state=0)
354
+ basic_conformal.fit_on_calib(calib_samples, alpha=alpha)
355
+ filtered_results, _ = basic_conformal.predict(test_samples)
356
+
357
+ coverage = compute_marginal_coverage(filtered_results)
358
+ retention = evaluate_retention(filtered_results, "BCP")
359
+
360
+ return {
361
+ 'coverage': coverage,
362
+ 'retention_rate': retention['overall_retention_rate'],
363
+ 'retained_claims': retention['retained_claims'],
364
+ 'total_claims': retention['total_claims'],
365
+ 'filtered_results': filtered_results
366
+ }
367
+
368
+ def run_as_experiment(calib_data: List[Dict], test_data: List[Dict],
369
+ model_names: List[str],
370
+ alpha: float,
371
+ as_mode: str,
372
+ subgroup_name: str, **kwargs) -> Dict:
373
+ """Run MACI (Adaptive Subclaims) experiment for a given subgroup."""
374
+ logging.info(f"📊 Running MACI experiment with mode: {as_mode} for subgroup: '{subgroup_name}'...")
375
+
376
+ timing: Dict[str, float] = {}
377
+
378
+ if as_mode == 'subgroup_optimized':
379
+ available_groupers = get_available_groupers()
380
+ if subgroup_name not in available_groupers:
381
+ raise ValueError(f"Unknown subgroup: {subgroup_name}")
382
+ grouper = available_groupers[subgroup_name]
383
+
384
+ as_model = SubgroupOptimizedMACI(
385
+ model_names=model_names,
386
+ grouper=grouper,
387
+ n_bins=3,
388
+ random_state=kwargs.get('random_state', 0),
389
+ solver='osqp',
390
+ )
391
+ t0 = time.perf_counter()
392
+ as_model.fit(calib_data, alpha=alpha, ensemble_train_ratio=0.5, target_tpr=kwargs.get('target_tpr', 0.95))
393
+ timing_details = as_model.get_timing()
394
+ timing['maci_weight_optimization_s'] = timing_details.get('weight_optimization_s', 0.0)
395
+ timing['maci_calibration_s'] = timing_details.get('calibration_s', 0.0)
396
+
397
+ t1 = time.perf_counter()
398
+ filtered_results, _ = as_model.predict(test_data)
399
+ timing['maci_inference_s'] = time.perf_counter() - t1
400
+ budgets = as_model.get_budgets()
401
+ weights = as_model.get_weights()
402
+
403
+ else:
404
+ score_type = kwargs.get("as_score_type", "ensemble")
405
+ def score_function(data_list: List[Dict]) -> List[np.ndarray]:
406
+ scores_list = []
407
+ for item in data_list:
408
+ valid_scores = [item['scores'][m] for m in model_names if m in item['scores']]
409
+ if valid_scores:
410
+ scores_list.append(np.mean(valid_scores, axis=0))
411
+ else:
412
+ scores_list.append(np.array([0.5] * len(item.get('sample', {}).get('atomic_facts', []))))
413
+ return scores_list
414
+
415
+ as_model = MACIAdaptiveConformal(score_function=score_function, random_state=kwargs.get('random_state', 0))
416
+ t0 = time.perf_counter()
417
+ as_model.fit_on_calib(calib_data, alpha=alpha)
418
+ timing['maci_calibration_s'] = time.perf_counter() - t0
419
+ t1 = time.perf_counter()
420
+ filtered_results, _ = as_model.predict(test_data)
421
+ timing['maci_inference_s'] = time.perf_counter() - t1
422
+ budgets = {'overall': as_model.tau_hat}
423
+ weights = None
424
+ coverage = compute_marginal_coverage(filtered_results)
425
+ retention = evaluate_retention(filtered_results, "MACI")
426
+ return {
427
+ 'coverage': coverage,
428
+ 'retention_rate': retention['overall_retention_rate'],
429
+ 'retained_claims': retention['retained_claims'],
430
+ 'total_claims': retention['total_claims'],
431
+ 'budgets': budgets,
432
+ 'weights': weights,
433
+ 'filtered_results': filtered_results,
434
+ 'timing': timing
435
+ }
436
+
437
+
438
+ def run_cci_experiment(
439
+ calib_data,
440
+ test_data,
441
+ alpha=0.1,
442
+ boosting_epochs=1000,
443
+ boosting_lr=0.005,
444
+ calib_split_for_boost=0.3,
445
+ random_seed=0,
446
+ adaptive_alpha: bool = False,
447
+ retention_target: float = 0.7
448
+ ):
449
+ """
450
+ Two-stage CCI:
451
+ - Stage 1 (Boosting): learn beta on a subset of calib_data
452
+ - Stage 2 (CondConf): calibrate CondConf on the remaining calib_data using learned beta
453
+ - Predict on test_data
454
+ """
455
+ logging.info("🎯 Running CCI (Boosting -> CondConf) with internal calib split...")
456
+
457
+ try:
458
+ from conformal.conditional_conformal import ConditionalConformalBoosting, ConditionalConformalInference
459
+ except Exception as e:
460
+ logging.error(f"CCI unavailable due to missing dependencies: {e}")
461
+ return {
462
+ "coverage": None,
463
+ "retention_rate": None,
464
+ "retained_claims": 0,
465
+ "total_claims": 0,
466
+ "filtered_results": [],
467
+ "timing": {"cci_skipped": True, "error": str(e)}
468
+ }
469
+
470
+ rng = np.random.default_rng(random_seed)
471
+ idx = np.arange(len(calib_data))
472
+ rng.shuffle(idx)
473
+ k = int(len(idx) * calib_split_for_boost)
474
+ idx_boost, idx_conf = idx[:k], idx[k:]
475
+ if len(idx_conf) == 0:
476
+ idx_boost, idx_conf = idx[:-1], idx[-1:] [1]
477
+ calib_boost = [calib_data[i] for i in idx_boost]
478
+ calib_conf = [calib_data[i] for i in idx_conf]
479
+ logging.info(f" 🔧 calib split -> boost:{len(calib_boost)} | conf:{len(calib_conf)} (seed={random_seed})")
480
+
481
+ booster = ConditionalConformalBoosting(random_state=random_seed)
482
+ t_boost_0 = time.perf_counter()
483
+ beta = booster.fit(
484
+ calib_boost,
485
+ boosting_epochs=boosting_epochs,
486
+ boosting_lr=boosting_lr
487
+ )
488
+ t_boost_1 = time.perf_counter()
489
+
490
+ infer = ConditionalConformalInference(random_state=random_seed)
491
+ t_fit_0 = time.perf_counter()
492
+ infer.fit(calib_conf, alpha=alpha, beta=beta, adaptive_alpha=adaptive_alpha, retention_target=retention_target)
493
+ t_fit_1 = time.perf_counter()
494
+ auroc_results = infer.evaluate_auroc(test_data)
495
+ t_pred_0 = time.perf_counter()
496
+ filtered_results = infer.predict(test_data)
497
+ t_pred_1 = time.perf_counter()
498
+
499
+ coverage = compute_marginal_coverage(filtered_results)
500
+ retention = evaluate_retention(filtered_results, "CCI")
501
+
502
+ return {
503
+ "coverage": coverage,
504
+ "retention_rate": retention["overall_retention_rate"],
505
+ "retained_claims": retention["retained_claims"],
506
+ "total_claims": retention["total_claims"],
507
+ "filtered_results": filtered_results,
508
+ "beta": beta,
509
+ "calib_sizes": {"boost": len(calib_boost), "conf": len(calib_conf)},
510
+ "split_seed": random_seed,
511
+ "timing": {
512
+ "cci_boost_fit_s": t_boost_1 - t_boost_0,
513
+ "cci_condconf_fit_s": t_fit_1 - t_fit_0,
514
+ "cci_inference_s": t_pred_1 - t_pred_0,
515
+ "cci_adaptive_alpha_enabled": bool(adaptive_alpha)
516
+ }
517
+ }
518
+
519
+ def evaluate_retention(filtered_dataset: List[Dict], method_name: str = "") -> Dict:
520
+ total_original_claims = 0
521
+ total_retained_claims = 0
522
+
523
+ if not filtered_dataset:
524
+ return {'overall_retention_rate': 0.0, 'retained_claims': 0, 'total_claims': 0}
525
+
526
+ for item in filtered_dataset:
527
+ sample_dict = item.get('sample', item)
528
+ if not isinstance(sample_dict, dict):
529
+ logging.warning(f"Skipping invalid item in retention evaluation: {type(sample_dict)}")
530
+ continue
531
+
532
+ original_claims = sample_dict.get('atomic_facts', [])
533
+ retained_claims = sample_dict.get('filtered_claims', [])
534
+
535
+ total_original_claims += len(original_claims)
536
+ total_retained_claims += len(retained_claims)
537
+
538
+ if total_original_claims > 0:
539
+ overall_retention_rate = total_retained_claims / total_original_claims
540
+ else:
541
+ overall_retention_rate = 0.0
542
+
543
+ return {
544
+ 'overall_retention_rate': overall_retention_rate,
545
+ 'retained_claims': total_retained_claims,
546
+ 'total_claims': total_original_claims
547
+ }
548
+
549
+ def compute_marginal_coverage(filtered_dataset: List[Dict]):
550
+ indicators = []
551
+ for item in filtered_dataset:
552
+ sample_dict = item.get('sample', item)
553
+ if not isinstance(sample_dict, dict):
554
+ logging.warning(f"Skipping invalid item in coverage calculation: {type(sample_dict)}")
555
+ continue
556
+
557
+ retained = sample_dict.get('filtered_claims', [])
558
+
559
+ if len(retained) == 0:
560
+ indicators.append(1.0)
561
+ else:
562
+ has_false = any(not claim.get('is_supported', False) for claim in retained if isinstance(claim, dict))
563
+ indicators.append(0.0 if has_false else 1.0)
564
+
565
+ return np.mean(indicators) if indicators else 0.0
566
+
567
+ def compute_conditional_coverage(test_data, filtered_results, grouper, alpha=0.1, binning_method='quartiles'):
568
+ """Compute conditional coverage for subgroups"""
569
+
570
+ combined_data = []
571
+ for orig_sample, filtered_sample in zip(test_data, filtered_results):
572
+ combined_sample = dict(orig_sample['sample'])
573
+ combined_sample['scores'] = orig_sample['scores']
574
+ combined_sample['filtered_claims'] = filtered_sample.get('filtered_claims', [])
575
+ combined_data.append(combined_sample)
576
+
577
+ method_mapping = {
578
+ 'quantile': 'tertiles',
579
+ 'equal_width': 'tertiles',
580
+ 'quartiles': 'tertiles'
581
+ }
582
+ method = method_mapping.get(binning_method, 'tertiles')
583
+
584
+ values = grouper.compute_values(combined_data)
585
+
586
+ if len(values) == 0:
587
+ logging.warning(f" ⚠️ Warning: {grouper.__class__.__name__} returned no values")
588
+ return {}
589
+
590
+ if np.all(values == values[0]):
591
+ logging.warning(f" ⚠️ Warning: {grouper.__class__.__name__} all values identical ({values[0]:.4f})")
592
+
593
+ bins = grouper.create_bins(values, method=method)
594
+
595
+ groups = {}
596
+ group_names = ['low', 'medium', 'high'] if len(bins) == 3 else [f'bin_{i}' for i in range(len(bins))]
597
+
598
+ for i, (bin_min, bin_max) in enumerate(bins):
599
+ if i == len(bins) - 1:
600
+ mask = (values >= bin_min) & (values <= bin_max)
601
+ else:
602
+ mask = (values >= bin_min) & (values < bin_max)
603
+
604
+ indices = np.where(mask)[0].tolist()
605
+ bin_name = group_names[i] if i < len(group_names) else f'bin_{i}'
606
+ groups[bin_name] = indices
607
+
608
+ results = {}
609
+ for group_name, indices in groups.items():
610
+ if len(indices) == 0:
611
+ continue
612
+
613
+ group_indicators = []
614
+ group_total_claims = 0
615
+ group_retained_claims = 0
616
+
617
+ for idx in indices:
618
+ filtered_sample = filtered_results[idx]
619
+ retained = filtered_sample.get('filtered_claims', [])
620
+ original_claims = test_data[idx]['sample'].get('atomic_facts', [])
621
+
622
+ has_false = any(not claim.get('is_supported', False) for claim in retained)
623
+ group_indicators.append(0.0 if has_false else 1.0)
624
+
625
+ group_total_claims += len(original_claims)
626
+ group_retained_claims += len(retained)
627
+
628
+ coverage = np.mean(group_indicators) if group_indicators else 0.0
629
+ retention_rate = group_retained_claims / group_total_claims if group_total_claims > 0 else 0.0
630
+ results[group_name] = {
631
+ 'size': len(indices),
632
+ 'coverage': coverage,
633
+ 'retention_rate': retention_rate,
634
+ 'retained_claims': group_retained_claims,
635
+ 'total_claims': group_total_claims,
636
+ 'target_coverage': 1 - alpha,
637
+ }
638
+
639
+ return results
640
+
641
+
642
+ def save_aggregated_results_to_json(results: Dict, args: argparse.Namespace):
643
+ repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
644
+ default_output_dir = os.path.join(repo_root, 'analysis', 'experiment_results')
645
+ output_dir = getattr(args, 'time_out', None) or default_output_dir
646
+ os.makedirs(output_dir, exist_ok=True)
647
+
648
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
649
+ groups_str = "_".join(sorted(args.conditional_groups))
650
+ filename = f"results_{args.dataset_type}_{args.model_set}_{groups_str}_{timestamp}.json"
651
+ filepath = os.path.join(output_dir, filename)
652
+
653
+ logging.info(f"\n💾 Saving aggregated results to {filepath}...")
654
+
655
+ def convert_to_native_types(obj):
656
+ if isinstance(obj, np.integer):
657
+ return int(obj)
658
+ elif isinstance(obj, np.floating):
659
+ return float(obj)
660
+ elif isinstance(obj, np.ndarray):
661
+ return obj.tolist()
662
+ elif isinstance(obj, defaultdict):
663
+ return dict(obj)
664
+ try:
665
+ json.dumps(obj)
666
+ return obj
667
+ except TypeError:
668
+ return str(obj)
669
+
670
+ keys_to_exclude = {'filtered_results', 'beta', 'weights', 'budgets', 'calib_sizes', 'split_seed'}
671
+
672
+ serializable_data = {}
673
+ for method_name, method_data in results.items():
674
+ serializable_data[method_name] = {}
675
+ for key, value in method_data.items():
676
+ if key in keys_to_exclude:
677
+ continue
678
+
679
+
680
+ try:
681
+ cleaned_value = json.loads(json.dumps(value, default=convert_to_native_types))
682
+ serializable_data[method_name][key] = cleaned_value
683
+ except Exception as e:
684
+ logging.warning(f"Could not serialize key '{key}' for method '{method_name}'. Skipping. Error: {e}")
685
+
686
+ try:
687
+ with open(filepath, 'w', encoding='utf-8') as f:
688
+ json.dump(serializable_data, f, indent=4, ensure_ascii=False)
689
+ logging.info(f"✅ Successfully saved results.")
690
+ except Exception as e:
691
+ logging.error(f"❌ Failed to save results to JSON: {e}")
692
+
693
+ def main():
694
+ parser = argparse.ArgumentParser(description="Experiment with three conformal methods")
695
+ parser.add_argument("--random-seed", type=int, default=123, help="Random seed")
696
+ parser.add_argument("--data-dir", type=str, default=None, help="Data directory (defaults to repo_root/data)")
697
+ parser.add_argument("--log-dir", type=str, default=None, help="Directory to save logs (defaults to repo_root/logs)")
698
+ parser.add_argument("--dataset-type", type=str, default="auto", choices=["auto", "wikibio", "medlfqa"],
699
+ help="Dataset type (auto-detected if not specified)")
700
+ parser.add_argument("--alpha", type=float, default=0.1, help="Significance level (fixed if --adaptive-alpha is false)")
701
+ parser.add_argument("--adaptive-alpha", action='store_true', help="Enable per-sample adaptive alpha (learn q*(z) for retention target)")
702
+ parser.add_argument("--retention-target", type=float, default=0.4, help="Target retention used to learn adaptive alpha")
703
+ parser.add_argument("--scores-dir", type=str, default=None, help="Directory containing final NPZ score files (optional)")
704
+ parser.add_argument("--calib-ratio", type=float, default=0.75, help="Calibration set ratio")
705
+ parser.add_argument("--test-ratio", type=float, default=0.25, help="Test set ratio")
706
+ parser.add_argument("--boosting-epochs", type=int, default=100, help="Boosting epochs")
707
+ parser.add_argument("--n-runs", type=int, default=10, help="Number of repeated runs with different random splits")
708
+ parser.add_argument("--model-set", type=str, default="fixed", choices=["fixed"], help="Model set (fixed 3 models)")
709
+ parser.add_argument("--bcp-score-type", type=str, default="frequency",
710
+ choices=['frequency', 'selfeval', 'logprob', 'ensemble'],
711
+ help="Score type for BCI")
712
+ # --as-score-type removed; MACI uses ensemble by default
713
+ parser.add_argument("--as-mode", type=str, default="subgroup_optimized", choices=["standard", "subgroup_optimized"], help="AS variant")
714
+ parser.add_argument("--conditional-groups", type=str, nargs='*',
715
+ default=['false_claim_risk','medicalcontent','view_count'],
716
+ choices=['false_claim_risk','medicalcontent','view_count'],
717
+ help="Conditional groups to analyze")
718
+ parser.add_argument("--view-metadata-csv", type=str, default=None,
719
+ help="Optional CSV for view_count grouper; defaults to repo-relative data path")
720
+ parser.add_argument("--binning-method", type=str, default="quantile",
721
+ choices=['quantile', 'equal_width'],
722
+ help="Binning method for conditional groups")
723
+ parser.add_argument("--limit-samples", type=int, default=2000, help="Max number of samples to load")
724
+ parser.add_argument("--target-tpr", type=float, default=0.8, help="Target TPR for subgroup-optimized AS")
725
+
726
+ parser.add_argument("--time-profile", action='store_true', help="Enable timing profile output")
727
+ parser.add_argument("--time-out", type=str, default=None, help="Directory to save timing JSON (defaults to repo_root/analysis/experiment_results)")
728
+
729
+ args = parser.parse_args()
730
+
731
+ repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
732
+ if not args.data_dir:
733
+ args.data_dir = os.path.join(repo_root, 'data')
734
+ if not args.log_dir:
735
+ args.log_dir = os.path.join(repo_root, 'logs')
736
+ if not getattr(args, 'time_out', None):
737
+ args.time_out = os.path.join(repo_root, 'analysis', 'experiment_results')
738
+
739
+ setup_logging(args.log_dir)
740
+ if args.view_metadata_csv:
741
+ set_view_metadata_csv(args.view_metadata_csv)
742
+
743
+ logging.info("=" * 80)
744
+ logging.info(f"📊 Setup: {args.calib_ratio*100:.0f}% calibration + {args.test_ratio*100:.0f}% test, α={args.alpha}, adaptive={args.adaptive_alpha}")
745
+ logging.info(f"🔄 Number of runs: {args.n_runs}")
746
+ logging.info(f"🏷️ BCI Score: {args.bcp_score_type}")
747
+ logging.info(f"🧠 CCI: enabled")
748
+ logging.info(f"🎯 MACI: enabled")
749
+ logging.info(f"📊 Conditional groups: {args.conditional_groups}")
750
+ logging.info(f"🔧 Binning Method: {args.binning_method}")
751
+ logging.info(f"🆕 Using provided scores and enhanced features")
752
+
753
+ limit_samples = args.limit_samples
754
+ data = load_1000_samples(args.data_dir, scores_dir=args.scores_dir, dataset_type=args.dataset_type, limit_samples=limit_samples)
755
+
756
+ from collections import defaultdict
757
+ all_runs_results = defaultdict(lambda: defaultdict(list))
758
+
759
+ groupers = []
760
+ available_groupers = get_available_groupers()
761
+ for group_name in args.conditional_groups:
762
+ if group_name in available_groupers:
763
+ groupers.append(available_groupers[group_name])
764
+ else:
765
+ logging.warning(f"⚠️ Unknown grouper: {group_name}")
766
+
767
+ detected_dataset_type = args.dataset_type
768
+ if detected_dataset_type == "auto":
769
+ if data and 'scores' in data[0] and isinstance(data[0]['scores'].get('frequency'), np.ndarray):
770
+ detected_dataset_type = 'medlfqa'
771
+ else:
772
+ detected_dataset_type = 'wikibio'
773
+ logging.info(f"➡️ Using detected dataset type: {detected_dataset_type}")
774
+
775
+ factscore_npz_path = None
776
+ if detected_dataset_type == 'wikibio':
777
+ wikibio_npz_path = os.path.join(args.data_dir, "wiki_scores", "wikibio_final_frequencies.npz")
778
+
779
+ model_names_to_use = MODEL_NAMES
780
+ logging.info(f" MACI Models: {', '.join(model_names_to_use)}")
781
+
782
+ for run_idx in range(args.n_runs):
783
+ logging.info(f"\n🔄 Run {run_idx + 1}/{args.n_runs}")
784
+ logging.info("-" * 50)
785
+
786
+ random_seed = args.random_seed + run_idx
787
+ calib_data, test_data, calib_idx, test_idx = create_splits(
788
+ data, args.calib_ratio, args.test_ratio, random_seed=random_seed
789
+ )
790
+
791
+ logging.info(f"📊 Run {run_idx + 1} sizes: {len(calib_data)} calib, {len(test_data)} test (seed: {random_seed})")
792
+
793
+ results = {}
794
+
795
+ try:
796
+ results['BCI'] = run_bcp_experiment(calib_data, test_data, score_type=args.bcp_score_type, alpha=args.alpha)
797
+ except Exception as e:
798
+ logging.error(f"❌ BCP failed: {e}")
799
+ import traceback
800
+ logging.error(f"Traceback: {traceback.format_exc()}")
801
+ results['BCI'] = None
802
+
803
+ try:
804
+ results['CCI'] = run_cci_experiment(
805
+ calib_data, test_data,
806
+ alpha=args.alpha,
807
+ boosting_epochs=args.boosting_epochs,
808
+ adaptive_alpha=args.adaptive_alpha,
809
+ retention_target=args.retention_target
810
+ )
811
+ except Exception as e:
812
+ logging.error(f"❌ CCI failed: {e}")
813
+ import traceback
814
+ logging.error(f"Traceback: {traceback.format_exc()}")
815
+ results['CCI'] = None
816
+
817
+ results['MACI'] = {
818
+ 'coverage': [],
819
+ 'retention_rate': [],
820
+ 'retained_claims': [],
821
+ 'total_claims': [],
822
+ 'subgroup_results': {}
823
+ }
824
+
825
+ logging.info("--- Starting MACI Experiments ---")
826
+ mace_marginal_results_set = False
827
+
828
+ for subgroup_name in args.conditional_groups:
829
+ try:
830
+ mace_subgroup_result = run_as_experiment(
831
+ calib_data, test_data,
832
+ model_names=model_names_to_use,
833
+ alpha=args.alpha,
834
+ as_mode='subgroup_optimized',
835
+ subgroup_name=subgroup_name,
836
+ random_state=random_seed,
837
+ target_tpr=args.target_tpr
838
+ )
839
+
840
+ if mace_subgroup_result and 'filtered_results' in mace_subgroup_result:
841
+ flat_filtered_results = []
842
+ for res in mace_subgroup_result['filtered_results']:
843
+ flat_item = dict(res.get('sample', {}))
844
+ flat_item['filtered_claims'] = res.get('sample', {}).get('filtered_claims', [])
845
+ flat_filtered_results.append(flat_item)
846
+ mace_subgroup_result['filtered_results'] = flat_filtered_results
847
+
848
+ if not mace_marginal_results_set:
849
+ results['MACI']['coverage'] = mace_subgroup_result['coverage']
850
+ results['MACI']['retention_rate'] = mace_subgroup_result['retention_rate']
851
+ results['MACI']['retained_claims'] = mace_subgroup_result['retained_claims']
852
+ results['MACI']['total_claims'] = mace_subgroup_result['total_claims']
853
+ results['MACI']['filtered_results'] = mace_subgroup_result.get('filtered_results', [])
854
+ results['MACI']['timing'] = mace_subgroup_result.get('timing', {})
855
+
856
+ mace_marginal_results_set = True
857
+
858
+ target_grouper = available_groupers.get(subgroup_name)
859
+ if target_grouper:
860
+ try:
861
+ conditional_results = compute_conditional_coverage(
862
+ test_data,
863
+ mace_subgroup_result['filtered_results'],
864
+ target_grouper,
865
+ args.alpha,
866
+ args.binning_method
867
+ )
868
+ results['MACI']['subgroup_results'][target_grouper.__class__.__name__] = conditional_results
869
+ except Exception as e:
870
+ logging.error(f" ❌ MACI subgroup analysis for {target_grouper.__class__.__name__} failed: {e}")
871
+
872
+ except Exception as e:
873
+ logging.error(f"❌ MACI ({subgroup_name}) failed: {e}")
874
+ import traceback
875
+ logging.error(f"Traceback: {traceback.format_exc()}")
876
+
877
+
878
+ for method_name, result in results.items():
879
+ if not result or result.get('coverage') is None:
880
+ continue
881
+
882
+ all_runs_results[method_name]['coverage'].append(result['coverage'])
883
+ all_runs_results[method_name]['retention_rate'].append(result['retention_rate'])
884
+ all_runs_results[method_name]['retained_claims'].append(result['retained_claims'])
885
+ all_runs_results[method_name]['total_claims'].append(result['total_claims'])
886
+
887
+ run_subgroup_results = {}
888
+ if method_name == 'MACI':
889
+ run_subgroup_results = result.get('subgroup_results', {})
890
+ else:
891
+ for grouper in groupers:
892
+ try:
893
+ conditional_results = compute_conditional_coverage(
894
+ test_data,
895
+ result['filtered_results'],
896
+ grouper,
897
+ args.alpha,
898
+ args.binning_method
899
+ )
900
+ run_subgroup_results[grouper.__class__.__name__] = conditional_results
901
+ except Exception as e:
902
+ logging.error(f" ❌ {grouper.__class__.__name__} failed for {method_name}: {e}")
903
+
904
+ all_runs_results[method_name]['subgroup_results'].append(run_subgroup_results)
905
+
906
+
907
+ logging.info(f"\n📊 Run {run_idx + 1} Results:")
908
+ for method_name, result in results.items():
909
+ if not result or result.get('coverage') is None:
910
+ logging.info(f" {method_name}: ❌ FAILED or SKIPPED")
911
+ continue
912
+ logging.info(f" {method_name}: Coverage={result['coverage']:.4f}, Retention={result['retention_rate']:.3f}, Claims={result['retained_claims']}/{result['total_claims']}")
913
+
914
+ if args.time_profile:
915
+ timing_payload = {
916
+ 'dataset_type': detected_dataset_type,
917
+ 'model_set': args.model_set,
918
+ 'boosting_epochs': args.boosting_epochs,
919
+ 'adaptive_alpha': args.adaptive_alpha,
920
+ 'retention_target': args.retention_target,
921
+ 'run_idx': run_idx,
922
+ 'CCI': results.get('CCI', {}).get('timing', {}),
923
+ 'MACI': {}
924
+ }
925
+ try:
926
+ first_subgroup = next(iter(results['MACI'].get('subgroup_results', {}).keys()), None)
927
+ if first_subgroup:
928
+ mace_timing = None
929
+ mace_timing = results['MACI'].get('timing')
930
+ timing_payload['MACI'] = mace_timing if mace_timing else {}
931
+ except Exception:
932
+ pass
933
+
934
+ if not timing_payload['MACI']:
935
+ try:
936
+ timing_payload['MACI'] = {}
937
+ except Exception:
938
+ timing_payload['MACI'] = {}
939
+
940
+ os.makedirs(args.time_out, exist_ok=True)
941
+ tstamp = datetime.now().strftime('%Y%m%d-%H%M%S')
942
+ timing_path = os.path.join(args.time_out, f"time_profile_{detected_dataset_type}_{args.model_set}_{tstamp}.json")
943
+ with open(timing_path, 'w', encoding='utf-8') as f:
944
+ json.dump(timing_payload, f, indent=2, ensure_ascii=False)
945
+ logging.info(f"⏱️ Saved timing profile to {timing_path}")
946
+
947
+
948
+ if run_idx == 0 and getattr(args, 'show_sample_idx', None) is not None and args.show_sample_idx >= 0:
949
+ idx = int(args.show_sample_idx)
950
+ if 0 <= idx < len(test_data):
951
+ def _get_claim_text(c: Dict[str, Any]) -> str:
952
+ if not isinstance(c, dict):
953
+ return str(c)
954
+ return c.get('atom') or c.get('text') or c.get('claim') or c.get('fact') or str(c)
955
+ def _get_claim_support(c: Dict[str, Any]) -> str:
956
+ if isinstance(c, dict):
957
+ v = c.get('is_supported')
958
+ if isinstance(v, (bool, np.bool_)):
959
+ return 'T' if bool(v) else 'F'
960
+ return '?'
961
+
962
+ sample = test_data[idx]['sample']
963
+ prompt = sample.get('prompt', '')
964
+ response = sample.get('response', '')
965
+ original_claims = sample.get('atomic_facts', [])
966
+ original_pairs = [(_get_claim_text(c), _get_claim_support(c)) for c in original_claims]
967
+
968
+ bci_item = results.get('BCI', {}).get('filtered_results', [None]*len(test_data))[idx]
969
+ cci_item = results.get('CCI', {}).get('filtered_results', [None]*len(test_data))[idx]
970
+ mace_item = results.get('MACE', {}).get('filtered_results', [None]*len(test_data))[idx]
971
+
972
+ def _filtered_claims(item):
973
+ if not item:
974
+ return []
975
+ claims = item.get('filtered_claims')
976
+ if claims is None and isinstance(item.get('sample'), dict):
977
+ claims = item['sample'].get('filtered_claims', [])
978
+ return [(_get_claim_text(c), _get_claim_support(c)) for c in (claims or [])]
979
+
980
+ logging.info("\n=== SAMPLE CLAIMS DUMP ===")
981
+ logging.info(f"[Test idx={idx}] Prompt: {prompt}")
982
+ logging.info(f"Original claims ({len(original_pairs)}):")
983
+ for i, (t, lab) in enumerate(original_pairs, 1):
984
+ logging.info(f" {i:2d}. [{lab}] {t}")
985
+
986
+ bci_pairs = _filtered_claims(bci_item)
987
+ cci_pairs = _filtered_claims(cci_item)
988
+ mace_pairs = _filtered_claims(mace_item)
989
+
990
+ logging.info(f"\n[BCI] filtered claims ({len(bci_pairs)}):")
991
+ for i, (t, lab) in enumerate(bci_pairs, 1):
992
+ logging.info(f" {i:2d}. [{lab}] {t}")
993
+
994
+ logging.info(f"\n[CCI] filtered claims ({len(cci_pairs)}):")
995
+ for i, (t, lab) in enumerate(cci_pairs, 1):
996
+ logging.info(f" {i:2d}. [{lab}] {t}")
997
+
998
+ logging.info(f"\n[MACI] filtered claims ({len(mace_pairs)}):")
999
+ for i, (t, lab) in enumerate(mace_pairs, 1):
1000
+ logging.info(f" {i:2d}. [{lab}] {t}")
1001
+
1002
+ if run_idx == 0 and getattr(args, 'show_sample_count', 0) > 0:
1003
+ dump_n = min(int(args.show_sample_count), len(test_data))
1004
+ def _get_claim_text(c: Dict[str, Any]) -> str:
1005
+ if not isinstance(c, dict):
1006
+ return str(c)
1007
+ return c.get('atom') or c.get('text') or c.get('claim') or c.get('fact') or str(c)
1008
+ def _get_claim_support(c: Dict[str, Any]) -> str:
1009
+ if isinstance(c, dict):
1010
+ v = c.get('is_supported')
1011
+ if isinstance(v, (bool, np.bool_)):
1012
+ return 'T' if bool(v) else 'F'
1013
+ return '?'
1014
+ def _filtered_pairs(item):
1015
+ if not item:
1016
+ return []
1017
+ claims = item.get('filtered_claims')
1018
+ if claims is None and isinstance(item.get('sample'), dict):
1019
+ claims = item['sample'].get('filtered_claims', [])
1020
+ return [(_get_claim_text(c), _get_claim_support(c)) for c in (claims or [])]
1021
+ for idx in range(dump_n):
1022
+ sample = test_data[idx]['sample']
1023
+ prompt = sample.get('prompt', '')
1024
+ original_pairs = [(_get_claim_text(c), _get_claim_support(c)) for c in sample.get('atomic_facts', [])]
1025
+ bci_item = results.get('BCI', {}).get('filtered_results', [None]*len(test_data))[idx]
1026
+ cci_item = results.get('CCI', {}).get('filtered_results', [None]*len(test_data))[idx]
1027
+ mace_item = results.get('MACE', {}).get('filtered_results', [None]*len(test_data))[idx]
1028
+ logging.info("\n=== SAMPLE CLAIMS DUMP ===")
1029
+ logging.info(f"[Test idx={idx}] Prompt: {prompt}")
1030
+ logging.info(f"Original claims ({len(original_pairs)}):")
1031
+ for i, (t, lab) in enumerate(original_pairs, 1):
1032
+ logging.info(f" {i:2d}. [{lab}] {t}")
1033
+ bci_pairs = _filtered_pairs(bci_item)
1034
+ cci_pairs = _filtered_pairs(cci_item)
1035
+ mace_pairs = _filtered_pairs(mace_item)
1036
+ logging.info(f"\n[BCI] filtered claims ({len(bci_pairs)}):")
1037
+ for i, (t, lab) in enumerate(bci_pairs, 1):
1038
+ logging.info(f" {i:2d}. [{lab}] {t}")
1039
+ logging.info(f"\n[CCI] filtered claims ({len(cci_pairs)}):")
1040
+ for i, (t, lab) in enumerate(cci_pairs, 1):
1041
+ logging.info(f" {i:2d}. [{lab}] {t}")
1042
+ logging.info(f"\n[MACI] filtered claims ({len(mace_pairs)}):")
1043
+ for i, (t, lab) in enumerate(mace_pairs, 1):
1044
+ logging.info(f" {i:2d}. [{lab}] {t}")
1045
+
1046
+ logging.info("\n" + "=" * 100)
1047
+ logging.info("📊 AGGREGATED RESULTS (All Runs)")
1048
+ logging.info("=" * 100)
1049
+ for method_name in sorted(all_runs_results.keys()):
1050
+ method_results = all_runs_results[method_name]
1051
+
1052
+ if not method_results['coverage']:
1053
+ logging.info(f"\n{method_name}: ❌ NO SUCCESSFUL RUNS")
1054
+ continue
1055
+
1056
+ n_runs = len(method_results['coverage'])
1057
+ coverage_mean = np.mean(method_results['coverage'])
1058
+ coverage_std = np.std(method_results['coverage'])
1059
+ retention_mean = np.mean(method_results['retention_rate'])
1060
+ retention_std = np.std(method_results['retention_rate'])
1061
+ retained_claims_mean = np.mean(method_results['retained_claims'])
1062
+ retained_claims_std = np.std(method_results['retained_claims'])
1063
+ total_claims_mean = np.mean(method_results['total_claims'])
1064
+
1065
+ logging.info(f"\n{'='*20} {method_name} ({n_runs} runs) {'='*20}")
1066
+ logging.info(f"📈 MARGINAL RESULTS:")
1067
+ logging.info(f" Coverage: {coverage_mean:.4f} ± {coverage_std:.4f}")
1068
+ logging.info(f" Retention Rate: {retention_mean:.3f} ± {retention_std:.3f}")
1069
+ logging.info(f" Claims: {retained_claims_mean:.1f} ± {retained_claims_std:.1f}/{total_claims_mean:.1f}")
1070
+
1071
+ if method_results['subgroup_results']:
1072
+ logging.info(f"\n📊 SUBGROUP RESULTS:")
1073
+
1074
+ subgroup_data = {}
1075
+ for run_results in method_results['subgroup_results']:
1076
+ for grouper_name, grouper_results in run_results.items():
1077
+ if grouper_name not in subgroup_data:
1078
+ subgroup_data[grouper_name] = {}
1079
+
1080
+ for group_name, group_result in grouper_results.items():
1081
+ if group_name not in subgroup_data[grouper_name]:
1082
+ subgroup_data[grouper_name][group_name] = {
1083
+ 'coverage': [], 'retention_rate': [], 'retained_claims': [],
1084
+ 'total_claims': [], 'size': []
1085
+ }
1086
+
1087
+ subgroup_data[grouper_name][group_name]['coverage'].append(group_result['coverage'])
1088
+ subgroup_data[grouper_name][group_name]['retention_rate'].append(group_result['retention_rate'])
1089
+ subgroup_data[grouper_name][group_name]['retained_claims'].append(group_result['retained_claims'])
1090
+ subgroup_data[grouper_name][group_name]['total_claims'].append(group_result['total_claims'])
1091
+ subgroup_data[grouper_name][group_name]['size'].append(group_result['size'])
1092
+
1093
+ for grouper_name, groups in subgroup_data.items():
1094
+ logging.info(f"\n 🔍 {grouper_name}:")
1095
+
1096
+ for group_name, group_data in groups.items():
1097
+ if not group_data['coverage']:
1098
+ continue
1099
+
1100
+ group_coverage_mean = np.mean(group_data['coverage'])
1101
+ group_coverage_std = np.std(group_data['coverage'])
1102
+ group_retention_mean = np.mean(group_data['retention_rate'])
1103
+ group_retention_std = np.std(group_data['retention_rate'])
1104
+ group_retained_claims_mean = np.mean(group_data['retained_claims'])
1105
+ group_retained_claims_std = np.std(group_data['retained_claims'])
1106
+ group_total_claims_mean = np.mean(group_data['total_claims'])
1107
+ group_size_mean = np.mean(group_data['size'])
1108
+
1109
+ target_coverage = 1 - args.alpha
1110
+ violation_marker = "⚠️ " if abs(group_coverage_mean - target_coverage) > 0.014 else "✅ "
1111
+
1112
+ logging.info(f" {violation_marker}{group_name}:")
1113
+ logging.info(f" Coverage: {group_coverage_mean:.3f} ± {group_coverage_std:.3f} (target: {target_coverage:.1f})")
1114
+ logging.info(f" Retention: {group_retention_mean:.3f} ± {group_retention_std:.3f}")
1115
+ logging.info(f" Claims: {group_retained_claims_mean:.1f} ± {group_retained_claims_std:.1f}/{group_total_claims_mean:.1f}")
1116
+ logging.info(f" Group size: {group_size_mean:.1f} samples")
1117
+ logging.info(f" Coverage gap: {group_coverage_mean - target_coverage:+.3f}")
1118
+
1119
+
1120
+
1121
+ logging.info("\n" + "=" * 100)
1122
+
1123
+ save_aggregated_results_to_json(all_runs_results, args)
1124
+
1125
+
1126
+ if __name__ == "__main__":
1127
+ main()
MACI-main/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==2.0.2
2
+ scipy==1.13.1
3
+ scikit-learn==1.6.1
4
+ pandas==2.3.1
5
+ matplotlib==3.9.4
6
+ seaborn==0.13.2
7
+ tqdm==4.67.1
8
+ cvxpy==1.7.1
9
+ conditionalconformal==0.0.5
10
+ torch==2.8.0
11
+ torchvision==0.23.0
12
+ torchaudio==2.8.0