File size: 3,096 Bytes
be3dd25
 
 
b28c678
 
 
 
 
 
be3dd25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

import os
from huggingface_hub import login

login(token=os.environ.get("visual_cuad_vcuad"))


TOK_NAME = "mrm8488/longformer-base-4096-finetuned-squadv2"
MODEL_NAME = "jira877832/cuad-longformer-squadv2-finetuned"

_tokenizer = None
_model = None

def get_model():
    global _tokenizer, _model
    if _tokenizer is None:
        _tokenizer = AutoTokenizer.from_pretrained(TOK_NAME)
    if _model is None:
        _model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
        _model.eval()
    return _tokenizer, _model

def answer_topk_longformer(question, chunks, top_k=5, max_answer_len=4096):
    tokenizer, model = get_model()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    all_answers = []

    for chunk in chunks:
        if not chunk.strip():
            continue
        encoding = tokenizer(
            question, chunk,
            return_tensors="pt",
            truncation="only_second",
            max_length=4096,
            padding="max_length",
            return_offsets_mapping=True,
        )
        input_ids = encoding["input_ids"].to(device)
        attention_mask = encoding["attention_mask"].to(device)
        offsets = encoding["offset_mapping"][0]
        sequence_ids = encoding.sequence_ids(0)

        context_start = next((i for i, s in enumerate(sequence_ids) if s == 1), None)
        context_end = next((i for i in range(len(sequence_ids)-1, -1, -1) if sequence_ids[i] == 1), None)
        if context_start is None or context_end is None:
            continue

        sep_indices = (input_ids[0] == tokenizer.sep_token_id).nonzero(as_tuple=True)[0]
        if len(sep_indices) < 2:
            continue
        question_end = sep_indices[0].item() + 1
        global_attention_mask = torch.zeros_like(input_ids)
        global_attention_mask[0, :question_end] = 1

        with torch.no_grad():
            outputs = model(
                input_ids,
                attention_mask=attention_mask,
                global_attention_mask=global_attention_mask,
            )

        start_scores = outputs.start_logits[0]
        end_scores = outputs.end_logits[0]
        start_indexes = start_scores[context_start:context_end+1].argsort(descending=True)[:20]
        start_indexes = [i + context_start for i in start_indexes]

        for start_idx in start_indexes:
            for end_idx in range(start_idx, min(start_idx + max_answer_len, context_end + 1)):
                if offsets[start_idx] is None or offsets[end_idx] is None:
                    continue
                start_char = offsets[start_idx][0].item()
                end_char = offsets[end_idx][1].item()
                answer_text = chunk[start_char:end_char].strip()
                if not answer_text:
                    continue
                score = (start_scores[start_idx] + end_scores[end_idx]).item()
                all_answers.append((answer_text, score))

    return sorted(all_answers, key=lambda x: x[1], reverse=True)[:top_k]