| | import torch |
| |
|
| | from .modules import ResBlock |
| |
|
| |
|
| | class MyBackbone(torch.nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | num_layers: int = 2, |
| | input_dim: int = 2, |
| | hidden_dim: int = 128, |
| | output_dim: int = 2, |
| | ): |
| | super().__init__() |
| | self.num_layers = num_layers |
| | self.input_dim = input_dim |
| | self.hidden_dim = hidden_dim |
| | self.output_dim = output_dim |
| |
|
| | |
| | layers = [torch.nn.Linear(input_dim, hidden_dim)] |
| | for _ in range(num_layers): |
| | layers.append(ResBlock(hidden_dim)) |
| | layers.append(torch.nn.Linear(hidden_dim, output_dim)) |
| | |
| | self.model = torch.nn.Sequential(*layers) |
| | |
| | def forward(self, x): |
| | |
| | return self.model(x) |
| | |