File size: 3,803 Bytes
cd6454b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import random
import numpy as np
import torch
from transformers import BertTokenizer, BertModel
import torch.nn.functional as F
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

class PolyEncoder(torch.nn.Module):
    def __init__(self, bert_model_name='klue/bert-base', poly_m=16):
        super(PolyEncoder, self).__init__()
        self.poly_m = poly_m
        self.bert_model = BertModel.from_pretrained(bert_model_name)
        self.poly_code_embeddings = torch.nn.Embedding(poly_m, self.bert_model.config.hidden_size)
        
    def forward(self, context_input_ids, context_attention_mask, question_input_ids, question_attention_mask):
        # Encode the question
        question_outputs = self.bert_model(input_ids=question_input_ids, attention_mask=question_attention_mask)
        question_cls_embeddings = question_outputs.last_hidden_state[:, 0, :]  # CLS token

        # Encode the context
        context_outputs = self.bert_model(input_ids=context_input_ids, attention_mask=context_attention_mask)
        context_hidden_states = context_outputs.last_hidden_state

        # Poly codes
        poly_codes = self.poly_code_embeddings.weight.unsqueeze(0).expand(context_hidden_states.size(0), -1, -1)

        # Context and poly code interactions
        attention_weights = F.softmax(torch.einsum('bmd,bnd->bmn', context_hidden_states, poly_codes), dim=-1)
        poly_context_embeddings = torch.einsum('bmn,bmd->bnd', attention_weights, context_hidden_states)

        # Question and poly context interactions
        scores = torch.einsum('bnd,bmd->bnm', poly_context_embeddings, question_cls_embeddings.unsqueeze(1).expand(-1, self.poly_m, -1))

        # Aggregate scores over poly_m dimension
        scores = scores.max(dim=1).values

        return scores

def get_top_n_relevant_sentences(context, question, tokenizer, model, top_n):
    context_sentences = sent_tokenize(context)  # NLTK를 사용하여 문장 분할

    context_inputs = tokenizer(context_sentences, padding=True, truncation=True, return_tensors='pt')
    question_inputs = tokenizer(question, return_tensors='pt')

    with torch.no_grad():
        scores = model(context_inputs['input_ids'], context_inputs['attention_mask'], 
                       question_inputs['input_ids'].expand(len(context_sentences), -1),
                       question_inputs['attention_mask'].expand(len(context_sentences), -1))

    score_rows, score_cols = scores.shape

    scores_index = scores[:, 0].tolist()
    indexed_dict = {idx: value for idx, value in enumerate(scores_index)}
    sorted_dict = dict(sorted(indexed_dict.items(), key=lambda item: item[1], reverse=True))
    sorted_data = sorted(sorted_dict.items(), key=lambda item: item[1], reverse=True)
    top_n_keys = list(sorted_dict.keys())[:top_n]
    unique_values = set()
    top_keys = []

    for key, value in sorted_data:
        if value not in unique_values:
            unique_values.add(value)
            top_keys.append(key)
        if len(top_keys) == top_n:
            break

    top_n_sentences = [context_sentences[idx] for idx in top_keys]
    return top_n_sentences

# 예제 실행 함수
def run_example(context, question):
    # 모델 및 토크나이저 로드를 전역 변수로 설정
    tokenizer = BertTokenizer.from_pretrained('klue/bert-base')
    model = PolyEncoder(bert_model_name='klue/bert-base')

    top_n_sentences = get_top_n_relevant_sentences(context, question, tokenizer, model, top_n=5)
    sentences = ""
    for sentence in top_n_sentences:
        sentences+=sentence
    print(sentences)
    return sentences