| from __future__ import annotations
|
|
|
| import os
|
| import sys
|
| import math
|
| from typing import Optional, Tuple, Union, List, Callable
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torch.nn import Module
|
|
|
| from einops import rearrange, repeat, pack, unpack
|
| from einx import get_at
|
|
|
| from torch.utils.checkpoint import checkpoint
|
| from transformers import AutoImageProcessor
|
| from transformers.modeling_utils import PreTrainedModel, get_parameter_device, get_parameter_dtype
|
|
|
| from .configuration_dualvitok import DualViTokConfig
|
| from .modeling_movqgan import MoVQModel, MoVQEncoder, MoVQDecoder, Decoder
|
|
|
| from .configuration_qwen2vit import Qwen2VLVisionConfig
|
| from .modeling_qwen2vit import Qwen2VisionTransformerPretrainedModel, \
|
| VisionRotaryEmbedding, Qwen2VLBatchVisionBlock
|
|
|
| try:
|
| import xformers.ops as xops
|
|
|
| is_xformers_available = True
|
| except Exception as e:
|
| is_xformers_available = False
|
|
|
| if torch.__version__ > "2.1.2":
|
| IS_SDPA_AVAILABLE = True
|
| else:
|
| IS_SDPA_AVAILABLE = False
|
|
|
| cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| sys.path.append(cur_dir)
|
|
|
|
|
|
|
|
|
| def exists(v):
|
| return v is not None
|
|
|
|
|
| def identity(t):
|
| return t
|
|
|
|
|
| def default(v, d):
|
| return v if exists(v) else d
|
|
|
|
|
| def pack_one(t, pattern):
|
| packed, packed_shape = pack([t], pattern)
|
|
|
| def inverse(out, inv_pattern=None):
|
| inv_pattern = default(inv_pattern, pattern)
|
| out, = unpack(out, packed_shape, inv_pattern)
|
| return out
|
|
|
| return packed, inverse
|
|
|
|
|
|
|
|
|
|
|
| class SimVQ(Module):
|
| def __init__(
|
| self,
|
| dim,
|
| codebook_size,
|
| codebook_transform: Module | None = None,
|
| init_fn: Callable = identity,
|
| channel_first=True,
|
| input_to_quantize_commit_loss_weight=0.25,
|
| commitment_weight=1.,
|
| frozen_codebook_dim=None
|
| ):
|
| super().__init__()
|
| self.codebook_size = codebook_size
|
| self.channel_first = channel_first
|
|
|
| frozen_codebook_dim = default(frozen_codebook_dim, dim)
|
| codebook = torch.randn(codebook_size, frozen_codebook_dim) * (frozen_codebook_dim ** -0.5)
|
| codebook = init_fn(codebook)
|
|
|
|
|
|
|
| if not exists(codebook_transform):
|
| codebook_transform = nn.Linear(frozen_codebook_dim, dim, bias=False)
|
|
|
| self.code_transform = codebook_transform
|
|
|
| self.register_buffer('frozen_codebook', codebook)
|
|
|
|
|
| self.input_to_quantize_commit_loss_weight = input_to_quantize_commit_loss_weight
|
|
|
|
|
| self.commitment_weight = commitment_weight
|
|
|
| @property
|
| def codebook(self):
|
| return self.code_transform(self.frozen_codebook)
|
|
|
| def indices_to_codes(
|
| self,
|
| indices
|
| ):
|
| implicit_codebook = self.codebook
|
|
|
| frozen_codes = get_at('[c] d, b ... -> b ... d', self.frozen_codebook, indices)
|
| quantized = self.code_transform(frozen_codes)
|
|
|
| if self.channel_first:
|
| quantized = rearrange(quantized, 'b ... d -> b d ...')
|
|
|
| return quantized
|
|
|
| def forward(
|
| self,
|
| x
|
| ):
|
| if self.channel_first:
|
| x = rearrange(x, 'b d ... -> b ... d')
|
|
|
| x, inverse_pack = pack_one(x, 'b * d')
|
|
|
| implicit_codebook = self.codebook
|
|
|
| with torch.no_grad():
|
| dist = torch.cdist(x, implicit_codebook)
|
| indices = dist.argmin(dim=-1)
|
|
|
|
|
|
|
| quantized = get_at('[c] d, b n -> b n d', implicit_codebook, indices)
|
|
|
|
|
|
|
| commit_loss = (
|
| F.mse_loss(x.detach(), quantized) +
|
| F.mse_loss(x, quantized.detach()) * self.input_to_quantize_commit_loss_weight
|
| )
|
|
|
| quantized = (quantized - x).detach() + x
|
|
|
| quantized = inverse_pack(quantized)
|
| indices = inverse_pack(indices, 'b *')
|
|
|
| if self.channel_first:
|
| quantized = rearrange(quantized, 'b ... d-> b d ...')
|
|
|
| return quantized, commit_loss * self.commitment_weight, indices
|
|
|
|
|
| def init_weights(m):
|
| if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.LayerNorm):
|
| if m.weight is not None:
|
| nn.init.constant_(m.weight, 1)
|
| if m.bias is not None:
|
| nn.init.constant_(m.bias, 0)
|
| elif isinstance(m, nn.Linear):
|
| nn.init.xavier_uniform_(m.weight)
|
| if m.bias is not None:
|
| nn.init.constant_(m.bias, 0)
|
| elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) \
|
| or isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
|
| w = m.weight.data
|
| nn.init.xavier_uniform_(w)
|
| if m.bias is not None:
|
| nn.init.constant_(m.bias, 0)
|
| elif isinstance(m, nn.Embedding):
|
| nn.init.normal_(m.weight, mean=0, std=1)
|
|
|
|
|
| class ScalingLayerForQwen2ViT:
|
| def __init__(
|
| self,
|
| min_pixels: int = 56 * 56,
|
| max_pixels: int = 28 * 28 * 1280,
|
| patch_size: int = 14,
|
| temporal_patch_size: int = 2,
|
| merge_size: int = 2,
|
| **kwargs,
|
| ) -> None:
|
| super().__init__(**kwargs)
|
| OPENAI_CLIP_MEAN = torch.as_tensor([0.48145466, 0.4578275, 0.40821073])[None, :, None, None]
|
| OPENAI_CLIP_STD = torch.as_tensor([0.26862954, 0.26130258, 0.27577711])[None, :, None, None]
|
|
|
| self.image_mean = OPENAI_CLIP_MEAN
|
| self.image_std = OPENAI_CLIP_STD
|
| self.min_pixels = min_pixels
|
| self.max_pixels = max_pixels
|
| self.patch_size = patch_size
|
| self.temporal_patch_size = temporal_patch_size
|
| self.merge_size = merge_size
|
| self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
|
|
|
| def __call__(self, images):
|
| if images.ndim == 4:
|
| images = images.unsqueeze(1)
|
| batch_size, temporal, channel, height, width = images.shape
|
|
|
| factor = self.patch_size * self.merge_size
|
|
|
| resized_height, resized_width = height // factor * factor, width // factor * factor
|
|
|
| images = (images + 1) / 2
|
|
|
| images = torch.nn.functional.interpolate(
|
| images.flatten(0, 1).float(),
|
| size=(resized_height, resized_width),
|
| mode='bicubic',
|
| align_corners=False,
|
| antialias=True
|
| ).to(images.dtype)
|
|
|
| images = images.clamp(0, 1)
|
| images = ((images - self.image_mean.to(images)) / self.image_std.to(images))
|
|
|
| images = rearrange(images, '(b t) c h w -> b t c h w', b=batch_size, t=temporal)
|
| if temporal == 1:
|
| images = images.repeat_interleave(self.temporal_patch_size, dim=1)
|
| temporal = self.temporal_patch_size
|
|
|
| grid_t = temporal // self.temporal_patch_size
|
| grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
|
|
| images = images.reshape(
|
| batch_size * grid_t,
|
| self.temporal_patch_size,
|
| channel,
|
| -1
|
| )
|
|
|
| images = rearrange(images, 'b p c n -> b n (c p)')
|
| images = images.reshape(
|
| batch_size * grid_t,
|
| grid_h // self.merge_size,
|
| self.merge_size,
|
| self.patch_size,
|
| grid_w // self.merge_size,
|
| self.merge_size,
|
| self.patch_size,
|
| -1
|
| )
|
| images = rearrange(images, 'b h k s1 w l s2 n -> (b h w k l) (n s1 s2)')
|
|
|
| return dict(image=images, image_grid_thw=torch.as_tensor([[grid_t, grid_h, grid_w] for _ in range(batch_size)]))
|
|
|
|
|
| class SemanticEncoder(nn.Module):
|
| def __init__(self,
|
| semantic_encoder,
|
| z_channels=4,
|
| num_blocks=2,
|
| embed_dim=1280,
|
| proj_layer='linear',
|
| attn_implementation='xformers',
|
| target_mlp='identity',
|
| ):
|
| super().__init__()
|
| self.embed_dim = embed_dim
|
|
|
| if isinstance(semantic_encoder, str):
|
| self.model = Qwen2VisionTransformerPretrainedModel.from_pretrained(
|
| semantic_encoder,
|
| attn_implementation=attn_implementation
|
| )
|
| elif isinstance(semantic_encoder, dict):
|
| config = Qwen2VLVisionConfig(**semantic_encoder, attn_implementation=attn_implementation)
|
| self.model = Qwen2VisionTransformerPretrainedModel(config)
|
| else:
|
| raise ValueError(f"Invalid semantic_encoder: {semantic_encoder}")
|
| input_channels = self.model.config.hidden_size
|
|
|
| for p in self.model.parameters():
|
| p.requires_grad = False
|
|
|
| self.proj_in = nn.Conv2d(input_channels, embed_dim, 1, 1) if input_channels != embed_dim else nn.Identity()
|
|
|
| config = Qwen2VLVisionConfig(depth=num_blocks,
|
| embed_dim=embed_dim, )
|
| head_dim = config.embed_dim // config.num_heads
|
| self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
|
|
| self.blocks = nn.ModuleList(
|
| [Qwen2VLBatchVisionBlock(config, attn_implementation) for _ in range(num_blocks)]
|
| )
|
|
|
| if proj_layer == 'norm_linear':
|
| self.proj_out = nn.Sequential(
|
| nn.LayerNorm(embed_dim),
|
| nn.Linear(
|
| embed_dim,
|
| z_channels,
|
| )
|
| )
|
| elif proj_layer == 'linear':
|
| self.proj_out = nn.Sequential(
|
| nn.Linear(
|
| embed_dim,
|
| z_channels,
|
| )
|
| )
|
| elif proj_layer == 'mlp':
|
| self.proj_out = nn.Sequential(
|
| nn.Linear(embed_dim, embed_dim),
|
| nn.Tanh(),
|
| nn.Linear(embed_dim, z_channels),
|
| )
|
| else:
|
| raise RuntimeError(f"Wrong proj layer. Got {proj_layer}")
|
|
|
| if target_mlp == 'identity':
|
| self.target_mlp = nn.Sequential(
|
| nn.Identity(),
|
| )
|
| elif target_mlp == 'norm':
|
| self.target_mlp = nn.Sequential(
|
| nn.LayerNorm(input_channels, eps=1e-6, elementwise_affine=False),
|
| )
|
| self.init_weight()
|
|
|
| def init_weight(self):
|
| self.proj_in.apply(init_weights)
|
| self.blocks.apply(init_weights)
|
| self.proj_out.apply(init_weights)
|
| self.target_mlp.apply(init_weights)
|
|
|
| def rot_pos_emb(self, grid_thw, max_seq_len):
|
| pos_ids = torch.zeros((len(grid_thw), max_seq_len, 2), dtype=torch.long)
|
| for idx, (t, h, w) in enumerate(grid_thw):
|
| hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| hpos_ids = hpos_ids.flatten()
|
|
|
| wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| wpos_ids = wpos_ids.flatten()
|
|
|
| current_pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
|
| pos_ids[idx, :current_pos_ids.shape[0]] = current_pos_ids
|
| max_grid_size = grid_thw[:, 1:].max()
|
| rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
| rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(2)
|
| return rotary_pos_emb
|
|
|
| def forward(self, x, grid_thw):
|
| x = self.model(x, grid_thw=grid_thw)
|
|
|
| x = x_target = self.target_mlp(x)
|
|
|
| x = F.linear(x,
|
| self.proj_in.weight.view(self.proj_in.weight.shape[0], -1),
|
| self.proj_in.bias)
|
|
|
| new_grid_thw = torch.as_tensor([[t, h // 2, w // 2] for t, h, w in grid_thw])
|
|
|
| seq_lens = [t_i * h_i * w_i for t_i, h_i, w_i in new_grid_thw]
|
| max_seq_len = max(seq_lens)
|
|
|
| x = rearrange(x, '(b h w) c -> b (h w) c', h=new_grid_thw[0, 1], w=new_grid_thw[0, 2])
|
|
|
| rotary_pos_emb = self.rot_pos_emb(new_grid_thw, max_seq_len)
|
|
|
| for blk in self.blocks:
|
| x = blk(x, rotary_pos_emb=rotary_pos_emb)
|
|
|
| x = self.proj_out(x)
|
|
|
| t, h, w = new_grid_thw[0]
|
| b = len(grid_thw)
|
| x = rearrange(x, 'b (h w) c ->b c h w', b=b, h=h, w=w)
|
| x_target = rearrange(x_target, '(b h w) c ->b c h w', b=b, h=h, w=w)
|
| return x, x_target
|
|
|
|
|
| class SemanticDecoder(nn.Module):
|
| def __init__(self,
|
| z_channels=4,
|
| embed_dim=1280,
|
| num_blocks=2,
|
| output_channels=1280,
|
| attn_implementation='xformers',
|
| proj_layer='linear_norm'):
|
| super().__init__()
|
| self.proj_in = nn.Linear(z_channels, embed_dim)
|
|
|
| self.output_channels = output_channels
|
| config = Qwen2VLVisionConfig(depth=num_blocks, embed_dim=embed_dim)
|
|
|
| self.blocks = nn.ModuleList(
|
| [Qwen2VLBatchVisionBlock(config, attn_implementation) for _ in range(num_blocks)]
|
| )
|
| head_dim = config.embed_dim // config.num_heads
|
| self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
|
|
| if proj_layer == 'norm_linear':
|
| self.proj_out = nn.Sequential(
|
| nn.LayerNorm(embed_dim),
|
| nn.Linear(embed_dim, output_channels),
|
| )
|
| elif proj_layer == 'linear':
|
| self.proj_out = nn.Sequential(
|
| nn.Linear(embed_dim, output_channels)
|
| )
|
| elif proj_layer == 'mlp':
|
| self.proj_out = nn.Sequential(
|
| nn.Linear(embed_dim, embed_dim),
|
| nn.Tanh(),
|
| nn.Linear(embed_dim, output_channels),
|
| )
|
| elif proj_layer == 'linear_norm':
|
| self.proj_out = nn.Sequential(
|
| nn.Linear(embed_dim, output_channels),
|
| nn.LayerNorm(output_channels),
|
| )
|
|
|
| self.apply(init_weights)
|
|
|
| @property
|
| def last_layer(self):
|
| return self.proj_out[-1].weight
|
|
|
| def rot_pos_emb(self, grid_thw, max_seq_len):
|
| pos_ids = torch.zeros((len(grid_thw), max_seq_len, 2), dtype=torch.long)
|
| for idx, (t, h, w) in enumerate(grid_thw):
|
| hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| hpos_ids = hpos_ids.flatten()
|
|
|
| wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| wpos_ids = wpos_ids.flatten()
|
|
|
| current_pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
|
| pos_ids[idx, :current_pos_ids.shape[0]] = current_pos_ids
|
| max_grid_size = grid_thw[:, 1:].max()
|
| rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
| rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(2)
|
| return rotary_pos_emb
|
|
|
| def forward(self, z: torch.Tensor):
|
| x = z
|
|
|
| b, c, h, w = x.shape
|
|
|
| x = rearrange(x, 'b c h w -> b (h w) c')
|
|
|
| grid_thw = torch.as_tensor([[1, h, w] for _ in range(b)])
|
| seq_lens = [t * h * w for t, h, w in grid_thw]
|
| max_seq_len = max(seq_lens)
|
|
|
| x = self.proj_in(x)
|
|
|
| rotary_pos_emb = self.rot_pos_emb(grid_thw, max_seq_len)
|
|
|
| for blk in self.blocks:
|
| x = blk(x, rotary_pos_emb=rotary_pos_emb)
|
|
|
| x = self.proj_out(x)
|
| x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
| return x
|
|
|
|
|
| class DualViTokPretrainModel(PreTrainedModel):
|
| """
|
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| models.
|
| """
|
|
|
| config_class = DualViTokConfig
|
| base_model_prefix = "dualvitok"
|
| main_input_name = "pixel_values"
|
| _no_split_modules = ["BatchQwen2VLVisionBlock", "MoVQResnetBlock", "MoVQAttnBlock", "MoVQResnetTemporalBlock"]
|
| _supports_flash_attn_2 = True
|
| _supports_sdpa = True
|
| _supports_cache_class = True
|
| _supports_static_cache = True
|
|
|
| def _init_weights(self, module):
|
| if isinstance(module, (nn.Conv2d, nn.Conv3d)):
|
| nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
|
|
| elif isinstance(module, nn.Linear):
|
| nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
|
| if module.bias is not None:
|
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
|
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| nn.init.uniform_(module.bias, -bound, bound)
|
| elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
|
| nn.init.constant_(module.weight, 1)
|
| nn.init.constant_(module.bias, 0)
|
|
|
|
|
| class DualViTok(DualViTokPretrainModel):
|
| def __init__(self, config: DualViTokConfig):
|
| super().__init__(config)
|
| self.config = config
|
|
|
| self._semantic_channel = config.semantic_encoder.z_channels
|
| self._pixel_channel = config.pixel_encoder.z_channels
|
|
|
| self.semantic_encoder = SemanticEncoder(
|
| semantic_encoder=config.semantic_encoder.pretrained_semantic_encoder,
|
| z_channels=config.semantic_encoder.z_channels,
|
| num_blocks=config.semantic_encoder.num_blocks,
|
| embed_dim=config.semantic_encoder.embed_dim,
|
| proj_layer=config.semantic_encoder.out_layer,
|
| attn_implementation=config.attn_implementation,
|
| target_mlp=config.semantic_encoder.target_mlp, )
|
| self.semantic_decoder = SemanticDecoder(
|
| z_channels=config.semantic_decoder.z_channels,
|
| embed_dim=config.semantic_decoder.embed_dim,
|
| num_blocks=config.semantic_decoder.num_blocks,
|
| output_channels=config.semantic_decoder.out_channels,
|
| attn_implementation=config.attn_implementation,
|
| proj_layer=config.semantic_decoder.out_layer,
|
| )
|
|
|
| if config.semantic_quantizer_type.lower() == 'simvq':
|
| self.semantic_quantizer = SimVQ(
|
| dim=config.semantic_encoder.z_channels,
|
| codebook_size=config.semantic_quantizer_codebook_size,
|
| )
|
| elif config.semantic_quantizer_type.lower() == 'vq':
|
| raise NotImplementedError
|
| self.semantic_quantizer = VQ(
|
| dim=config.semantic_encoder.z_channels,
|
| codebook_size=config.semantic_quantizer_codebook_size,
|
| )
|
|
|
| self.pixel_encoder = MoVQEncoder(config.pixel_encoder)
|
| self.pixel_quant_conv = nn.Conv2d(config.pixel_encoder.z_channels, config.pixel_encoder.embed_dim, 1)
|
|
|
| if config.pixel_quantizer_type.lower() == 'simvq':
|
| self.pixel_quantizer = SimVQ(
|
| dim=config.pixel_encoder.z_channels,
|
| codebook_size=config.pixel_quantizer_codebook_size,
|
| )
|
| elif config.pixel_quantizer_type.lower() == 'vq':
|
| raise NotImplementedError
|
| self.pixel_quantizer = VQ(
|
| dim=config.pixel_encoder.z_channels,
|
| codebook_size=config.pixel_quantizer_codebook_size,
|
| )
|
|
|
| self.pixel_post_quant_conv = nn.Conv2d(config.pixel_decoder.embed_dim,
|
| config.pixel_decoder.z_channels, 1)
|
|
|
| self.pixel_decoder = MoVQDecoder(config.pixel_decoder)
|
|
|
| self.scaling_layer = ScalingLayerForQwen2ViT()
|
|
|
| @property
|
| def device(self):
|
| return get_parameter_device(self)
|
|
|
| @property
|
| def dtype(self):
|
| return get_parameter_dtype(self)
|
|
|
| @property
|
| def pixel_channel(self):
|
| return self._pixel_channel
|
|
|
| @property
|
| def semantic_channel(self):
|
| return self._semantic_channel
|
|
|
| def encode(self, image: torch.FloatTensor):
|
| scale_output = self.scaling_layer(image)
|
| image, image_grid_thw, image_gen = scale_output['image'], scale_output['image_grid_thw'], image
|
|
|
| h_semantic, target_semantic = self.semantic_encoder(image, image_grid_thw)
|
| quant_semantic, emb_loss_semantic, info_semantic = self.semantic_quantizer(h_semantic.float())
|
|
|
| h_pixel = self.pixel_encoder(image_gen)
|
| h_pixel = self.pixel_quant_conv(h_pixel)
|
|
|
| quant_pixel, emb_loss_pixel, info_pixel = self.pixel_quantizer(h_pixel.float())
|
|
|
| return (quant_semantic, emb_loss_semantic, info_semantic, target_semantic), \
|
| (quant_pixel, emb_loss_pixel, info_pixel)
|
|
|
| def encode_code(self, *args, **kwargs):
|
| (_, _, semantic_indices, _), \
|
| (_, _, pixel_indices) = self.encode(*args, **kwargs)
|
| return semantic_indices, pixel_indices
|
|
|
| def indices_to_codes(self, semantic_indices, pixel_indices):
|
| quant_semantic = self.semantic_quantizer.indices_to_codes(semantic_indices)
|
| quant_pixel = self.pixel_quantizer.indices_to_codes(pixel_indices)
|
| return quant_semantic, quant_pixel
|
|
|
| def encode_semantic(self, image: torch.FloatTensor):
|
| scale_output = self.scaling_layer(image)
|
| image, image_grid_thw, image_gen = scale_output['image'], scale_output['image_grid_thw'], image
|
|
|
| h_semantic, target_semantic = self.semantic_encoder(image, image_grid_thw)
|
| quant_semantic, emb_loss_semantic, info_semantic = self.semantic_quantizer(h_semantic.float())
|
| return quant_semantic, emb_loss_semantic, info_semantic, target_semantic
|
|
|
| def merge_quants(self, quant_semantic: torch.Tensor, quant_pixel: torch.Tensor):
|
| quant_semantic_resized = F.interpolate(
|
| quant_semantic, quant_pixel.shape[-2:], mode='bicubic'
|
| ).to(quant_semantic.dtype)
|
| quant_semantic = quant_semantic_resized
|
|
|
| quant = torch.cat([quant_semantic, quant_pixel], dim=1)
|
|
|
| return quant
|
|
|
| def decode(self, quant_semantic: torch.Tensor, quant_pixel: torch.Tensor, ):
|
| quant = self.merge_quants(quant_semantic, quant_pixel)
|
| quant2 = self.pixel_post_quant_conv(quant)
|
| x = self.pixel_decoder(quant2, quant)
|
| return x
|
|
|
| def decode_code(self, semantic_indices, pixel_indices):
|
| quant_semantic = self.semantic_quantizer.indices_to_codes(semantic_indices)
|
| quant_pixel = self.pixel_quantizer.indices_to_codes(pixel_indices)
|
| return self.decode(quant_semantic, quant_pixel)
|
|
|
| def decode_semantic(self, x: List[torch.Tensor]):
|
| return self.semantic_decoder(x)
|
|
|
| def forward(self, pixel_values: torch.FloatTensor):
|
| (quant_semantic, diff_semantic, _, target_semantic), \
|
| (quant_pixel, diff_pixel, _) = self.encode(pixel_values)
|
| dec = self.decode(quant_semantic, quant_pixel)
|
| dec_semantic = self.decode_semantic(quant_semantic)
|
| return (dec_semantic, diff_semantic, target_semantic), (dec, diff_pixel)
|
|
|
| def build_sdxl_decoder(self, path='ILLUME-MLLM/dualvitok-sdxl-decoder',
|
| image_processor=None,
|
| torch_dtype=torch.float16,
|
| add_watermarker=False,
|
| device='cuda',
|
| ):
|
| from .sdxl_decoder_pipe import StableDiffusionXLDecoderPipeline
|
|
|
| if image_processor is None:
|
| image_processor = AutoImageProcessor.from_pretrained('ILLUME-MLLM/dualvitok', trust_remote_code=True)
|
|
|
| return StableDiffusionXLDecoderPipeline.from_pretrained(path,
|
| torch_dtype=torch_dtype,
|
| add_watermarker=add_watermarker,
|
| vq_model=self,
|
| vq_image_processor=image_processor).to(device)
|
|
|
|
|