| | from dataclasses import dataclass |
| | from typing import Any, Dict, Optional, Union |
| |
|
| | import torch |
| | from torch import Tensor, nn |
| | from torch.nn.functional import fold, unfold |
| |
|
| | from prx_layers import ( |
| | EmbedND, |
| | LastLayer, |
| | PRXBlock, |
| | MLPEmbedder, |
| | get_image_ids, |
| | timestep_embedding, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class PRXParams: |
| | in_channels: int |
| | patch_size: int |
| | context_in_dim: int |
| | hidden_size: int |
| | mlp_ratio: float |
| | num_heads: int |
| | depth: int |
| | axes_dim: list[int] |
| | theta: int |
| | use_image_guidance: bool = False |
| | use_dyn_tanh: bool = False |
| | image_guidance_hidden_size: int = 1280 |
| |
|
| | |
| | time_factor: float = 1000.0 |
| | time_max_period: int = 10_000 |
| |
|
| | conditioning_block_ids: list[int] | None = None |
| |
|
| |
|
| | PRXTinyConfig = PRXParams( |
| | in_channels=4, |
| | patch_size=2, |
| | context_in_dim=512, |
| | hidden_size=2304, |
| | mlp_ratio=3.5, |
| | num_heads=32, |
| | depth=3, |
| | axes_dim=[64, 64], |
| | theta=10_000, |
| | ) |
| |
|
| |
|
| | PRXSmallConfig = PRXParams( |
| | in_channels=16, |
| | patch_size=2, |
| | context_in_dim=2304, |
| | hidden_size=1792, |
| | mlp_ratio=3.5, |
| | num_heads=28, |
| | depth=16, |
| | axes_dim=[32, 32], |
| | theta=10_000, |
| | ) |
| |
|
| |
|
| | PRXDCAESmallConfig = PRXParams( |
| | in_channels=32, |
| | patch_size=1, |
| | context_in_dim=2304, |
| | hidden_size=1792, |
| | mlp_ratio=3.5, |
| | num_heads=28, |
| | depth=16, |
| | axes_dim=[32, 32], |
| | theta=10_000, |
| | ) |
| |
|
| |
|
| | def img2seq(img: Tensor, patch_size: int) -> Tensor: |
| | """ |
| | Flatten an image into a sequence of patches |
| | b c (h ph) (w pw) -> b (h w) (c ph pw) |
| | """ |
| | return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) |
| |
|
| |
|
| | def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: |
| | """ |
| | Revert img2seq |
| | b (h w) (c ph pw) -> b c (h ph) (w pw) |
| | """ |
| | if isinstance(shape, tuple): |
| | shape = shape[-2:] |
| | elif isinstance(shape, torch.Tensor): |
| | shape = (int(shape[0]), int(shape[1])) |
| | else: |
| | raise NotImplementedError(f"shape type {type(shape)} not supported") |
| | return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) |
| |
|
| |
|
| | class PRX(nn.Module): |
| | """ |
| | PRX |
| | """ |
| | transformer_block_class = PRXBlock |
| |
|
| | def __init__(self, params: PRXParams | Dict[str, Any] | None = None, **kwargs: Any): |
| | super().__init__() |
| |
|
| | if params is None: |
| | |
| | params = kwargs |
| |
|
| | if isinstance(params, dict): |
| | |
| | params_dict = {k: v for k, v in params.items() if not k.startswith("_")} |
| | |
| | params = PRXParams(**params_dict) |
| | elif not isinstance(params, PRXParams): |
| | raise TypeError("params must be either PRXParams, a dictionary, or keyword arguments") |
| |
|
| | self.params = params |
| | |
| | self.in_channels = params.in_channels |
| | self.patch_size = params.patch_size |
| | self.use_image_guidance = params.use_image_guidance |
| | self.image_guidance_hidden_size = params.image_guidance_hidden_size |
| |
|
| | self.out_channels = self.in_channels * self.patch_size**2 |
| |
|
| | self.time_factor = params.time_factor |
| | self.time_max_period = params.time_max_period |
| |
|
| | if params.hidden_size % params.num_heads != 0: |
| | raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") |
| |
|
| | pe_dim = params.hidden_size // params.num_heads |
| |
|
| | if sum(params.axes_dim) != pe_dim: |
| | raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") |
| |
|
| | self.hidden_size = params.hidden_size |
| | self.num_heads = params.num_heads |
| | self.pe_embedder = EmbedND( |
| | dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim |
| | ) |
| | self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) |
| | self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) |
| | self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) |
| |
|
| | conditioning_block_ids: list[int] = params.conditioning_block_ids or list( |
| | range(params.depth) |
| | ) |
| |
|
| | def block_class(idx: int) -> PRXBlock: |
| | return self.transformer_block_class if idx in conditioning_block_ids else PRXBlock |
| |
|
| | self.blocks = nn.ModuleList( |
| | [ |
| | block_class(i)( |
| | self.hidden_size, |
| | self.num_heads, |
| | mlp_ratio=params.mlp_ratio, |
| | use_image_guidance=self.use_image_guidance, |
| | image_guidance_hidden_size=params.image_guidance_hidden_size, |
| | ) |
| | for i in range(params.depth) |
| | ] |
| | ) |
| |
|
| | self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) |
| |
|
| | if params.use_dyn_tanh: |
| | |
| | print("Replacing all the LayerNorms by DynTanh") |
| | self.blocks = convert_ln_to_dyt(self.blocks) |
| | self.final_layer = convert_ln_to_dyt(self.final_layer) |
| |
|
| | def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: |
| | "Timestep independent stuff" |
| | txt = self.txt_in(txt) |
| | img = img2seq(image_latent, self.patch_size) |
| | bs, _, h, w = image_latent.shape |
| | img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device) |
| | pe = self.pe_embedder(img_ids) |
| | return img, txt, pe |
| |
|
| | def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: |
| | return self.time_in( |
| | timestep_embedding(t=timestep, dim=256, max_period=self.time_max_period, time_factor=self.time_factor).to( |
| | dtype |
| | ) |
| | ) |
| |
|
| | def forward_transformers( |
| | self, |
| | image_latent: Tensor, |
| | cross_attn_conditioning: Tensor, |
| | timestep: Optional[Tensor] = None, |
| | time_embedding: Optional[Tensor] = None, |
| | attention_mask: Optional[Tensor] = None, |
| | **block_kwargs: Any, |
| | ) -> Tensor: |
| | img = self.img_in(image_latent) |
| |
|
| | if time_embedding is not None: |
| | |
| | vec = time_embedding |
| | else: |
| | if timestep is None: |
| | raise ValueError("Please provide either a timestep or a timestep_embedding") |
| | vec = self.compute_timestep_embedding(timestep, dtype=img.dtype) |
| | for block in self.blocks: |
| | img = block(img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs) |
| |
|
| | img = self.final_layer(img, vec) |
| | return img |
| |
|
| | def forward( |
| | self, |
| | image_latent: Tensor, |
| | timestep: Tensor, |
| | cross_attn_conditioning: Tensor, |
| | micro_conditioning: Tensor, |
| | cross_attn_mask: None | Tensor = None, |
| | image_conditioning: None | Tensor = None, |
| | image_guidance_scale: None | float | Tensor = None, |
| | guidance: None = None, |
| | apply_token_drop: bool = False, |
| | ) -> Tensor: |
| | img_seq, txt, pe = self.process_inputs(image_latent, cross_attn_conditioning) |
| | img_seq = self.forward_transformers( |
| | img_seq, |
| | txt, |
| | timestep, |
| | pe=pe, |
| | image_conditioning=image_conditioning, |
| | image_guidance_scale=image_guidance_scale, |
| | attention_mask=cross_attn_mask, |
| | ) |
| | return seq2img(img_seq, self.patch_size, image_latent.shape) |
| |
|
| | if __name__ == "__main__": |
| | DEVICE = torch.device("cuda") |
| | DTYPE = torch.bfloat16 |
| |
|
| | BS = 2 |
| | LATENT_C = 16 |
| | FEATURE_H, FEATURE_W = 512 // 8, 512 // 8 |
| | PROMPT_L = 120 |
| | config = PRXSmallConfig |
| |
|
| | denoiser = PRX(config) |
| | total_params = sum(p.numel() for p in denoiser.parameters()) |
| | print(f"Total number of parameters : {total_params / 1e9: .3f}B") |
| | denoiser = denoiser.to(DEVICE, DTYPE) |
| |
|
| | out = denoiser( |
| | image_latent=torch.randn(BS, LATENT_C, FEATURE_H, FEATURE_W, device=DEVICE, dtype=DTYPE), |
| | timestep=torch.zeros(BS, device=DEVICE, dtype=DTYPE), |
| | cross_attn_conditioning=torch.zeros(BS, PROMPT_L, 2304, device=DEVICE, dtype=DTYPE), |
| | micro_conditioning=None, |
| | cross_attn_mask=torch.ones(BS, PROMPT_L, device=DEVICE, dtype=DTYPE), |
| | ) |
| | loss = out.sum() |
| | loss.backward() |
| | print("ok") |
| | checkpoint_path = "../diffusers_ok/old_and_checkpoints/computer_vision_checkpoints/denoiser_sft_weights.pth" |
| | |
| | print(f"Loading checkpoint from: {checkpoint_path}") |
| | state_dict = torch.load(checkpoint_path) |
| | included_keys = denoiser.load_state_dict(torch.load(checkpoint_path), strict=True) |
| | print(f"Included keys: {included_keys}") |
| | |
| |
|
| |
|