import torch import torch.nn as nn class MLP(nn.Module): def __init__(self, class_embeddings, input_size=256, hidden_sizes=[256, 256], output_size=1): super(MLP, self).__init__() self.class_embeddings = class_embeddings.clone().detach() #self.class_embeddings = nn.Parameter(class_embeddings) # Hidden layer layers = [] in_size = input_size for hidden_size in hidden_sizes: layers.append(nn.Linear(in_size, hidden_size)) layers.append(nn.ReLU()) # Apply ReLU activation function in_size = hidden_size # Output layer layers.append(nn.Linear(in_size, output_size)) self.model = nn.Sequential(*layers) def forward(self, x): # Expand the dimensions of x to concatenate with class_embeddings; x has shape (batch_size, input_size) batch_size = x.size(0) # Here, concatenate each input x with all rows of class_embeddings x_expanded = x.unsqueeze(1).expand(batch_size, self.class_embeddings.size(0), -1) device = x_expanded.device embeddings_expanded = self.class_embeddings.unsqueeze(0).expand(batch_size, -1, -1).to(device) # Concatenate x and class_embeddings x_combined = torch.cat((x_expanded, embeddings_expanded), dim=-1) # 沿着最后一个维度拼接 # Flatten the concatenated tensor to pass it to subsequent layers x_combined = x_combined.view(batch_size * self.class_embeddings.size(0), -1) # Flatten to (batch_size * n, input_size + h) # Forward pass through the network output = self.model(x_combined) # The returned output needs to be reshaped to (batch_size, n, output_size) to match the final output output = output.view(batch_size, self.class_embeddings.size(0)) # Reshape to (batch_size, n) return output