| | import string |
| | import numpy as np |
| | import os |
| | import json |
| |
|
| | from concurrent.futures import ThreadPoolExecutor |
| | |
| |
|
| | |
| | |
| | from retrieval import DocDB, Retrieval |
| |
|
| | class Scorer(object): |
| |
|
| | def __init__(self, |
| | client, |
| | config, |
| | model_name="retrieval+ChatGPT", |
| | batch_size=256): |
| | assert model_name in ["retrieval+llama", "retrieval+llama+npm", "retrieval+ChatGPT", "npm", "retrieval+ChatGPT+npm", "retrieval"] |
| | self.model_name = model_name |
| | self.client = client |
| | self.config = config |
| |
|
| | self.data_dir = config.model.annotator.data_path |
| | self.cache_dir = config.model.annotator.retrieval_cache_path |
| |
|
| | self.db = {} |
| | self.retrieval = {} |
| | self.npm = {} |
| | self.batch_size = batch_size |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | self.af_generator = None |
| |
|
| | def save_cache(self): |
| | self.client.save_cache() |
| | if "npm" in self.model_name: |
| | for k, v in self.npm.items(): |
| | v.save_cache() |
| | for k, v in self.retrieval.items(): |
| | v.save_cache() |
| | for k, v in self.db: |
| | v.save_cache() |
| |
|
| | def register_knowledge_source(self, name="enwiki-20230401", db_path=None, data_path=None): |
| | assert name not in self.retrieval, f"{name} already registered" |
| |
|
| | if db_path is None: |
| | db_path = os.path.join(self.data_dir, f"{name}.db") |
| |
|
| | if data_path is None: |
| | data_path = os.path.join(self.data_dir, f"{name}.jsonl") |
| |
|
| | if name == "medlfqa": |
| | datasets = {} |
| | suffix = "_test_MedLFQA.jsonl" |
| |
|
| | |
| | for path in os.listdir(self.data_dir): |
| | if "MedLFQA" not in path: |
| | continue |
| | dataset_name = path[:-len(suffix)] |
| | with open(os.path.join(self.data_dir, path), 'r') as fp: |
| | datasets[dataset_name] = [json.loads(line) for line in fp.readlines()] |
| | retrieval = {} |
| | for _, dataset in datasets.items(): |
| | for pt in dataset: |
| | retrieval[pt['Question']] = { |
| | 'context': pt['Free_form_answer'], |
| | 'must_have': pt['Must_have'], |
| | 'nice_to_have': pt['Nice_to_have'] |
| | } |
| | self.retrieval[name] = retrieval |
| | |
| | else: |
| | db_cache_path = os.path.join(self.cache_dir, f"db-{name}.pkl") |
| | cache_path = os.path.join(self.cache_dir, f"retrieval-{name}.json") |
| | embed_cache_path = os.path.join(self.cache_dir, f"retrieval-{name}.pkl") |
| |
|
| | self.db[name] = DocDB(db_path=db_path, data_path=data_path, cache_path=db_cache_path) |
| | self.retrieval[name] = Retrieval(self.db[name], cache_path, embed_cache_path, retrieval_type="bm25", batch_size=self.batch_size) |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def get_score(self, |
| | topics, |
| | generations, |
| | atomic_facts, |
| | gamma=10, |
| | knowledge_source=None): |
| | if knowledge_source is None: |
| | |
| | knowledge_source = "enwiki-20230401" |
| |
|
| | if knowledge_source not in self.retrieval: |
| | self.register_knowledge_source(knowledge_source) |
| |
|
| | if type(topics)==type(generations)==str: |
| | topics = [topics] |
| | generations = [generations] |
| | atomic_facts = [atomic_facts] |
| | else: |
| | assert type(topics)==type(generations)==list, "`topics` and `generations` should be lists." |
| | assert len(topics)==len(generations), "`topics` and `generations` should have the same length" |
| | assert len(topics)==len(atomic_facts), "`topics` and `atomic_facts` should have the same length" |
| |
|
| | respond_ratio = np.mean([facts is not None for facts in atomic_facts]) |
| |
|
| | scores = [] |
| | init_scores = [] |
| | decisions = [] |
| | for topic, generation, facts in zip(topics, generations, atomic_facts): |
| | if facts is None: |
| | decisions.append(None) |
| | else: |
| | decision = [] |
| | for fact in facts: |
| | decision.append( |
| | self._get_score(topic, generation, fact, knowledge_source, decision) |
| | ) |
| | score = np.mean([d["is_supported"] for d in decision]) |
| | |
| | if gamma: |
| | init_scores.append(score) |
| | penalty = 1.0 if len(facts)>gamma else np.exp(1-gamma/max(len(facts), 1)) |
| | score = penalty * score |
| | |
| | decisions.append(decision) |
| | scores.append(score) |
| | |
| | |
| |
|
| | out = {"score": np.mean(scores), |
| | "respond_ratio": respond_ratio, |
| | "decisions": decisions, |
| | "num_facts_per_response": np.mean([len(d) for d in decisions if d is not None])} |
| |
|
| | if gamma: |
| | out["init_score"] = np.mean(init_scores) |
| | |
| | return out |
| |
|
| | def _get_score(self, topic, generation, atom, knowledge_source, prev_decisions = []): |
| | definition = f"Answer the question about {topic} based on the given context and your previous answers.\n\n" |
| | atom = atom.strip() |
| | if knowledge_source == "medlfqa": |
| | context = self.retrieval[knowledge_source][topic]['context'] |
| | else: |
| | passages = self.retrieval[knowledge_source].get_passages(topic, atom, k=5) |
| | context = "" |
| | for psg in reversed(passages): |
| | context += "Title: {}\nText: {}\n\n".format(psg["title"], psg["text"].replace("<s>", "").replace("</s>", "")) |
| | definition += context.strip() |
| | if not definition[-1] in string.punctuation: |
| | definition += "." |
| | prompt = f"{definition.strip()}\n\n" |
| | for prev_decision in prev_decisions: |
| | prev_score = "True" if prev_decision["is_supported"] else "False" |
| | prompt += f"Previous input: {prev_decision['atom']}\nTrue or False? Output: {prev_score}\n" |
| | |
| | prompt += f"Input: {atom.strip()} True or False?\nOutput:" |
| | |
| | output = self.client.query(prompt) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | generated_answer = output[0]['message'].lower() |
| | if "true" in generated_answer or "false" in generated_answer: |
| | if "true" in generated_answer and "false" not in generated_answer: |
| | is_supported = True |
| | elif "false" in generated_answer and "true" not in generated_answer: |
| | is_supported = False |
| | else: |
| | is_supported = generated_answer.index("true") > generated_answer.index("false") |
| | else: |
| | is_supported = all([keyword not in generated_answer.lower().translate(str.maketrans("", "", string.punctuation)).split() for keyword in ["not", "cannot", "unknown", "information"]]) |
| |
|
| | if is_supported and "npm" in self.model_name: |
| | npprob = self.npm[knowledge_source].get_probabilty(topic, atom) |
| | is_supported = npprob > 0.3 |
| |
|
| | decision = {"atom": atom, "is_supported": is_supported} |
| |
|
| | return decision |