Update app.py
Browse files
app.py
CHANGED
|
@@ -99,14 +99,13 @@ class Retriever:
|
|
| 99 |
|
| 100 |
return retrieved_texts
|
| 101 |
|
| 102 |
-
|
| 103 |
class RAG:
|
| 104 |
def __init__(self,
|
| 105 |
file_path,
|
| 106 |
device,
|
| 107 |
context_model_name="facebook/dpr-ctx_encoder-multiset-base",
|
| 108 |
question_model_name="facebook/dpr-question_encoder-multiset-base",
|
| 109 |
-
generator_name="
|
| 110 |
|
| 111 |
# generator_name = "valhalla/bart-large-finetuned-squadv1"
|
| 112 |
# generator_name = "'vblagoje/bart_lfqa'"
|
|
@@ -122,22 +121,24 @@ class RAG:
|
|
| 122 |
|
| 123 |
|
| 124 |
def abstractive_query(self, question):
|
| 125 |
-
self.generator_tokenizer = BartTokenizer.from_pretrained(
|
| 126 |
-
self.generator_model = BartForConditionalGeneration.from_pretrained(
|
| 127 |
context = self.retriever.retrieve_top_k(question, k=5)
|
|
|
|
| 128 |
|
| 129 |
input_text = "answer: " + " ".join(context) + " " + question
|
| 130 |
|
| 131 |
-
inputs = self.generator_tokenizer.encode(input_text, return_tensors='pt', max_length=
|
| 132 |
-
outputs = self.generator_model.generate(inputs, max_length=
|
| 133 |
|
| 134 |
answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 135 |
return answer
|
| 136 |
|
| 137 |
def extractive_query(self, question):
|
| 138 |
-
context = self.retriever.retrieve_top_k(question, k=
|
|
|
|
| 139 |
|
| 140 |
-
inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=
|
| 141 |
with torch.no_grad():
|
| 142 |
model_inputs = inputs.to(device)
|
| 143 |
outputs = self.generator_model(**model_inputs)
|
|
|
|
| 99 |
|
| 100 |
return retrieved_texts
|
| 101 |
|
|
|
|
| 102 |
class RAG:
|
| 103 |
def __init__(self,
|
| 104 |
file_path,
|
| 105 |
device,
|
| 106 |
context_model_name="facebook/dpr-ctx_encoder-multiset-base",
|
| 107 |
question_model_name="facebook/dpr-question_encoder-multiset-base",
|
| 108 |
+
generator_name="facebook/bart-large"):
|
| 109 |
|
| 110 |
# generator_name = "valhalla/bart-large-finetuned-squadv1"
|
| 111 |
# generator_name = "'vblagoje/bart_lfqa'"
|
|
|
|
| 121 |
|
| 122 |
|
| 123 |
def abstractive_query(self, question):
|
| 124 |
+
self.generator_tokenizer = BartTokenizer.from_pretrained(generator_name)
|
| 125 |
+
self.generator_model = BartForConditionalGeneration.from_pretrained(generator_name).to(device)
|
| 126 |
context = self.retriever.retrieve_top_k(question, k=5)
|
| 127 |
+
# input_text = question + " " + " ".join(context)
|
| 128 |
|
| 129 |
input_text = "answer: " + " ".join(context) + " " + question
|
| 130 |
|
| 131 |
+
inputs = self.generator_tokenizer.encode(input_text, return_tensors='pt', max_length=500, truncation=True).to(device)
|
| 132 |
+
outputs = self.generator_model.generate(inputs, max_length=150, min_length=2, length_penalty=2.0, num_beams=4, early_stopping=True)
|
| 133 |
|
| 134 |
answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 135 |
return answer
|
| 136 |
|
| 137 |
def extractive_query(self, question):
|
| 138 |
+
context = self.retriever.retrieve_top_k(question, k=15)
|
| 139 |
+
|
| 140 |
|
| 141 |
+
inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=300, padding="max_length")
|
| 142 |
with torch.no_grad():
|
| 143 |
model_inputs = inputs.to(device)
|
| 144 |
outputs = self.generator_model(**model_inputs)
|