File size: 694 Bytes
88ca267 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_alphapilot import AlphaPilotConfig
class AlphaPilotModel(PreTrainedModel):
config_class = AlphaPilotConfig
def __init__(self, config):
super().__init__(config)
layers = []
input_dim = config.state_dim
for h_dim in config.hidden_layers:
layers.append(nn.Linear(input_dim, h_dim))
layers.append(nn.ReLU())
input_dim = h_dim
layers.append(nn.Linear(input_dim, config.action_dim))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x) |