GRM / models /MLP.py
Kang2691196427's picture
Upload 29 files
78cd756 verified
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, class_embeddings, input_size=256, hidden_sizes=[256, 256], output_size=1):
super(MLP, self).__init__()
self.class_embeddings = class_embeddings.clone().detach()
#self.class_embeddings = nn.Parameter(class_embeddings)
# Hidden layer
layers = []
in_size = input_size
for hidden_size in hidden_sizes:
layers.append(nn.Linear(in_size, hidden_size))
layers.append(nn.ReLU()) # Apply ReLU activation function
in_size = hidden_size
# Output layer
layers.append(nn.Linear(in_size, output_size))
self.model = nn.Sequential(*layers)
def forward(self, x):
# Expand the dimensions of x to concatenate with class_embeddings; x has shape (batch_size, input_size)
batch_size = x.size(0)
# Here, concatenate each input x with all rows of class_embeddings
x_expanded = x.unsqueeze(1).expand(batch_size, self.class_embeddings.size(0), -1)
device = x_expanded.device
embeddings_expanded = self.class_embeddings.unsqueeze(0).expand(batch_size, -1, -1).to(device)
# Concatenate x and class_embeddings
x_combined = torch.cat((x_expanded, embeddings_expanded), dim=-1) # 沿着最后一个维度拼接
# Flatten the concatenated tensor to pass it to subsequent layers
x_combined = x_combined.view(batch_size * self.class_embeddings.size(0),
-1) # Flatten to (batch_size * n, input_size + h)
# Forward pass through the network
output = self.model(x_combined)
# The returned output needs to be reshaped to (batch_size, n, output_size) to match the final output
output = output.view(batch_size, self.class_embeddings.size(0)) # Reshape to (batch_size, n)
return output