Spaces:
Running
Running
| import numpy as np | |
| import string | |
| import re | |
| import collections | |
| import torch | |
| import nltk | |
| def one_paragraph(text): | |
| paras = text.lstrip('\n').split('\n\n') | |
| if not paras: | |
| return '' | |
| else: | |
| return paras[0].rstrip('\n') | |
| def strong_one_paragraph(text): | |
| paras = text.lstrip('\n').split('\n') | |
| if not paras: | |
| return '' | |
| else: | |
| return paras[0].rstrip('\n') | |
| def compute_str_em(data): | |
| """Compute STR-EM metric (only for ASQA) | |
| Args: | |
| data: requires field `qa_pairs/short_answers` and `output` | |
| Returns: | |
| STR-EM and STR-EM-HIT () | |
| """ | |
| if 'qa_pairs' not in data[0] or data[0]['qa_pairs'] is None: | |
| return 0 | |
| acc = [] | |
| hit = [] | |
| for item in data: | |
| loc_acc = [] | |
| for qa_pair in item['qa_pairs']: | |
| loc_acc.append(exact_presence(qa_pair['short_answers'], item["output"])) | |
| acc.append(np.mean(loc_acc)) | |
| hit.append(int(np.mean(loc_acc) == 1)) | |
| return 100 * np.mean(acc) | |
| return 100 * np.mean(acc), 100 * np.mean(hit) | |
| def average(func): | |
| def avg_func(dataset): | |
| print(len(dataset)) | |
| results = [func(*data) for data in dataset] if dataset else [] | |
| if results: | |
| return np.mean(np.array(results), axis=0).tolist() | |
| else: | |
| return None | |
| return avg_func | |
| def normalize_answer(s): | |
| def remove_articles(text): | |
| return re.sub(r"\b(a|an|the)\b", " ", text) | |
| def white_space_fix(text): | |
| return " ".join(text.split()) | |
| def remove_punc(text): | |
| exclude = set(string.punctuation) | |
| return "".join(ch for ch in text if ch not in exclude) | |
| def lower(text): | |
| return text.lower() | |
| return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
| def compute_f1(a_gold, a_pred): | |
| """Compute F1 score between two strings.""" | |
| def _get_tokens(s): | |
| if not s: | |
| return [] | |
| return normalize_answer(s).split() | |
| gold_toks = _get_tokens(a_gold) | |
| pred_toks = _get_tokens(a_pred) | |
| common = collections.Counter(gold_toks) & collections.Counter(pred_toks) | |
| num_same = sum(common.values()) | |
| if len(gold_toks) == 0 or len(pred_toks) == 0: | |
| # If either is no-answer, then F1 is 1 if they agree, 0 otherwise | |
| return int(gold_toks == pred_toks) | |
| if num_same == 0: | |
| return 0 | |
| precision = 1.0 * num_same / len(pred_toks) | |
| recall = 1.0 * num_same / len(gold_toks) | |
| f1 = (2 * precision * recall) / (precision + recall) | |
| return f1 | |
| def compute_exact(a_gold, a_pred): | |
| """Check whether two strings are equal up to normalization.""" | |
| return int(normalize_answer(a_gold) == normalize_answer(a_pred)) | |
| def exact_presence(short_answers, context): | |
| """Verify if any of the answers is present in the given context. | |
| Args: | |
| short_answers: list of short answers to look for in the context | |
| context: a paragraph to search for short answers | |
| Returns: | |
| true if any of the short answers is present in the context | |
| """ | |
| n_short_answers = [normalize_answer(sa) for sa in short_answers] | |
| n_context = normalize_answer(context) | |
| for ans in n_short_answers: | |
| if ans in n_context: | |
| return True | |
| return False | |
| def output_begin_with(word): | |
| def f(self) -> bool: | |
| return self.last_message.strip().lower()[:len(word)] == word | |
| return f | |
| def output_end_with(word): | |
| def f(self) -> bool: | |
| return strong_one_paragraph(self.last_message.strip())[-len(word):] == word | |
| return f | |
| def make_as(datakey): | |
| def f(passage): | |
| return {datakey:passage} | |
| return f | |
| def cut_and_make_as(datakey): | |
| def f(passage): | |
| return {datakey:one_paragraph(passage)} | |
| return f | |
| def remove_citations(sent): | |
| return re.sub(r"{\d+", "", re.sub(r" {\d+", "", sent)).replace(" |", "").replace("}", "").replace("{", "") | |
| def remove_citations(sent): | |
| return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "") | |
| def match_document(ref_mark, output_ref_span): | |
| ref = set() | |
| ref_span = [] | |
| for num in ref_mark: | |
| ref_str = str(num) | |
| if ref_str in output_ref_span: | |
| ref_parts = output_ref_span[ref_str].split("[") | |
| if len(ref_parts) > 1: | |
| ref_id_parts = ref_parts[1].split("]") | |
| if len(ref_id_parts) > 0: | |
| ref_id = ref_id_parts[0].strip() | |
| if ref_id.isdigit(): | |
| ref.add(int(ref_id)) # 添加Document id | |
| ref_span_parts = output_ref_span[ref_str].split(":",1)#第一个冒号后面的片段 | |
| if len(ref_span_parts) > 1: | |
| ref_span.append(ref_span_parts[1].strip()) # 添加后面的句子片段 | |
| else: | |
| ref_span.append('') | |
| return list(ref), ref_span | |
| def get_max_memory(): | |
| """Get the maximum memory available for the current GPU for loading models.""" | |
| free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) | |
| max_memory = f'{free_in_GB-6}GB' | |
| n_gpus = torch.cuda.device_count() | |
| max_memory = {i: max_memory for i in range(n_gpus)} | |
| return max_memory | |
| def each_make_as(key): | |
| def function(output): | |
| sents = nltk.sent_tokenize(one_paragraph(output)) | |
| if len(sents)>3: | |
| sents = sents[:3] | |
| return [make_as(key)(sent) for sent in sents] | |
| return function | |
| def each_par_make_as(key): | |
| def function(output): | |
| sents = one_paragraph(output).split('\n') | |
| if len(sents)>3: | |
| sents = sents[:3] | |
| return [make_as(key)(sent) for sent in sents] | |
| return function | |
| def sentence(key): | |
| def function(output): | |
| sents = nltk.sent_tokenize(one_paragraph(output)) | |
| for sent in sents: | |
| refs = re.findall(r'\[\d+\]', sent) | |
| if refs: | |
| return make_as(key)(sent) | |
| return make_as(key)('') | |
| return function | |
| def sentences(key): | |
| def function(output): | |
| sents = nltk.sent_tokenize(one_paragraph(output)) | |
| return [make_as(key)(sent) for sent in sents][:1] | |
| return function | |
| def three_sentences(key): | |
| def function(output): | |
| sents = nltk.sent_tokenize(one_paragraph(output)) | |
| return [make_as(key)(sent) for sent in sents][:3] | |
| return function | |
| def first_sentence(text): | |
| sents = nltk.sent_tokenize(one_paragraph(text)) | |
| for sent in sents: | |
| return sent | |
| return '' | |
| def flatten_dict(d, parent_key='', sep='_'): | |
| items = [] | |
| for k, v in d.items(): | |
| new_key = f'{parent_key}{sep}{k}' if parent_key else k | |
| if isinstance(v, dict): | |
| items.extend(flatten_dict(v, new_key, sep=sep).items()) | |
| else: | |
| items.append((new_key, v)) | |
| return dict(items) | |
| import re | |
| from bs4 import BeautifulSoup | |
| def parse_html_prompt(input_str): | |
| soup = BeautifulSoup(input_str, "html.parser") | |
| # 处理 <p></p> 内的内容 | |
| p_content = soup.find("p").decode_contents().replace("<br>", "\n") | |
| p_content = re.sub(r'<span[^>]*>(.*?)</span>', r'<\1>', p_content) | |
| template = p_content.strip().replace(' <br/>', '').replace(' ', '').replace('<br/>', '') | |
| # 解析 component-item | |
| components = {} | |
| for item in soup.find_all("div", class_="component-item"): | |
| key_span = item.find("div", class_="component-key").find("span") | |
| key = key_span.get_text(strip=True) if key_span else "" | |
| value_div = item.find("div", class_="component-value") | |
| value_content = value_div.decode_contents() | |
| value_content = re.sub(r'<span[^>]*>(.*?)</span>', r'{\1}', value_content) | |
| components[key] = value_content.strip().replace(' <br/>', '').replace('<br/>', '') | |
| # 解析 self-info-item | |
| self_prompt = {} | |
| for item in soup.find_all("div", class_="self-info-item"): | |
| key_span = item.find("div", class_="component-key").find("span") | |
| key = key_span.get_text(strip=True) if key_span else "" | |
| value_div = item.find("div", class_="component-value") | |
| value = value_div.get_text(strip=True) if value_div else "" | |
| self_prompt[key] = value.replace(' <br/>', '').replace('<br/>', '') | |
| return { | |
| 'template': template, | |
| 'components': components, | |
| 'self_prompt': self_prompt | |
| } | |
| def parse_html_destination(input_str): | |
| soup = BeautifulSoup(input_str, "html.parser") | |
| destination = soup.find("destination").get_text(strip=True) | |
| prompt_key = soup.find("prompt_key").get_text(strip=True) | |
| return destination, prompt_key | |
| def parse_html_new_model(input_str): | |
| soup = BeautifulSoup(input_str, "html.parser") | |
| model_type = soup.find("model_type").get_text(strip=True) | |
| model_name = soup.find("model").get_text(strip=True) | |
| key = soup.find("prompt_key").get_text(strip=True) | |
| return model_type, model_name, key | |
| def parse_delete_destination(input_str): | |
| soup = BeautifulSoup(input_str, "html.parser") | |
| destination = soup.find("deletedestination").get_text(strip=True) | |
| return destination | |
| def parse_html_header(input_str): | |
| soup = BeautifulSoup(input_str, "html.parser") | |
| header = soup.find("to_head").get_text(strip=True) | |
| return header | |
| def parse_html_config(info): | |
| config = '' | |
| if 'class="component-value"' in info: | |
| func = parse_html_prompt | |
| config = 'prompt' | |
| elif '</destination>' in info: | |
| func = parse_html_destination | |
| config = 'destination' | |
| elif '<model_type>' in info: | |
| func = parse_html_new_model | |
| config = 'new_model' | |
| elif 'deletedestination' in info: | |
| config = 'delete_destination' | |
| func = parse_delete_destination | |
| elif 'to_head' in info: | |
| config = 'header' | |
| func = parse_html_header | |
| else: | |
| raise NotImplementedError | |
| result = func(info) | |
| print(info, 'parsed as', config) | |
| return config, result |