Diffusers
PyTorch
custom_code
dualvitok / modeling_dualvitok.py
huangrh9's picture
Upload folder using huggingface_hub
5ca5652 verified
raw
history blame
24.5 kB
from __future__ import annotations
import os
import sys
import math
from typing import Optional, Tuple, Union, List, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module
from einops import rearrange, repeat, pack, unpack
from einx import get_at
from torch.utils.checkpoint import checkpoint
from transformers import AutoImageProcessor
from transformers.modeling_utils import PreTrainedModel, get_parameter_device, get_parameter_dtype
from .configuration_dualvitok import DualViTokConfig
from .modeling_movqgan import MoVQModel, MoVQEncoder, MoVQDecoder, Decoder
from .configuration_qwen2vit import Qwen2VLVisionConfig
from .modeling_qwen2vit import Qwen2VisionTransformerPretrainedModel, \
VisionRotaryEmbedding, Qwen2VLBatchVisionBlock
try:
import xformers.ops as xops
is_xformers_available = True
except Exception as e:
is_xformers_available = False
if torch.__version__ > "2.1.2":
IS_SDPA_AVAILABLE = True
else:
IS_SDPA_AVAILABLE = False
cur_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(cur_dir)
# helper functions
def exists(v):
return v is not None
def identity(t):
return t
def default(v, d):
return v if exists(v) else d
def pack_one(t, pattern):
packed, packed_shape = pack([t], pattern)
def inverse(out, inv_pattern=None):
inv_pattern = default(inv_pattern, pattern)
out, = unpack(out, packed_shape, inv_pattern)
return out
return packed, inverse
# class
class SimVQ(Module):
def __init__(
self,
dim,
codebook_size,
codebook_transform: Module | None = None,
init_fn: Callable = identity,
channel_first=True,
input_to_quantize_commit_loss_weight=0.25,
commitment_weight=1.,
frozen_codebook_dim=None # frozen codebook dim could have different dimensions than projection
):
super().__init__()
self.codebook_size = codebook_size
self.channel_first = channel_first
frozen_codebook_dim = default(frozen_codebook_dim, dim)
codebook = torch.randn(codebook_size, frozen_codebook_dim) * (frozen_codebook_dim ** -0.5)
codebook = init_fn(codebook)
# the codebook is actually implicit from a linear layer from frozen gaussian or uniform
if not exists(codebook_transform):
codebook_transform = nn.Linear(frozen_codebook_dim, dim, bias=False)
self.code_transform = codebook_transform
self.register_buffer('frozen_codebook', codebook)
# commit loss weighting - weighing input to quantize a bit less is crucial for it to work
self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight
# total commitment loss weight
self.commitment_weight = commitment_weight
@property
def codebook(self):
return self.code_transform(self.frozen_codebook)
def indices_to_codes(
self,
indices
):
implicit_codebook = self.codebook
frozen_codes = get_at('[c] d, b ... -> b ... d', self.frozen_codebook, indices)
quantized = self.code_transform(frozen_codes)
if self.channel_first:
quantized = rearrange(quantized, 'b ... d -> b d ...')
return quantized
def forward(
self,
x
):
if self.channel_first:
x = rearrange(x, 'b d ... -> b ... d')
x, inverse_pack = pack_one(x, 'b * d')
implicit_codebook = self.codebook
with torch.no_grad():
dist = torch.cdist(x, implicit_codebook)
indices = dist.argmin(dim=-1)
# select codes
quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices)
# commit loss and straight through, as was done in the paper
commit_loss = (
F.mse_loss(x.detach(), quantized) +
F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight
)
quantized = (quantized - x).detach() + x
quantized = inverse_pack(quantized)
indices = inverse_pack(indices, 'b *')
if self.channel_first:
quantized = rearrange(quantized, 'b ... d-> b d ...')
return quantized, commit_loss * self.commitment_weight, indices
def init_weights(m):
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.LayerNorm):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) \
or isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
w = m.weight.data
nn.init.xavier_uniform_(w)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0, std=1)
class ScalingLayerForQwen2ViT:
def __init__(
self,
min_pixels: int = 56 * 56,
max_pixels: int = 28 * 28 * 1280,
patch_size: int = 14,
temporal_patch_size: int = 2,
merge_size: int = 2,
**kwargs,
) -> None:
super().__init__(**kwargs)
OPENAI_CLIP_MEAN = torch.as_tensor([0.48145466, 0.4578275, 0.40821073])[None, :, None, None]
OPENAI_CLIP_STD = torch.as_tensor([0.26862954, 0.26130258, 0.27577711])[None, :, None, None]
self.image_mean = OPENAI_CLIP_MEAN
self.image_std = OPENAI_CLIP_STD
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.merge_size = merge_size
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
def __call__(self, images):
if images.ndim == 4:
images = images.unsqueeze(1)
batch_size, temporal, channel, height, width = images.shape
factor = self.patch_size * self.merge_size
resized_height, resized_width = height // factor * factor, width // factor * factor
images = (images + 1) / 2 # rescale to [0, 1.]
images = torch.nn.functional.interpolate(
images.flatten(0, 1).float(),
size=(resized_height, resized_width),
mode='bicubic',
align_corners=False,
antialias=True
).to(images.dtype)
images = images.clamp(0, 1) # rescale to [0, 1.]
images = ((images - self.image_mean.to(images)) / self.image_std.to(images))
images = rearrange(images, '(b t) c h w -> b t c h w', b=batch_size, t=temporal)
if temporal == 1:
images = images.repeat_interleave(self.temporal_patch_size, dim=1)
temporal = self.temporal_patch_size
grid_t = temporal // self.temporal_patch_size
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
images = images.reshape(
batch_size * grid_t,
self.temporal_patch_size,
channel,
-1
)
images = rearrange(images, 'b p c n -> b n (c p)')
images = images.reshape(
batch_size * grid_t,
grid_h // self.merge_size,
self.merge_size,
self.patch_size,
grid_w // self.merge_size,
self.merge_size,
self.patch_size,
-1
)
images = rearrange(images, 'b h k s1 w l s2 n -> (b h w k l) (n s1 s2)')
return dict(image=images, image_grid_thw=torch.as_tensor([[grid_t, grid_h, grid_w] for _ in range(batch_size)]))
class SemanticEncoder(nn.Module):
def __init__(self,
semantic_encoder,
z_channels=4,
num_blocks=2,
embed_dim=1280,
proj_layer='linear',
attn_implementation='xformers',
target_mlp='identity',
):
super().__init__()
self.embed_dim = embed_dim
if isinstance(semantic_encoder, str):
self.model = Qwen2VisionTransformerPretrainedModel.from_pretrained(
semantic_encoder,
attn_implementation=attn_implementation
)
elif isinstance(semantic_encoder, dict):
config = Qwen2VLVisionConfig(**semantic_encoder, attn_implementation=attn_implementation)
self.model = Qwen2VisionTransformerPretrainedModel(config)
else:
raise ValueError(f"Invalid semantic_encoder: {semantic_encoder}")
input_channels = self.model.config.hidden_size
for p in self.model.parameters():
p.requires_grad = False
self.proj_in = nn.Conv2d(input_channels, embed_dim, 1, 1) if input_channels != embed_dim else nn.Identity()
config = Qwen2VLVisionConfig(depth=num_blocks,
embed_dim=embed_dim, )
head_dim = config.embed_dim // config.num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[Qwen2VLBatchVisionBlock(config, attn_implementation) for _ in range(num_blocks)]
)
if proj_layer == 'norm_linear':
self.proj_out = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(
embed_dim,
z_channels,
)
)
elif proj_layer == 'linear':
self.proj_out = nn.Sequential(
nn.Linear(
embed_dim,
z_channels,
)
)
elif proj_layer == 'mlp':
self.proj_out = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.Tanh(),
nn.Linear(embed_dim, z_channels),
)
else:
raise RuntimeError(f"Wrong proj layer. Got {proj_layer}")
if target_mlp == 'identity':
self.target_mlp = nn.Sequential(
nn.Identity(),
)
elif target_mlp == 'norm':
self.target_mlp = nn.Sequential(
nn.LayerNorm(input_channels, eps=1e-6, elementwise_affine=False),
)
self.init_weight()
def init_weight(self):
self.proj_in.apply(init_weights)
self.blocks.apply(init_weights)
self.proj_out.apply(init_weights)
self.target_mlp.apply(init_weights)
def rot_pos_emb(self, grid_thw, max_seq_len):
pos_ids = torch.zeros((len(grid_thw), max_seq_len, 2), dtype=torch.long)
for idx, (t, h, w) in enumerate(grid_thw):
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.flatten()
current_pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
pos_ids[idx, :current_pos_ids.shape[0]] = current_pos_ids
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(2)
return rotary_pos_emb
def forward(self, x, grid_thw):
x = self.model(x, grid_thw=grid_thw)
x = x_target = self.target_mlp(x)
x = F.linear(x,
self.proj_in.weight.view(self.proj_in.weight.shape[0], -1),
self.proj_in.bias)
new_grid_thw = torch.as_tensor([[t, h // 2, w // 2] for t, h, w in grid_thw])
seq_lens = [t_i * h_i * w_i for t_i, h_i, w_i in new_grid_thw]
max_seq_len = max(seq_lens)
x = rearrange(x, '(b h w) c -> b (h w) c', h=new_grid_thw[0, 1], w=new_grid_thw[0, 2])
rotary_pos_emb = self.rot_pos_emb(new_grid_thw, max_seq_len)
for blk in self.blocks:
x = blk(x, rotary_pos_emb=rotary_pos_emb)
x = self.proj_out(x) # [b, max_length, d]
t, h, w = new_grid_thw[0]
b = len(grid_thw)
x = rearrange(x, 'b (h w) c ->b c h w', b=b, h=h, w=w)
x_target = rearrange(x_target, '(b h w) c ->b c h w', b=b, h=h, w=w)
return x, x_target
class SemanticDecoder(nn.Module):
def __init__(self,
z_channels=4,
embed_dim=1280,
num_blocks=2,
output_channels=1280,
attn_implementation='xformers',
proj_layer='linear_norm'):
super().__init__()
self.proj_in = nn.Linear(z_channels, embed_dim)
self.output_channels = output_channels
config = Qwen2VLVisionConfig(depth=num_blocks, embed_dim=embed_dim)
self.blocks = nn.ModuleList(
[Qwen2VLBatchVisionBlock(config, attn_implementation) for _ in range(num_blocks)]
)
head_dim = config.embed_dim // config.num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
if proj_layer == 'norm_linear':
self.proj_out = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, output_channels),
)
elif proj_layer == 'linear':
self.proj_out = nn.Sequential(
nn.Linear(embed_dim, output_channels)
)
elif proj_layer == 'mlp':
self.proj_out = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.Tanh(),
nn.Linear(embed_dim, output_channels),
)
elif proj_layer == 'linear_norm':
self.proj_out = nn.Sequential(
nn.Linear(embed_dim, output_channels),
nn.LayerNorm(output_channels),
)
self.apply(init_weights)
@property
def last_layer(self):
return self.proj_out[-1].weight
def rot_pos_emb(self, grid_thw, max_seq_len):
pos_ids = torch.zeros((len(grid_thw), max_seq_len, 2), dtype=torch.long)
for idx, (t, h, w) in enumerate(grid_thw):
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.flatten()
current_pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
pos_ids[idx, :current_pos_ids.shape[0]] = current_pos_ids
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(2)
return rotary_pos_emb
def forward(self, z: torch.Tensor):
x = z
b, c, h, w = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
grid_thw = torch.as_tensor([[1, h, w] for _ in range(b)])
seq_lens = [t * h * w for t, h, w in grid_thw]
max_seq_len = max(seq_lens)
x = self.proj_in(x)
rotary_pos_emb = self.rot_pos_emb(grid_thw, max_seq_len)
for blk in self.blocks:
x = blk(x, rotary_pos_emb=rotary_pos_emb)
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
return x
class DualViTokPretrainModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = DualViTokConfig
base_model_prefix = "dualvitok"
main_input_name = "pixel_values"
_no_split_modules = ["BatchQwen2VLVisionBlock", "MoVQResnetBlock", "MoVQAttnBlock", "MoVQResnetTemporalBlock"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
if isinstance(module, (nn.Conv2d, nn.Conv3d)):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
elif isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(module.bias, -bound, bound)
elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
class DualViTok(DualViTokPretrainModel):
def __init__(self, config: DualViTokConfig):
super().__init__(config)
self.config = config
self._semantic_channel = config.semantic_encoder.z_channels
self._pixel_channel = config.pixel_encoder.z_channels
self.semantic_encoder = SemanticEncoder(
semantic_encoder=config.semantic_encoder.pretrained_semantic_encoder,
z_channels=config.semantic_encoder.z_channels,
num_blocks=config.semantic_encoder.num_blocks,
embed_dim=config.semantic_encoder.embed_dim,
proj_layer=config.semantic_encoder.out_layer,
attn_implementation=config.attn_implementation,
target_mlp=config.semantic_encoder.target_mlp, )
self.semantic_decoder = SemanticDecoder(
z_channels=config.semantic_decoder.z_channels,
embed_dim=config.semantic_decoder.embed_dim,
num_blocks=config.semantic_decoder.num_blocks,
output_channels=config.semantic_decoder.out_channels,
attn_implementation=config.attn_implementation,
proj_layer=config.semantic_decoder.out_layer,
)
if config.semantic_quantizer_type.lower() == 'simvq':
self.semantic_quantizer = SimVQ(
dim=config.semantic_encoder.z_channels,
codebook_size=config.semantic_quantizer_codebook_size,
)
elif config.semantic_quantizer_type.lower() == 'vq':
raise NotImplementedError
self.semantic_quantizer = VQ(
dim=config.semantic_encoder.z_channels,
codebook_size=config.semantic_quantizer_codebook_size,
)
self.pixel_encoder = MoVQEncoder(config.pixel_encoder)
self.pixel_quant_conv = nn.Conv2d(config.pixel_encoder.z_channels, config.pixel_encoder.embed_dim, 1)
if config.pixel_quantizer_type.lower() == 'simvq':
self.pixel_quantizer = SimVQ(
dim=config.pixel_encoder.z_channels,
codebook_size=config.pixel_quantizer_codebook_size,
)
elif config.pixel_quantizer_type.lower() == 'vq':
raise NotImplementedError
self.pixel_quantizer = VQ(
dim=config.pixel_encoder.z_channels,
codebook_size=config.pixel_quantizer_codebook_size,
)
self.pixel_post_quant_conv = nn.Conv2d(config.pixel_decoder.embed_dim,
config.pixel_decoder.z_channels, 1)
self.pixel_decoder = MoVQDecoder(config.pixel_decoder)
self.scaling_layer = ScalingLayerForQwen2ViT()
@property
def device(self):
return get_parameter_device(self)
@property
def dtype(self):
return get_parameter_dtype(self)
@property
def pixel_channel(self):
return self._pixel_channel
@property
def semantic_channel(self):
return self._semantic_channel
def encode(self, image: torch.FloatTensor):
scale_output = self.scaling_layer(image)
image, image_grid_thw, image_gen = scale_output['image'], scale_output['image_grid_thw'], image
h_semantic, target_semantic = self.semantic_encoder(image, image_grid_thw)
quant_semantic, emb_loss_semantic, info_semantic = self.semantic_quantizer(h_semantic.float())
h_pixel = self.pixel_encoder(image_gen)
h_pixel = self.pixel_quant_conv(h_pixel)
quant_pixel, emb_loss_pixel, info_pixel = self.pixel_quantizer(h_pixel.float())
return (quant_semantic, emb_loss_semantic, info_semantic, target_semantic), \
(quant_pixel, emb_loss_pixel, info_pixel)
def encode_code(self, *args, **kwargs):
(_, _, semantic_indices, _), \
(_, _, pixel_indices) = self.encode(*args, **kwargs)
return semantic_indices, pixel_indices
def indices_to_codes(self, semantic_indices, pixel_indices):
quant_semantic = self.semantic_quantizer.indices_to_codes(semantic_indices)
quant_pixel = self.pixel_quantizer.indices_to_codes(pixel_indices)
return quant_semantic, quant_pixel
def encode_semantic(self, image: torch.FloatTensor):
scale_output = self.scaling_layer(image)
image, image_grid_thw, image_gen = scale_output['image'], scale_output['image_grid_thw'], image
h_semantic, target_semantic = self.semantic_encoder(image, image_grid_thw)
quant_semantic, emb_loss_semantic, info_semantic = self.semantic_quantizer(h_semantic.float())
return quant_semantic, emb_loss_semantic, info_semantic, target_semantic
def merge_quants(self, quant_semantic: torch.Tensor, quant_pixel: torch.Tensor):
quant_semantic_resized = F.interpolate(
quant_semantic, quant_pixel.shape[-2:], mode='bicubic'
).to(quant_semantic.dtype)
quant_semantic = quant_semantic_resized
quant = torch.cat([quant_semantic, quant_pixel], dim=1)
return quant
def decode(self, quant_semantic: torch.Tensor, quant_pixel: torch.Tensor, ):
quant = self.merge_quants(quant_semantic, quant_pixel)
quant2 = self.pixel_post_quant_conv(quant)
x = self.pixel_decoder(quant2, quant)
return x
def decode_code(self, semantic_indices, pixel_indices):
quant_semantic = self.semantic_quantizer.indices_to_codes(semantic_indices)
quant_pixel = self.pixel_quantizer.indices_to_codes(pixel_indices)
return self.decode(quant_semantic, quant_pixel)
def decode_semantic(self, x: List[torch.Tensor]):
return self.semantic_decoder(x)
def forward(self, pixel_values: torch.FloatTensor):
(quant_semantic, diff_semantic, _, target_semantic), \
(quant_pixel, diff_pixel, _) = self.encode(pixel_values)
dec = self.decode(quant_semantic, quant_pixel)
dec_semantic = self.decode_semantic(quant_semantic)
return (dec_semantic, diff_semantic, target_semantic), (dec, diff_pixel)
def build_sdxl_decoder(self, path='ILLUME-MLLM/dualvitok-sdxl-decoder',
image_processor=None,
torch_dtype=torch.float16,
add_watermarker=False,
device='cuda',
):
from .sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline
if image_processor is None:
image_processor = AutoImageProcessor.from_pretrained('ILLUME-MLLM/dualvitok', trust_remote_code=True)
return StableDiffusionXLDecoderPipeline.from_pretrained(path,
torch_dtype=torch_dtype,
add_watermarker=add_watermarker,
vq_model=self,
vq_image_processor=image_processor).to(device)