import torch from transformers import BertModel import os class EnsembleIdentfier(torch.nn.Module): def __init__(self, models_path, model_name): super().__init__() self.model_name = model_name self.models = torch.nn.ModuleList() # List .pt files in models_path for filename in os.listdir(models_path): if filename.endswith(".pt"): model = LanguageIdentfier(self.model_name) model.load_state_dict(torch.load(os.path.join(models_path, filename))) model.eval() self.models.append(model) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def forward(self, input_ids, attention_mask): logits = torch.zeros(len(self.models), input_ids.shape[0]).to(self.device) for i, model in enumerate(self.models): model.to(self.device) logits[i] = model(input_ids, attention_mask=attention_mask).squeeze(dim=1) model.cpu() return logits class LanguageIdentfier(torch.nn.Module): def __init__(self, model_name): super().__init__() self.model = BertModel.from_pretrained(model_name) self.dropout = torch.nn.Dropout(0.1) self.linear = torch.nn.Linear( self.model.config.hidden_size, 1) self.sigmoid = torch.nn.Sigmoid() def forward(self, input_ids, attention_mask): outputs = self.model(input_ids, attention_mask=attention_mask) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.linear(pooled_output) logits = self.sigmoid(logits) return logits