File size: 645 Bytes
e1dc1a5 4ad132d e1dc1a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import torch
from torch import nn
from unet import UNET, UNETOutputLayer
from unet import TimeEmbedding
class Diffusion(nn.Module):
def __init__(self, h_dim=128, n_head=4):
super().__init__()
self.time_embedding = TimeEmbedding(320)
self.unet = UNET(h_dim, n_head)
self.unet_output = UNETOutputLayer(h_dim, 4)
@torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True, cache_enabled=True)
def forward(self, latent, context, time):
time = self.time_embedding(time)
output = self.unet(latent, context, time)
output = self.unet_output(output)
return output
|