|
|
import argparse |
|
|
import numpy as np |
|
|
|
|
|
from config import get_config |
|
|
from conformal import compute_conformity_scores, calibrate_thresholds, conformal_filter, assess_factscore_coverage |
|
|
from dataset import load_dataset, split_dataset |
|
|
from featurizer import get_features |
|
|
from llm_utils import merge_claims |
|
|
from prob_model import fit_model |
|
|
from gpt import GPTClient |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser( |
|
|
prog="conformal-safety", |
|
|
description="Auto-filter claims from LLM to meet accuracy and safety guarantees.", |
|
|
) |
|
|
parser.add_argument('-config_path', '-c', default='configs/default.toml', help="Config for construction.") |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_args() |
|
|
|
|
|
config = get_config(args.config_path) |
|
|
|
|
|
rng = np.random.default_rng(seed=config.dataset.seed) |
|
|
|
|
|
|
|
|
dataset = load_dataset(config) |
|
|
|
|
|
|
|
|
dataset_train, dataset_valid, dataset_test = split_dataset( |
|
|
dataset, |
|
|
train_perc=config.dataset.train_percent, |
|
|
valid_perc=config.dataset.valid_percent, |
|
|
rng=rng if config.dataset.randomize else None |
|
|
) |
|
|
|
|
|
X_train = get_features(dataset_train, config) |
|
|
|
|
|
y_train = np.concatenate([[c['is_supported'] for c in dat['atomic_facts']] for dat in dataset_train]) |
|
|
y_train[y_train == True] = 1 |
|
|
y_train[y_train == False] = 0 |
|
|
y_train = y_train.astype(np.int8) |
|
|
|
|
|
X_valid = get_features(dataset_valid, config) |
|
|
y_valid = np.concatenate([[c['is_supported'] for c in dat['atomic_facts']] for dat in dataset_valid]) |
|
|
y_valid[y_valid == True] = 1 |
|
|
y_valid[y_valid == False] = 0 |
|
|
y_valid = y_valid.astype(np.int8) |
|
|
splits_valid = np.cumsum([len(dat['atomic_facts']) for dat in dataset_valid])[:-1] |
|
|
|
|
|
X_test = get_features(dataset_test, config) |
|
|
y_test = np.concatenate([[c['is_supported'] for c in dat['atomic_facts']] for dat in dataset_test]) |
|
|
y_test[y_test == True] = 1 |
|
|
y_test[y_test == False] = 0 |
|
|
y_test = y_test.astype(np.int8) |
|
|
splits_test = np.cumsum([len(dat['atomic_facts']) for dat in dataset_test])[:-1] |
|
|
|
|
|
model = fit_model(X_train, y_train, config, dataset_train, |
|
|
eval_dict={'X_valid': X_valid, 'X_test': X_test, 'dataset_valid': dataset_valid, 'splits_valid': splits_valid, 'splits_test': splits_test}) |
|
|
|
|
|
scores_valid = model.predict_proba(X_valid)[:,1] |
|
|
scores_valid = np.array_split(scores_valid, splits_valid) |
|
|
|
|
|
scores_test = model.predict_proba(X_test)[:,1] |
|
|
scores_test = np.array_split(scores_test, splits_test) |
|
|
|
|
|
score_features_v = [np.zeros((len(u['atomic_facts']), 1)) for u in dataset_valid] |
|
|
score_features_te = [np.zeros((len(u['atomic_facts']), 1)) for u in dataset_test] |
|
|
|
|
|
conf_scores_valid = compute_conformity_scores(dataset_valid, scores_valid) |
|
|
|
|
|
|
|
|
|
|
|
alpha_fn = lambda x: [config.conformal.alpha] * len(x) |
|
|
|
|
|
|
|
|
conf_features_v = np.zeros((len(dataset_valid),1)) |
|
|
conf_features_te = np.zeros((len(dataset_test),1)) |
|
|
|
|
|
|
|
|
thresholds = calibrate_thresholds( |
|
|
conf_features_te, |
|
|
conf_features_v, |
|
|
conf_scores_valid, |
|
|
alpha_fn |
|
|
) |
|
|
|
|
|
dataset_test = conformal_filter( |
|
|
dataset_test, |
|
|
scores_test, |
|
|
thresholds |
|
|
) |
|
|
|
|
|
if config.dataset.name.lower() == "factscore": |
|
|
assess_factscore_coverage(dataset_test, config.conformal.alpha) |
|
|
|
|
|
print("Merging filtered responses.") |
|
|
|
|
|
merge_client = GPTClient(cache_file = config.model.merger.cache_path) |
|
|
merged_responses = merge_claims( |
|
|
dataset_test, |
|
|
merge_client |
|
|
) |
|
|
merge_client.save_cache() |
|
|
|
|
|
rand_idx = rng.integers(0, len(dataset_test)) |
|
|
print(dataset_test[rand_idx]['response']['message'] + "\n") |
|
|
print(merged_responses[rand_idx]) |
|
|
|
|
|
import IPython; IPython.embed() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|