import copy import torch from model import RoBERTa from torch import nn from transformers import RobertaTokenizerFast class RoBERTaModule(nn.Module): def __init__(self): super().__init__() self.tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") self.model = RoBERTa(vocab_size=self.tokenizer.vocab_size, padding_idx=self.tokenizer.pad_token_id) def forward(self, x, attn_mask): return self.model(x, attn_mask) def inference(self, sentence): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(device) self.model.eval() tokenizer = self.tokenizer input_ids = tokenizer.encode(sentence) input_ids_tensor = torch.tensor([input_ids]).to(device) attention_mask = (input_ids_tensor != tokenizer.pad_token_id).long() mask_token_id = tokenizer.mask_token_id mask_indices = [i for i, token in enumerate(input_ids) if token == mask_token_id] if not mask_indices: return "No token found" with torch.no_grad(): logits = self.model(input_ids_tensor, attention_mask) predicted_tokens = [] for idx in mask_indices: pred_token_id = logits[0, idx].argmax().item() predicted_tokens.append(tokenizer.decode([pred_token_id])) return predicted_tokens if len(predicted_tokens) > 1 else predicted_tokens[0] def load_checkpoint(self, path="finishedBest10.pt"): checkpoint = torch.load(path, map_location=torch.device("cpu")) self.model.load_state_dict(checkpoint["model_state_dict"])