| import torch | |
| from transformers import ( | |
| LongformerTokenizer, | |
| LongformerForQuestionAnswering | |
| ) | |
| from typing import List, Dict, Tuple | |
| from dotenv import load_dotenv | |
| from src.readers.base_reader import Reader | |
| load_dotenv() | |
| class LongformerReader(Reader): | |
| def __init__(self) -> None: | |
| checkpoint = "valhalla/longformer-base-4096-finetuned-squadv1" | |
| self.tokenizer = LongformerTokenizer.from_pretrained(checkpoint) | |
| self.model = LongformerForQuestionAnswering.from_pretrained(checkpoint) | |
| def read(self, | |
| query: str, | |
| context: Dict[str, List[str]], | |
| num_answers=5) -> List[Tuple]: | |
| answers = [] | |
| for text in context['texts'][:num_answers]: | |
| encoding = self.tokenizer(query, text, return_tensors="pt") | |
| input_ids = encoding["input_ids"] | |
| attention_mask = encoding["attention_mask"] | |
| outputs = self.model(input_ids, attention_mask=attention_mask) | |
| start_logits = outputs.start_logits | |
| end_logits = outputs.end_logits | |
| all_tokens = self.tokenizer.convert_ids_to_tokens( | |
| input_ids[0].tolist()) | |
| answer_tokens = all_tokens[ | |
| torch.argmax(start_logits):torch.argmax(end_logits) + 1] | |
| answer = self.tokenizer.decode( | |
| self.tokenizer.convert_tokens_to_ids(answer_tokens) | |
| ) | |
| answers.append([answer, [], []]) | |
| return answers | |