| | import torch |
| | from transformers import BertModel |
| | import os |
| |
|
| | class EnsembleIdentfier(torch.nn.Module): |
| | def __init__(self, models_path, model_name): |
| | super().__init__() |
| | self.model_name = model_name |
| |
|
| | self.models = torch.nn.ModuleList() |
| | |
| | |
| | for filename in os.listdir(models_path): |
| | if filename.endswith(".pt"): |
| | model = LanguageIdentfier(self.model_name) |
| | model.load_state_dict(torch.load(os.path.join(models_path, filename))) |
| | model.eval() |
| | self.models.append(model) |
| | |
| | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | def forward(self, input_ids, attention_mask): |
| | logits = torch.zeros(len(self.models), input_ids.shape[0]).to(self.device) |
| |
|
| | for i, model in enumerate(self.models): |
| | model.to(self.device) |
| | logits[i] = model(input_ids, attention_mask=attention_mask).squeeze(dim=1) |
| | model.cpu() |
| |
|
| | return logits |
| |
|
| |
|
| | class LanguageIdentfier(torch.nn.Module): |
| | def __init__(self, model_name): |
| | super().__init__() |
| | self.model = BertModel.from_pretrained(model_name) |
| | self.dropout = torch.nn.Dropout(0.1) |
| |
|
| | self.linear = torch.nn.Linear( |
| | self.model.config.hidden_size, 1) |
| |
|
| | self.sigmoid = torch.nn.Sigmoid() |
| |
|
| | def forward(self, input_ids, attention_mask): |
| | outputs = self.model(input_ids, attention_mask=attention_mask) |
| | pooled_output = outputs[1] |
| | pooled_output = self.dropout(pooled_output) |
| | logits = self.linear(pooled_output) |
| | logits = self.sigmoid(logits) |
| |
|
| | return logits |
| |
|