File size: 1,657 Bytes
0c8750c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import copy
import torch
from model import RoBERTa
from torch import nn
from transformers import RobertaTokenizerFast


class RoBERTaModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
        self.model = RoBERTa(vocab_size=self.tokenizer.vocab_size, padding_idx=self.tokenizer.pad_token_id)

    def forward(self, x, attn_mask):
        return self.model(x, attn_mask)

    def inference(self, sentence):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)
        self.model.eval()

        tokenizer = self.tokenizer
        input_ids = tokenizer.encode(sentence)
        input_ids_tensor = torch.tensor([input_ids]).to(device)
        attention_mask = (input_ids_tensor != tokenizer.pad_token_id).long()

        mask_token_id = tokenizer.mask_token_id
        mask_indices = [i for i, token in enumerate(input_ids) if token == mask_token_id]
        if not mask_indices:
            return "No <mask> token found"

        with torch.no_grad():
            logits = self.model(input_ids_tensor, attention_mask)

        predicted_tokens = []
        for idx in mask_indices:
            pred_token_id = logits[0, idx].argmax().item()
            predicted_tokens.append(tokenizer.decode([pred_token_id]))

        return predicted_tokens if len(predicted_tokens) > 1 else predicted_tokens[0]
    
    def load_checkpoint(self, path="finishedBest10.pt"):
        checkpoint = torch.load(path, map_location=torch.device("cpu"))
        self.model.load_state_dict(checkpoint["model_state_dict"])