File size: 645 Bytes
e1dc1a5
 
 
4ad132d
e1dc1a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from torch import nn
from unet import UNET, UNETOutputLayer
from unet import TimeEmbedding


class Diffusion(nn.Module):
    def __init__(self, h_dim=128, n_head=4):
        super().__init__()
        self.time_embedding = TimeEmbedding(320)
        self.unet = UNET(h_dim, n_head)
        self.unet_output = UNETOutputLayer(h_dim, 4)

    @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True, cache_enabled=True)
    def forward(self, latent, context, time):
        time = self.time_embedding(time)
        output = self.unet(latent, context, time)
        output = self.unet_output(output)
        return output