Spaces:
Running
Running
| import json | |
| from context_cite import ContextCiter | |
| import re | |
| import torch | |
| from transformers import LlamaForCausalLM, LlamaTokenizer | |
| def all_normalize(obj): | |
| all_values = [] | |
| for output_sent_result in obj: | |
| for each_doc in output_sent_result: | |
| for each_span in each_doc: | |
| all_values.append(each_span[1]) | |
| max_val = max(all_values) | |
| min_val = min(all_values) | |
| for output_sent_result in obj: | |
| for i, each_doc in enumerate(output_sent_result): | |
| for j, each_span in enumerate(each_doc): | |
| each_span = (each_span[0], (each_span[1] - min_val) / (max_val - min_val)) | |
| output_sent_result[i][j] = each_span | |
| return obj | |
| def all_normalize_in(obj): | |
| for output_sent_result in obj: | |
| all_values = [] | |
| for each_doc in output_sent_result: | |
| for each_span in each_doc: | |
| all_values.append(each_span[1]) | |
| max_val = max(all_values) | |
| min_val = min(all_values) | |
| for i, each_doc in enumerate(output_sent_result): | |
| for j, each_span in enumerate(each_doc): | |
| each_span = (each_span[0], (each_span[1] - min_val) / (max_val - min_val)) | |
| output_sent_result[i][j] = each_span | |
| return obj | |
| def load_json(file_path): | |
| with open(file_path, 'r') as file: | |
| data = file.read() | |
| if file_path.endswith('.jsonl'): | |
| joined = "},{".join(data.split("}\n{")) | |
| data = f'[{{{joined}}}]' | |
| objects = json.loads(data) | |
| return objects | |
| def ma(text): | |
| pattern = r"Document \[\d+\]\(Title:[^)]+\)" | |
| match = re.search(pattern, text) | |
| if match: | |
| index = match.end() | |
| return index | |
| else: | |
| return 0 | |
| def write_json(file_path, data): | |
| with open(file_path, 'w') as json_file: | |
| json.dump(data, json_file, indent=4) | |
| def load_model(model_name_or_path): | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| device_map='auto', | |
| token = 'your token' | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| model.eval() | |
| return model, tokenizer | |
| def compute_log_prob(model, tokenizer, input_text, output_text): | |
| inputs = tokenizer(input_text, return_tensors="pt") | |
| output_tokens = tokenizer(output_text, return_tensors="pt")["input_ids"] | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits[:, -output_tokens.shape[1]-1:-1, :] | |
| log_probs = torch.nn.functional.log_softmax(logits, dim=-1) | |
| output_log_probs = log_probs.gather(2, output_tokens.unsqueeze(-1)).squeeze(-1) | |
| return output_log_probs.sum().item() | |
| def compute_contributions(model, tokenizer, question, docs, output): | |
| full_input = question + '\n\n' + '\n'.join(docs) | |
| base_prob = compute_log_prob(model, tokenizer, full_input, output) | |
| contributions = [] | |
| for i in range(len(docs)): | |
| reduced_docs = docs[:i] + docs[i+1:] | |
| reduced_input = question + '\n\n' + '\n'.join(reduced_docs) | |
| reduced_prob = compute_log_prob(model, tokenizer, reduced_input, output) | |
| contributions.append(base_prob - reduced_prob) | |
| return contributions | |
| class InterpretableAttributer: | |
| def __init__(self, levels=['doc', 'span', 'word'], model = 'gpt-2'): | |
| for level in levels: | |
| assert level in ['doc', 'span', 'word'], f'Invalid level: {level}' | |
| # span before doc | |
| self.levels = sorted(levels, key=lambda x: ['span', 'doc', 'word'].index(x)) | |
| #self.model, self.tokenizer = load_model(model) | |
| def attribute(self, question, docs, output): | |
| attribute_results = {} | |
| for level in self.levels: | |
| attribute_result = [] | |
| for sentence in output: | |
| attribute_result.append(self._attribute(question, docs, sentence, level)) | |
| attribute_results[level] = attribute_result | |
| return attribute_results | |
| def _attribute(self, question, docs, output, level): | |
| if level == 'doc': | |
| return self.doc_level_attribution(question, docs, output) | |
| elif level == 'span': | |
| return self.span_level_attribution(question, docs, output) | |
| elif level == 'word': | |
| return self.word_level_attribution(question, docs, output) | |
| else: | |
| raise ValueError(f'Invalid level: {level}') | |
| def span_level_attribution(self, question, docs, output): | |
| # USE CONTEXT CITE | |
| context = '\n\n'.join(docs) | |
| response = output | |
| cc = ContextCiter(self.model, self.tokenizer, context, question) | |
| _, prompt = cc._get_prompt_ids(return_prompt=True) | |
| cc._cache["output"] = prompt + response | |
| result = cc.get_attributions(as_dataframe=True, top_k=1000).data.to_dict(orient='records') | |
| return result | |
| def parse_attribution_results(self, docs, results): | |
| context = '\n\n'.join(docs) | |
| lens = [len(doc) for doc in docs] | |
| len_sep = len('\n\n') | |
| final_results = {} | |
| for level, result in results.items(): | |
| if level == 'span': | |
| ordered_all_sents = [] | |
| for output_sent_result in result: | |
| final_end_for_span = {} | |
| all_span_results = [] | |
| for each_span in output_sent_result: | |
| span_text = each_span["Source"] | |
| span_score = each_span["Score"] | |
| start = 0 | |
| if span_text in final_end_for_span: | |
| start = final_end_for_span[span_text] | |
| span_start = context.find(span_text, start) | |
| span_end = span_start + len(span_text) | |
| final_end_for_span[span_text] = span_end | |
| # locate the document | |
| doc_idx = 0 | |
| while span_start > lens[doc_idx]: | |
| span_start -= lens[doc_idx] + len_sep | |
| span_end -= lens[doc_idx] + len_sep | |
| doc_idx += 1 | |
| all_span_results.append((span_start, span_score, doc_idx)) | |
| ordered = [[] for _ in range(len(docs))] | |
| for span_start, span_score, doc_idx in all_span_results: | |
| ordered[doc_idx].append((span_start, span_score)) | |
| for i in range(len(docs)): | |
| doc = docs[i] | |
| real_start = ma(doc) | |
| ordered[i] = sorted(ordered[i], key=lambda x: x[0]) | |
| ordered[i][0] = (real_start, ordered[i][0][1]) | |
| ordered_all_sents.append(ordered) | |
| final_results[level+'_level'] = all_normalize_in(ordered_all_sents) | |
| elif level == 'doc': | |
| self.span_to_doc(result) | |
| else: | |
| raise NotImplementedError(f'Parsing for {level} not implemented yet') | |
| return final_results | |
| def span_to_doc(self, results): | |
| import numpy as np | |
| span_level = results['span_level'] | |
| doc_level = [] | |
| for output_sent_result in span_level: | |
| doc_level.append([np.mean([span[1] for span in doc]) for doc in output_sent_result]) | |
| results['doc_level'] = doc_level | |
| def attribute_for_result(self, result): | |
| docs = result['doc_cache'] | |
| question = result['data']['question'] | |
| output = result['output'] | |
| attribution_results = self.attribute(question, docs, output) | |
| parsed_results = self.parse_attribution_results(docs, attribution_results) | |
| result.update(parsed_results) | |
| if 'doc' not in self.levels: | |
| # if doc is not in the levels, we need to convert the span level to doc level | |
| print('Converting span level to doc level...') | |
| try: | |
| self.span_to_doc(result) | |
| print('Conversion successful') | |
| except Exception as e: | |
| print(f'Error converting span level to doc level: {e}') | |
| def attribute_for_results(self, results): | |
| for result in results: | |
| self.attribute_for_result(result) | |
| return results | |
| if __name__ == '__main__': | |
| attributer = InterpretableAttributer(levels=['span']) | |
| results = load_json('res_attr.json') | |
| attributer.attribute_for_results(results) | |
| write_json('res_attr_span.json', results) |