Spaces:
Build error
Build error
File size: 1,735 Bytes
8b13a4f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | import torch
from transformers import BertModel
import os
class EnsembleIdentfier(torch.nn.Module):
def __init__(self, models_path, model_name):
super().__init__()
self.model_name = model_name
self.models = torch.nn.ModuleList()
# List .pt files in models_path
for filename in os.listdir(models_path):
if filename.endswith(".pt"):
model = LanguageIdentfier(self.model_name)
model.load_state_dict(torch.load(os.path.join(models_path, filename)))
model.eval()
self.models.append(model)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def forward(self, input_ids, attention_mask):
logits = torch.zeros(len(self.models), input_ids.shape[0]).to(self.device)
for i, model in enumerate(self.models):
model.to(self.device)
logits[i] = model(input_ids, attention_mask=attention_mask).squeeze(dim=1)
model.cpu()
return logits
class LanguageIdentfier(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
self.model = BertModel.from_pretrained(model_name)
self.dropout = torch.nn.Dropout(0.1)
self.linear = torch.nn.Linear(
self.model.config.hidden_size, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, input_ids, attention_mask):
outputs = self.model(input_ids, attention_mask=attention_mask)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.linear(pooled_output)
logits = self.sigmoid(logits)
return logits
|