LordXido commited on
Commit
2bdc5ba
·
verified ·
1 Parent(s): eaa19f0

Create policy.py

Browse files
Files changed (1) hide show
  1. policy.py +35 -0
policy.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+
5
+
6
+ class PsiPolicy(nn.Module):
7
+ def __init__(self, dim):
8
+ super().__init__()
9
+ self.model = nn.Sequential(
10
+ nn.Linear(dim, 32),
11
+ nn.ReLU(),
12
+ nn.Linear(32, dim)
13
+ )
14
+
15
+ def forward(self, x):
16
+ return self.model(x)
17
+
18
+
19
+ class PolicyController:
20
+ def __init__(self, system):
21
+ self.system = system
22
+ self.policy = PsiPolicy(system.n)
23
+ self.opt = optim.Adam(self.policy.parameters(), lr=0.001)
24
+
25
+ def policy_step(self):
26
+ state = torch.tensor(self.system.Xi, dtype=torch.float32)
27
+ Psi = self.policy(state)
28
+ Psi_np = Psi.detach().numpy()
29
+
30
+ self.system.step(Psi_np)
31
+
32
+ loss = torch.norm(Psi)
33
+ self.opt.zero_grad()
34
+ loss.backward()
35
+ self.opt.step()