| import torch
|
| import torch.nn as nn
|
|
|
|
|
| class MLP(nn.Module):
|
| def __init__(self, class_embeddings, input_size=256, hidden_sizes=[256, 256], output_size=1):
|
| super(MLP, self).__init__()
|
|
|
| self.class_embeddings = class_embeddings.clone().detach()
|
|
|
|
|
|
|
| layers = []
|
| in_size = input_size
|
| for hidden_size in hidden_sizes:
|
| layers.append(nn.Linear(in_size, hidden_size))
|
| layers.append(nn.ReLU())
|
| in_size = hidden_size
|
|
|
|
|
| layers.append(nn.Linear(in_size, output_size))
|
| self.model = nn.Sequential(*layers)
|
|
|
| def forward(self, x):
|
|
|
| batch_size = x.size(0)
|
|
|
| x_expanded = x.unsqueeze(1).expand(batch_size, self.class_embeddings.size(0), -1)
|
| device = x_expanded.device
|
| embeddings_expanded = self.class_embeddings.unsqueeze(0).expand(batch_size, -1, -1).to(device)
|
|
|
|
|
| x_combined = torch.cat((x_expanded, embeddings_expanded), dim=-1)
|
|
|
|
|
| x_combined = x_combined.view(batch_size * self.class_embeddings.size(0),
|
| -1)
|
|
|
| output = self.model(x_combined)
|
|
|
|
|
| output = output.view(batch_size, self.class_embeddings.size(0))
|
| return output |