Spaces:
Paused
Paused
| # Copyright (c) 2023-2024 DeepSeek. | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy of | |
| # this software and associated documentation files (the "Software"), to deal in | |
| # the Software without restriction, including without limitation the rights to | |
| # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of | |
| # the Software, and to permit persons to whom the Software is furnished to do so, | |
| # subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS | |
| # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |
| # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER | |
| # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN | |
| # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |
| # modified from: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/simple_diffusion.py | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.distributed as dist | |
| import torch.nn.functional as F | |
| from typing import Optional, Tuple, Union | |
| import numpy as np | |
| import torchvision | |
| import torchvision.utils | |
| from diffusers.models.embeddings import Timesteps, TimestepEmbedding | |
| from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm | |
| class ImageHead(nn.Module): | |
| def __init__(self, decoder_cfg, gpt_cfg, layer_id=None): | |
| super().__init__() | |
| self.layer_id = layer_id | |
| cfg = ( | |
| AttrDict( | |
| norm_type="layernorm", | |
| is_exp_norm=False, | |
| sequence_parallel=False, | |
| use_userbuffer=False, | |
| norm_eps=1e-5, | |
| norm_bias=True, | |
| gradient_accumulation_fusion=True, | |
| use_fp32_head_weight=False, | |
| ) | |
| + gpt_cfg | |
| ) | |
| group = PG.tensor_parallel_group() | |
| assert cfg.norm_type in [ | |
| "layernorm", | |
| "rmsnorm", | |
| ], f"Norm type:{cfg.norm_type} not supported" | |
| if cfg.norm_type == "rmsnorm": | |
| self.norm = DropoutAddRMSNorm( | |
| cfg.n_embed, | |
| prenorm=False, | |
| eps=cfg.norm_eps, | |
| is_exp_norm=cfg.is_exp_norm, | |
| sequence_parallel=cfg.sequence_parallel, | |
| ) | |
| else: | |
| self.norm = DropoutAddLayerNorm( | |
| cfg.n_embed, | |
| prenorm=False, | |
| eps=cfg.norm_eps, | |
| is_exp_norm=cfg.is_exp_norm, | |
| sequence_parallel=cfg.sequence_parallel, | |
| bias=cfg.norm_bias, | |
| ) | |
| multiple_of = 256 | |
| if decoder_cfg.in_channels % multiple_of != 0: | |
| warnings.warn( | |
| f"建议把 vocab_size 设置为 {multiple_of} 的倍数, 否则会影响矩阵乘法的性能" | |
| ) | |
| dtype = default_dtype = torch.get_default_dtype() | |
| if cfg.use_fp32_head_weight: | |
| dtype = torch.float32 | |
| print( | |
| "使用 fp32 head weight!!!! 与原来的 bf16 head weight 不兼容\n", | |
| end="", | |
| flush=True, | |
| ) | |
| torch.set_default_dtype(dtype) | |
| self.head = ColumnParallelLinear( | |
| cfg.n_embed, | |
| decoder_cfg.in_channels, | |
| bias=True, | |
| group=group, | |
| sequence_parallel=cfg.sequence_parallel, | |
| use_userbuffer=cfg.use_userbuffer, | |
| gradient_accumulation_fusion=cfg.gradient_accumulation_fusion, | |
| use_fp32_output=False, | |
| ) | |
| torch.set_default_dtype(default_dtype) | |
| self.use_fp32_head_weight = cfg.use_fp32_head_weight | |
| def forward( | |
| self, input_args, images_split_mask: Optional[torch.BoolTensor] = None, **kwargs | |
| ): | |
| residual = None | |
| if isinstance(input_args, tuple): | |
| x, residual = input_args | |
| else: | |
| x = input_args | |
| x = self.norm(x, residual) | |
| if self.use_fp32_head_weight: | |
| assert ( | |
| self.head.weight.dtype == torch.float32 | |
| ), f"head.weight is {self.head.weight.dtype}" | |
| x = x.float() | |
| if images_split_mask is None: | |
| logits = self.head(x) | |
| else: | |
| bs, n_images = images_split_mask.shape[:2] | |
| n_embed = x.shape[-1] | |
| images_embed = torch.masked_select( | |
| x.unsqueeze(1), images_split_mask.unsqueeze(-1) | |
| ) | |
| images_embed = images_embed.view((bs * n_images, -1, n_embed)) | |
| logits = self.head(images_embed) | |
| return logits | |
| class GlobalResponseNorm(nn.Module): | |
| # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim)) | |
| self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim)) | |
| def forward(self, x): | |
| gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) | |
| nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) | |
| return torch.addcmul(self.bias, (self.weight * nx + 1), x, value=1) | |
| class Downsample2D(nn.Module): | |
| """A 2D downsampling layer with an optional convolution. | |
| Parameters: | |
| channels (`int`): | |
| number of channels in the inputs and outputs. | |
| use_conv (`bool`, default `False`): | |
| option to use a convolution. | |
| out_channels (`int`, optional): | |
| number of output channels. Defaults to `channels`. | |
| padding (`int`, default `1`): | |
| padding for the convolution. | |
| name (`str`, default `conv`): | |
| name of the downsampling 2D layer. | |
| """ | |
| def __init__( | |
| self, | |
| channels: int, | |
| use_conv: bool = False, | |
| out_channels: Optional[int] = None, | |
| padding: int = 1, | |
| name: str = "conv", | |
| kernel_size=3, | |
| stride=2, | |
| norm_type=None, | |
| eps=None, | |
| elementwise_affine=None, | |
| bias=True, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| self.padding = padding | |
| self.name = name | |
| if norm_type == "ln_norm": | |
| self.norm = nn.LayerNorm(channels, eps, elementwise_affine) | |
| elif norm_type == "rms_norm": | |
| self.norm = RMSNorm(channels, eps) | |
| elif norm_type is None: | |
| self.norm = None | |
| else: | |
| raise ValueError(f"unknown norm_type: {norm_type}") | |
| if use_conv: | |
| conv = nn.Conv2d( | |
| self.channels, | |
| self.out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| bias=bias, | |
| ) | |
| else: | |
| assert self.channels == self.out_channels | |
| conv = nn.AvgPool2d(kernel_size=stride, stride=stride) | |
| # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed | |
| if name == "conv": | |
| self.Conv2d_0 = conv | |
| self.conv = conv | |
| elif name == "Conv2d_0": | |
| self.conv = conv | |
| else: | |
| self.conv = conv | |
| def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: | |
| assert hidden_states.shape[1] == self.channels | |
| if self.norm is not None: | |
| hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute( | |
| 0, 3, 1, 2 | |
| ) | |
| if self.use_conv and self.padding == 0: | |
| pad = (0, 1, 0, 1) | |
| hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) | |
| assert hidden_states.shape[1] == self.channels | |
| hidden_states = self.conv(hidden_states) | |
| return hidden_states | |
| class Upsample2D(nn.Module): | |
| """A 2D upsampling layer with an optional convolution. | |
| Parameters: | |
| channels (`int`): | |
| number of channels in the inputs and outputs. | |
| use_conv (`bool`, default `False`): | |
| option to use a convolution. | |
| use_conv_transpose (`bool`, default `False`): | |
| option to use a convolution transpose. | |
| out_channels (`int`, optional): | |
| number of output channels. Defaults to `channels`. | |
| name (`str`, default `conv`): | |
| name of the upsampling 2D layer. | |
| """ | |
| def __init__( | |
| self, | |
| channels: int, | |
| use_conv: bool = False, | |
| use_conv_transpose: bool = False, | |
| out_channels: Optional[int] = None, | |
| name: str = "conv", | |
| kernel_size: Optional[int] = None, | |
| padding=1, | |
| stride=2, | |
| norm_type=None, | |
| eps=None, | |
| elementwise_affine=None, | |
| bias=True, | |
| interpolate=True, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| self.use_conv_transpose = use_conv_transpose | |
| self.name = name | |
| self.interpolate = interpolate | |
| self.stride = stride | |
| if norm_type == "ln_norm": | |
| self.norm = nn.LayerNorm(channels, eps, elementwise_affine) | |
| elif norm_type == "rms_norm": | |
| self.norm = RMSNorm(channels, eps) | |
| elif norm_type is None: | |
| self.norm = None | |
| else: | |
| raise ValueError(f"unknown norm_type: {norm_type}") | |
| conv = None | |
| if use_conv_transpose: | |
| if kernel_size is None: | |
| kernel_size = 4 | |
| conv = nn.ConvTranspose2d( | |
| channels, | |
| self.out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| bias=bias, | |
| ) | |
| elif use_conv: | |
| if kernel_size is None: | |
| kernel_size = 3 | |
| conv = nn.Conv2d( | |
| self.channels, | |
| self.out_channels, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| bias=bias, | |
| ) | |
| # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed | |
| if name == "conv": | |
| self.conv = conv | |
| else: | |
| self.Conv2d_0 = conv | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| output_size: Optional[int] = None, | |
| *args, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| assert hidden_states.shape[1] == self.channels | |
| if self.norm is not None: | |
| hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute( | |
| 0, 3, 1, 2 | |
| ) | |
| if self.use_conv_transpose: | |
| return self.conv(hidden_states) | |
| # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 | |
| # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch | |
| # https://github.com/pytorch/pytorch/issues/86679 | |
| dtype = hidden_states.dtype | |
| if dtype == torch.bfloat16: | |
| hidden_states = hidden_states.to(torch.float32) | |
| # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 | |
| if hidden_states.shape[0] >= 64: | |
| hidden_states = hidden_states.contiguous() | |
| # if `output_size` is passed we force the interpolation output | |
| # size and do not make use of `scale_factor=2` | |
| if self.interpolate: | |
| if output_size is None: | |
| hidden_states = F.interpolate( | |
| hidden_states, scale_factor=self.stride, mode="nearest" | |
| ) | |
| else: | |
| hidden_states = F.interpolate( | |
| hidden_states, size=output_size, mode="nearest" | |
| ) | |
| # If the input is bfloat16, we cast back to bfloat16 | |
| if dtype == torch.bfloat16: | |
| hidden_states = hidden_states.to(dtype) | |
| # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed | |
| if self.use_conv: | |
| if self.name == "conv": | |
| hidden_states = self.conv(hidden_states) | |
| else: | |
| hidden_states = self.Conv2d_0(hidden_states) | |
| return hidden_states | |
| class ConvNextBlock(nn.Module): | |
| def __init__( | |
| self, | |
| channels, | |
| norm_eps, | |
| elementwise_affine, | |
| use_bias, | |
| hidden_dropout, | |
| hidden_size, | |
| res_ffn_factor: int = 4, | |
| ): | |
| super().__init__() | |
| self.depthwise = nn.Conv2d( | |
| channels, | |
| channels, | |
| kernel_size=7, | |
| padding=3, | |
| groups=channels, | |
| bias=use_bias, | |
| ) | |
| self.norm = RMSNorm(channels, norm_eps) | |
| self.channelwise_linear_1 = nn.Linear( | |
| channels, int(channels * res_ffn_factor), bias=use_bias | |
| ) | |
| self.channelwise_act = nn.GELU() | |
| self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor)) | |
| self.channelwise_linear_2 = nn.Linear( | |
| int(channels * res_ffn_factor), channels, bias=use_bias | |
| ) | |
| self.channelwise_dropout = nn.Dropout(hidden_dropout) | |
| self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias) | |
| def forward(self, x, cond_embeds): | |
| x_res = x | |
| x = self.depthwise(x) | |
| x = x.permute(0, 2, 3, 1) | |
| x = self.norm(x) | |
| x = self.channelwise_linear_1(x) | |
| x = self.channelwise_act(x) | |
| x = self.channelwise_norm(x) | |
| x = self.channelwise_linear_2(x) | |
| x = self.channelwise_dropout(x) | |
| x = x.permute(0, 3, 1, 2) | |
| x = x + x_res | |
| scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1) | |
| # x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None] | |
| x = torch.addcmul( | |
| shift[:, :, None, None], x, (1 + scale)[:, :, None, None], value=1 | |
| ) | |
| return x | |
| class Patchify(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| block_out_channels, | |
| patch_size, | |
| bias, | |
| elementwise_affine, | |
| eps, | |
| kernel_size=None, | |
| ): | |
| super().__init__() | |
| if kernel_size is None: | |
| kernel_size = patch_size | |
| self.patch_conv = nn.Conv2d( | |
| in_channels, | |
| block_out_channels, | |
| kernel_size=kernel_size, | |
| stride=patch_size, | |
| bias=bias, | |
| ) | |
| self.norm = RMSNorm(block_out_channels, eps) | |
| def forward(self, x): | |
| embeddings = self.patch_conv(x) | |
| embeddings = embeddings.permute(0, 2, 3, 1) | |
| embeddings = self.norm(embeddings) | |
| embeddings = embeddings.permute(0, 3, 1, 2) | |
| return embeddings | |
| class Unpatchify(nn.Module): | |
| def __init__( | |
| self, in_channels, out_channels, patch_size, bias, elementwise_affine, eps | |
| ): | |
| super().__init__() | |
| self.norm = RMSNorm(in_channels, eps) | |
| self.unpatch_conv = nn.Conv2d( | |
| in_channels, | |
| out_channels * patch_size * patch_size, | |
| kernel_size=1, | |
| bias=bias, | |
| ) | |
| self.pixel_shuffle = nn.PixelShuffle(patch_size) | |
| self.patch_size = patch_size | |
| def forward(self, x): | |
| # [b, c, h, w] | |
| x = x.permute(0, 2, 3, 1) | |
| x = self.norm(x) | |
| x = x.permute(0, 3, 1, 2) | |
| x = self.unpatch_conv(x) | |
| x = self.pixel_shuffle(x) | |
| return x | |
| class UVitBlock(nn.Module): | |
| def __init__( | |
| self, | |
| channels, | |
| out_channels, | |
| num_res_blocks, | |
| stride, | |
| hidden_size, | |
| hidden_dropout, | |
| elementwise_affine, | |
| norm_eps, | |
| use_bias, | |
| downsample: bool, | |
| upsample: bool, | |
| res_ffn_factor: int = 4, | |
| seq_len=None, | |
| concat_input=False, | |
| original_input_channels=None, | |
| use_zero=True, | |
| norm_type="RMS", | |
| ): | |
| super().__init__() | |
| self.res_blocks = nn.ModuleList() | |
| for i in range(num_res_blocks): | |
| conv_block = ConvNextBlock( | |
| channels, | |
| norm_eps, | |
| elementwise_affine, | |
| use_bias, | |
| hidden_dropout, | |
| hidden_size, | |
| res_ffn_factor=res_ffn_factor, | |
| ) | |
| self.res_blocks.append(conv_block) | |
| if downsample: | |
| self.downsample = Downsample2D( | |
| channels=channels, | |
| out_channels=out_channels, | |
| use_conv=True, | |
| name="Conv2d_0", | |
| kernel_size=3, | |
| padding=1, | |
| stride=stride, | |
| norm_type="rms_norm", | |
| eps=norm_eps, | |
| elementwise_affine=elementwise_affine, | |
| bias=use_bias, | |
| ) | |
| else: | |
| self.downsample = None | |
| if upsample: | |
| self.upsample = Upsample2D( | |
| channels=channels, | |
| out_channels=out_channels, | |
| use_conv_transpose=False, | |
| use_conv=True, | |
| kernel_size=3, | |
| padding=1, | |
| stride=stride, | |
| name="conv", | |
| norm_type="rms_norm", | |
| eps=norm_eps, | |
| elementwise_affine=elementwise_affine, | |
| bias=use_bias, | |
| interpolate=True, | |
| ) | |
| else: | |
| self.upsample = None | |
| def forward(self, x, emb, recompute=False): | |
| for res_block in self.res_blocks: | |
| x = res_block(x, emb) | |
| if self.downsample is not None: | |
| x = self.downsample(x) | |
| if self.upsample is not None: | |
| x = self.upsample(x) | |
| return x | |
| class ShallowUViTEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| input_channels=3, | |
| stride=4, | |
| kernel_size=7, | |
| padding=None, | |
| block_out_channels=(768,), | |
| layers_in_middle=2, | |
| hidden_size=2048, | |
| elementwise_affine=True, | |
| use_bias=True, | |
| norm_eps=1e-6, | |
| dropout=0.0, | |
| use_mid_block=True, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.time_proj = Timesteps( | |
| block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0 | |
| ) | |
| self.time_embed = TimestepEmbedding( | |
| block_out_channels[0], hidden_size, sample_proj_bias=use_bias | |
| ) | |
| if padding is None: | |
| padding = math.ceil(kernel_size - stride) | |
| self.in_conv = nn.Conv2d( | |
| in_channels=input_channels, | |
| out_channels=block_out_channels[0], | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| ) | |
| if use_mid_block: | |
| self.mid_block = UVitBlock( | |
| block_out_channels[-1], | |
| block_out_channels[-1], | |
| num_res_blocks=layers_in_middle, | |
| hidden_size=hidden_size, | |
| hidden_dropout=dropout, | |
| elementwise_affine=elementwise_affine, | |
| norm_eps=norm_eps, | |
| use_bias=use_bias, | |
| downsample=False, | |
| upsample=False, | |
| stride=1, | |
| res_ffn_factor=4, | |
| ) | |
| else: | |
| self.mid_block = None | |
| def get_num_extra_tensors(self): | |
| return 2 | |
| def forward(self, x, timesteps): | |
| bs = x.shape[0] | |
| dtype = x.dtype | |
| t_emb = self.time_proj(timesteps.flatten()).view(bs, -1).to(dtype) | |
| t_emb = self.time_embed(t_emb) | |
| x_emb = self.in_conv(x) | |
| if self.mid_block is not None: | |
| x_emb = self.mid_block(x_emb, t_emb) | |
| hs = [x_emb] | |
| return x_emb, t_emb, hs | |
| class ShallowUViTDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels=768, | |
| out_channels=3, | |
| block_out_channels: Tuple[int] = (768,), | |
| upsamples=2, | |
| layers_in_middle=2, | |
| hidden_size=2048, | |
| elementwise_affine=True, | |
| norm_eps=1e-6, | |
| use_bias=True, | |
| dropout=0.0, | |
| use_mid_block=True, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| if use_mid_block: | |
| self.mid_block = UVitBlock( | |
| in_channels + block_out_channels[-1], | |
| block_out_channels[ | |
| -1 | |
| ], # In fact, the parameter is not used because it has no effect when both downsample and upsample are set to false. | |
| num_res_blocks=layers_in_middle, | |
| hidden_size=hidden_size, | |
| hidden_dropout=dropout, | |
| elementwise_affine=elementwise_affine, | |
| norm_eps=norm_eps, | |
| use_bias=use_bias, | |
| downsample=False, | |
| upsample=False, | |
| stride=1, | |
| res_ffn_factor=4, | |
| ) | |
| else: | |
| self.mid_block = None | |
| self.out_convs = nn.ModuleList() | |
| for rank in range(upsamples): | |
| if rank == upsamples - 1: | |
| curr_out_channels = out_channels | |
| else: | |
| curr_out_channels = block_out_channels[-1] | |
| if rank == 0: | |
| curr_in_channels = block_out_channels[-1] + in_channels | |
| else: | |
| curr_in_channels = block_out_channels[-1] | |
| self.out_convs.append( | |
| Unpatchify( | |
| curr_in_channels, | |
| curr_out_channels, | |
| patch_size=2, | |
| bias=use_bias, | |
| elementwise_affine=elementwise_affine, | |
| eps=norm_eps, | |
| ) | |
| ) | |
| self.input_norm = RMSNorm(in_channels, norm_eps) | |
| def forward(self, x, hs, t_emb): | |
| x = x.permute(0, 2, 3, 1) | |
| x = self.input_norm(x) | |
| x = x.permute(0, 3, 1, 2) | |
| x = torch.cat([x, hs.pop()], dim=1) | |
| if self.mid_block is not None: | |
| x = self.mid_block(x, t_emb) | |
| for out_conv in self.out_convs: | |
| x = out_conv(x) | |
| assert len(hs) == 0 | |
| return x | |