| import torch | |
| from torch import nn | |
| from unet import UNET, UNETOutputLayer | |
| from unet import TimeEmbedding | |
| class Diffusion(nn.Module): | |
| def __init__(self, h_dim=128, n_head=4): | |
| super().__init__() | |
| self.time_embedding = TimeEmbedding(320) | |
| self.unet = UNET(h_dim, n_head) | |
| self.unet_output = UNETOutputLayer(h_dim, 4) | |
| def forward(self, latent, context, time): | |
| time = self.time_embedding(time) | |
| output = self.unet(latent, context, time) | |
| output = self.unet_output(output) | |
| return output | |