AlexSychovUN commited on
Commit
b3705e9
·
1 Parent(s): 1e0ff3f

Added files

Browse files
Files changed (1) hide show
  1. pinns_practice/basic_pinn.py +67 -0
pinns_practice/basic_pinn.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import matplotlib.pyplot as plt
4
+
5
+ class BasicPINN(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.net = nn.Sequential(
9
+ nn.Linear(1, 20),
10
+ nn.Tanh(), # for RELU 2 derivative is always 0, so use Tanh
11
+ nn.Linear(20, 20),
12
+ nn.Tanh(),
13
+ nn.Linear(20, 1)
14
+ )
15
+
16
+ def forward(self, x):
17
+ x =self.net(x)
18
+ return x
19
+
20
+ model = BasicPINN()
21
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
22
+
23
+ for epoch in range(2000):
24
+ optimizer.zero_grad()
25
+
26
+ t_physics = torch.rand(100, 1).requires_grad_(True) # requires_grad=True important for derivatives
27
+
28
+ y_pred = model(t_physics)
29
+
30
+ # y = e ^ (-2t)
31
+ # dy/dt = -2y
32
+
33
+ # Calculating derivative dy/dt
34
+ # We use PyTorch auto-differentiation to find the rate of change of y_pred with respect to t_physics.
35
+ dy_dt = torch.autograd.grad(
36
+ outputs=y_pred, # What we differentiate, y
37
+ inputs=t_physics, # What we differentiate with respect to, (time, t)
38
+ grad_outputs=torch.ones_like(y_pred), # vector from 1, for 100 examples, calculates gradients independently
39
+ create_graph=True, # history of calculations, critical for PINNs
40
+ )[0]
41
+
42
+ # Physical Loss dy/dt + 2y = 0
43
+ physical_loss = torch.mean((dy_dt + 2 * y_pred) ** 2)
44
+
45
+ # Initial condition, t = 0 -> 1.0
46
+ t_0 = torch.zeros(1, 1)
47
+ y_0_pred = model(t_0)
48
+ initial_condition_loss = torch.mean((y_0_pred - 1.0) ** 2)
49
+
50
+ loss = physical_loss + initial_condition_loss
51
+ loss.backward()
52
+ optimizer.step()
53
+
54
+ if epoch % 200 == 0:
55
+ print(f"Epoch {epoch}, Loss: {loss.item():.5f}")
56
+
57
+ t_test = torch.linspace(0, 2, 100).view(-1, 1)
58
+ with torch.no_grad():
59
+ y_test_pred = model(t_test)
60
+
61
+
62
+ y_exact = torch.exp(-2 * t_test)
63
+ plt.plot(t_test.numpy(), y_test_pred.numpy(), label="PINN model", color="red", linestyle="--")
64
+ plt.plot(t_test.numpy(), y_exact.numpy(), label="Exact solution (Math)", alpha=0.5)
65
+ plt.legend()
66
+ plt.title("Solving the differential equation!!")
67
+ plt.show()