File size: 5,933 Bytes
906e061 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import argparse
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from config import get_config
from featurizer import get_frequency, get_self_eval
from gpt import GPTClient
from atomizer import text_to_sentences
from dataset import get_prompts
from scorer import Scorer
import ray
def parse_args():
parser = argparse.ArgumentParser(
prog="conformal-safety",
description="Auto-filter claims from LLM to meet accuracy and safety guarantees.",
)
parser.add_argument('-config_path', '-c', default='configs/default.toml', help="Config for construction.")
args = parser.parse_args()
return args
def find_unique_element(lst, condition, approx_index):
# Check the approximate index first
if condition(lst[approx_index]):
return approx_index
# Initialize left and right pointers
left = approx_index - 1
right = approx_index + 1
# Expand outwards from the approximate index
while left >= 0 or right < len(lst):
if left >= 0 and condition(lst[left]):
return left
if right < len(lst) and condition(lst[right]):
return right
left -= 1
right += 1
# If no element satisfies the condition, return None or raise an exception
return None
@ray.remote
def parallel_scorer(*args, **kwargs):
return None
return run_experiment(*args, **kwargs)
if __name__ == "__main__":
args = parse_args()
config = get_config(args.config_path)
import IPython; IPython.embed()
responder = GPTClient(config.model.responder.cache_path)
topics, prompts = get_prompts(config.dataset.name)
with ThreadPoolExecutor(max_workers=25) as executor:
responses = list(
tqdm(
executor.map(
lambda x : responder.query(x),
prompts
),
total=len(prompts)
)
)
responses = [r[0] for r in responses]
outputs = [{'prompt': p, 'response': o['message']}
for p, o in zip(prompts, responses)] # first output is the response we will filter
print("Loading atomizer.")
atomizer_client = GPTClient(config.model.parser.cache_path, model=config.model.parser.name)
responder_cache = responder.cache_dict
messages = []
for val in responder_cache.values():
messages.append(val[0]['message'])
atomizer_cache = atomizer_client.cache_dict
idx_guess = 0
atomic_facts = [[] for _ in range(len(messages))]
for k in tqdm(atomizer_cache.keys()):
atomized_msg = atomizer_cache[k][0]['message']
atomized_facts = text_to_sentences(atomized_msg)
sentence = k.split('\n')[-1].split('facts:')[-1].strip()[:-2]
cur_idx = find_unique_element(messages, lambda x: sentence in x, approx_index=idx_guess)
if cur_idx is None: # TODO: TERRIBLE SPECIAL CASING that I looked at by hand...
if idx_guess == 4151:
cur_idx = 4152
else:
cur_idx = idx_guess
idx_guess = cur_idx
atomic_facts[cur_idx].extend(atomized_facts)
# time to annotate responses using factscore code
print("Loading annotator.")
scorer_client = GPTClient(config.model.annotator.cache_path, model=config.model.annotator.name)
scorer = Scorer(scorer_client, config, model_name="retrieval")
scorer_inputs = [(topic, output['response'], fact) for topic, output, fact in zip(topics, outputs, atomic_facts)]
import IPython; IPython.embed()
# connect to cluster
ray.init(address="auto")
results = []
for seed in range(args.seed, args.seed + args.n_trials):
if args.type == 'coverage':
result = parallel_coverage_experiment.remote(
(X, Y), n_test, n_calib, alpha, methods=args.methods, seed=seed
)
else:
result = parallel_experiment.remote(
(X, Y), n_test, n_calib, alpha, methods=args.methods, seed=seed
)
results.append(result)
trial_results = ray.get(results)
with ThreadPoolExecutor(max_workers=1) as executor:
scores = list(
tqdm(
executor.map(
lambda x : scorer.get_score(*x),
scorer_inputs
),
total=len(scorer_inputs)
)
)
scorer.save_cache()
dataset = []
for p, r, s in zip(prompts, responses, scores):
data_pt = {
'prompt': p,
'response': r,
'atomic_facts': s['decisions'][0]
}
dataset.append(data_pt)
import IPython
IPython.embed()
# client = GPTClient(f'.cache/{config.dataset.name}_frequency.pkl')
# with ThreadPoolExecutor(max_workers=5) as executor:
# frequencies = list(
# tqdm(
# executor.map(
# lambda x: get_frequency(client, [af['atom'] for af in x['atomic_facts']], x['prompt'], config.model.prob.frequency.model),
# dataset
# ),
# total=len(dataset)
# )
# )
# client.save_cache()
# eval_client = GPTClient(f'.cache/{config.dataset.name}_self_evals.pkl')
# with ThreadPoolExecutor(max_workers=25) as executor:
# self_evals = list(
# tqdm(
# executor.map(
# lambda x: get_self_eval(x['prompt'], [af['atom'] for af in x['atomic_facts']], eval_client),
# dataset
# ),
# total=len(dataset)
# )
# )
# eval_client.save_cache()
# features = np.concatenate(
# [
# np.concatenate(frequencies).reshape(-1,1),
# np.concatenate(self_evals).reshape(-1,1)
# ],
# axis=1
# ) |