trioskosmos commited on
Commit
e129a67
·
verified ·
1 Parent(s): f9b81b4

Upload ai/models/student_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/models/student_model.py +41 -0
ai/models/student_model.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ STUDENT_HIDDEN = 32
5
+
6
+
7
+ class StudentActor(nn.Module):
8
+ def __init__(self, obs_dim, action_dim):
9
+ super().__init__()
10
+ self.net = nn.Sequential(
11
+ nn.Linear(obs_dim, STUDENT_HIDDEN),
12
+ nn.ReLU(),
13
+ nn.Linear(STUDENT_HIDDEN, STUDENT_HIDDEN),
14
+ nn.ReLU(),
15
+ nn.Linear(STUDENT_HIDDEN, action_dim),
16
+ )
17
+
18
+ def forward(self, x):
19
+ return self.net(x)
20
+
21
+ def predict(self, obs, action_masks=None, deterministic=True):
22
+ # API Matcher for batched_env
23
+ with torch.no_grad():
24
+ x = torch.as_tensor(obs).float()
25
+ if next(self.parameters()).is_cuda:
26
+ x = x.to(next(self.parameters()).device)
27
+
28
+ logits = self.net(x)
29
+
30
+ if action_masks is not None:
31
+ # Apply mask (set invalid logits to -inf)
32
+ masks = torch.as_tensor(action_masks, device=logits.device)
33
+ logits[~masks.bool()] = -1e8
34
+
35
+ if deterministic:
36
+ actions = torch.argmax(logits, dim=1)
37
+ else:
38
+ probs = torch.softmax(logits, dim=1)
39
+ actions = torch.multinomial(probs, 1).squeeze(1)
40
+
41
+ return actions.cpu().numpy(), None