File size: 914 Bytes
912c7e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn


class MLP(nn.Module):
    def __init__(
        self,
        num_in,
        num_out,
        num_hid=500,
    ):
        super().__init__()
        self.num_in = num_in
        self.num_hid = num_hid
        self.num_out = num_out

        self.rdFrequency = torch.normal(0, 1, (1, 100))

        self.net = nn.Sequential(
            nn.Linear(num_in + 100, num_hid),
            nn.SiLU(),
            nn.Linear(num_hid, num_hid),
            nn.SiLU(),
            nn.Linear(num_hid, num_hid),
            nn.SiLU(),
            nn.Linear(num_hid, num_hid),
            nn.SiLU(),
            nn.Linear(num_hid, num_out),
        )

    def forward(self, noisy_y, timesteps):
        time_feature = torch.cos(torch.matmul(timesteps, self.rdFrequency.to(timesteps)))
        x_in = torch.cat([noisy_y, time_feature], dim=-1)
        out = self.net(x_in)
        return out