File size: 1,735 Bytes
8b13a4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from transformers import BertModel
import os

class EnsembleIdentfier(torch.nn.Module):
    def __init__(self, models_path, model_name):
        super().__init__()
        self.model_name = model_name

        self.models = torch.nn.ModuleList()
        # List .pt files in models_path
        
        for filename in os.listdir(models_path):
            if filename.endswith(".pt"):
                model = LanguageIdentfier(self.model_name)
                model.load_state_dict(torch.load(os.path.join(models_path, filename)))
                model.eval()
                self.models.append(model)
             
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def forward(self, input_ids, attention_mask):
        logits = torch.zeros(len(self.models), input_ids.shape[0]).to(self.device)

        for i, model in enumerate(self.models):
            model.to(self.device)
            logits[i] = model(input_ids, attention_mask=attention_mask).squeeze(dim=1)
            model.cpu()

        return logits


class LanguageIdentfier(torch.nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model = BertModel.from_pretrained(model_name)
        self.dropout = torch.nn.Dropout(0.1)

        self.linear = torch.nn.Linear(
            self.model.config.hidden_size, 1)

        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        
        logits = self.linear(pooled_output)
        logits = self.sigmoid(logits)

        return logits