Spaces:
Runtime error
Runtime error
| # Copyright 2024 EPFL and Apple Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| lm: latent mapping | |
| """ | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_up_block | |
| FREEZE_MODULES = ['encoder', 'quant_proj', 'quantize', 'cls_emb'] | |
| class Token2VAE(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels=32, | |
| output_type="stats", | |
| up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D",), | |
| block_out_channels=(256, 512), | |
| layers_per_block=2, | |
| norm_num_groups=32, | |
| act_fn="silu", | |
| vq_model=None, | |
| vae=None, | |
| ): | |
| super().__init__() | |
| assert output_type in ["stats", "sample"], "`output_type` can be either of 'stats' or 'sample'" | |
| self.output_type = output_type | |
| out_channels = 4 if output_type == "sample" else 8 | |
| self.layers_per_block = layers_per_block | |
| self.conv_in = nn.Conv2d( | |
| in_channels, | |
| block_out_channels[-1], | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| ) | |
| self.mid_block = None | |
| self.up_blocks = nn.ModuleList([]) | |
| # mid | |
| self.mid_block = UNetMidBlock2D( | |
| in_channels=block_out_channels[-1], | |
| resnet_eps=1e-6, | |
| resnet_act_fn=act_fn, | |
| output_scale_factor=1, | |
| resnet_time_scale_shift="default", | |
| attention_head_dim=block_out_channels[-1], | |
| resnet_groups=norm_num_groups, | |
| temb_channels=None, | |
| ) | |
| # up | |
| reversed_block_out_channels = list(reversed(block_out_channels)) | |
| output_channel = reversed_block_out_channels[0] | |
| for i, up_block_type in enumerate(up_block_types): | |
| prev_output_channel = output_channel | |
| output_channel = reversed_block_out_channels[i] | |
| is_final_block = i == len(block_out_channels) - 1 | |
| up_block = get_up_block( | |
| up_block_type, | |
| num_layers=self.layers_per_block + 1, | |
| in_channels=prev_output_channel, | |
| out_channels=output_channel, | |
| prev_output_channel=None, | |
| add_upsample=not is_final_block, | |
| resnet_eps=1e-6, | |
| resnet_act_fn=act_fn, | |
| resnet_groups=norm_num_groups, | |
| attention_head_dim=output_channel, | |
| temb_channels=None, | |
| resnet_time_scale_shift="group", | |
| ) | |
| self.up_blocks.append(up_block) | |
| prev_output_channel = output_channel | |
| # out | |
| self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) | |
| self.conv_act = nn.SiLU() | |
| self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) | |
| self.vq_model = vq_model | |
| self.vae = vae | |
| def vae_encode(self, x): | |
| assert self.vae is not None, "VAE is not initialized" | |
| z = self.vae.encode(x).latent_dist | |
| if self.output_type == "sample": | |
| z = z.sample() | |
| else: | |
| z = torch.cat((z.mean, z.std), dim=1) | |
| z = z * self.vae.config.scaling_factor | |
| return z | |
| def vae_decode(self, x, clip=True): | |
| assert self.vae is not None, "VAE is not initialized" | |
| x = self.sample(x) | |
| x = self.vae.decode(x / self.vae.config.scaling_factor).sample | |
| if clip: | |
| x = torch.clip(x, min=-1, max=1) | |
| return x | |
| def sample(self, x): | |
| if x.shape[1] == 4: | |
| return x | |
| mean, std = x.chunk(2, dim=1) | |
| x = mean + std * torch.randn_like(std) | |
| return x | |
| def forward(self, quant=None, image=None): | |
| if quant is None: | |
| assert image is not None, "Neither of `quant` or `image` are provided" | |
| assert self.vq_model is not None, "VQ encoder is not initialized" | |
| with torch.no_grad(): | |
| quant, _, _ = self.vq_model.encode(image) | |
| x = self.conv_in(quant) | |
| upscale_dtype = next(iter(self.up_blocks.parameters())).dtype | |
| # middle | |
| x = self.mid_block(x) | |
| x = x.to(upscale_dtype) | |
| # up | |
| for up_block in self.up_blocks: | |
| x = up_block(x) | |
| # post-process | |
| x = self.conv_norm_out(x) | |
| x = self.conv_act(x) | |
| x = self.conv_out(x) | |
| return x | |
| def create_model( | |
| in_channels=32, | |
| output_type="stats", | |
| vq_model=None, | |
| vae=None, | |
| ): | |
| return Token2VAE( | |
| in_channels=in_channels, | |
| output_type=output_type, | |
| up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D",), | |
| block_out_channels=(256, 512), | |
| layers_per_block=2, | |
| norm_num_groups=32, | |
| act_fn="silu", | |
| vq_model=vq_model, | |
| vae=vae, | |
| ) |