ai / inference.py
jira877832's picture
Update inference.py
b28c678 verified
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import os
from huggingface_hub import login
login(token=os.environ.get("visual_cuad_vcuad"))
TOK_NAME = "mrm8488/longformer-base-4096-finetuned-squadv2"
MODEL_NAME = "jira877832/cuad-longformer-squadv2-finetuned"
_tokenizer = None
_model = None
def get_model():
global _tokenizer, _model
if _tokenizer is None:
_tokenizer = AutoTokenizer.from_pretrained(TOK_NAME)
if _model is None:
_model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
_model.eval()
return _tokenizer, _model
def answer_topk_longformer(question, chunks, top_k=5, max_answer_len=4096):
tokenizer, model = get_model()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
all_answers = []
for chunk in chunks:
if not chunk.strip():
continue
encoding = tokenizer(
question, chunk,
return_tensors="pt",
truncation="only_second",
max_length=4096,
padding="max_length",
return_offsets_mapping=True,
)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
offsets = encoding["offset_mapping"][0]
sequence_ids = encoding.sequence_ids(0)
context_start = next((i for i, s in enumerate(sequence_ids) if s == 1), None)
context_end = next((i for i in range(len(sequence_ids)-1, -1, -1) if sequence_ids[i] == 1), None)
if context_start is None or context_end is None:
continue
sep_indices = (input_ids[0] == tokenizer.sep_token_id).nonzero(as_tuple=True)[0]
if len(sep_indices) < 2:
continue
question_end = sep_indices[0].item() + 1
global_attention_mask = torch.zeros_like(input_ids)
global_attention_mask[0, :question_end] = 1
with torch.no_grad():
outputs = model(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
)
start_scores = outputs.start_logits[0]
end_scores = outputs.end_logits[0]
start_indexes = start_scores[context_start:context_end+1].argsort(descending=True)[:20]
start_indexes = [i + context_start for i in start_indexes]
for start_idx in start_indexes:
for end_idx in range(start_idx, min(start_idx + max_answer_len, context_end + 1)):
if offsets[start_idx] is None or offsets[end_idx] is None:
continue
start_char = offsets[start_idx][0].item()
end_char = offsets[end_idx][1].item()
answer_text = chunk[start_char:end_char].strip()
if not answer_text:
continue
score = (start_scores[start_idx] + end_scores[end_idx]).item()
all_answers.append((answer_text, score))
return sorted(all_answers, key=lambda x: x[1], reverse=True)[:top_k]