Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers import AutoencoderKL | |
| from . import PatchEmbed | |
| class VAE(nn.Module): | |
| # VAE代替DINOv2作为encoder的实现 | |
| # Decoder的部分也顺便写了 | |
| def __init__(self): | |
| super(VAE, self).__init__() | |
| self.vae = AutoencoderKL.from_pretrained("/home/lihong/UPS_Lightning/vae/vae/stable-diffusion-3.5-large", subfolder="vae").requires_grad_(False) # vae不训 | |
| def encode(self, x): | |
| """ | |
| x: [B*f,3,H,W](multi-lihgt) or x: [B,3,H,W](nml) | |
| """ | |
| z = self.vae.encode(x).latent_dist.sample() # [B*f,16,64,64] | |
| return z | |
| def decode(self, latent): | |
| """ | |
| x: [B,16,64,64] nml的latent | |
| """ | |
| decode_nml = self.vae.decode(latent).sample | |
| return decode_nml # |