import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig class FingerNetConfig(PretrainedConfig): model_type = "fingernet" def __init__( self, x_dim=[6], y_dim=[6, 1800], h1_dim=[100, 1000], h2_dim=[100, 1000], **kwargs, ): super().__init__(**kwargs) self.x_dim = x_dim self.y_dim = y_dim self.h1_dim = h1_dim self.h2_dim = h2_dim class FingerNetSurf(PreTrainedModel): config_class = FingerNetConfig def __init__(self, config): super().__init__(config) self.x_dim = config.x_dim self.y_dim = config.y_dim self.h1_dim = config.h1_dim self.h2_dim = config.h2_dim self.model = nn.ModuleDict() # Define the model architecture for i in range(len(self.y_dim)): self.model[f"estimator_{i}"] = nn.Sequential( nn.Linear(self.x_dim[0], self.h1_dim[i]), nn.ReLU(), nn.Linear(self.h1_dim[i], self.h2_dim[i]), nn.ReLU(), nn.Linear(self.h2_dim[i], self.y_dim[i]), ) # initialize weights self.post_init() def forward(self, x): outputs = [] for i in range(len(self.y_dim)): # Get the estimator for the i-th output estimator = self.model[f"estimator_{i}"] y = estimator(x) outputs.append(y) return outputs