wjnwjn59's picture
fix import relative funcs
4ad132d
raw
history blame contribute delete
645 Bytes
import torch
from torch import nn
from unet import UNET, UNETOutputLayer
from unet import TimeEmbedding
class Diffusion(nn.Module):
def __init__(self, h_dim=128, n_head=4):
super().__init__()
self.time_embedding = TimeEmbedding(320)
self.unet = UNET(h_dim, n_head)
self.unet_output = UNETOutputLayer(h_dim, 4)
@torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True, cache_enabled=True)
def forward(self, latent, context, time):
time = self.time_embedding(time)
output = self.unet(latent, context, time)
output = self.unet_output(output)
return output