import torch import torch.nn as nn from transformers import PreTrainedModel from .configuration_alphapilot import AlphaPilotConfig class AlphaPilotModel(PreTrainedModel): config_class = AlphaPilotConfig def __init__(self, config): super().__init__(config) layers = [] input_dim = config.state_dim for h_dim in config.hidden_layers: layers.append(nn.Linear(input_dim, h_dim)) layers.append(nn.ReLU()) input_dim = h_dim layers.append(nn.Linear(input_dim, config.action_dim)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x)