Spaces:
Build error
Build error
| import torch | |
| from transformers import BertModel | |
| import os | |
| class EnsembleIdentfier(torch.nn.Module): | |
| def __init__(self, models_path, model_name): | |
| super().__init__() | |
| self.model_name = model_name | |
| self.models = torch.nn.ModuleList() | |
| # List .pt files in models_path | |
| for filename in os.listdir(models_path): | |
| if filename.endswith(".pt"): | |
| model = LanguageIdentfier(self.model_name) | |
| model.load_state_dict(torch.load(os.path.join(models_path, filename))) | |
| model.eval() | |
| self.models.append(model) | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| def forward(self, input_ids, attention_mask): | |
| logits = torch.zeros(len(self.models), input_ids.shape[0]).to(self.device) | |
| for i, model in enumerate(self.models): | |
| model.to(self.device) | |
| logits[i] = model(input_ids, attention_mask=attention_mask).squeeze(dim=1) | |
| model.cpu() | |
| return logits | |
| class LanguageIdentfier(torch.nn.Module): | |
| def __init__(self, model_name): | |
| super().__init__() | |
| self.model = BertModel.from_pretrained(model_name) | |
| self.dropout = torch.nn.Dropout(0.1) | |
| self.linear = torch.nn.Linear( | |
| self.model.config.hidden_size, 1) | |
| self.sigmoid = torch.nn.Sigmoid() | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.model(input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs[1] | |
| pooled_output = self.dropout(pooled_output) | |
| logits = self.linear(pooled_output) | |
| logits = self.sigmoid(logits) | |
| return logits | |