Spaces:
Running
Running
| import torch | |
| from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
| import os | |
| from huggingface_hub import login | |
| login(token=os.environ.get("visual_cuad_vcuad")) | |
| TOK_NAME = "mrm8488/longformer-base-4096-finetuned-squadv2" | |
| MODEL_NAME = "jira877832/cuad-longformer-squadv2-finetuned" | |
| _tokenizer = None | |
| _model = None | |
| def get_model(): | |
| global _tokenizer, _model | |
| if _tokenizer is None: | |
| _tokenizer = AutoTokenizer.from_pretrained(TOK_NAME) | |
| if _model is None: | |
| _model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME) | |
| _model.eval() | |
| return _tokenizer, _model | |
| def answer_topk_longformer(question, chunks, top_k=5, max_answer_len=4096): | |
| tokenizer, model = get_model() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| all_answers = [] | |
| for chunk in chunks: | |
| if not chunk.strip(): | |
| continue | |
| encoding = tokenizer( | |
| question, chunk, | |
| return_tensors="pt", | |
| truncation="only_second", | |
| max_length=4096, | |
| padding="max_length", | |
| return_offsets_mapping=True, | |
| ) | |
| input_ids = encoding["input_ids"].to(device) | |
| attention_mask = encoding["attention_mask"].to(device) | |
| offsets = encoding["offset_mapping"][0] | |
| sequence_ids = encoding.sequence_ids(0) | |
| context_start = next((i for i, s in enumerate(sequence_ids) if s == 1), None) | |
| context_end = next((i for i in range(len(sequence_ids)-1, -1, -1) if sequence_ids[i] == 1), None) | |
| if context_start is None or context_end is None: | |
| continue | |
| sep_indices = (input_ids[0] == tokenizer.sep_token_id).nonzero(as_tuple=True)[0] | |
| if len(sep_indices) < 2: | |
| continue | |
| question_end = sep_indices[0].item() + 1 | |
| global_attention_mask = torch.zeros_like(input_ids) | |
| global_attention_mask[0, :question_end] = 1 | |
| with torch.no_grad(): | |
| outputs = model( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| global_attention_mask=global_attention_mask, | |
| ) | |
| start_scores = outputs.start_logits[0] | |
| end_scores = outputs.end_logits[0] | |
| start_indexes = start_scores[context_start:context_end+1].argsort(descending=True)[:20] | |
| start_indexes = [i + context_start for i in start_indexes] | |
| for start_idx in start_indexes: | |
| for end_idx in range(start_idx, min(start_idx + max_answer_len, context_end + 1)): | |
| if offsets[start_idx] is None or offsets[end_idx] is None: | |
| continue | |
| start_char = offsets[start_idx][0].item() | |
| end_char = offsets[end_idx][1].item() | |
| answer_text = chunk[start_char:end_char].strip() | |
| if not answer_text: | |
| continue | |
| score = (start_scores[start_idx] + end_scores[end_idx]).item() | |
| all_answers.append((answer_text, score)) | |
| return sorted(all_answers, key=lambda x: x[1], reverse=True)[:top_k] |