|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
|
|
class FingerNetConfig(PretrainedConfig):
|
|
|
model_type = "fingernet"
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
x_dim=[6],
|
|
|
y_dim=[6, 1800],
|
|
|
h1_dim=[100, 1000],
|
|
|
h2_dim=[100, 1000],
|
|
|
**kwargs,
|
|
|
):
|
|
|
super().__init__(**kwargs)
|
|
|
self.x_dim = x_dim
|
|
|
self.y_dim = y_dim
|
|
|
self.h1_dim = h1_dim
|
|
|
self.h2_dim = h2_dim
|
|
|
|
|
|
|
|
|
class FingerNetSurf(PreTrainedModel):
|
|
|
config_class = FingerNetConfig
|
|
|
|
|
|
def __init__(self, config):
|
|
|
super().__init__(config)
|
|
|
self.x_dim = config.x_dim
|
|
|
self.y_dim = config.y_dim
|
|
|
self.h1_dim = config.h1_dim
|
|
|
self.h2_dim = config.h2_dim
|
|
|
|
|
|
self.model = nn.ModuleDict()
|
|
|
|
|
|
for i in range(len(self.y_dim)):
|
|
|
self.model[f"estimator_{i}"] = nn.Sequential(
|
|
|
nn.Linear(self.x_dim[0], self.h1_dim[i]),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(self.h1_dim[i], self.h2_dim[i]),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(self.h2_dim[i], self.y_dim[i]),
|
|
|
)
|
|
|
|
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
def forward(self, x):
|
|
|
outputs = []
|
|
|
for i in range(len(self.y_dim)):
|
|
|
|
|
|
estimator = self.model[f"estimator_{i}"]
|
|
|
y = estimator(x)
|
|
|
outputs.append(y)
|
|
|
return outputs
|
|
|
|