alphapilot-v1 / modeling_alphapilot.py
prostochel097's picture
Create modeling_alphapilot.py
88ca267 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_alphapilot import AlphaPilotConfig
class AlphaPilotModel(PreTrainedModel):
config_class = AlphaPilotConfig
def __init__(self, config):
super().__init__(config)
layers = []
input_dim = config.state_dim
for h_dim in config.hidden_layers:
layers.append(nn.Linear(input_dim, h_dim))
layers.append(nn.ReLU())
input_dim = h_dim
layers.append(nn.Linear(input_dim, config.action_dim))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)