| import torch | |
| import torch.nn as nn | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| num_in, | |
| num_out, | |
| num_hid=500, | |
| ): | |
| super().__init__() | |
| self.num_in = num_in | |
| self.num_hid = num_hid | |
| self.num_out = num_out | |
| self.rdFrequency = torch.normal(0, 1, (1, 100)) | |
| self.net = nn.Sequential( | |
| nn.Linear(num_in + 100, num_hid), | |
| nn.SiLU(), | |
| nn.Linear(num_hid, num_hid), | |
| nn.SiLU(), | |
| nn.Linear(num_hid, num_hid), | |
| nn.SiLU(), | |
| nn.Linear(num_hid, num_hid), | |
| nn.SiLU(), | |
| nn.Linear(num_hid, num_out), | |
| ) | |
| def forward(self, noisy_y, timesteps): | |
| time_feature = torch.cos(torch.matmul(timesteps, self.rdFrequency.to(timesteps))) | |
| x_in = torch.cat([noisy_y, time_feature], dim=-1) | |
| out = self.net(x_in) | |
| return out | |