theachyuttiwari commited on
Commit
daaf6f3
·
1 Parent(s): 29b1daa

Upload common.py

Browse files
Files changed (1) hide show
  1. common.py +120 -0
common.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+
5
+ kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
6
+ 'wikidata_info', 'history']
7
+
8
+ kilt_wikipedia_paragraph_columns = ['wikipedia_id', 'start_paragraph_id', 'start_character', 'end_paragraph_id',
9
+ 'end_character', 'title', 'section', 'text']
10
+
11
+
12
+ def clean_question(text):
13
+ result = cleanup_references(text)
14
+ result = result.replace("\n", " ")
15
+ result = re.sub(r"\s\s+", " ", result)
16
+ result = result.replace("[deleted]", "")
17
+ return result.lower().strip()
18
+
19
+
20
+ def cleanup_references(text):
21
+ # URL reference where we need to remove both the link text and URL
22
+ # ...and this letter is used by most biographers as the cornerstone of Lee's personal
23
+ # views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)).
24
+ # ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery.
25
+ result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE)
26
+
27
+ # URL reference where we need to preserve link text but remove URL
28
+ # At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South.
29
+ # At the outbreak of the Civil War, Leyburn left his church and joined the South.
30
+ result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE)
31
+
32
+ # lastly remove just dangling _URL_[0-9]_ URL references
33
+ result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE)
34
+ return result
35
+
36
+
37
+ def clean_answer(text):
38
+ result = cleanup_references(text)
39
+ result = result.replace("\n", " ")
40
+ result = re.sub(r"\s\s+", " ", result)
41
+ result = re.sub(r"BULLET::::-", "", result)
42
+ return trim(result.strip())
43
+
44
+
45
+ def trim(text, word_count: int = 100):
46
+ return " ".join(text.split(" ")[:word_count])
47
+
48
+
49
+ def articles_to_paragraphs(examples):
50
+ ids, titles, sections, texts, start_ps, end_ps, start_cs, end_cs = [], [], [], [], [], [], [], []
51
+ for bidx, example in enumerate(examples["text"]):
52
+ last_section = ""
53
+ for idx, p in enumerate(example["paragraph"]):
54
+ if "Section::::" in p:
55
+ last_section = p
56
+ ids.append(examples["wikipedia_id"][bidx])
57
+ titles.append(examples["wikipedia_title"][bidx])
58
+ sections.append(last_section)
59
+ texts.append(p)
60
+ start_ps.append(idx)
61
+ end_ps.append(idx)
62
+ start_cs.append(0)
63
+ end_cs.append(len(p))
64
+
65
+ return {"wikipedia_id": ids, "title": titles,
66
+ "section": sections, "text": texts,
67
+ "start_paragraph_id": start_ps, "end_paragraph_id": end_ps,
68
+ "start_character": start_cs,
69
+ "end_character": end_cs
70
+ }
71
+
72
+
73
+ def create_kilt_datapoint(eli5_example, columns, wiki_passages, min_length=20, topk=7):
74
+ res_list = [dict([(k, p[k]) for k in columns]) for p in wiki_passages]
75
+ res_list = [res for res in res_list if len(res["text"].split()) > min_length][:topk]
76
+
77
+ # make a KILT data point
78
+ # see https://github.com/facebookresearch/KILT#kilt-data-format
79
+ output = []
80
+ for a in eli5_example["answers"]["text"]:
81
+ output.append({"answer": a})
82
+
83
+ output.append({"provenance": [
84
+ # evidence set for the answer from the KILT ks
85
+ {
86
+ "wikipedia_id": r["wikipedia_id"], # *mandatory*
87
+ "title": r["title"],
88
+ "section": r["section"],
89
+ "start_paragraph_id": r["start_paragraph_id"],
90
+ "start_character": r["start_character"],
91
+ "end_paragraph_id": r["end_paragraph_id"],
92
+ "end_character": r["end_character"],
93
+ "text": r["text"],
94
+ "bleu_score": None, # wrt original evidence
95
+ "meta": None # dataset/task specific
96
+ } for r in res_list
97
+ ]})
98
+ return {"id": eli5_example["q_id"],
99
+ "input": eli5_example["title"],
100
+ "output": output, # each element is an answer or provenance (can have multiple of each)
101
+ "meta": None # dataset/task specific
102
+ }
103
+
104
+
105
+ def embed_questions(question_model, question_tokenizer, questions, max_length=128, device="cuda:0"):
106
+ query = question_tokenizer(questions, max_length=max_length, padding="max_length", truncation=True,
107
+ return_tensors="pt")
108
+ with torch.no_grad():
109
+ q_reps = question_model(query["input_ids"].to(device),
110
+ query["attention_mask"].to(device)).pooler_output
111
+ return q_reps.cpu().numpy()
112
+
113
+
114
+ def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cuda:0"):
115
+ p = ctx_tokenizer(passages["text"], max_length=max_length, padding="max_length",
116
+ truncation=True, return_tensors="pt")
117
+ with torch.no_grad():
118
+ a_reps = ctx_model(p["input_ids"].to(device),
119
+ p["attention_mask"].to(device)).pooler_output
120
+ return {"embeddings": a_reps.cpu().numpy()}