| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| HuggingFace Diffusers-Compatible Vision Token to Image Pipeline. |
| |
| Usage: |
| from diffusers_vtoken_pipeline import VisionTokenToImagePipeline |
| |
| pipeline = VisionTokenToImagePipeline.from_pretrained("path/to/model") |
| image = pipeline(vision_tokens).images[0] |
| """ |
|
|
| import json |
| import logging |
| import math |
| import os |
| from dataclasses import dataclass |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from diffusers import AutoencoderKL, DiffusionPipeline |
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_utils import ModelMixin |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
| from diffusers.utils import BaseOutput |
| from einops import rearrange, repeat |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| |
| logger = logging.getLogger("vision-decoder-api.pipeline") |
|
|
| |
| try: |
| from torch.nn.attention import sdpa_kernel, SDPBackend |
| FLASH_ATTN_AVAILABLE = torch.cuda.is_available() |
| except ImportError: |
| FLASH_ATTN_AVAILABLE = False |
| sdpa_kernel = None |
| SDPBackend = None |
|
|
|
|
| |
| |
| |
|
|
| def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: |
| """Rotary Position Embedding computation.""" |
| assert dim % 2 == 0 |
| scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim |
| omega = 1.0 / (theta ** scale) |
| out = torch.einsum("...n,d->...nd", pos, omega) |
| out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) |
| out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) |
| return out.float() |
|
|
|
|
| def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Apply rotary position embedding to query and key.""" |
| xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) |
| xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) |
| xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] |
| xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] |
| return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) |
|
|
|
|
| def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pe: torch.Tensor) -> torch.Tensor: |
| """Attention with rotary position embedding and Flash Attention optimization.""" |
| q, k = apply_rope(q, k, pe) |
| |
| x = F.scaled_dot_product_attention(q, k, v) |
| |
| x = rearrange(x, "B H L D -> B L (H D)") |
| return x |
|
|
|
|
| @torch.no_grad() |
| def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000, time_factor: float = 1000.0) -> torch.Tensor: |
| """Create sinusoidal timestep embeddings.""" |
| t = time_factor * t |
| half = dim // 2 |
| freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) |
| args = t[:, None].float() * freqs[None] |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if dim % 2: |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| if torch.is_floating_point(t): |
| embedding = embedding.to(t) |
| return embedding |
|
|
|
|
| class EmbedND(nn.Module): |
| """N-dimensional position embedding.""" |
| def __init__(self, dim: int, theta: int, axes_dim: List[int]): |
| super().__init__() |
| self.dim = dim |
| self.theta = theta |
| self.axes_dim = axes_dim |
|
|
| def forward(self, ids: torch.Tensor) -> torch.Tensor: |
| n_axes = ids.shape[-1] |
| emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) |
| return emb.unsqueeze(1) |
|
|
|
|
| class MLPEmbedder(nn.Module): |
| """MLP for timestep and vector embeddings.""" |
| def __init__(self, in_dim: int, hidden_dim: int): |
| super().__init__() |
| self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) |
| self.silu = nn.SiLU() |
| self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.out_layer(self.silu(self.in_layer(x))) |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root Mean Square Layer Normalization.""" |
| def __init__(self, dim: int): |
| super().__init__() |
| self.scale = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x_dtype = x.dtype |
| x = x.float() |
| rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + 1e-6) |
| return (x * rrms).to(dtype=x_dtype) * self.scale |
|
|
|
|
| class QKNorm(nn.Module): |
| """Query-Key normalization.""" |
| def __init__(self, dim: int): |
| super().__init__() |
| self.query_norm = RMSNorm(dim) |
| self.key_norm = RMSNorm(dim) |
|
|
| def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| q = self.query_norm(q) |
| k = self.key_norm(k) |
| return q.to(v), k.to(v) |
|
|
|
|
| @dataclass |
| class ModulationOut: |
| shift: torch.Tensor |
| scale: torch.Tensor |
| gate: torch.Tensor |
|
|
|
|
| class Modulation(nn.Module): |
| """Adaptive layer normalization modulation.""" |
| def __init__(self, dim: int, double: bool): |
| super().__init__() |
| self.is_double = double |
| self.multiplier = 6 if double else 3 |
| self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) |
|
|
| def forward(self, vec: torch.Tensor) -> Tuple[ModulationOut, Optional[ModulationOut]]: |
| out = self.lin(F.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) |
| return (ModulationOut(*out[:3]), ModulationOut(*out[3:]) if self.is_double else None) |
|
|
|
|
| class SingleStreamBlock(nn.Module): |
| """Single stream transformer block (parallel attention and MLP).""" |
| def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, qk_scale: Optional[float] = None): |
| super().__init__() |
| self.hidden_dim = hidden_size |
| self.num_heads = num_heads |
| head_dim = hidden_size // num_heads |
| self.scale = qk_scale or head_dim ** -0.5 |
|
|
| self.mlp_hidden_dim = int(hidden_size * mlp_ratio) |
| self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) |
| self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) |
|
|
| self.norm = QKNorm(head_dim) |
| self.hidden_size = hidden_size |
| self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.mlp_act = nn.GELU(approximate="tanh") |
| self.modulation = Modulation(hidden_size, double=False) |
|
|
| def forward(self, x: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor) -> torch.Tensor: |
| mod, _ = self.modulation(vec) |
| x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift |
| qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) |
|
|
| q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) |
| q, k = self.norm(q, k, v) |
|
|
| attn = attention(q, k, v, pe=pe) |
| output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) |
| return x + mod.gate * output |
|
|
|
|
| class LastLayer(nn.Module): |
| """Final projection layer with adaptive normalization.""" |
| def __init__(self, hidden_size: int, patch_size: int, out_channels: int): |
| super().__init__() |
| self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) |
| self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) |
|
|
| def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: |
| shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) |
| x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] |
| x = self.linear(x) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class VisionTransformer(ModelMixin, ConfigMixin): |
| """ |
| Vision Transformer for vision token to image generation. |
| |
| This model is fully compatible with HuggingFace Diffusers and can be: |
| - Saved with `model.save_pretrained()` |
| - Loaded with `VisionTransformer.from_pretrained()` |
| - Uploaded to HuggingFace Hub |
| """ |
| |
| @register_to_config |
| def __init__( |
| self, |
| in_channels: int = 16, |
| vec_in_dim: int = 1536, |
| context_in_dim: int = 1536, |
| hidden_size: int = 1920, |
| mlp_ratio: float = 4.0, |
| num_heads: int = 24, |
| depth: int = 0, |
| depth_single_blocks: int = 35, |
| axes_dim: Tuple[int, int, int] = (8, 36, 36), |
| theta: int = 10_000, |
| qkv_bias: bool = True, |
| guidance_embed: bool = False, |
| use_patchify: bool = False, |
| ): |
| super().__init__() |
| |
| self.in_channels = in_channels |
| self.context_in_dim = context_in_dim |
| self.out_channels = in_channels |
| self.hidden_size = hidden_size |
| self.num_heads = num_heads |
| self.use_patchify = use_patchify |
| |
| if hidden_size % num_heads != 0: |
| raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}") |
| |
| pe_dim = hidden_size // num_heads |
| axes_dim_list = list(axes_dim) |
| if sum(axes_dim_list) != pe_dim: |
| raise ValueError(f"Got {axes_dim_list} but expected positional dim {pe_dim}") |
| |
| |
| self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim_list) |
| |
| |
| self.img_in = nn.Linear(in_channels + context_in_dim, hidden_size, bias=True) |
| self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size) |
| self.vector_in = MLPEmbedder(vec_in_dim, hidden_size) |
| |
| |
| self.single_blocks = nn.ModuleList([ |
| SingleStreamBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) |
| for _ in range(depth_single_blocks) |
| ]) |
| |
| |
| self.final_layer = LastLayer(hidden_size, 1, self.out_channels) |
| |
| def forward( |
| self, |
| img: torch.Tensor, |
| img_ids: torch.Tensor, |
| timesteps: torch.Tensor, |
| y: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Forward pass of the transformer. |
| |
| Args: |
| img: Input tensor (B, L, in_channels + context_in_dim) |
| img_ids: Position IDs tensor (B, L, 3) |
| timesteps: Sigma/timestep tensor (B,) |
| y: Vision pooler output tensor (B, vec_in_dim) |
| |
| Returns: |
| Output tensor (B, L, out_channels) |
| """ |
| if img.ndim != 3: |
| raise ValueError("Input img tensor must have 3 dimensions.") |
| |
| |
| img = self.img_in(img) |
| |
| |
| vec = self.time_in( |
| timestep_embedding(timesteps, 256).to(dtype=self.time_in.in_layer.weight.dtype, device=img.device) |
| ) |
| vec = vec + self.vector_in(y) |
| |
| |
| pe = self.pe_embedder(img_ids) |
| |
| |
| for block in self.single_blocks: |
| img = block(img, vec=vec, pe=pe) |
| |
| |
| img = self.final_layer(img, vec) |
| |
| return img |
|
|
|
|
| |
| |
| |
|
|
| class VisionTokenEmbedder(ModelMixin, ConfigMixin): |
| """ |
| Vision Token Embedder that converts discrete vision tokens to embeddings. |
| """ |
|
|
| @register_to_config |
| def __init__( |
| self, |
| vocab_size: int = 65536, |
| embedding_dim: int = 1536, |
| token_length: int = 729, |
| ): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.embedding_dim = embedding_dim |
| self.token_length = token_length |
| |
| |
| self.register_buffer("vocab_embeddings", torch.zeros(vocab_size, embedding_dim)) |
| self.register_buffer("uncond_embedding", torch.zeros(1, embedding_dim)) |
| |
| def load_vocab_embeddings(self, embeddings: torch.Tensor): |
| """Load vocabulary embeddings from a tensor.""" |
| if embeddings.shape != (self.vocab_size, self.embedding_dim): |
| raise ValueError( |
| f"Expected embeddings shape ({self.vocab_size}, {self.embedding_dim}), " |
| f"got {embeddings.shape}" |
| ) |
| self.vocab_embeddings.copy_(embeddings) |
| |
| def forward(self, tokens: torch.Tensor) -> Dict[str, torch.Tensor]: |
| """ |
| Convert vision tokens to embeddings. |
| |
| Args: |
| tokens: Vision token IDs (B, L) |
| |
| Returns: |
| Dictionary with vision_last_hidden_state and vision_pooler_output |
| """ |
| hidden_states = self.vocab_embeddings[tokens] |
| pooler_output = hidden_states.mean(dim=1) |
| |
| return { |
| "vision_last_hidden_state": hidden_states, |
| "vision_pooler_output": pooler_output, |
| } |
| |
| def get_uncond_embeddings(self, batch_size: int, token_length: int) -> Dict[str, torch.Tensor]: |
| """Get unconditional embeddings for classifier-free guidance.""" |
| uncond_hidden = self.uncond_embedding.expand(batch_size, token_length, -1) |
| uncond_pooler = uncond_hidden.mean(dim=1) |
| |
| return { |
| "vision_last_hidden_state": uncond_hidden, |
| "vision_pooler_output": uncond_pooler, |
| } |
| |
| @classmethod |
| def from_numpy(cls, npy_path: str) -> "VisionTokenEmbedder": |
| """Create embedder from numpy file.""" |
| embeddings = torch.from_numpy(np.load(npy_path)).float() |
| vocab_size, embedding_dim = embeddings.shape |
| |
| embedder = cls( |
| vocab_size=vocab_size, |
| embedding_dim=embedding_dim, |
| token_length=729, |
| ) |
| embedder.load_vocab_embeddings(embeddings) |
| return embedder |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class VisionTokenToImagePipelineOutput(BaseOutput): |
| """Output class for VisionTokenToImagePipeline.""" |
| images: List[Image.Image] |
|
|
|
|
| class VisionTokenToImagePipeline(DiffusionPipeline): |
| """ |
| HuggingFace Diffusers Pipeline for Vision Token to Image generation. |
| |
| Supports autoguidance when transformer2 is provided and guidance_scale > 0. |
| """ |
| |
| |
| model_cpu_offload_seq = "token_embedder->transformer->transformer2->vae" |
| _optional_components = ["transformer2"] |
| |
| def __init__( |
| self, |
| transformer: VisionTransformer, |
| vae: AutoencoderKL, |
| scheduler: FlowMatchEulerDiscreteScheduler, |
| token_embedder: VisionTokenEmbedder, |
| transformer2: Optional[VisionTransformer] = None, |
| ): |
| super().__init__() |
| |
| |
| self.register_modules( |
| transformer=transformer, |
| transformer2=transformer2, |
| vae=vae, |
| scheduler=scheduler, |
| token_embedder=token_embedder, |
| ) |
| |
| self.vae_scale_factor = 8 |
| self._use_autoguidance = transformer2 is not None |
| |
| def _prepare_latents( |
| self, |
| batch_size: int, |
| height: int, |
| width: int, |
| dtype: torch.dtype, |
| device: torch.device, |
| generator: Optional[torch.Generator] = None, |
| ) -> torch.Tensor: |
| """Prepare random latents for diffusion.""" |
| latent_h = height // self.vae_scale_factor |
| latent_w = width // self.vae_scale_factor |
| latent_channels = 16 |
| |
| shape = (batch_size, latent_channels, latent_h, latent_w) |
| latents = torch.randn(shape, device=device, dtype=dtype, generator=generator) |
| |
| return latents |
| |
| def _prepare_img_ids( |
| self, |
| batch_size: int, |
| img_h: int, |
| img_w: int, |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> torch.Tensor: |
| """Prepare position IDs for the transformer.""" |
| img_ids = torch.zeros(img_h, img_w, 3) |
| img_ids[..., 1] = img_ids[..., 1] + torch.arange(img_h)[:, None] |
| img_ids[..., 2] = img_ids[..., 2] + torch.arange(img_w)[None, :] |
| img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) |
| return img_ids.to(device=device, dtype=dtype) |
| |
| def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: |
| """Decode latents to images using VAE.""" |
| scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0) |
| shift_factor = getattr(self.vae.config, "shift_factor", 0.0) |
| |
| latents = latents / scaling_factor + shift_factor |
| images = self.vae.decode(latents).sample |
| |
| return images |
| |
| def postprocess(self, images: torch.Tensor) -> List[Image.Image]: |
| """Convert tensor images to PIL images.""" |
| images = images.clamp(-1, 1) |
| images = (images + 1) / 2 |
| images = images.float().cpu().permute(0, 2, 3, 1).numpy() |
| images = (images * 255).astype(np.uint8) |
| |
| return [Image.fromarray(img) for img in images] |
| |
| @torch.no_grad() |
| def __call__( |
| self, |
| vision_tokens: Union[np.ndarray, torch.Tensor, List[int]], |
| height: int = 768, |
| width: int = 768, |
| num_inference_steps: int = 50, |
| guidance_scale: float = 0.0, |
| generator: Optional[Union[torch.Generator, int]] = None, |
| output_type: str = "pil", |
| return_dict: bool = True, |
| ) -> Union[VisionTokenToImagePipelineOutput, Tuple]: |
| """ |
| Generate images from vision tokens. |
| |
| Args: |
| vision_tokens: Vision token IDs. Can be: |
| - numpy array of shape (L,) or (B, L) |
| - torch tensor of shape (L,) or (B, L) |
| - List of integers |
| height: Output image height (must be divisible by 16) |
| width: Output image width (must be divisible by 16) |
| num_inference_steps: Number of diffusion steps |
| guidance_scale: Autoguidance scale (0 = no guidance, requires transformer2) |
| generator: Random generator for reproducibility (int seed or torch.Generator) |
| output_type: Output type ("pil" or "pt" for tensor) |
| return_dict: Whether to return a dataclass or tuple |
| |
| Returns: |
| VisionTokenToImagePipelineOutput with generated images |
| """ |
| |
| logger.info(f"[Pipeline] __call__ invoked with guidance_scale={guidance_scale}, num_inference_steps={num_inference_steps}") |
| logger.info(f"[Pipeline] Image dimensions: {width}x{height}") |
| logger.info(f"[Pipeline] transformer2 available: {self.transformer2 is not None}") |
| logger.info(f"[Pipeline] _use_autoguidance flag: {self._use_autoguidance}") |
| |
| device = self._execution_device |
| dtype = self.transformer.dtype if hasattr(self.transformer, 'dtype') else torch.bfloat16 |
| |
| |
| if isinstance(generator, int): |
| logger.info(f"[Pipeline] Using seed: {generator}") |
| generator = torch.Generator(device=device).manual_seed(generator) |
| |
| |
| if isinstance(vision_tokens, list): |
| vision_tokens = np.array(vision_tokens, dtype=np.int64) |
| if isinstance(vision_tokens, np.ndarray): |
| vision_tokens = torch.from_numpy(vision_tokens) |
| if vision_tokens.ndim == 1: |
| vision_tokens = vision_tokens.unsqueeze(0) |
| |
| vision_tokens = vision_tokens.long().to(device) |
| batch_size = vision_tokens.shape[0] |
| |
| |
| vision_cond = self.token_embedder(vision_tokens) |
| |
| |
| latents = self._prepare_latents( |
| batch_size, height, width, |
| dtype=dtype, device=device, |
| generator=generator, |
| ) |
| |
| |
| _, _, h, w = latents.shape |
| use_patchify = self.transformer.use_patchify |
| |
| if use_patchify: |
| img_h, img_w = h // 2, w // 2 |
| else: |
| img_h, img_w = h, w |
| |
| img_ids = self._prepare_img_ids(batch_size, img_h, img_w, device, dtype) |
| |
| |
| vision_hidden = vision_cond["vision_last_hidden_state"].to(dtype) |
| vision_pooler = vision_cond["vision_pooler_output"].to(dtype) |
| |
| |
| cond_h = cond_w = int(vision_hidden.shape[1] ** 0.5) |
| vision_spatial = rearrange(vision_hidden, "b (h w) c -> b c h w", h=cond_h, w=cond_w) |
| vision_spatial = F.interpolate(vision_spatial, size=(img_h, img_w), mode="bilinear") |
| vision_spatial = rearrange(vision_spatial, "b d h w -> b (h w) d") |
| |
| |
| self.scheduler.set_timesteps(num_inference_steps, device=device) |
| timesteps = self.scheduler.timesteps |
| |
| |
| use_autoguidance = self._use_autoguidance and guidance_scale > 0 |
| |
| |
| if use_autoguidance: |
| logger.info(f"[Pipeline] ✓ Denoising: transformer + transformer2 (guidance_scale={guidance_scale})") |
| elif guidance_scale > 0 and not self._use_autoguidance: |
| logger.warning(f"[Pipeline] ✗ transformer2 not available, guidance_scale={guidance_scale} ignored") |
| else: |
| logger.info(f"[Pipeline] Denoising: transformer only (guidance_scale={guidance_scale})") |
| |
| |
| for i, t in enumerate(tqdm(timesteps, desc="Denoising", leave=False)): |
| |
| if use_patchify: |
| x_t = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) |
| else: |
| x_t = rearrange(latents, "b c h w -> b (h w) c") |
| |
| |
| x_t = torch.cat((x_t, vision_spatial), dim=2) |
|
|
| |
| t_batch = torch.full((batch_size,), t.item(), device=device, dtype=torch.long) |
| sigma = t_batch.float() / self.scheduler.config.num_train_timesteps |
| |
| |
| pred = self.transformer( |
| img=x_t, |
| img_ids=img_ids, |
| timesteps=sigma, |
| y=vision_pooler, |
| ) |
|
|
| |
| if use_autoguidance: |
| pred2 = self.transformer2(img=x_t, img_ids=img_ids, timesteps=sigma, y=vision_pooler) |
| pred = pred + guidance_scale * (pred - pred2) |
| |
| |
| if use_patchify: |
| pred = rearrange(pred, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h//2, w=w//2, ph=2, pw=2) |
| else: |
| pred = rearrange(pred, "b (h w) c -> b c h w", h=h, w=w) |
| |
| |
| latents = self.scheduler.step(pred, t, latents, generator=generator).prev_sample |
|
|
| |
| images = self.decode_latents(latents) |
|
|
| |
| if output_type == "pil": |
| images = self.postprocess(images) |
|
|
| if not return_dict: |
| return (images,) |
|
|
| return VisionTokenToImagePipelineOutput(images=images) |
|
|