Demo / model.py
arubenruben's picture
Submit Demo
8b13a4f
import torch
from transformers import BertModel
import os
class EnsembleIdentfier(torch.nn.Module):
def __init__(self, models_path, model_name):
super().__init__()
self.model_name = model_name
self.models = torch.nn.ModuleList()
# List .pt files in models_path
for filename in os.listdir(models_path):
if filename.endswith(".pt"):
model = LanguageIdentfier(self.model_name)
model.load_state_dict(torch.load(os.path.join(models_path, filename)))
model.eval()
self.models.append(model)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def forward(self, input_ids, attention_mask):
logits = torch.zeros(len(self.models), input_ids.shape[0]).to(self.device)
for i, model in enumerate(self.models):
model.to(self.device)
logits[i] = model(input_ids, attention_mask=attention_mask).squeeze(dim=1)
model.cpu()
return logits
class LanguageIdentfier(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
self.model = BertModel.from_pretrained(model_name)
self.dropout = torch.nn.Dropout(0.1)
self.linear = torch.nn.Linear(
self.model.config.hidden_size, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, input_ids, attention_mask):
outputs = self.model(input_ids, attention_mask=attention_mask)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.linear(pooled_output)
logits = self.sigmoid(logits)
return logits