prostochel097 commited on
Commit
88ca267
·
verified ·
1 Parent(s): 9bca349

Create modeling_alphapilot.py

Browse files
Files changed (1) hide show
  1. modeling_alphapilot.py +25 -0
modeling_alphapilot.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from .configuration_alphapilot import AlphaPilotConfig
5
+
6
+ class AlphaPilotModel(PreTrainedModel):
7
+ config_class = AlphaPilotConfig
8
+
9
+ def __init__(self, config):
10
+ super().__init__(config)
11
+
12
+ layers = []
13
+ input_dim = config.state_dim
14
+
15
+ for h_dim in config.hidden_layers:
16
+ layers.append(nn.Linear(input_dim, h_dim))
17
+ layers.append(nn.ReLU())
18
+ input_dim = h_dim
19
+
20
+ layers.append(nn.Linear(input_dim, config.action_dim))
21
+
22
+ self.net = nn.Sequential(*layers)
23
+
24
+ def forward(self, x):
25
+ return self.net(x)