import torch from transformers import AutoTokenizer, AutoModelForQuestionAnswering import os from huggingface_hub import login login(token=os.environ.get("visual_cuad_vcuad")) TOK_NAME = "mrm8488/longformer-base-4096-finetuned-squadv2" MODEL_NAME = "jira877832/cuad-longformer-squadv2-finetuned" _tokenizer = None _model = None def get_model(): global _tokenizer, _model if _tokenizer is None: _tokenizer = AutoTokenizer.from_pretrained(TOK_NAME) if _model is None: _model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME) _model.eval() return _tokenizer, _model def answer_topk_longformer(question, chunks, top_k=5, max_answer_len=4096): tokenizer, model = get_model() device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) all_answers = [] for chunk in chunks: if not chunk.strip(): continue encoding = tokenizer( question, chunk, return_tensors="pt", truncation="only_second", max_length=4096, padding="max_length", return_offsets_mapping=True, ) input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) offsets = encoding["offset_mapping"][0] sequence_ids = encoding.sequence_ids(0) context_start = next((i for i, s in enumerate(sequence_ids) if s == 1), None) context_end = next((i for i in range(len(sequence_ids)-1, -1, -1) if sequence_ids[i] == 1), None) if context_start is None or context_end is None: continue sep_indices = (input_ids[0] == tokenizer.sep_token_id).nonzero(as_tuple=True)[0] if len(sep_indices) < 2: continue question_end = sep_indices[0].item() + 1 global_attention_mask = torch.zeros_like(input_ids) global_attention_mask[0, :question_end] = 1 with torch.no_grad(): outputs = model( input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask, ) start_scores = outputs.start_logits[0] end_scores = outputs.end_logits[0] start_indexes = start_scores[context_start:context_end+1].argsort(descending=True)[:20] start_indexes = [i + context_start for i in start_indexes] for start_idx in start_indexes: for end_idx in range(start_idx, min(start_idx + max_answer_len, context_end + 1)): if offsets[start_idx] is None or offsets[end_idx] is None: continue start_char = offsets[start_idx][0].item() end_char = offsets[end_idx][1].item() answer_text = chunk[start_char:end_char].strip() if not answer_text: continue score = (start_scores[start_idx] + end_scores[end_idx]).item() all_answers.append((answer_text, score)) return sorted(all_answers, key=lambda x: x[1], reverse=True)[:top_k]