diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..d763c293f353e053bf107a3fd0a1f479a623bb55 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +2D_Stage/material/examples/1.png filter=lfs diff=lfs merge=lfs -text +2D_Stage/material/examples/2.png filter=lfs diff=lfs merge=lfs -text +2D_Stage/material/examples/3.png filter=lfs diff=lfs merge=lfs -text +2D_Stage/material/examples/4.png filter=lfs diff=lfs merge=lfs -text +2D_Stage/material/examples/5.png filter=lfs diff=lfs merge=lfs -text +2D_Stage/material/examples/7.png filter=lfs diff=lfs merge=lfs -text +2D_Stage/material/examples/8.png filter=lfs diff=lfs merge=lfs -text +3D_Stage/material/examples/1/1.png filter=lfs diff=lfs merge=lfs -text +3D_Stage/material/examples/1/2.png filter=lfs diff=lfs merge=lfs -text +3D_Stage/material/examples/1/3.png filter=lfs diff=lfs merge=lfs -text +3D_Stage/material/examples/1/4.png filter=lfs diff=lfs merge=lfs -text +final_texture.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..01f427e77212f86512a5f6474878c0f4d27cf74b --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +# Ignore Python bytecode files +*.pyc +*.pyo +__pycache__/ + +# Ignore virtual environment directory +venv/ + +/3D_stage/outputs/ +input_3D.png +input.png + +# LFS pointer files (large model files downloaded at runtime) +3D_Stage/load/tets/*.npz \ No newline at end of file diff --git a/2D_Stage/configs/infer.yaml b/2D_Stage/configs/infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee53fe601bb4599bbdea936acc7f4d462a8c4e11 --- /dev/null +++ b/2D_Stage/configs/infer.yaml @@ -0,0 +1,24 @@ +pretrained_model_path: "sd2-community/stable-diffusion-2-1" +image_encoder_path: "./models/image_encoder" +ckpt_dir: "./models/checkpoint" + +validation: + guidance_scale: 5.0 + use_inv_latent: False + video_length: 4 + +use_pose_guider: True +use_noise: False +use_shifted_noise: False +unet_condition_type: image + +unet_from_pretrained_kwargs: + camera_embedding_type: 'e_de_da_sincos' + projection_class_embeddings_input_dim: 10 # modify + joint_attention: false # modify + num_views: 4 + sample_size: 96 + zero_init_conv_in: false + zero_init_camera_projection: false + in_channels: 4 + use_safetensors: true \ No newline at end of file diff --git a/2D_Stage/material/examples/1.png b/2D_Stage/material/examples/1.png new file mode 100644 index 0000000000000000000000000000000000000000..64e5183ffec3b03a4aaa82e828d03be5f5d2b434 --- /dev/null +++ b/2D_Stage/material/examples/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8fd677efa043cc71fbe0d78e30b93c1f49fd88fd8d2a00ae6946f6f0a06d3b2 +size 620509 diff --git a/2D_Stage/material/examples/2.png b/2D_Stage/material/examples/2.png new file mode 100644 index 0000000000000000000000000000000000000000..d29a82ce617bd746cfbfbaf31b0689c056750335 --- /dev/null +++ b/2D_Stage/material/examples/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b342e025d7170708fdc19c1fa816835d70de61a315d6bde3e2afb390977882d0 +size 304610 diff --git a/2D_Stage/material/examples/3.png b/2D_Stage/material/examples/3.png new file mode 100644 index 0000000000000000000000000000000000000000..d31da2adddca3f6a1a3cfcf7a3a1f2c6852e1480 --- /dev/null +++ b/2D_Stage/material/examples/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cf15c6d6321c8ae8177aefd6ae1d3c41b3cbd51f1f0df13f1be53633273fa457 +size 187843 diff --git a/2D_Stage/material/examples/4.png b/2D_Stage/material/examples/4.png new file mode 100644 index 0000000000000000000000000000000000000000..930b0842d94851302b5407e282c6df08632c847b --- /dev/null +++ b/2D_Stage/material/examples/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7291130686044453c3c1a7fc7c9b66bb2424a825cf2406149d6a193cdcf7638c +size 179319 diff --git a/2D_Stage/material/examples/5.png b/2D_Stage/material/examples/5.png new file mode 100644 index 0000000000000000000000000000000000000000..9bf77f6bc52e74fac152ef963acc9a3e7b4d0849 --- /dev/null +++ b/2D_Stage/material/examples/5.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:385007de9cb47cafd6f63971afaf72e7d8ab6f7c146dc10cf0ad4788072d592c +size 141777 diff --git a/2D_Stage/material/examples/6.png b/2D_Stage/material/examples/6.png new file mode 100644 index 0000000000000000000000000000000000000000..ba5dc79b7d3283af18f57e3686836762e2aa03da Binary files /dev/null and b/2D_Stage/material/examples/6.png differ diff --git a/2D_Stage/material/examples/7.png b/2D_Stage/material/examples/7.png new file mode 100644 index 0000000000000000000000000000000000000000..78d984a3f88770f95c8a5ba7063423431f9d3df8 --- /dev/null +++ b/2D_Stage/material/examples/7.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b8e002752dbae219db8c7f6ecab6b79d62c31c06a5a36520effe82eb53418a4 +size 166635 diff --git a/2D_Stage/material/examples/8.png b/2D_Stage/material/examples/8.png new file mode 100644 index 0000000000000000000000000000000000000000..2facade339f89502e74a7a9ebf6296e1fe8ea74d --- /dev/null +++ b/2D_Stage/material/examples/8.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd3def91bb231c82fb39a770a8b645078ab47d4a8f74e2068ccb8db0a0438837 +size 198307 diff --git a/2D_Stage/material/pose.json b/2D_Stage/material/pose.json new file mode 100644 index 0000000000000000000000000000000000000000..40ed1ce28aa4cd809afd2d1f637887ed88b6c5fd --- /dev/null +++ b/2D_Stage/material/pose.json @@ -0,0 +1,38 @@ +[ + [ + [ + 0, 0, -1, 0, + 0, 1, 0, 0, + 1, 0, 0, 0, + 1.5, 0, 0, 1 + ], + "pose0.png" + ], + [ + [ + 0, 0, 1, 0, + 0, 1, 0, 0, + -1, 0, 0, 0, + -1.5, 0, 0, 1 + ], + "pose1.png" + ], + [ + [ + 0, 0, 1, 0, + 0, 1, 0, 0, + -1, 0, 0, 0, + -1.5, 0, 0, 1 + ], + "pose2.png" + ], + [ + [ + -1, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, -1, 0, + 0, 0, -1.5, 1 + ], + "pose3.png" + ] +] \ No newline at end of file diff --git a/2D_Stage/material/pose0.png b/2D_Stage/material/pose0.png new file mode 100644 index 0000000000000000000000000000000000000000..7e044b3aaf9fce856e4a8fa50dc86fdc90fa41ac Binary files /dev/null and b/2D_Stage/material/pose0.png differ diff --git a/2D_Stage/material/pose1.png b/2D_Stage/material/pose1.png new file mode 100644 index 0000000000000000000000000000000000000000..fd4093ce11b6f950853bb77ed95d6605db476784 Binary files /dev/null and b/2D_Stage/material/pose1.png differ diff --git a/2D_Stage/material/pose2.png b/2D_Stage/material/pose2.png new file mode 100644 index 0000000000000000000000000000000000000000..537d77bf217ac5ba9522062d3348e4cb0161883e Binary files /dev/null and b/2D_Stage/material/pose2.png differ diff --git a/2D_Stage/material/pose3.png b/2D_Stage/material/pose3.png new file mode 100644 index 0000000000000000000000000000000000000000..c562d5f9f5db7d26818a3fc4512e2d168157cded Binary files /dev/null and b/2D_Stage/material/pose3.png differ diff --git a/2D_Stage/tuneavideo/models/PoseGuider.py b/2D_Stage/tuneavideo/models/PoseGuider.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0e1e2b3843fa89272788c8ef9ac45ae7ac3fc9 --- /dev/null +++ b/2D_Stage/tuneavideo/models/PoseGuider.py @@ -0,0 +1,59 @@ +import os +import torch +import torch.nn as nn +import torch.nn.init as init +from einops import rearrange + +class PoseGuider(nn.Module): + def __init__(self, noise_latent_channels=4): + super(PoseGuider, self).__init__() + + self.conv_layers = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), + nn.ReLU() + ) + + # Final projection layer + self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1) + + # Initialize layers + self._initialize_weights() + + def _initialize_weights(self): + # Initialize weights with Gaussian distribution and zero out the final layer + for m in self.conv_layers: + if isinstance(m, nn.Conv2d): + init.normal_(m.weight, mean=0.0, std=0.02) + if m.bias is not None: + init.zeros_(m.bias) + + init.zeros_(self.final_proj.weight) + if self.final_proj.bias is not None: + init.zeros_(self.final_proj.bias) + + def forward(self, pose_image): + x = self.conv_layers(pose_image) + x = self.final_proj(x) + + return x + + @classmethod + def from_pretrained(pretrained_model_path): + if not os.path.exists(pretrained_model_path): + print(f"There is no model file in {pretrained_model_path}") + print(f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ...") + + state_dict = torch.load(pretrained_model_path, map_location="cpu") + model = PoseGuider(noise_latent_channels=4) + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()] + print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M") + + return model diff --git a/2D_Stage/tuneavideo/models/attention.py b/2D_Stage/tuneavideo/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e4be16c609d63b2d2f9541e124c03cb572dfc736 --- /dev/null +++ b/2D_Stage/tuneavideo/models/attention.py @@ -0,0 +1,344 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers import ModelMixin +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm + +from einops import rearrange, repeat + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_attn_temp: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_attn_temp = use_attn_temp, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_attn_temp: bool = False + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.use_attn_temp = use_attn_temp + # SC-Attn + self.attn1 = SparseCausalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + # Temp-Attn + if self.use_attn_temp: + self.attn_temp = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + #self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + if self.use_attn_temp: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + +class SparseCausalAttention(CrossAttention): + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, use_full_attn=True): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + # query = rearrange(query, "(b f) d c -> b (f d) c", f=video_length) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + former_frame_index = torch.arange(video_length) - 1 + former_frame_index[0] = 0 + + key = rearrange(key, "(b f) d c -> b f d c", f=video_length) + if not use_full_attn: + key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) + else: + # key = torch.cat([key[:, [0] * video_length], key[:, [1] * video_length], key[:, [2] * video_length], key[:, [3] * video_length]], dim=2) + key_video_length = [key[:, [i] * video_length] for i in range(video_length)] + key = torch.cat(key_video_length, dim=2) + key = rearrange(key, "b f d c -> (b f) d c") + + value = rearrange(value, "(b f) d c -> b f d c", f=video_length) + if not use_full_attn: + value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) + else: + # value = torch.cat([value[:, [0] * video_length], value[:, [1] * video_length], value[:, [2] * video_length], value[:, [3] * video_length]], dim=2) + value_video_length = [value[:, [i] * video_length] for i in range(video_length)] + value = torch.cat(value_video_length, dim=2) + value = rearrange(value, "b f d c -> (b f) d c") + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states diff --git a/2D_Stage/tuneavideo/models/imageproj.py b/2D_Stage/tuneavideo/models/imageproj.py new file mode 100644 index 0000000000000000000000000000000000000000..63e20527154594ef7a207b81c6520af2b07b8e50 --- /dev/null +++ b/2D_Stage/tuneavideo/models/imageproj.py @@ -0,0 +1,118 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +import math + +import torch +import torch.nn as nn + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) \ No newline at end of file diff --git a/2D_Stage/tuneavideo/models/refunet.py b/2D_Stage/tuneavideo/models/refunet.py new file mode 100644 index 0000000000000000000000000000000000000000..8808e3243eddd47ddaab63f35b56edff4ebd75ee --- /dev/null +++ b/2D_Stage/tuneavideo/models/refunet.py @@ -0,0 +1,125 @@ +import torch +from einops import rearrange +from typing import Any, Dict, Optional +from diffusers.utils.import_utils import is_xformers_available +from tuneavideo.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor +class ReferenceOnlyAttnProc(torch.nn.Module): + def __init__( + self, + chained_proc, + enabled=False, + name=None + ) -> None: + super().__init__() + self.enabled = enabled + self.chained_proc = chained_proc + self.name = name + + def __call__( + self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, + mode="w", ref_dict: dict = None, is_cfg_guidance = False,num_views=4, + multiview_attention=True, + cross_domain_attention=False, + ) -> Any: + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + # print(self.enabled) + if self.enabled: + if mode == 'w': + ref_dict[self.name] = encoder_hidden_states + res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=1, + multiview_attention=False, + cross_domain_attention=False,) + elif mode == 'r': + encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views) + if self.name in ref_dict: + encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1) + res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views, + multiview_attention=False, + cross_domain_attention=False,) + elif mode == 'm': + encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1) + elif mode == 'n': + encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views) + encoder_hidden_states = torch.cat([encoder_hidden_states], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1) + res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views, + multiview_attention=False, + cross_domain_attention=False,) + else: + assert False, mode + else: + res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) + return res + +class RefOnlyNoisedUNet(torch.nn.Module): + def __init__(self, unet, train_sched, val_sched) -> None: + super().__init__() + self.unet = unet + self.train_sched = train_sched + self.val_sched = val_sched + + unet_lora_attn_procs = dict() + for name, _ in unet.attn_processors.items(): + if is_xformers_available(): + default_attn_proc = XFormersMVAttnProcessor() + else: + default_attn_proc = MVAttnProcessor() + unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( + default_attn_proc, enabled=name.endswith("attn1.processor"), name=name) + + self.unet.set_attn_processor(unet_lora_attn_procs) + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.unet, name) + + def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs): + if is_cfg_guidance: + encoder_hidden_states = encoder_hidden_states[1:] + class_labels = class_labels[1:] + self.unet( + noisy_cond_lat, timestep, + encoder_hidden_states=encoder_hidden_states, + class_labels=class_labels, + cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), + **kwargs + ) + + def forward( + self, sample, timestep, encoder_hidden_states, class_labels=None, + *args, cross_attention_kwargs, + down_block_res_samples=None, mid_block_res_sample=None, + **kwargs + ): + cond_lat = cross_attention_kwargs['cond_lat'] + is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False) + noise = torch.randn_like(cond_lat) + if self.training: + noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) + noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep) + else: + noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1)) + noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) + ref_dict = {} + self.forward_cond( + noisy_cond_lat, timestep, + encoder_hidden_states, class_labels, + ref_dict, is_cfg_guidance, **kwargs + ) + weight_dtype = self.unet.dtype + return self.unet( + sample, timestep, + encoder_hidden_states, *args, + class_labels=class_labels, + cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance), + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ] if down_block_res_samples is not None else None, + mid_block_additional_residual=( + mid_block_res_sample.to(dtype=weight_dtype) + if mid_block_res_sample is not None else None + ), + **kwargs + ) \ No newline at end of file diff --git a/2D_Stage/tuneavideo/models/resnet.py b/2D_Stage/tuneavideo/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9b8b0fc2737f3335ea4e313568c1890326900f01 --- /dev/null +++ b/2D_Stage/tuneavideo/models/resnet.py @@ -0,0 +1,210 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + # temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, :, None, None].permute(0,2,1,3,4) + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/2D_Stage/tuneavideo/models/transformer_mv2d.py b/2D_Stage/tuneavideo/models/transformer_mv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..bf9ca74289ce6f991ac151131d8b01c85c2fa510 --- /dev/null +++ b/2D_Stage/tuneavideo/models/transformer_mv2d.py @@ -0,0 +1,1010 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput, deprecate +try: + from diffusers.utils import maybe_allow_in_graph +except: + from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention +from diffusers.models.embeddings import PatchEmbed +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.import_utils import is_xformers_available + +from einops import rearrange +import pdb +import random + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +@dataclass +class TransformerMV2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class TransformerMV2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + num_views: int = 1, + joint_attention: bool=False, + joint_attention_twice: bool=False, + multiview_attention: bool=True, + cross_domain_attention: bool=False + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) + else: + self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicMVTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) + else: + self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return TransformerMV2DModelOutput(sample=output) + + +@maybe_allow_in_graph +class BasicMVTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + num_views: int = 1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool = False + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + self.multiview_attention = multiview_attention + self.cross_domain_attention = cross_domain_attention + # import pdb;pdb.set_trace() + self.attn1 = CustomAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=MVAttnProcessor() + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + self.num_views = num_views + + self.joint_attention = joint_attention + + if self.joint_attention: + # Joint task -Attn + self.attn_joint = CustomJointAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=JointAttnProcessor() + ) + nn.init.zeros_(self.attn_joint.to_out[0].weight.data) + self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + + self.joint_attention_twice = joint_attention_twice + + if self.joint_attention_twice: + print("joint twice") + # Joint task -Attn + self.attn_joint_twice = CustomJointAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + processor=JointAttnProcessor() + ) + nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data) + self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + assert attention_mask is None # not supported yet + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + num_views=self.num_views, + multiview_attention=self.multiview_attention, + cross_domain_attention=self.cross_domain_attention, + **cross_attention_kwargs, + ) + + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # joint attention twice + if self.joint_attention_twice: + norm_hidden_states = ( + self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states) + ) + hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + if self.joint_attention: + norm_hidden_states = ( + self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states) + ) + hidden_states = self.attn_joint(norm_hidden_states) + hidden_states + + return hidden_states + + +class CustomAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersMVAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + + +class CustomJointAttention(Attention): + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, *args, **kwargs + ): + processor = XFormersJointAttnProcessor() + self.set_processor(processor) + # print("using xformers attention processor") + +class MVAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1, + multiview_attention=True + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # print('query', query.shape, 'key', key.shape, 'value', value.shape) + #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) + # pdb.set_trace() + # multi-view self-attention + if multiview_attention: + if num_views <= 6: + # after use xformer; possible to train with 6 views + # key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) + # value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) + key = rearrange(key, '(b t) d c-> b (t d) c', t=num_views).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1) + value = rearrange(value, '(b t) d c-> b (t d) c', t=num_views).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1) + + else:# apply sparse attention + pass + # print("use sparse attention") + # # seems that the sparse random sampling cause problems + # # don't use random sampling, just fix the indexes + # onekey = rearrange(key, "(b t) d c -> b t d c", t=num_views) + # onevalue = rearrange(value, "(b t) d c -> b t d c", t=num_views) + # allkeys = [] + # allvalues = [] + # all_indexes = { + # 0 : [0, 2, 3, 4], + # 1: [0, 1, 3, 5], + # 2: [0, 2, 3, 4], + # 3: [0, 2, 3, 4], + # 4: [0, 2, 3, 4], + # 5: [0, 1, 3, 5] + # } + # for jj in range(num_views): + # # valid_index = [x for x in range(0, num_views) if x!= jj] + # # indexes = random.sample(valid_index, 3) + [jj] + [0] + # indexes = all_indexes[jj] + + # indexes = torch.tensor(indexes).long().to(key.device) + # allkeys.append(onekey[:, indexes]) + # allvalues.append(onevalue[:, indexes]) + # keys = torch.stack(allkeys, dim=1) # checked, should be dim=1 + # values = torch.stack(allvalues, dim=1) + # key = rearrange(keys, 'b t f d c -> (b t) (f d) c') + # value = rearrange(values, 'b t f d c -> (b t) (f d) c') + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class XFormersMVAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_views=1., + multiview_attention=True, + cross_domain_attention=False, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key_raw = attn.to_k(encoder_hidden_states) + value_raw = attn.to_v(encoder_hidden_states) + + # print('query', query.shape, 'key', key.shape, 'value', value.shape) + #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) + # pdb.set_trace() + # multi-view self-attention + if multiview_attention: + key = rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) + value = rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0) + + if cross_domain_attention: + # memory efficient, cross domain attention + key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2) + key_cross = torch.concat([key_1, key_0], dim=0) + value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c + key = torch.cat([key, key_cross], dim=1) + value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c + else: + # print("don't use multiview attention.") + key = key_raw + value = value_raw + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + + +class XFormersJointAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2 + ): + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + # from yuancheng; here attention_mask is None + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value, dim=0, chunks=2) + key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c + value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c + key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c + value = torch.cat([value]*2, dim=0) # (2 b t) 2d c + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class JointAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + num_tasks=2 + ): + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + assert num_tasks == 2 # only support two tasks now + + key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c + value_0, value_1 = torch.chunk(value, dim=0, chunks=2) + key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c + value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c + key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c + value = torch.cat([value]*2, dim=0) # (2 b t) 2d c + + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file diff --git a/2D_Stage/tuneavideo/models/unet.py b/2D_Stage/tuneavideo/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..fa156afe8f0f179cd5d4b2aef569073eafeb690c --- /dev/null +++ b/2D_Stage/tuneavideo/models/unet.py @@ -0,0 +1,497 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import json + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers import ModelMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) +from .resnet import InflatedConv3d + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_attn_temp: bool = False, + camera_input_dim: int = 12, + camera_hidden_dim: int = 320, + camera_output_dim: int = 1280, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + # camera metrix + # def init_linear(l, stddev): + # nn.init.normal_(l.weight, std=stddev) + # if l.bias is not None: + # nn.init.constant_(l.bias, 0.0) + # self.camera_embedding_1 = nn.Linear(camera_input_dim, camera_hidden_dim) + # self.camera_embedding_2 = nn.Linear(camera_hidden_dim, camera_output_dim) + # init_linear(self.camera_embedding_1, 0.25) + # init_linear(self.camera_embedding_2, 0.25) + + self.camera_embedding = nn.Sequential( + nn.Linear(camera_input_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_attn_temp=use_attn_temp + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_attn_temp=use_attn_temp, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + camera_matrixs: Optional[torch.Tensor] = None, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) #torch.Size([32, 1280]) + emb = torch.unsqueeze(emb, 1) + if camera_matrixs is not None: + # came emb + cam_emb = self.camera_embedding(camera_matrixs) + # cam_emb = self.camera_embedding_2(cam_emb) + emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280]) + emb = emb + cam_emb + + # import pdb;pdb.set_trace() + if self.class_embedding is not None: + # if class_labels is None: + # raise ValueError("class_labels should be provided when num_class_embeds > 0") + if class_labels is not None: + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels) + emb = emb + class_emb + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + from diffusers.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME + # model = cls.from_config(config) + # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + # if not os.path.isfile(model_file): + # raise RuntimeError(f"{model_file} does not exist") + # state_dict = torch.load(model_file, map_location="cpu") + + import safetensors + model = cls.from_config(config) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + else: + state_dict = safetensors.torch.load_file(model_file, device="cpu") + else: + state_dict = torch.load(model_file, map_location="cpu") + + for k, v in model.state_dict().items(): + if '_temp.' in k or 'camera_embedding' in k or 'class_embedding' in k: + state_dict.update({k: v}) + for k in list(state_dict.keys()): + if 'camera_embedding_' in k: + v = state_dict.pop(k) + model.load_state_dict(state_dict) + + return model \ No newline at end of file diff --git a/2D_Stage/tuneavideo/models/unet_blocks.py b/2D_Stage/tuneavideo/models/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..ac9b78d42fbd35261c23775e54cfa88f6507925d --- /dev/null +++ b/2D_Stage/tuneavideo/models/unet_blocks.py @@ -0,0 +1,596 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py + +import torch +from torch import nn + +# from .attention import Transformer3DModel +from .resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_attn_temp=False, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_attn_temp=use_attn_temp, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_attn_temp=False, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_attn_temp=use_attn_temp, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_attn_temp=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_attn_temp=use_attn_temp, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_attn_temp=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_attn_temp=use_attn_temp, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/2D_Stage/tuneavideo/models/unet_mv2d_blocks.py b/2D_Stage/tuneavideo/models/unet_mv2d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..009c774337feb5f5d94a7958174400a150e351f3 --- /dev/null +++ b/2D_Stage/tuneavideo/models/unet_mv2d_blocks.py @@ -0,0 +1,926 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import is_torch_version, logging +# from diffusers.models.attention import AdaGroupNorm +from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from diffusers.models.dual_transformer_2d import DualTransformer2DModel +from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D +from tuneavideo.models.transformer_mv2d import TransformerMV2DModel + +from diffusers.models.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D +from diffusers.models.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + downsample_type=None, + num_views=1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool=False +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + # custom MV2D attention block + elif down_block_type == "CrossAttnDownBlockMV2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D") + return CrossAttnDownBlockMV2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + upsample_type=None, + num_views=1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool=False +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + # custom MV2D attention block + elif up_block_type == "CrossAttnUpBlockMV2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D") + return CrossAttnUpBlockMV2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlockMV2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + num_views: int = 1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool=False + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + ) + else: + raise NotImplementedError + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlockMV2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + num_views: int = 1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool=False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + ) + else: + raise NotImplementedError + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + if num_views == 4: + self.gradient_checkpointing = False + else: + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + # hidden_states = attn( + # hidden_states, + # encoder_hidden_states=encoder_hidden_states, + # cross_attention_kwargs=cross_attention_kwargs, + # attention_mask=attention_mask, + # encoder_attention_mask=encoder_attention_mask, + # return_dict=False, + # )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class CrossAttnDownBlockMV2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + num_views: int = 1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool=False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + TransformerMV2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + ) + else: + raise NotImplementedError + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + if num_views == 4: + self.gradient_checkpointing = False + else: + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals=None, + ): + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + # import ipdb + # ipdb.set_trace() + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + diff --git a/2D_Stage/tuneavideo/models/unet_mv2d_condition.py b/2D_Stage/tuneavideo/models/unet_mv2d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..c719b84c6b8cc1cdbcf81d449511e05eede20d92 --- /dev/null +++ b/2D_Stage/tuneavideo/models/unet_mv2d_condition.py @@ -0,0 +1,1509 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +import os + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from einops import rearrange + + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +from diffusers.models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model +from diffusers.models.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, +) +from diffusers.utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + FLAX_WEIGHTS_NAME, + HF_HUB_OFFLINE, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + _add_variant, + _get_model_file, + deprecate, + is_accelerate_available, + is_torch_version, + logging, +) +from diffusers import __version__ +from tuneavideo.models.unet_mv2d_blocks import ( + CrossAttnDownBlockMV2D, + CrossAttnUpBlockMV2D, + UNetMidBlockMV2DCrossAttn, + get_down_block, + get_up_block, +) +from diffusers.models.attention_processor import Attention, AttnProcessor +from diffusers.utils.import_utils import is_xformers_available +from tuneavideo.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor +from tuneavideo.models.refunet import ReferenceOnlyAttnProc + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetMV2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + +class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + num_views: int = 1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool = False, + camera_input_dim: int = 12, + camera_hidden_dim: int = 320, + camera_output_dim: int = 1280, + + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.camera_embedding = nn.Sequential( + nn.Linear(camera_input_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + # custom MV2D attention block + elif mid_block_type == "UNetMidBlockMV2DCrossAttn": + self.mid_block = UNetMidBlockMV2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + # def _set_gradient_checkpointing(self, module, value=False): + # if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)): + # module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + camera_matrixs: Optional[torch.Tensor] = None, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNetMV2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # import pdb; pdb.set_trace() + if camera_matrixs is not None: + emb = torch.unsqueeze(emb, 1) + # came emb + cam_emb = self.camera_embedding(camera_matrixs) + # cam_emb = self.camera_embedding_2(cam_emb) + # import ipdb + # ipdb.set_trace() + emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280]) + emb = emb + cam_emb + emb = rearrange(emb, "b f c -> (b f) c", f=emb.shape[1]) + + aug_emb = None + + if self.class_embedding is not None and class_labels is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = rearrange(sample, "b c f h w -> (b f) c h w", f=sample.shape[2]) + sample = self.conv_in(sample) + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + # print("after down: ", sample.mean(), emb.mean()) + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNetMV2DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_2d( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + camera_embedding_type: str, num_views: int, sample_size: int, + zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False, + projection_class_embeddings_input_dim: int=6, joint_attention: bool = False, + joint_attention_twice: bool = False, multiview_attention: bool = True, + cross_domain_attention: bool = False, + in_channels: int = 8, out_channels: int = 4, local_crossattn=False, + **kwargs + ): + r""" + Instantiate a pretrained PyTorch model from a pretrained model configuration. + + The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To + train the model, set it back in training mode with `model.train()`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`~ModelMixin.save_pretrained`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info (`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if `device_map` contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + variant (`str`, *optional*): + Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` + weights. If set to `False`, `safetensors` weights are not loaded. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + Example: + + ```py + from diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + # if use_safetensors and not is_safetensors_available(): + # raise ValueError( + # "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" + # ) + + allow_pickle = False + if use_safetensors is None: + # use_safetensors = is_safetensors_available() + use_safetensors = False + allow_pickle = True + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + user_agent=user_agent, + **kwargs, + ) + + # modify config + config["_class_name"] = cls.__name__ + config['in_channels'] = in_channels + config['out_channels'] = out_channels + config['sample_size'] = sample_size # training resolution + config['num_views'] = num_views + config['joint_attention'] = joint_attention + config['joint_attention_twice'] = joint_attention_twice + config['multiview_attention'] = multiview_attention + config['cross_domain_attention'] = cross_domain_attention + config["down_block_types"] = [ + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D" + ] + config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn" + config["up_block_types"] = [ + "UpBlock2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D" + ] + config['class_embed_type'] = 'projection' + if camera_embedding_type == 'e_de_da_sincos': + config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6 + else: + raise NotImplementedError + + # load model + model_file = None + if from_flax: + raise NotImplementedError + else: + if use_safetensors: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except IOError as e: + if not allow_pickle: + raise e + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + model = cls.from_config(config, **unused_kwargs) + if local_crossattn: + unet_lora_attn_procs = dict() + for name, _ in model.attn_processors.items(): + if not name.endswith("attn1.processor"): + default_attn_proc = AttnProcessor() + elif is_xformers_available(): + default_attn_proc = XFormersMVAttnProcessor() + else: + default_attn_proc = MVAttnProcessor() + unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( + default_attn_proc, enabled=name.endswith("attn1.processor"), name=name + ) + model.set_attn_processor(unet_lora_attn_procs) + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + conv_in_weight = state_dict['conv_in.weight'] + conv_out_weight = state_dict['conv_out.weight'] + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=True, + ) + if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_in.weight.data[:,:4] = conv_in_weight + + # whether to place all zero to new layers? + if zero_init_conv_in: + model.conv_in.weight.data[:,4:] = 0. + + if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_out.weight.data[:,:4] = conv_out_weight + if out_channels == 8: # copy for the last 4 channels + model.conv_out.weight.data[:, 4:] = conv_out_weight + + if zero_init_camera_projection: + for p in model.class_embedding.parameters(): + torch.nn.init.zeros_(p) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model_2d( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = list(state_dict.keys()) + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + diff --git a/2D_Stage/tuneavideo/models/unet_mv2d_ref.py b/2D_Stage/tuneavideo/models/unet_mv2d_ref.py new file mode 100644 index 0000000000000000000000000000000000000000..add370667c096523d47d77c3313e86ab2ab730e7 --- /dev/null +++ b/2D_Stage/tuneavideo/models/unet_mv2d_ref.py @@ -0,0 +1,1570 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +import os + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from einops import rearrange + + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +from diffusers.models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.lora import LoRALinearLayer + +from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model +from diffusers.models.unet_2d_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + UpBlock2D, +) +from diffusers.utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + FLAX_WEIGHTS_NAME, + HF_HUB_OFFLINE, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + _add_variant, + _get_model_file, + deprecate, + is_accelerate_available, + is_torch_version, + logging, +) +from diffusers import __version__ +from tuneavideo.models.unet_mv2d_blocks import ( + CrossAttnDownBlockMV2D, + CrossAttnUpBlockMV2D, + UNetMidBlockMV2DCrossAttn, + get_down_block, + get_up_block, +) +from diffusers.models.attention_processor import Attention, AttnProcessor +from diffusers.utils.import_utils import is_xformers_available +from tuneavideo.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor +from tuneavideo.models.refunet import ReferenceOnlyAttnProc + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetMV2DRefOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + +class Identity(torch.nn.Module): + r"""A placeholder identity operator that is argument-insensitive. + + Args: + args: any argument (unused) + kwargs: any keyword argument (unused) + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + Examples:: + + >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 20]) + + """ + def __init__(self, scale=None, *args, **kwargs) -> None: + super(Identity, self).__init__() + + def forward(self, input, *args, **kwargs): + return input + + + +class _LoRACompatibleLinear(nn.Module): + """ + A Linear layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + + def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): + self.lora_layer = lora_layer + + def _fuse_lora(self): + pass + + def _unfuse_lora(self): + pass + + def forward(self, hidden_states, scale=None, lora_scale: int = 1): + return hidden_states + +class UNetMV2DRefModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + num_views: int = 1, + joint_attention: bool = False, + joint_attention_twice: bool = False, + multiview_attention: bool = True, + cross_domain_attention: bool = False, + camera_input_dim: int = 12, + camera_hidden_dim: int = 320, + camera_output_dim: int = 1280, + + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.camera_embedding = nn.Sequential( + nn.Linear(camera_input_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + # custom MV2D attention block + elif mid_block_type == "UNetMidBlockMV2DCrossAttn": + self.mid_block = UNetMidBlockMV2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + num_views=num_views, + joint_attention=joint_attention, + joint_attention_twice=joint_attention_twice, + multiview_attention=multiview_attention, + cross_domain_attention=cross_domain_attention + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + # if norm_num_groups is not None: + # self.conv_norm_out = nn.GroupNorm( + # num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + # ) + + # self.conv_act = get_activation(act_fn) + + # else: + # self.conv_norm_out = None + # self.conv_act = None + + # conv_out_padding = (conv_out_kernel - 1) // 2 + # self.conv_out = nn.Conv2d( + # block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + # ) + + self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear() + self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear() + self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear() + self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()]) + self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity() + self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None + self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity() + self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity() + self.up_blocks[3].attentions[2].proj_out = Identity() + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + camera_matrixs: Optional[torch.Tensor] = None, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNetMV2DRefOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # import pdb; pdb.set_trace() + if camera_matrixs is not None: + emb = torch.unsqueeze(emb, 1) + # came emb + cam_emb = self.camera_embedding(camera_matrixs) + # cam_emb = self.camera_embedding_2(cam_emb) + emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280]) + emb = emb + cam_emb + emb = rearrange(emb, "b f c -> (b f) c", f=emb.shape[1]) + + aug_emb = None + + if self.class_embedding is not None and class_labels is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = rearrange(sample, "b c f h w -> (b f) c h w", f=sample.shape[2]) + sample = self.conv_in(sample) + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + # print("after down: ", sample.mean(), emb.mean()) + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # print("after mid: ", sample.mean()) + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + # if self.conv_norm_out: + # sample = self.conv_norm_out(sample) + # sample = self.conv_act(sample) + # sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNetMV2DRefOutput(sample=sample) + + @classmethod + def from_pretrained_2d( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + camera_embedding_type: str, num_views: int, sample_size: int, + zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False, + projection_class_embeddings_input_dim: int=6, joint_attention: bool = False, + joint_attention_twice: bool = False, multiview_attention: bool = True, + cross_domain_attention: bool = False, + in_channels: int = 8, out_channels: int = 4, local_crossattn=False, + **kwargs + ): + r""" + Instantiate a pretrained PyTorch model from a pretrained model configuration. + + The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To + train the model, set it back in training mode with `model.train()`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`~ModelMixin.save_pretrained`]. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info (`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + mirror (`str`, *optional*): + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the + same device. + + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + The path to offload weights if `device_map` contains the value `"disk"`. + offload_state_dict (`bool`, *optional*): + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + variant (`str`, *optional*): + Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` + weights. If set to `False`, `safetensors` weights are not loaded. + + + + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. + + + + Example: + + ```py + from diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + # if use_safetensors and not is_safetensors_available(): + # raise ValueError( + # "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" + # ) + + allow_pickle = False + if use_safetensors is None: + # use_safetensors = is_safetensors_available() + use_safetensors = False + allow_pickle = True + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + user_agent=user_agent, + **kwargs, + ) + + # modify config + config["_class_name"] = cls.__name__ + config['in_channels'] = in_channels + config['out_channels'] = out_channels + config['sample_size'] = sample_size # training resolution + config['num_views'] = num_views + config['joint_attention'] = joint_attention + config['joint_attention_twice'] = joint_attention_twice + config['multiview_attention'] = multiview_attention + config['cross_domain_attention'] = cross_domain_attention + config["down_block_types"] = [ + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "CrossAttnDownBlockMV2D", + "DownBlock2D" + ] + config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn" + config["up_block_types"] = [ + "UpBlock2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D", + "CrossAttnUpBlockMV2D" + ] + config['class_embed_type'] = 'projection' + if camera_embedding_type == 'e_de_da_sincos': + config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6 + else: + raise NotImplementedError + + # load model + model_file = None + if from_flax: + raise NotImplementedError + else: + if use_safetensors: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except IOError as e: + if not allow_pickle: + raise e + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + model = cls.from_config(config, **unused_kwargs) + if local_crossattn: + unet_lora_attn_procs = dict() + for name, _ in model.attn_processors.items(): + if not name.endswith("attn1.processor"): + default_attn_proc = AttnProcessor() + elif is_xformers_available(): + default_attn_proc = XFormersMVAttnProcessor() + else: + default_attn_proc = MVAttnProcessor() + unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( + default_attn_proc, enabled=name.endswith("attn1.processor"), name=name + ) + model.set_attn_processor(unet_lora_attn_procs) + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + conv_in_weight = state_dict['conv_in.weight'] + conv_out_weight = state_dict['conv_out.weight'] + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=True, + ) + if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_in.weight.data[:,:4] = conv_in_weight + + # whether to place all zero to new layers? + if zero_init_conv_in: + model.conv_in.weight.data[:,4:] = 0. + + if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]): + # initialize from the original SD structure + model.conv_out.weight.data[:,:4] = conv_out_weight + if out_channels == 8: # copy for the last 4 channels + model.conv_out.weight.data[:, 4:] = conv_out_weight + + if zero_init_camera_projection: + for p in model.class_embedding.parameters(): + torch.nn.init.zeros_(p) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model_2d( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = list(state_dict.keys()) + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" + " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" + " identical (initializing a BertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" + f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" + " without further training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" + " able to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + diff --git a/2D_Stage/tuneavideo/pipelines/pipeline_tuneavideo.py b/2D_Stage/tuneavideo/pipelines/pipeline_tuneavideo.py new file mode 100644 index 0000000000000000000000000000000000000000..59687723d4190ba283616e68c6e799b4d2bd6225 --- /dev/null +++ b/2D_Stage/tuneavideo/pipelines/pipeline_tuneavideo.py @@ -0,0 +1,585 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py + +import tqdm + +import inspect +from typing import Callable, List, Optional, Union +from dataclasses import dataclass + +import numpy as np +import torch + +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer +import torchvision.transforms.functional as TF + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL +from diffusers import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import deprecate, logging, BaseOutput + +from einops import rearrange + +from ..models.unet import UNet3DConditionModel +from torchvision.transforms import InterpolationMode + +import ipdb + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class TuneAVideoPipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + + +class TuneAVideoPipeline(DiffusionPipeline): + _optional_components = [] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ref_unet = None, + feature_extractor=None, + image_encoder=None + ): + super().__init__() + self.ref_unet = ref_unet + self.feature_extractor = feature_extractor + self.image_encoder = image_encoder + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def enable_vae_slicing(self): + self.vae.enable_slicing() + + def disable_vae_slicing(self): + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + + @property + def _execution_device(self): + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_image(self, image_pil, device, num_images_per_prompt, do_classifier_free_guidance, img_proj=None): + dtype = next(self.image_encoder.parameters()).dtype + + # image_pt = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values + # image_pt = image_pt.to(device=device, dtype=dtype) + # image_embeddings = self.image_encoder(image_pt).image_embeds + # image_embeddings = image_embeddings.unsqueeze(1) + + # # image encoding + clip_image_mean = torch.as_tensor(self.feature_extractor.image_mean)[:,None,None].to(device, dtype=torch.float32) + clip_image_std = torch.as_tensor(self.feature_extractor.image_std)[:,None,None].to(device, dtype=torch.float32) + imgs_in_proc = TF.resize(image_pil, (self.feature_extractor.crop_size['height'], self.feature_extractor.crop_size['width']), interpolation=InterpolationMode.BICUBIC) + # do the normalization in float32 to preserve precision + imgs_in_proc = ((imgs_in_proc.float() - clip_image_mean) / clip_image_std).to(dtype) + if img_proj is None: + # (B*Nv, 1, 768) + image_embeddings = self.image_encoder(imgs_in_proc).image_embeds.unsqueeze(1) + # duplicate image embeddings for each generation per prompt, using mps friendly method + # Note: repeat differently from official pipelines + # B1B2B3B4 -> B1B2B3B4B1B2B3B4 + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(num_images_per_prompt, 1, 1) + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + else: + if do_classifier_free_guidance: + negative_image_proc = torch.zeros_like(imgs_in_proc) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + imgs_in_proc = torch.cat([negative_image_proc, imgs_in_proc]) + + image_embeds = image_encoder(imgs_in_proc, output_hidden_states=True).hidden_states[-2] + image_embeddings = img_proj(image_embeds) + + # image_embeddings_unet = rearrange(image_embeddings_unet, 'B Nv d c -> (B Nv) d c') + + # image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device) + # image_pil = image_pil * 2.0 - 1.0 + image_latents = self.vae.encode(image_pil* 2.0 - 1.0).latent_dist.mode() * self.vae.config.scaling_factor + + # Note: repeat differently from official pipelines + # B1B2B3B4 -> B1B2B3B4B1B2B3B4 + image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1) + + # if do_classifier_free_guidance: + # image_latents = torch.cat([torch.zeros_like(image_latents), image_latents]) + + return image_embeddings, image_latents + + def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def decode_latents(self, latents): + video_length = latents.shape[2] + latents = 1 / 0.18215 * latents + latents = rearrange(latents, "b c f h w -> (b f) c h w") + video = self.vae.decode(latents).sample + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + rand_device = "cpu" if device.type == "mps" else device + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, dtype=dtype).to(device) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[str, List[str]], + video_length: Optional[int], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + camera_matrixs = None, + class_labels = None, + prompt_ids = None, + unet_condition_type = None, + pose_guider = None, + pose_image = None, + img_proj=None, + use_noise=True, + use_shifted_noise=False, + rescale = 0.7, + **kwargs, + ): + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + if isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + # assert batch_size >= video_length and batch_size % video_length == 0 + # Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input image + # if isinstance(image, list): + # image_pil = image + # elif isinstance(image, torch.Tensor): + # image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])] + # encode input reference image + image_embeddings, image_latents = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance, img_proj=img_proj) #torch.Size([64, 1, 768]) torch.Size([64, 4, 32, 32]) + image_latents = rearrange(image_latents, "(b f) c h w -> b c f h w", f=1) #torch.Size([64, 4, 1, 32, 32]) + + # Encode input prompt_id + # encoder_hidden_states = self.text_encoder(prompt_ids)[0] #torch.Size([32, 77, 768]) + + # Encode input prompt + text_embeddings = self._encode_prompt( #torch.Size([64, 77, 768]) + prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( #torch.Size([32, 4, 4, 32, 32]) + batch_size * num_videos_per_prompt, + num_channels_latents, + video_length, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + latents_dtype = latents.dtype + # import ipdb + # ipdb.set_trace() + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # prepare camera_matrix + if camera_matrixs is not None: + camera_matrixs = torch.cat([camera_matrixs] * 2) if do_classifier_free_guidance else camera_matrixs #(64, 4, 12) + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + if pose_guider is not None: + if len(pose_image.shape) == 5: + pose_embeds = pose_guider(rearrange(pose_image, "b f c h w -> (b f) c h w")) + pose_embeds = rearrange(pose_embeds, "(b f) c h w-> b c f h w ", f=video_length) + else: + pose_embeds = pose_guider(pose_image).unsqueeze(0) + pose_embeds = torch.cat([pose_embeds]*2, dim=0) + # import ipdb + # ipdb.set_trace() + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(tqdm.tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + if pose_guider is not None: + latent_model_input = latent_model_input + pose_embeds + + noise_cond = torch.randn_like(image_latents) + if use_noise: + cond_latents = self.scheduler.add_noise(image_latents, noise_cond, t) + else: + cond_latents = image_latents + cond_latent_model_input = torch.cat([cond_latents] * 2) if do_classifier_free_guidance else cond_latents + cond_latent_model_input = self.scheduler.scale_model_input(cond_latent_model_input, t) + + # predict the noise residual + # ref text condition + ref_dict = {} + if self.ref_unet is not None: + noise_pred_cond = self.ref_unet( + cond_latent_model_input, #torch.Size([64, 4, 1, 32, 32]) + t, #torch.Size([32]) + encoder_hidden_states=text_embeddings.to(torch.float32), #torch.Size([64, 77, 768]) + cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict) + ).sample.to(dtype=latents_dtype) + + # if torch.isnan(noise_pred_cond).any(): + # ipdb.set_trace() + # Predict the noise residual and compute loss + # model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, camera_matrixs).sample + # unet + #text condition for unet + text_embeddings_unet = text_embeddings.unsqueeze(1).repeat(1,latents.shape[2],1,1) + text_embeddings_unet = rearrange(text_embeddings_unet, 'B Nv d c -> (B Nv) d c') + #image condition for unet + image_embeddings_unet = image_embeddings.unsqueeze(1).repeat(1,latents.shape[2],1, 1) + image_embeddings_unet = rearrange(image_embeddings_unet, 'B Nv d c -> (B Nv) d c') + + if unet_condition_type == 'text': + encoder_hidden_states_unet_cond = text_embeddings_unet + elif unet_condition_type == 'image': + encoder_hidden_states_unet_cond = image_embeddings_unet + else: + raise('need unet_condition_type') + + if self.ref_unet is not None: + noise_pred = self.unet( + latent_model_input.to(torch.float32), #torch.Size([64, 4, 4, 32, 32]) + t, + encoder_hidden_states=encoder_hidden_states_unet_cond.to(torch.float32), + camera_matrixs=camera_matrixs.to(torch.float32), #torch.Size([64, 4, 12]) + cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance) + # cross_attention_kwargs=dict(mode="n", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance) + ).sample.to(dtype=latents_dtype) + else: + noise_pred = self.unet( + latent_model_input.to(torch.float32), #torch.Size([64, 4, 4, 32, 32]) + t, + encoder_hidden_states=encoder_hidden_states_unet_cond.to(torch.float32), + camera_matrixs=camera_matrixs.to(torch.float32), #torch.Size([64, 4, 12]) + # cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance) + cross_attention_kwargs=dict(mode="n", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance) + ).sample.to(dtype=latents_dtype) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + if use_shifted_noise: + # Apply regular classifier-free guidance. + cfg = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # Calculate standard deviations. + std_pos = noise_pred_text.std([1,2,3], keepdim=True) + std_cfg = cfg.std([1,2,3], keepdim=True) + # Apply guidance rescale with fused operations. + factor = std_pos / std_cfg + factor = rescale * factor + (1 - rescale) + noise_pred = cfg * factor + else: + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # noise_pred_uncond_, noise_pred_text_ = noise_pred_cond.chunk(2) + # noise_pred_cond = noise_pred_uncond_ + guidance_scale * (noise_pred_text_ - noise_pred_uncond_) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred = rearrange(noise_pred, "(b f) c h w -> b c f h w", f=video_length) + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # noise_pred_cond = rearrange(noise_pred_cond, "(b f) c h w -> b c f h w", f=1) + # cond_latents = self.scheduler.step(noise_pred_cond, t, cond_latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # Post-processing + video = self.decode_latents(latents) + + # Convert to tensor + if output_type == "tensor": + video = torch.from_numpy(video) + + if not return_dict: + return video + + return TuneAVideoPipelineOutput(videos=video) diff --git a/2D_Stage/tuneavideo/util.py b/2D_Stage/tuneavideo/util.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4daa393169f319084632b4a1f172d8ba981bfc --- /dev/null +++ b/2D_Stage/tuneavideo/util.py @@ -0,0 +1,128 @@ +import os +import imageio +import numpy as np +from typing import Union +import cv2 +import torch +import torchvision + +from tqdm import tqdm +from einops import rearrange + +def shifted_noise(betas, image_d=512, noise_d=256, shifted_noise=True): + alphas = 1 - betas + alphas_bar = torch.cumprod(alphas, dim=0) + d = (image_d / noise_d) ** 2 + if shifted_noise: + alphas_bar = alphas_bar / (d - (d - 1) * alphas_bar) + alphas_bar_sqrt = torch.sqrt(alphas_bar) + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # Shift so last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + # Scale so first timestep is back to old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / ( + alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt ** 2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + return betas + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.mimsave(path, outputs, duration=1000/fps) + +def save_imgs_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): + videos = rearrange(videos, "b c t h w -> t b c h w") + for i, x in enumerate(videos): + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + os.makedirs(os.path.dirname(path), exist_ok=True) + cv2.imwrite(os.path.join(path, f'view_{i}.png'), x[:,:,::-1]) + +def imgs_grid(videos: torch.Tensor, rescale=False, n_rows=4, fps=8): + videos = rearrange(videos, "b c t h w -> t b c h w") + image_list = [] + for i, x in enumerate(videos): + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + # image_list.append(x[:,:,::-1]) + image_list.append(x) + return image_list + +# DDIM Inversion +@torch.no_grad() +def init_prompt(prompt, pipeline): + uncond_input = pipeline.tokenizer( + [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, + return_tensors="pt" + ) + uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] + text_input = pipeline.tokenizer( + [prompt], + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] + context = torch.cat([uncond_embeddings, text_embeddings]) + + return context + + +def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): + timestep, next_timestep = min( + timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep + alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod + alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction + return next_sample + + +def get_noise_pred_single(latents, t, context, unet): + noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] + return noise_pred + + +@torch.no_grad() +def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): + context = init_prompt(prompt, pipeline) + uncond_embeddings, cond_embeddings = context.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + for i in tqdm(range(num_inv_steps)): + t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] + noise_pred = get_noise_pred_single(latent.to(torch.float32), t, cond_embeddings.to(torch.float32), pipeline.unet) + latent = next_step(noise_pred, t, latent, ddim_scheduler) + all_latent.append(latent) + return all_latent + + +@torch.no_grad() +def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): + ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) + return ddim_latents diff --git a/2D_Stage/webui.py b/2D_Stage/webui.py new file mode 100644 index 0000000000000000000000000000000000000000..14e37c9a47a8c5a906407725a3a2f58f309bc3e7 --- /dev/null +++ b/2D_Stage/webui.py @@ -0,0 +1,323 @@ +import gradio as gr +from PIL import Image +import glob + +import io +import argparse +import inspect +import os +import random +from typing import Dict, Optional, Tuple +from omegaconf import OmegaConf +import numpy as np + +import torch +import torch.utils.checkpoint + +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.utils import check_min_version +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection +from torchvision import transforms + +from tuneavideo.models.unet_mv2d_condition import UNetMV2DConditionModel +from tuneavideo.models.unet_mv2d_ref import UNetMV2DRefModel +from tuneavideo.models.PoseGuider import PoseGuider +from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline +from tuneavideo.util import shifted_noise +from einops import rearrange +import PIL +from PIL import Image +from torchvision.utils import save_image +import json +import cv2 + +import onnxruntime as rt +from huggingface_hub.file_download import hf_hub_download +from rm_anime_bg.cli import get_mask, SCALE + +from huggingface_hub import hf_hub_download, list_repo_files + +repo_id = "zjpshadow/CharacterGen" +all_files = list_repo_files(repo_id, revision="main") + +for file in all_files: + if os.path.exists("../" + file): + continue + if file.startswith("2D_Stage"): + hf_hub_download(repo_id, file, local_dir="../") + +class rm_bg_api: + + def __init__(self, force_cpu: Optional[bool] = True): + session_infer_path = hf_hub_download( + repo_id="skytnt/anime-seg", filename="isnetis.onnx", + ) + providers: list[str] = ["CPUExecutionProvider"] + if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers(): + providers = ["CUDAExecutionProvider"] + + self.session_infer = rt.InferenceSession( + session_infer_path, providers=providers, + ) + + def remove_background( + self, + imgs: list[np.ndarray], + alpha_min: float, + alpha_max: float, + ) -> list: + process_imgs = [] + for img in imgs: + # CHANGE to RGB + img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) + mask = get_mask(self.session_infer, img) + + mask[mask < alpha_min] = 0.0 # type: ignore + mask[mask > alpha_max] = 1.0 # type: ignore + + img_after = (mask * img + SCALE * (1 - mask)).astype(np.uint8) # type: ignore + mask = (mask * SCALE).astype(np.uint8) # type: ignore + img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8) + mask = mask.repeat(3, axis=2) + process_imgs.append(Image.fromarray(img_after)) + return process_imgs + +check_min_version("0.24.0") + +logger = get_logger(__name__, log_level="INFO") + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def get_bg_color(bg_color): + if bg_color == 'white': + bg_color = np.array([1., 1., 1.], dtype=np.float32) + elif bg_color == 'black': + bg_color = np.array([0., 0., 0.], dtype=np.float32) + elif bg_color == 'gray': + bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) + elif bg_color == 'random': + bg_color = np.random.rand(3) + elif isinstance(bg_color, float): + bg_color = np.array([bg_color] * 3, dtype=np.float32) + else: + raise NotImplementedError + return bg_color + +def process_image(image, totensor): + if not image.mode == "RGBA": + image = image.convert("RGBA") + + # Find non-transparent pixels + non_transparent = np.nonzero(np.array(image)[..., 3]) + min_x, max_x = non_transparent[1].min(), non_transparent[1].max() + min_y, max_y = non_transparent[0].min(), non_transparent[0].max() + image = image.crop((min_x, min_y, max_x, max_y)) + + # paste to center + max_dim = max(image.width, image.height) + max_height = max_dim + max_width = int(max_dim / 3 * 2) + new_image = Image.new("RGBA", (max_width, max_height)) + left = (max_width - image.width) // 2 + top = (max_height - image.height) // 2 + new_image.paste(image, (left, top)) + + image = new_image.resize((512, 768), resample=PIL.Image.BICUBIC) + image = np.array(image) + image = image.astype(np.float32) / 255. + assert image.shape[-1] == 4 # RGBA + alpha = image[..., 3:4] + bg_color = get_bg_color("gray") + image = image[..., :3] * alpha + bg_color * (1 - alpha) + # save image + # new_image = Image.fromarray((image * 255).astype(np.uint8)) + # new_image.save("input.png") + return totensor(image) + +class Inference_API: + + def __init__(self): + self.validation_pipeline = None + + @torch.no_grad() + def inference(self, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type, + pose_guider=None, use_noise=True, use_shifted_noise=False, noise_d=256, crop=False, seed=100, timestep=20): + set_seed(seed) + # Get the validation pipeline + if self.validation_pipeline is None: + noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") + if use_shifted_noise: + print(f"enable shifted noise for {val_height} to {noise_d}") + betas = shifted_noise(noise_scheduler.betas, image_d=val_height, noise_d=noise_d) + noise_scheduler.betas = betas + noise_scheduler.alphas = 1 - betas + noise_scheduler.alphas_cumprod = torch.cumprod(noise_scheduler.alphas, dim=0) + self.validation_pipeline = TuneAVideoPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, ref_unet=ref_unet,feature_extractor=feature_extractor,image_encoder=image_encoder, + scheduler=noise_scheduler + ) + self.validation_pipeline.enable_vae_slicing() + self.validation_pipeline.set_progress_bar_config(disable=True) + + totensor = transforms.ToTensor() + + metas = json.load(open("./material/pose.json", "r")) + cameras = [] + pose_images = [] + input_path = "./material" + for lm in metas: + cameras.append(torch.tensor(np.array(lm[0]).reshape(4, 4).transpose(1,0)[:3, :4]).reshape(-1)) + if not crop: + pose_images.append(totensor(np.asarray(Image.open(os.path.join(input_path, lm[1])).resize( + (val_height, val_width), resample=PIL.Image.BICUBIC)).astype(np.float32) / 255.)) + else: + pose_image = Image.open(os.path.join(input_path, lm[1])) + crop_area = (128, 0, 640, 768) + pose_images.append(totensor(np.array(pose_image.crop(crop_area)).astype(np.float32)) / 255.) + camera_matrixs = torch.stack(cameras).unsqueeze(0).to("cuda") + pose_imgs_in = torch.stack(pose_images).to("cuda") + prompts = "high quality, best quality" + prompt_ids = tokenizer( + prompts, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, + return_tensors="pt" + ).input_ids[0] + + # (B*Nv, 3, H, W) + B = 1 + weight_dtype = torch.bfloat16 + imgs_in = process_image(input_image, totensor) + imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W") + + with torch.autocast("cuda", dtype=weight_dtype): + imgs_in = imgs_in.to("cuda") + # B*Nv images + out = self.validation_pipeline(prompt=prompts, image=imgs_in.to(weight_dtype), generator=generator, + num_inference_steps=timestep, + camera_matrixs=camera_matrixs.to(weight_dtype), prompt_ids=prompt_ids, + height=val_height, width=val_width, unet_condition_type=unet_condition_type, + pose_guider=None, pose_image=pose_imgs_in, use_noise=use_noise, + use_shifted_noise=use_shifted_noise, **validation).videos + out = rearrange(out, "B C f H W -> (B f) C H W", f=validation.video_length) + + image_outputs = [] + for bs in range(4): + img_buf = io.BytesIO() + save_image(out[bs], img_buf, format='PNG') + img_buf.seek(0) + img = Image.open(img_buf) + image_outputs.append(img) + torch.cuda.empty_cache() + return image_outputs + +@torch.no_grad() +def main( + pretrained_model_path: str, + image_encoder_path: str, + ckpt_dir: str, + validation: Dict, + local_crossattn: bool = True, + unet_from_pretrained_kwargs=None, + unet_condition_type=None, + use_pose_guider=False, + use_noise=True, + use_shifted_noise=False, + noise_d=256 +): + *_, config = inspect.getargvalues(inspect.currentframe()) + + device = "cuda" + + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path) + feature_extractor = CLIPImageProcessor() + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs) + ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs) + if use_pose_guider: + pose_guider = PoseGuider(noise_latent_channels=4).to("cuda") + else: + pose_guider = None + + unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model.bin"), map_location="cpu") + if use_pose_guider: + pose_guider_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_1.bin"), map_location="cpu") + ref_unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_2.bin"), map_location="cpu") + pose_guider.load_state_dict(pose_guider_params) + else: + ref_unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_1.bin"), map_location="cpu") + unet.load_state_dict(unet_params) + ref_unet.load_state_dict(ref_unet_params) + + weight_dtype = torch.float16 + + text_encoder.to(device, dtype=weight_dtype) + image_encoder.to(device, dtype=weight_dtype) + vae.to(device, dtype=weight_dtype) + ref_unet.to(device, dtype=weight_dtype) + unet.to(device, dtype=weight_dtype) + + vae.requires_grad_(False) + unet.requires_grad_(False) + ref_unet.requires_grad_(False) + + generator = torch.Generator(device="cuda") + inferapi = Inference_API() + remove_api = rm_bg_api() + def gen4views(image, width, height, seed, timestep, remove_bg): + if remove_bg: + image = remove_api.remove_background( + imgs=[np.array(image)], + alpha_min=0.1, + alpha_max=0.9, + )[0] + return inferapi.inference( + image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, text_encoder, pretrained_model_path, + generator, validation, width, height, unet_condition_type, + pose_guider=pose_guider, use_noise=use_noise, use_shifted_noise=use_shifted_noise, noise_d=noise_d, + crop=True, seed=seed, timestep=timestep + ) + + with gr.Blocks() as demo: + gr.Markdown("# [SIGGRAPH'24] CharacterGen: Efficient 3D Character Generation from Single Images with Multi-View Pose Calibration") + gr.Markdown("# 2D Stage: One Image to Four Views of Character Image") + gr.Markdown("**Please Upload the Image without background, and the pictures uploaded should preferably be full-body frontal photos.**") + with gr.Row(): + with gr.Column(): + img_input = gr.Image(type="pil", label="Upload Image(without background)", image_mode="RGBA", width=768, height=512) + gr.Examples( + label="Example Images", + examples=glob.glob("./material/examples/*.png"), + inputs=[img_input] + ) + with gr.Row(): + width_input = gr.Number(label="Width", value=512) + height_input = gr.Number(label="Height", value=768) + seed_input = gr.Number(label="Seed", value=2333) + remove_bg = gr.Checkbox(label="Remove Background (with algorithm)", value=False) + timestep = gr.Slider(minimum=10, maximum=70, step=1, value=40, label="Timesteps") + with gr.Column(): + button = gr.Button(value="Generate") + output = gr.Gallery(label="4 views of Character Image") + + button.click( + fn=gen4views, + inputs=[img_input, width_input, height_input, seed_input, timestep, remove_bg], + outputs=[output] + ) + + demo.launch() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="./configs/infer.yaml") + args = parser.parse_args() + + main(**OmegaConf.load(args.config)) \ No newline at end of file diff --git a/3D_Stage/configs/infer.yaml b/3D_Stage/configs/infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c349dbdffe7355daef1209311a7efda83be54b79 --- /dev/null +++ b/3D_Stage/configs/infer.yaml @@ -0,0 +1,104 @@ +system_cls: lrm.systems.multiview_lrm.MultiviewLRM +data: + cond_width: 504 + cond_height: 504 + +system: + weights: ./models/lrm.ckpt + + weights_ignore_modules: + - decoder.heads.density + + check_train_every_n_steps: 100 + + camera_embedder_cls: lrm.models.camera.LinearCameraEmbedder + camera_embedder: + in_channels: 16 + out_channels: 768 + conditions: + - c2w_cond + + # image tokenizer transforms input images to tokens + image_tokenizer_cls: lrm.models.tokenizers.image.DINOV2SingleImageTokenizer + image_tokenizer: + pretrained_model_name_or_path: "./models/base" + freeze_backbone_params: false + enable_memory_efficient_attention: true + enable_gradient_checkpointing: true + # camera modulation to the DINO transformer layers + modulation: true + modulation_zero_init: true + modulation_single_layer: true + modulation_cond_dim: ${system.camera_embedder.out_channels} + + # tokenizer gives a tokenized representation for the 3D scene + # triplane tokens in this case + tokenizer_cls: lrm.models.tokenizers.triplane.TriplaneLearnablePositionalEmbedding + tokenizer: + plane_size: 32 + num_channels: 512 + + # backbone network is a transformer that takes scene tokens (potentially with conditional image tokens) + # and outputs scene tokens of the same size + backbone_cls: lrm.models.transformers.transformer_1d.Transformer1D + backbone: + in_channels: ${system.tokenizer.num_channels} + num_attention_heads: 16 + attention_head_dim: 64 + num_layers: 12 + cross_attention_dim: 768 # hard-code, =DINO feature dim + # camera modulation to the transformer layers + # if not needed, set norm_type=layer_norm and do not specify cond_dim_ada_norm_continuous + norm_type: "layer_norm" + enable_memory_efficient_attention: true + gradient_checkpointing: true + + # post processor takes scene tokens and outputs the final scene parameters that will be used for rendering + # in this case, triplanes are upsampled and the features are condensed + post_processor_cls: lrm.models.networks.TriplaneUpsampleNetwork + post_processor: + in_channels: 512 + out_channels: 80 + + renderer_cls: lrm.models.renderers.triplane_dmtet.TriplaneDMTetRenderer + renderer: + radius: 0.6 # slightly larger than 0.5 + feature_reduction: concat + sdf_bias: -2. + tet_dir: "./load/tets/" + isosurface_resolution: 256 + enable_isosurface_grid_deformation: false + sdf_activation: negative + + decoder_cls: lrm.models.networks.MultiHeadMLP + decoder: + in_channels: 240 # 3 * 80 + n_neurons: 64 + n_hidden_layers_share: 8 + heads: + - name: sdf + out_channels: 1 + n_hidden_layers: 1 + output_activation: null + - name: features + out_channels: 3 + n_hidden_layers: 1 + output_activation: null # activate in material + activation: silu + chunk_mode: deferred + chunk_size: 131072 + + exporter: + fmt: "obj" + #visual: "vertex" + visual: "uv" + save_uv: True + save_texture: True + uv_unwrap_method: "open3d" + output_path: "./outputs" + + material_cls: lrm.models.materials.no_material.NoMaterial + + background_cls: lrm.models.background.solid_color_background.SolidColorBackground + background: + color: [0.5, 0.5, 0.5] \ No newline at end of file diff --git a/3D_Stage/load/tets/generate_tets.py b/3D_Stage/load/tets/generate_tets.py new file mode 100644 index 0000000000000000000000000000000000000000..424d852d25105fae2d03d6a0d269bc7c2797c6bd --- /dev/null +++ b/3D_Stage/load/tets/generate_tets.py @@ -0,0 +1,58 @@ +# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + +import os + +import numpy as np + +""" +This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet, +to generate a tet grid +1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet` +2) Run the function below to generate a file `cube_32_tet.tet` +""" + + +def generate_tetrahedron_grid_file(res=32, root=".."): + frac = 1.0 / res + command = f"cd {root}; ./quartet meshes/cube.obj {frac} meshes/cube_{res}_tet.tet -s meshes/cube_boundary_{res}.obj" + os.system(command) + + +""" +This code segment shows how to convert from a quartet .tet file to compressed npz file +""" + + +def convert_from_quartet_to_npz(quartetfile="cube_32_tet.tet", npzfile="32_tets"): + file1 = open(quartetfile, "r") + header = file1.readline() + numvertices = int(header.split(" ")[1]) + numtets = int(header.split(" ")[2]) + print(numvertices, numtets) + + # load vertices + vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices) + print(vertices.shape) + + # load indices + indices = np.loadtxt( + quartetfile, dtype=int, skiprows=1 + numvertices, max_rows=numtets + ) + print(indices.shape) + + np.savez_compressed(npzfile, vertices=vertices, indices=indices) + + +root = "/home/gyc/quartet" +for res in [300, 350, 400]: + generate_tetrahedron_grid_file(res, root) + convert_from_quartet_to_npz( + os.path.join(root, f"meshes/cube_{res}_tet.tet"), npzfile=f"{res}_tets" + ) diff --git a/3D_Stage/lrm/__init__.py b/3D_Stage/lrm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69f5951a4991d24895d3eba9c7f285cf7ec2e3e6 --- /dev/null +++ b/3D_Stage/lrm/__init__.py @@ -0,0 +1,29 @@ +import importlib + + +def find(cls_string): + module_string = ".".join(cls_string.split(".")[:-1]) + cls_name = cls_string.split(".")[-1] + module = importlib.import_module(module_string, package=None) + cls = getattr(module, cls_name) + return cls + + +### grammar sugar for logging utilities ### +import logging + +logger = logging.getLogger("pytorch_lightning") + +from pytorch_lightning.utilities.rank_zero import ( + rank_zero_debug, + rank_zero_info, + rank_zero_only, +) + +debug = rank_zero_debug +info = rank_zero_info + + +@rank_zero_only +def warn(*args, **kwargs): + logger.warn(*args, **kwargs) diff --git a/3D_Stage/lrm/models/__init__.py b/3D_Stage/lrm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3D_Stage/lrm/models/background/__init__.py b/3D_Stage/lrm/models/background/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3D_Stage/lrm/models/background/base.py b/3D_Stage/lrm/models/background/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3900de842573a60b6343f4590229aacdd3e0ecbb --- /dev/null +++ b/3D_Stage/lrm/models/background/base.py @@ -0,0 +1,24 @@ +import random +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import lrm +from ...utils.base import BaseModule +from ...utils.typing import * + + +class BaseBackground(BaseModule): + @dataclass + class Config(BaseModule.Config): + pass + + cfg: Config + + def configure(self): + pass + + def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]: + raise NotImplementedError diff --git a/3D_Stage/lrm/models/background/solid_color_background.py b/3D_Stage/lrm/models/background/solid_color_background.py new file mode 100644 index 0000000000000000000000000000000000000000..8fad75512bcce180ae1c0c45830b696460bb098d --- /dev/null +++ b/3D_Stage/lrm/models/background/solid_color_background.py @@ -0,0 +1,58 @@ +import random +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import lrm +from .base import BaseBackground +from ...utils.typing import * + + +class SolidColorBackground(BaseBackground): + @dataclass + class Config(BaseBackground.Config): + n_output_dims: int = 3 + color: Tuple = (1.0, 1.0, 1.0) + learned: bool = False + random_aug: bool = False + random_aug_prob: float = 0.5 + + cfg: Config + + def configure(self) -> None: + self.env_color: Float[Tensor, "Nc"] + if self.cfg.learned: + self.env_color = nn.Parameter( + torch.as_tensor(self.cfg.color, dtype=torch.float32) + ) + else: + self.register_buffer( + "env_color", torch.as_tensor(self.cfg.color, dtype=torch.float32) + ) + + def forward( + self, + dirs: Float[Tensor, "B H W Nc"], + color_spec: Optional[Float[Tensor, "Nc"]] = None, + ) -> Float[Tensor, "B H W Nc"]: + color = torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(dirs) * ( + color_spec if color_spec is not None else self.env_color + ) + if ( + self.training + and self.cfg.random_aug + and random.random() < self.cfg.random_aug_prob + ): + # use random background color with probability random_aug_prob + # color = color * 0 + ( + # torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(dirs) * + # torch.rand(self.cfg.n_output_dims).to(dirs) + # ) + color = color * 0 + ( # prevent checking for unused parameters in DDP + torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims) + .to(dirs) + .expand(*dirs.shape[:-1], -1) + ) + return color diff --git a/3D_Stage/lrm/models/camera.py b/3D_Stage/lrm/models/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..62461c3aadf73d1cd3f138eda5cd494ffe48d40a --- /dev/null +++ b/3D_Stage/lrm/models/camera.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass, field + +import torch +import torch.nn as nn + +from ..utils.base import BaseModule +from ..utils.typing import * + + +class LinearCameraEmbedder(BaseModule): + @dataclass + class Config(BaseModule.Config): + in_channels: int = 0 + out_channels: int = 0 + conditions: List[str] = field(default_factory=list) + + cfg: Config + + def configure(self) -> None: + super().configure() + self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels) + + def forward(self, **kwargs): + cond_tensors = [] + for cond_name in self.cfg.conditions: + assert cond_name in kwargs + cond = kwargs[cond_name] + # cond in shape (B, Nv, ...) + cond_tensors.append(cond.view(*cond.shape[:2], -1)) + cond_tensor = torch.cat(cond_tensors, dim=-1) + assert cond_tensor.shape[-1] == self.cfg.in_channels + embedding = self.linear(cond_tensor) + return embedding diff --git a/3D_Stage/lrm/models/exporters/__init__.py b/3D_Stage/lrm/models/exporters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3D_Stage/lrm/models/exporters/base.py b/3D_Stage/lrm/models/exporters/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d3779ff65a1839ce7cba056d24159991743f6b89 --- /dev/null +++ b/3D_Stage/lrm/models/exporters/base.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass + +import lrm +from ..renderers.base import BaseRenderer +from ...utils.base import BaseObject +from ...utils.typing import * + + +@dataclass +class ExporterOutput: + save_name: str + save_type: str + params: Dict[str, Any] + + +class Exporter(BaseObject): + @dataclass + class Config(BaseObject.Config): + save_video: bool = False + + cfg: Config + + def configure(self, renderer: BaseRenderer) -> None: + self.renderer = renderer + + def __call__(self, *args, **kwargs) -> List[ExporterOutput]: + raise NotImplementedError + + +class DummyExporter(Exporter): + def __call__(self, *args, **kwargs) -> List[ExporterOutput]: + # DummyExporter does not export anything + return [] diff --git a/3D_Stage/lrm/models/exporters/mesh_exporter.py b/3D_Stage/lrm/models/exporters/mesh_exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..236fa48aa2e067f67b1df77b8b9963312b042a25 --- /dev/null +++ b/3D_Stage/lrm/models/exporters/mesh_exporter.py @@ -0,0 +1,263 @@ +from dataclasses import dataclass, field +import tempfile +import os + +import cv2 +import numpy as np +import torch + +import lrm +from ..renderers.base import BaseRenderer +from .base import Exporter, ExporterOutput +from ..mesh import Mesh +from ...utils.rasterize import NVDiffRasterizerContext +from ...utils.typing import * +from ...utils.misc import time_recorder as tr, time_recorder_enabled + + +def uv_padding_cpu(image, hole_mask, padding): + uv_padding_size = padding + inpaint_image = ( + cv2.inpaint( + (image.detach().cpu().numpy() * 255).astype(np.uint8), + (hole_mask.detach().cpu().numpy() * 255).astype(np.uint8), + uv_padding_size, + cv2.INPAINT_TELEA, + ) + / 255.0 + ) + return torch.from_numpy(inpaint_image).to(image) + + +def uv_padding_cvc(image, hole_mask, padding): + import cvcuda + + torch_to_cvc = lambda x, layout: cvcuda.as_tensor(x, layout) + cvc_to_torch = lambda x: torch.as_tensor(x.cuda()) + + uv_padding_size = padding + image_cvc = torch_to_cvc((image.detach() * 255).to(torch.uint8), "HWC") + hole_mask_cvc = torch_to_cvc((hole_mask.detach() * 255).to(torch.uint8), "HW") + inpaint_image = cvcuda.inpaint(image_cvc, hole_mask_cvc, uv_padding_size) + inpaint_image = cvc_to_torch(inpaint_image) / 255.0 + return inpaint_image.to(image) + + +def uv_padding(image, hole_mask, padding): + try: + inpaint_image = uv_padding_cvc(image, hole_mask, padding) + except: + lrm.info(f"CVCUDA not available, fallback to CPU UV padding.") + inpaint_image = uv_padding_cpu(image, hole_mask, padding) + return inpaint_image + + +class MeshExporter(Exporter): + @dataclass + class Config(Exporter.Config): + fmt: str = "obj" # in ['obj', 'glb'] + visual: str = "uv" # in ['uv', 'vertex'] + save_name: str = "model" + save_normal: bool = False + save_uv: bool = True + save_texture: bool = True + texture_size: int = 1024 + texture_format: str = "jpg" + uv_unwrap_method: str = "xatlas" + xatlas_chart_options: dict = field(default_factory=dict) + xatlas_pack_options: dict = field(default_factory=dict) + smartuv_options: dict = field(default_factory=dict) + uv_padding_size: int = 2 + subdivide: bool = False + post_process: bool = False + post_process_options: dict = field(default_factory=dict) + context_type: str = "gl" + output_path: str = "outputs" + + cfg: Config + + def configure(self, renderer: BaseRenderer) -> None: + super().configure(renderer) + self.ctx = NVDiffRasterizerContext(self.cfg.context_type, self.device) + if self.cfg.fmt == "obj-mtl": + lrm.warn( + f"fmt=obj-mtl is deprecated, please us fmt=obj and visual=uv instead." + ) + self.cfg.fmt = "obj" + self.cfg.visual = "uv" + + if self.cfg.fmt == "glb": + assert self.cfg.visual in [ + "vertex", + "uv-blender", + ], "GLB format only supports visual=vertex and visual=uv-blender!" + + def get_geometry(self, scene_code: torch.Tensor) -> Mesh: + tr.start("Surface extraction") + mesh: Mesh = self.renderer.isosurface(scene_code) + tr.end("Surface extraction") + return mesh + + def get_texture_maps( + self, scene_code: torch.Tensor, mesh: Mesh + ) -> Dict[str, torch.Tensor]: + assert mesh.has_uv + # clip space transform + uv_clip = mesh.v_tex * 2.0 - 1.0 + # pad to four component coordinate + uv_clip4 = torch.cat( + ( + uv_clip, + torch.zeros_like(uv_clip[..., 0:1]), + torch.ones_like(uv_clip[..., 0:1]), + ), + dim=-1, + ) + # rasterize + rast, _ = self.ctx.rasterize_one( + uv_clip4, + mesh.t_tex_idx, + (self.cfg.texture_size, self.cfg.texture_size), + ) + + hole_mask = ~(rast[:, :, 3] > 0) + + # Interpolate world space position + gb_pos, _ = self.ctx.interpolate_one( + mesh.v_pos, rast[None, ...], mesh.t_pos_idx + ) + gb_pos = gb_pos[0] + + # Sample out textures from MLP + tr.start("Query color") + geo_out = self.renderer.query(scene_code, points=gb_pos) + tr.end("Query color") + mat_out = self.renderer.material.export(points=gb_pos, **geo_out) + + textures = {} + tr.start("UV padding") + if "albedo" in mat_out: + textures["map_Kd"] = uv_padding( + mat_out["albedo"], hole_mask, self.cfg.uv_padding_size + ) + else: + lrm.warn( + "save_texture is True but no albedo texture found, using default white texture" + ) + if "metallic" in mat_out: + textures["map_Pm"] = uv_padding( + mat_out["metallic"], hole_mask, self.cfg.uv_padding_size + ) + if "roughness" in mat_out: + textures["map_Pr"] = uv_padding( + mat_out["roughness"], hole_mask, self.cfg.uv_padding_size + ) + if "bump" in mat_out: + textures["map_Bump"] = uv_padding( + mat_out["bump"], hole_mask, self.cfg.uv_padding_size + ) + tr.end("UV padding") + return textures + + def __call__(self, names, scene_codes) -> List[ExporterOutput]: + outputs = [] + for name, scene_code in zip(names, scene_codes): + mesh = self.get_geometry(scene_code) + if self.cfg.post_process: + tr.start("Mesh post-processing") + mesh = mesh.post_process(self.cfg.post_process_options) + tr.end("Mesh post-processing") + if self.cfg.visual == "uv": + output = self.export_model_with_mtl( + name, self.cfg.fmt, scene_code, mesh + ) + elif self.cfg.visual == "vertex": + output = self.export_model(name, self.cfg.fmt, scene_code, mesh) + elif self.cfg.visual == "uv-blender": + output = self.export_model_blender(name, self.cfg.fmt, scene_code, mesh) + else: + raise ValueError(f"Unsupported visual format: {self.cfg.visual}") + outputs.append(output) + return outputs + + def export_model_with_mtl( + self, name: str, fmt: str, scene_code: torch.Tensor, mesh: Mesh + ) -> ExporterOutput: + params = { + "mesh": mesh, + "save_mat": True, + "save_normal": self.cfg.save_normal, + "save_uv": self.cfg.save_uv, + "save_vertex_color": False, + "map_Kd": None, # Base Color + "map_Ks": None, # Specular + "map_Bump": None, # Normal + # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering + "map_Pm": None, # Metallic + "map_Pr": None, # Roughness + "map_format": self.cfg.texture_format, + } + + if self.cfg.save_uv: + mesh.unwrap_uv( + self.cfg.uv_unwrap_method, + self.cfg.xatlas_chart_options, + self.cfg.xatlas_pack_options, + self.cfg.smartuv_options, + ) + + if self.cfg.save_texture: + lrm.info("Exporting textures ...") + assert self.cfg.save_uv, "save_uv must be True when save_texture is True" + + with time_recorder_enabled(): + textures = self.get_texture_maps(scene_code, mesh) + params.update(textures) + os.makedirs(self.cfg.output_path, exist_ok=True) + np.savez(f"{self.cfg.output_path}/tex_info.npz", v_tex=mesh.v_tex.cpu().numpy(), t_tex_idx=mesh.t_tex_idx.cpu().numpy()) + return ExporterOutput( + save_name=f"{self.cfg.save_name}-{name}.{fmt}", save_type=fmt, params=params + ) + + def export_model( + self, name: str, fmt: str, scene_code, mesh: Mesh + ) -> ExporterOutput: + params = { + "mesh": mesh, + "save_mat": False, + "save_normal": self.cfg.save_normal, + "save_uv": self.cfg.save_uv, + "save_vertex_color": False, + "map_Kd": None, # Base Color + "map_Ks": None, # Specular + "map_Bump": None, # Normal + # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering + "map_Pm": None, # Metallic + "map_Pr": None, # Roughness + "map_format": self.cfg.texture_format, + } + + if self.cfg.save_uv: + mesh.unwrap_uv( + self.cfg.uv_unwrap_method, + self.cfg.xatlas_chart_options, + self.cfg.xatlas_pack_options, + self.cfg.smartuv_options, + ) + + if self.cfg.save_texture: + lrm.info("Exporting textures ...") + geo_out = self.renderer.query(scene_code, points=mesh.v_pos) + mat_out = self.renderer.material.export(points=mesh.v_pos, **geo_out) + + if "albedo" in mat_out: + mesh.set_vertex_color(mat_out["albedo"]) + params["save_vertex_color"] = True + else: + lrm.warn( + "save_texture is True but no albedo texture found, not saving vertex color" + ) + + return ExporterOutput( + save_name=f"{self.cfg.save_name}-{name}.{fmt}", save_type=fmt, params=params + ) diff --git a/3D_Stage/lrm/models/isosurface.py b/3D_Stage/lrm/models/isosurface.py new file mode 100644 index 0000000000000000000000000000000000000000..0ece3ac42cc573d057460ddcea4a61d0b9281c27 --- /dev/null +++ b/3D_Stage/lrm/models/isosurface.py @@ -0,0 +1,272 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import lrm +from ..models.mesh import Mesh +from ..utils.typing import * +from ..utils.ops import scale_tensor + + +class IsosurfaceHelper(nn.Module): + points_range: Tuple[float, float] = (0, 1) + + @property + def grid_vertices(self) -> Float[Tensor, "N 3"]: + raise NotImplementedError + + +class MarchingCubeCPUHelper(IsosurfaceHelper): + def __init__(self, resolution: int) -> None: + super().__init__() + self.resolution = resolution + import mcubes + + self.mc_func: Callable = mcubes.marching_cubes + self._grid_vertices: Optional[Float[Tensor, "N3 3"]] = None + self._dummy: Float[Tensor, "..."] + self.register_buffer( + "_dummy", torch.zeros(0, dtype=torch.float32), persistent=False + ) + + @property + def grid_vertices(self) -> Float[Tensor, "N3 3"]: + if self._grid_vertices is None: + # keep the vertices on CPU so that we can support very large resolution + x, y, z = ( + torch.linspace(*self.points_range, self.resolution), + torch.linspace(*self.points_range, self.resolution), + torch.linspace(*self.points_range, self.resolution), + ) + x, y, z = torch.meshgrid(x, y, z, indexing="ij") + verts = torch.cat( + [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1 + ).reshape(-1, 3) + self._grid_vertices = verts + return self._grid_vertices + + def forward( + self, + level: Float[Tensor, "N3 1"], + deformation: Optional[Float[Tensor, "N3 3"]] = None, + ) -> Mesh: + if deformation is not None: + lrm.warn( + f"{self.__class__.__name__} does not support deformation. Ignoring." + ) + level = -level.view(self.resolution, self.resolution, self.resolution) + print(level.shape, level.min(), level.max()) + v_pos, t_pos_idx = self.mc_func( + level.detach().cpu().numpy(), 0.0 + ) # transform to numpy + # test + v_pos, t_pos_idx = ( + torch.from_numpy(v_pos).float().to(self._dummy.device), + torch.from_numpy(t_pos_idx.astype(np.int64)).long().to(self._dummy.device), + ) # transform back to torch tensor on CUDA + v_pos = v_pos / (self.resolution - 1.0) + return Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx) + + +def get_center_boundary_index(verts): + # Assuming the verts are in range [-1.0, 1.0] + length_ = torch.linalg.norm(verts**2, ord=2, dim=-1, keepdim=False) + center_idx = torch.argmin(length_) + # center_idx = torch.where(length_ < 0.1)[0] + boundary_neg = verts == verts.max() + boundary_pos = verts == verts.min() + boundary = torch.bitwise_or(boundary_pos, boundary_neg) + boundary = torch.sum(boundary.float(), dim=-1) + boundary_idx = torch.nonzero(boundary) + return center_idx.unsqueeze(0), boundary_idx.squeeze(dim=-1) + + +class MarchingTetrahedraHelper(IsosurfaceHelper): + def __init__(self, resolution: int, tets_path: str): + super().__init__() + self.resolution = resolution + self.tets_path = tets_path + + self.triangle_table: Float[Tensor, "..."] + self.register_buffer( + "triangle_table", + torch.as_tensor( + [ + [-1, -1, -1, -1, -1, -1], + [1, 0, 2, -1, -1, -1], + [4, 0, 3, -1, -1, -1], + [1, 4, 2, 1, 3, 4], + [3, 1, 5, -1, -1, -1], + [2, 3, 0, 2, 5, 3], + [1, 4, 0, 1, 5, 4], + [4, 2, 5, -1, -1, -1], + [4, 5, 2, -1, -1, -1], + [4, 1, 0, 4, 5, 1], + [3, 2, 0, 3, 5, 2], + [1, 3, 5, -1, -1, -1], + [4, 1, 2, 4, 3, 1], + [3, 0, 4, -1, -1, -1], + [2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1], + ], + dtype=torch.long, + ), + persistent=False, + ) + self.num_triangles_table: Integer[Tensor, "..."] + self.register_buffer( + "num_triangles_table", + torch.as_tensor( + [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long + ), + persistent=False, + ) + self.base_tet_edges: Integer[Tensor, "..."] + self.register_buffer( + "base_tet_edges", + torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long), + persistent=False, + ) + + tets = np.load(self.tets_path) + self._grid_vertices: Float[Tensor, "..."] + self.register_buffer( + "_grid_vertices", + torch.from_numpy(tets["vertices"]).float(), + persistent=False, + ) + self.indices: Integer[Tensor, "..."] + self.register_buffer( + "indices", torch.from_numpy(tets["indices"]).long(), persistent=False + ) + + self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None + self.center_indices, self.boundary_indices = get_center_boundary_index( + scale_tensor(self.grid_vertices, self.points_range, (-1.0, 1.0)) + ) + + def normalize_grid_deformation( + self, grid_vertex_offsets: Float[Tensor, "Nv 3"] + ) -> Float[Tensor, "Nv 3"]: + return ( + (self.points_range[1] - self.points_range[0]) + / (self.resolution) # half tet size is approximately 1 / self.resolution + * torch.tanh(grid_vertex_offsets) + ) # FIXME: hard-coded activation + + @property + def grid_vertices(self) -> Float[Tensor, "Nv 3"]: + return self._grid_vertices + + @property + def all_edges(self) -> Integer[Tensor, "Ne 2"]: + if self._all_edges is None: + # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation) + edges = torch.tensor( + [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], + dtype=torch.long, + device=self.indices.device, + ) + _all_edges = self.indices[:, edges].reshape(-1, 2) + _all_edges_sorted = torch.sort(_all_edges, dim=1)[0] + _all_edges = torch.unique(_all_edges_sorted, dim=0) + self._all_edges = _all_edges + return self._all_edges + + def sort_edges(self, edges_ex2): + with torch.no_grad(): + order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() + order = order.unsqueeze(dim=1) + + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1 - order, dim=1) + + return torch.stack([a, b], -1) + + def _forward(self, pos_nx3, sdf_n, tet_fx4): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2) + all_edges = self.sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = ( + torch.ones( + (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device + ) + * -1 + ) + mapping[mask_edges] = torch.arange( + mask_edges.sum(), dtype=torch.long, device=pos_nx3.device + ) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device)) + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = self.num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], + dim=1, + index=self.triangle_table[tetindex[num_triangles == 1]][:, :3], + ).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], + dim=1, + index=self.triangle_table[tetindex[num_triangles == 2]][:, :6], + ).reshape(-1, 3), + ), + dim=0, + ) + + return verts, faces + + def forward( + self, + level: Float[Tensor, "N3 1"], + deformation: Optional[Float[Tensor, "N3 3"]] = None, + ) -> Mesh: + if deformation is not None: + grid_vertices = self.grid_vertices + self.normalize_grid_deformation( + deformation + ) + else: + grid_vertices = self.grid_vertices + + v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices) + + mesh = Mesh( + v_pos=v_pos, + t_pos_idx=t_pos_idx, + # extras + grid_vertices=grid_vertices, + tet_edges=self.all_edges, + grid_level=level, + grid_deformation=deformation, + ) + + return mesh diff --git a/3D_Stage/lrm/models/lpips.py b/3D_Stage/lrm/models/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..40c70e278b1c157b3cf4105b707c6679d0f7e1e6 --- /dev/null +++ b/3D_Stage/lrm/models/lpips.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import Any +import lpips + +from ..utils.ops import scale_tensor +from ..utils.misc import get_device + + +class LPIPS: + def __init__(self): + self.model = lpips.LPIPS(net="vgg").to(get_device()) + self.model.eval() + for params in self.model.parameters(): + params.requires_grad = False + self.model_input_range = (-1, 1) + + def __call__(self, x1, x2, return_layers=False, input_range=(0, 1)): + x1 = scale_tensor(x1, input_range, self.model_input_range) + x2 = scale_tensor(x2, input_range, self.model_input_range) + return self.model(x1, x2, retPerLayer=return_layers, normalize=False) diff --git a/3D_Stage/lrm/models/materials/__init__.py b/3D_Stage/lrm/models/materials/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3D_Stage/lrm/models/materials/base.py b/3D_Stage/lrm/models/materials/base.py new file mode 100644 index 0000000000000000000000000000000000000000..968e4e2ed771227f0db347f6858965a6e4e4f5c0 --- /dev/null +++ b/3D_Stage/lrm/models/materials/base.py @@ -0,0 +1,29 @@ +import random +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import lrm +from ...utils.base import BaseModule +from ...utils.typing import * + + +class BaseMaterial(BaseModule): + @dataclass + class Config(BaseModule.Config): + pass + + cfg: Config + requires_normal: bool = False + requires_tangent: bool = False + + def configure(self): + pass + + def forward(self, *args, **kwargs) -> Float[Tensor, "*B 3"]: + raise NotImplementedError + + def export(self, *args, **kwargs) -> Dict[str, Any]: + return {} diff --git a/3D_Stage/lrm/models/materials/no_material.py b/3D_Stage/lrm/models/materials/no_material.py new file mode 100644 index 0000000000000000000000000000000000000000..754fec75a42ff1835e9df5e2c9fcdefcc5099814 --- /dev/null +++ b/3D_Stage/lrm/models/materials/no_material.py @@ -0,0 +1,60 @@ +import random +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import lrm +from .base import BaseMaterial +from ..networks import get_encoding, get_mlp +from ...utils.ops import dot, get_activation +from ...utils.typing import * + + +class NoMaterial(BaseMaterial): + @dataclass + class Config(BaseMaterial.Config): + n_output_dims: int = 3 + color_activation: str = "sigmoid" + input_feature_dims: Optional[int] = None + mlp_network_config: Optional[dict] = None + requires_normal: bool = False + + cfg: Config + + def configure(self) -> None: + self.use_network = False + if ( + self.cfg.input_feature_dims is not None + and self.cfg.mlp_network_config is not None + ): + self.network = get_mlp( + self.cfg.input_feature_dims, + self.cfg.n_output_dims, + self.cfg.mlp_network_config, + ) + self.use_network = True + self.requires_normal = self.cfg.requires_normal + + def forward( + self, features: Float[Tensor, "B ... Nf"], **kwargs + ) -> Float[Tensor, "B ... Nc"]: + if not self.use_network: + assert ( + features.shape[-1] == self.cfg.n_output_dims + ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input." + color = get_activation(self.cfg.color_activation)(features) + else: + color = self.network(features.view(-1, features.shape[-1])).view( + *features.shape[:-1], self.cfg.n_output_dims + ) + color = get_activation(self.cfg.color_activation)(color) + return color + + def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: + color = self(features, **kwargs).clamp(0, 1) + assert color.shape[-1] >= 3, "Output color must have at least 3 channels" + if color.shape[-1] > 3: + lrm.warn("Output color has >3 channels, treating the first 3 as RGB") + return {"albedo": color[..., :3]} diff --git a/3D_Stage/lrm/models/mesh.py b/3D_Stage/lrm/models/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..9765a8131347b9abe316608aa7a5bbd509db7566 --- /dev/null +++ b/3D_Stage/lrm/models/mesh.py @@ -0,0 +1,471 @@ +from __future__ import annotations + +import functools +import numpy as np +import torch +import torch.nn.functional as F + +import lrm +from ..utils.ops import dot +from ..utils.typing import * +from ..utils.misc import time_recorder as tr, time_recorder_enabled + + +class Mesh: + def __init__( + self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs + ) -> None: + self.v_pos: Float[Tensor, "Nv 3"] = v_pos + self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx + self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None + self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None + self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None + self._t_tex_idx: Optional[Float[Tensor, "Nf 3"]] = None + self._v_rgb: Optional[Float[Tensor, "Nv 3"]] = None + self._edges: Optional[Integer[Tensor, "Ne 2"]] = None + self.extras: Dict[str, Any] = {} + for k, v in kwargs.items(): + self.add_extra(k, v) + + def add_extra(self, k, v) -> None: + self.extras[k] = v + + def remove_outlier(self, outlier_n_faces_threshold: Union[int, float]) -> Mesh: + if self.requires_grad: + lrm.debug("Mesh is differentiable, not removing outliers") + return self + + # use trimesh to first split the mesh into connected components + # then remove the components with less than n_face_threshold faces + import trimesh + + # construct a trimesh object + mesh = trimesh.Trimesh( + vertices=self.v_pos.detach().cpu().numpy(), + faces=self.t_pos_idx.detach().cpu().numpy(), + ) + + # split the mesh into connected components + components = mesh.split(only_watertight=False) + # log the number of faces in each component + lrm.debug( + "Mesh has {} components, with faces: {}".format( + len(components), [c.faces.shape[0] for c in components] + ) + ) + + n_faces_threshold: int + if isinstance(outlier_n_faces_threshold, float): + # set the threshold to the number of faces in the largest component multiplied by outlier_n_faces_threshold + n_faces_threshold = int( + max([c.faces.shape[0] for c in components]) * outlier_n_faces_threshold + ) + else: + # set the threshold directly to outlier_n_faces_threshold + n_faces_threshold = outlier_n_faces_threshold + + # log the threshold + lrm.debug( + "Removing components with less than {} faces".format(n_faces_threshold) + ) + + # remove the components with less than n_face_threshold faces + components = [c for c in components if c.faces.shape[0] >= n_faces_threshold] + + # log the number of faces in each component after removing outliers + lrm.debug( + "Mesh has {} components after removing outliers, with faces: {}".format( + len(components), [c.faces.shape[0] for c in components] + ) + ) + # merge the components + mesh = trimesh.util.concatenate(components) + + # convert back to our mesh format + v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos) + t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx) + + clean_mesh = Mesh(v_pos, t_pos_idx) + # keep the extras unchanged + + if len(self.extras) > 0: + clean_mesh.extras = self.extras + lrm.debug( + f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}" + ) + return clean_mesh + + def subdivide(self): + if self.requires_grad: + lrm.debug("Mesh is differentiable, not performing subdivision") + return self + + import trimesh + + mesh = trimesh.Trimesh( + vertices=self.v_pos.detach().cpu().numpy(), + faces=self.t_pos_idx.detach().cpu().numpy(), + ) + + mesh.subdivide_loop() + + v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos) + t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx) + + subdivided_mesh = Mesh(v_pos, t_pos_idx) + + if len(self.extras) > 0: + subdivided_mesh.extras = self.extras + lrm.debug( + f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}" + ) + + return subdivided_mesh + + def post_process(self, options): + if self.requires_grad: + lrm.debug("Mesh is differentiable, not performing post processing") + return self + + from extern.mesh_process.MeshProcess import process_mesh + + v_pos, t_pos_idx = process_mesh( + vertices=self.v_pos.detach().cpu().numpy(), + faces=self.t_pos_idx.detach().cpu().numpy(), + **options, + ) + + v_pos = torch.from_numpy(v_pos).to(self.v_pos).contiguous() + t_pos_idx = torch.from_numpy(t_pos_idx).to(self.t_pos_idx).contiguous() + + processed_mesh = Mesh(v_pos, t_pos_idx) + + if len(self.extras) > 0: + processed_mesh.extras = self.extras + lrm.debug( + f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}" + ) + + return processed_mesh + + @property + def requires_grad(self): + return self.v_pos.requires_grad + + @property + def v_nrm(self): + if self._v_nrm is None: + self._v_nrm = self._compute_vertex_normal() + return self._v_nrm + + @property + def v_tng(self): + if self._v_tng is None: + self._v_tng = self._compute_vertex_tangent() + return self._v_tng + + @property + def v_tex(self): + if self._v_tex is None: + self._v_tex, self._t_tex_idx = self._unwrap_uv() + return self._v_tex + + @property + def t_tex_idx(self): + if self._t_tex_idx is None: + self._v_tex, self._t_tex_idx = self._unwrap_uv() + return self._t_tex_idx + + @property + def v_rgb(self): + return self._v_rgb + + @property + def edges(self): + if self._edges is None: + self._edges = self._compute_edges() + return self._edges + + def _compute_vertex_normal(self): + i0 = self.t_pos_idx[:, 0] + i1 = self.t_pos_idx[:, 1] + i2 = self.t_pos_idx[:, 2] + + v0 = self.v_pos[i0, :] + v1 = self.v_pos[i1, :] + v2 = self.v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(self.v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where( + dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) + ) + v_nrm = F.normalize(v_nrm, dim=1) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(v_nrm)) + + return v_nrm + + def _compute_vertex_tangent(self): + vn_idx = [None] * 3 + pos = [None] * 3 + tex = [None] * 3 + for i in range(0, 3): + pos[i] = self.v_pos[self.t_pos_idx[:, i]] + tex[i] = self.v_tex[self.t_tex_idx[:, i]] + # t_nrm_idx is always the same as t_pos_idx + vn_idx[i] = self.t_pos_idx[:, i] + + tangents = torch.zeros_like(self.v_nrm) + tansum = torch.zeros_like(self.v_nrm) + + # Compute tangent space for each triangle + uve1 = tex[1] - tex[0] + uve2 = tex[2] - tex[0] + pe1 = pos[1] - pos[0] + pe2 = pos[2] - pos[0] + + nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2] + denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1] + + # Avoid division by zero for degenerated texture coordinates + tang = nom / torch.where( + denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6) + ) + + # Update all 3 vertices + for i in range(0, 3): + idx = vn_idx[i][:, None].repeat(1, 3) + tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang + tansum.scatter_add_( + 0, idx, torch.ones_like(tang) + ) # tansum[n_i] = tansum[n_i] + 1 + tangents = tangents / tansum + + # Normalize and make sure tangent is perpendicular to normal + tangents = F.normalize(tangents, dim=1) + tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm) + + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(tangents)) + + return tangents + + def _unwrap_uv_open3d( + self + ): + import open3d as o3d + mesh = o3d.t.geometry.TriangleMesh() + mesh.vertex.positions = o3d.core.Tensor(self.v_pos.detach().cpu().numpy()) + mesh.triangle.indices = o3d.core.Tensor(self.t_pos_idx.cpu().numpy()) + mesh.compute_uvatlas(size=1024) + texture_uvs = torch.from_numpy(mesh.triangle.texture_uvs.numpy()).reshape(-1, 2).cuda() + indices = torch.arange(self.t_pos_idx.shape[0] * 3).reshape(-1, 3).to(torch.int64).cuda() + # Add a wood texture and visualize + return texture_uvs, indices + + def _unwrap_uv_xatlas( + self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {} + ): + lrm.info("Using xatlas to perform UV unwrapping, may take a while ...") + + import xatlas + + atlas = xatlas.Atlas() + atlas.add_mesh( + self.v_pos.detach().cpu().numpy(), + self.t_pos_idx.cpu().numpy(), + ) + co = xatlas.ChartOptions() + po = xatlas.PackOptions() + for k, v in xatlas_chart_options.items(): + setattr(co, k, v) + for k, v in xatlas_pack_options.items(): + setattr(po, k, v) + atlas.generate(co, po) + vmapping, indices, uvs = atlas.get_mesh(0) + vmapping = ( + torch.from_numpy( + vmapping.astype(np.uint64, casting="same_kind").view(np.int64) + ) + .to(self.v_pos.device) + .long() + ) + uvs = torch.from_numpy(uvs).to(self.v_pos.device).float() + indices = ( + torch.from_numpy( + indices.astype(np.uint64, casting="same_kind").view(np.int64) + ) + .to(self.v_pos.device) + .long() + ) + return uvs, indices + + def _unwrap_uv_smartuv(self, options: dict = {}): + from extern.mesh_process.MeshProcess import ( + mesh_to_bpy, + get_uv_from_bpy, + bpy_context, + bpy_export, + ) + from lrm.utils.misc import time_recorder as tr + + v_pos, t_pos_idx = self.v_pos.cpu().numpy(), self.t_pos_idx.cpu().numpy() + with bpy_context(): + mesh_bpy = mesh_to_bpy("_", v_pos, t_pos_idx) + v_tex = get_uv_from_bpy(mesh_bpy, **options).astype(np.float32) + + assert v_tex.shape[0] == self.t_pos_idx.shape[0] * 3 + + t_tex_idx = torch.arange( + self.t_pos_idx.shape[0] * 3, device=self.t_pos_idx.device, dtype=torch.long + ).reshape(-1, 3) + + """ + # super efficient de-duplication + v_tex_u_uint32 = v_tex[..., 0].view(np.uint32) + v_tex_v_uint32 = v_tex[..., 1].view(np.uint32) + v_hashed = (v_tex_u_uint32.astype(np.uint64) << 32) | v_tex_v_uint32 + v_hashed = torch.from_numpy(v_hashed.view(np.int64)).to(self.v_pos.device) + + v_tex = torch.from_numpy(v_tex).to( + device=self.v_pos.device, dtype=torch.float32 + ) + t_pos_idx_f3 = torch.arange( + self.t_pos_idx.shape[0] * 3, device=self.t_pos_idx.device, dtype=torch.long + ).reshape(-1, 3) + v_pos_f3 = self.v_pos[self.t_pos_idx].reshape(-1, 3) + + # super efficient de-duplication + v_hashed_dedup, inverse_indices = torch.unique(v_hashed, return_inverse=True) + dedup_size, full_size = v_hashed_dedup.shape[0], inverse_indices.shape[0] + indices = torch.scatter_reduce( + torch.full( + [dedup_size], + fill_value=full_size, + device=inverse_indices.device, + dtype=torch.long, + ), + index=inverse_indices, + src=torch.arange( + full_size, device=inverse_indices.device, dtype=torch.int64 + ), + dim=0, + reduce="amin", + ) + v_tex = v_tex[indices] + t_tex_idx = inverse_indices.reshape(-1, 3) + + v_pos = v_pos_f3[indices] + t_pos_idx = inverse_indices[t_pos_idx_f3] + """ + + return self.v_pos, self.t_pos_idx, v_tex, t_tex_idx + + def unwrap_uv( + self, + method: str, + xatlas_chart_options: dict = {}, + xatlas_pack_options: dict = {}, + smartuv_options: dict = {}, + ): + if method == "xatlas": + with time_recorder_enabled(): + tr.start("UV unwrapping xatlas") + self._v_tex, self._t_tex_idx = self._unwrap_uv_xatlas( + xatlas_chart_options, xatlas_pack_options + ) + tr.end("UV unwrapping xatlas") + elif method == "open3d": + with time_recorder_enabled(): + tr.start("UV unwrapping o3d") + self._v_tex, self._t_tex_idx = self._unwrap_uv_open3d() + tr.end("UV unwrapping o3d") + elif method == "smartuv": + with time_recorder_enabled(): + tr.start("UV unwrapping smartuv") + ( + self.v_pos, + self.t_pos_idx, + self._v_tex, + self._t_tex_idx, + ) = self._unwrap_uv_smartuv(smartuv_options) + tr.end("UV unwrapping smartuv") + else: + raise NotImplementedError + + def set_vertex_color(self, v_rgb): + assert v_rgb.shape[0] == self.v_pos.shape[0] + self._v_rgb = v_rgb + + def set_uv(self, v_tex, t_tex_idx): + self._v_tex = v_tex + self._t_tex_idx = t_tex_idx + + @property + def has_uv(self): + return self._v_tex is not None and self._t_tex_idx is not None + + def _compute_edges(self): + # Compute edges + edges = torch.cat( + [ + self.t_pos_idx[:, [0, 1]], + self.t_pos_idx[:, [1, 2]], + self.t_pos_idx[:, [2, 0]], + ], + dim=0, + ) + edges = edges.sort()[0] + edges = torch.unique(edges, dim=0) + return edges + + def normal_consistency(self) -> Float[Tensor, ""]: + edge_nrm: Float[Tensor, "Ne 2 3"] = self.v_nrm[self.edges] + nc = ( + 1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1) + ).mean() + return nc + + def _laplacian_uniform(self): + # from stable-dreamfusion + # https://github.com/ashawkey/stable-dreamfusion/blob/8fb3613e9e4cd1ded1066b46e80ca801dfb9fd06/nerf/renderer.py#L224 + verts, faces = self.v_pos, self.t_pos_idx + + V = verts.shape[0] + F = faces.shape[0] + + # Neighbor indices + ii = faces[:, [1, 2, 0]].flatten() + jj = faces[:, [2, 0, 1]].flatten() + adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique( + dim=1 + ) + adj_values = torch.ones(adj.shape[1]).to(verts) + + # Diagonal indices + diag_idx = adj[0] + + # Build the sparse matrix + idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1) + values = torch.cat((-adj_values, adj_values)) + + # The coalesce operation sums the duplicate indices, resulting in the + # correct diagonal + return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce() + + def laplacian(self) -> Float[Tensor, ""]: + with torch.no_grad(): + L = self._laplacian_uniform() + loss = L.mm(self.v_pos) + loss = loss.norm(dim=1) + loss = loss.mean() + return loss diff --git a/3D_Stage/lrm/models/networks.py b/3D_Stage/lrm/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..44ffc4052bd089dc9190ac9b9b98b8ec58fba85d --- /dev/null +++ b/3D_Stage/lrm/models/networks.py @@ -0,0 +1,390 @@ +from dataclasses import dataclass, field +from copy import deepcopy + +import torch +import torch.nn as nn +from einops import rearrange + +from ..utils.base import BaseModule +from ..utils.ops import get_activation +from ..utils.typing import * + + +class TriplaneUpsampleNetwork(BaseModule): + @dataclass + class Config(BaseModule.Config): + in_channels: int = 1024 + out_channels: int = 80 + + cfg: Config + + def configure(self) -> None: + super().configure() + self.upsample = nn.ConvTranspose2d( + self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2 + ) + + def forward( + self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"] + ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]: + triplanes_up = rearrange( + self.upsample( + rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3) + ), + "(B Np) Co Hp Wp -> B Np Co Hp Wp", + Np=3, + ) + return triplanes_up + + +class MLP(nn.Module): + def __init__( + self, + dim_in: int, + dim_out: int, + n_neurons: int, + n_hidden_layers: int, + activation: str = "relu", + output_activation: Optional[str] = None, + bias: bool = True, + weight_init: Optional[str] = "kaiming_uniform", + bias_init: Optional[str] = None, + ): + super().__init__() + layers = [ + self.make_linear( + dim_in, + n_neurons, + is_first=True, + is_last=False, + bias=bias, + weight_init=weight_init, + bias_init=bias_init, + ), + self.make_activation(activation), + ] + for i in range(n_hidden_layers - 1): + layers += [ + self.make_linear( + n_neurons, + n_neurons, + is_first=False, + is_last=False, + bias=bias, + weight_init=weight_init, + bias_init=bias_init, + ), + self.make_activation(activation), + ] + layers += [ + self.make_linear( + n_neurons, + dim_out, + is_first=False, + is_last=True, + bias=bias, + weight_init=weight_init, + bias_init=bias_init, + ) + ] + self.layers = nn.Sequential(*layers) + self.output_activation = get_activation(output_activation) + + def forward(self, x): + x = self.layers(x) + x = self.output_activation(x) + return x + + def make_linear( + self, + dim_in, + dim_out, + is_first, + is_last, + bias=True, + weight_init=None, + bias_init=None, + ): + layer = nn.Linear(dim_in, dim_out, bias=bias) + + if weight_init is None: + pass + elif weight_init == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu") + else: + raise NotImplementedError + + if bias: + if bias_init is None: + pass + elif bias_init == "zero": + torch.nn.init.zeros_(layer.bias) + else: + raise NotImplementedError + + return layer + + def make_activation(self, activation): + if activation == "relu": + return nn.ReLU(inplace=True) + elif activation == "silu": + return nn.SiLU(inplace=True) + else: + raise NotImplementedError + + +@dataclass +class HeadSpec: + name: str + out_channels: int + n_hidden_layers: int + output_activation: Optional[str] = None + + +class MultiHeadMLP(BaseModule): + @dataclass + class Config(BaseModule.Config): + in_channels: int = 0 + n_neurons: int = 0 + n_hidden_layers_share: int = 0 + heads: List[HeadSpec] = field(default_factory=lambda: []) + activation: str = "relu" + bias: bool = True + weight_init: Optional[str] = "kaiming_uniform" + bias_init: Optional[str] = None + chunk_mode: Optional[str] = None + chunk_size: int = -1 + + cfg: Config + + def configure(self) -> None: + super().configure() + shared_layers = [ + self.make_linear( + self.cfg.in_channels, + self.cfg.n_neurons, + bias=self.cfg.bias, + weight_init=self.cfg.weight_init, + bias_init=self.cfg.bias_init, + ), + self.make_activation(self.cfg.activation), + ] + for i in range(self.cfg.n_hidden_layers_share - 1): + shared_layers += [ + self.make_linear( + self.cfg.n_neurons, + self.cfg.n_neurons, + bias=self.cfg.bias, + weight_init=self.cfg.weight_init, + bias_init=self.cfg.bias_init, + ), + self.make_activation(self.cfg.activation), + ] + self.shared_layers = nn.Sequential(*shared_layers) + + assert len(self.cfg.heads) > 0 + heads = {} + for head in self.cfg.heads: + head_layers = [] + for i in range(head.n_hidden_layers): + head_layers += [ + self.make_linear( + self.cfg.n_neurons, + self.cfg.n_neurons, + bias=self.cfg.bias, + weight_init=self.cfg.weight_init, + bias_init=self.cfg.bias_init, + ), + self.make_activation(self.cfg.activation), + ] + head_layers += [ + self.make_linear( + self.cfg.n_neurons, + head.out_channels, + bias=self.cfg.bias, + weight_init=self.cfg.weight_init, + bias_init=self.cfg.bias_init, + ), + ] + heads[head.name] = nn.Sequential(*head_layers) + self.heads = nn.ModuleDict(heads) + + if self.cfg.chunk_mode is not None: + assert self.cfg.chunk_size > 0 + + def make_linear( + self, + dim_in, + dim_out, + bias=True, + weight_init=None, + bias_init=None, + ): + layer = nn.Linear(dim_in, dim_out, bias=bias) + + if weight_init is None: + pass + elif weight_init == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu") + else: + raise NotImplementedError + + if bias: + if bias_init is None: + pass + elif bias_init == "zero": + torch.nn.init.zeros_(layer.bias) + else: + raise NotImplementedError + + return layer + + def make_activation(self, activation): + if activation == "relu": + return nn.ReLU(inplace=True) + elif activation == "silu": + return nn.SiLU(inplace=True) + else: + raise NotImplementedError + + def forward( + self, x, include: Optional[List] = None, exclude: Optional[List] = None + ): + inp_shape = x.shape[:-1] + x = x.reshape(-1, x.shape[-1]) + + if self.cfg.chunk_mode is None: + shared_features = self.shared_layers(x) + elif self.cfg.chunk_mode == "deferred": + shared_features = DeferredFunc.apply( + self.shared_layers, x, self.cfg.chunk_size + ) + elif self.cfg.chunk_mode == "checkpointing": + shared_features = apply_batch_checkpointing( + self.shared_layers, x, self.cfg.chunk_size + ) + else: + raise NotImplementedError + + shared_features = shared_features.reshape(*inp_shape, -1) + + if include is not None and exclude is not None: + raise ValueError("Cannot specify both include and exclude.") + if include is not None: + heads = [h for h in self.cfg.heads if h.name in include] + elif exclude is not None: + heads = [h for h in self.cfg.heads if h.name not in exclude] + else: + heads = self.cfg.heads + + out = { + head.name: get_activation(head.output_activation)( + self.heads[head.name](shared_features) + ) + for head in heads + } + """ + # TypeError + if self.cfg.chunk_mode is None: + out = { + head.name: get_activation(head.output_activation)( + self.heads[head.name](shared_features) + ) + for head in heads + } + elif self.cfg.chunk_mode == "deferred": + out = { + head.name: get_activation(head.output_activation)( + DeferredFunc.apply(self.heads[head.name], shared_features, self.cfg.chunk_size) + ) + for head in heads + } + else: + raise NotImplementedError + """ + return out + + +class DeferredFunc(torch.autograd.Function): + # Note that forward, setup_context, and backward are @staticmethods + @staticmethod + def forward(ctx, model, x, chunk_size): + model_copy = deepcopy(model) + model_copy.requires_grad_(False) + + ret = [] + x_split = torch.split(x, chunk_size, dim=0) + + with torch.no_grad(): + for cur_x in x_split: + ret.append(model_copy(cur_x)) + + ctx.model = model + ctx.save_for_backward(x.detach(), torch.as_tensor(chunk_size)) + + ret = torch.cat(ret, dim=0) + + return ret + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, grad_output): + model = ctx.model + x, chunk_size = ctx.saved_tensors + chunk_size = chunk_size.item() + + model_copy = deepcopy(model) + + x_split = torch.split(x, chunk_size, dim=0) + grad_output_split = torch.split(grad_output, chunk_size, 0) + grad_input_split = [] + + with torch.set_grad_enabled(True): + model_copy.requires_grad_(True) + model_copy.zero_grad() + for cur_x, cur_grad_output in zip(x_split, grad_output_split): + cur_x.requires_grad_(True) + cur_y = model_copy(cur_x) + cur_y.backward(cur_grad_output) + + grad_input_split.append(cur_x.grad.clone()) + + grad_input = torch.cat(grad_input_split, dim=0) + + model_copy_params = list(model_copy.parameters()) + model_params = list(model.parameters()) + + for param, param_copy in zip(model_params, model_copy_params): + if param.grad is None: + param.grad = param_copy.grad.clone() + else: + param.grad.add_(param_copy.grad) + + return None, grad_input, None + + +def apply_batch_checkpointing(func, x, chunk_size): + if chunk_size >= len(x): + # return func(x) + return torch.utils.checkpoint.checkpoint(func, x, use_reentrant=False) + + x_split = torch.split(x, chunk_size, dim=0) + + def cat_and_query(y_all, x): + return torch.cat([y_all, func(x)]) + + y_all = func(x_split[0]) + for cur_x in x_split[1:]: + y_all = torch.utils.checkpoint.checkpoint( + cat_and_query, y_all, cur_x, use_reentrant=False + ) + + return y_all + + +def get_encoding(n_input_dims: int, config) -> nn.Module: + raise NotImplementedError + + +def get_mlp(n_input_dims, n_output_dims, config) -> nn.Module: + raise NotImplementedError diff --git a/3D_Stage/lrm/models/renderers/__init__.py b/3D_Stage/lrm/models/renderers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3D_Stage/lrm/models/renderers/base.py b/3D_Stage/lrm/models/renderers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0a733d00b8a8726cf23006a37f81df8781b7728e --- /dev/null +++ b/3D_Stage/lrm/models/renderers/base.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +import lrm +from ..networks import MultiHeadMLP +from ..background.base import BaseBackground +from ..materials.base import BaseMaterial +from ...utils.base import BaseModule +from ...utils.typing import * + + +class BaseRenderer(BaseModule): + @dataclass + class Config(BaseModule.Config): + radius: float = 1.0 + + cfg: Config + + def configure( + self, + decoder: MultiHeadMLP, + material: BaseMaterial, + background: BaseBackground, + ) -> None: + super().configure() + + self.set_decoder(decoder) + self.set_material(material) + self.set_background(background) + + # set up bounding box + self.bbox: Float[Tensor, "2 3"] + self.register_buffer( + "bbox", + torch.as_tensor( + [ + [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], + [self.cfg.radius, self.cfg.radius, self.cfg.radius], + ], + dtype=torch.float32, + ), + ) + + def forward(self, *args, **kwargs) -> Dict[str, Any]: + raise NotImplementedError + + @property + def decoder(self) -> MultiHeadMLP: + return self.non_module("decoder") + + @property + def material(self) -> BaseMaterial: + return self.non_module("material") + + @property + def background(self) -> BaseBackground: + return self.non_module("background") + + def set_decoder(self, decoder: MultiHeadMLP) -> None: + self.register_non_module("decoder", decoder) + + def set_material(self, material: BaseMaterial) -> None: + self.register_non_module("material", material) + + def set_background(self, background: BaseBackground) -> None: + self.register_non_module("background", background) diff --git a/3D_Stage/lrm/models/renderers/triplane_dmtet.py b/3D_Stage/lrm/models/renderers/triplane_dmtet.py new file mode 100644 index 0000000000000000000000000000000000000000..6b898e3d584ef18307fb9ca2df0b015bc01f0430 --- /dev/null +++ b/3D_Stage/lrm/models/renderers/triplane_dmtet.py @@ -0,0 +1,369 @@ +import os +from dataclasses import dataclass, field +from collections import defaultdict +from functools import partial + +import torch +import torch.nn.functional as F +from einops import rearrange, reduce + +import lrm +from ..renderers.base import BaseRenderer +from ..isosurface import MarchingTetrahedraHelper +from ...utils.ops import ( + get_activation, + scale_tensor, + dot, + chunk_batch +) +from ...utils.rasterize import NVDiffRasterizerContext +from ..mesh import Mesh +from ...utils.typing import * + + +class TriplaneDMTetRenderer(BaseRenderer): + @dataclass + class Config(BaseRenderer.Config): + feature_reduction: str = "concat" + sdf_activation: Optional[str] = None + sdf_bias: Union[str, float] = 0.0 + sdf_bias_params: Any = None + inside_out: bool = False + + isosurface_resolution: int = 128 + tet_dir: str = "load/tets/" + enable_isosurface_grid_deformation: bool = False + eval_chunk_size: int = 262144 + context_type: str = "gl" + + cfg: Config + + def configure(self, *args, **kwargs) -> None: + super().configure(*args, **kwargs) + + assert self.cfg.feature_reduction in ["concat", "mean"] + + self.ctx = NVDiffRasterizerContext(self.cfg.context_type, self.device) + self.isosurface_helper = MarchingTetrahedraHelper( + self.cfg.isosurface_resolution, + os.path.join(self.cfg.tet_dir, f"{self.cfg.isosurface_resolution}_tets.npz"), + ).to(self.device) + + def query_triplane( + self, + positions: Float[Tensor, "*B N 3"], + triplanes: Float[Tensor, "*B 3 Cp Hp Wp"], + ) -> Dict[str, Tensor]: + batched = positions.ndim == 3 + if not batched: + # no batch dimension + triplanes = triplanes[None, ...] + positions = positions[None, ...] + # import ipdb + # ipdb.set_trace() + assert triplanes.ndim == 5 and positions.ndim == 3 + + # assume positions in [-1, 1] + # normalized to (-1, 1) for grid sample + positions = scale_tensor( + positions, (-self.cfg.radius, self.cfg.radius), (-1, 1) + ) + + indices2D: Float[Tensor, "B 3 N 2"] = torch.stack( + (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]), + dim=-3, + ) + out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample( + rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3), + rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3), + align_corners=False, + mode="bilinear", + ) + if self.cfg.feature_reduction == "concat": + out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3) + elif self.cfg.feature_reduction == "mean": + out = reduce(out, "(B Np) Cp () N -> B N Cp", Np=3, reduction="mean") + else: + raise NotImplementedError + + net_out: Dict[str, Float[Tensor, "B N ..."]] = self.decoder(out) + assert "sdf" in net_out + net_out["sdf"] = get_activation(self.cfg.sdf_activation)( + self.get_shifted_sdf(positions, net_out["sdf"]) + ) + + if not batched: + net_out = {k: v.squeeze(0) for k, v in net_out.items()} + + return net_out + + def get_shifted_sdf( + self, points: Float[Tensor, "*N Di"], sdf: Float[Tensor, "*N 1"] + ) -> Float[Tensor, "*N 1"]: + sdf_bias: Union[float, Float[Tensor, "*N 1"]] + if self.cfg.sdf_bias == "ellipsoid": + assert ( + isinstance(self.cfg.sdf_bias_params, Sized) + and len(self.cfg.sdf_bias_params) == 3 + ) + size = torch.as_tensor(self.cfg.sdf_bias_params).to(points) + sdf_bias = ((points / size) ** 2).sum( + dim=-1, keepdim=True + ).sqrt() - 1.0 # pseudo signed distance of an ellipsoid + elif self.cfg.sdf_bias == "sphere": + assert isinstance(self.cfg.sdf_bias_params, float) + radius = self.cfg.sdf_bias_params + sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius + elif isinstance(self.cfg.sdf_bias, float): + sdf_bias = self.cfg.sdf_bias + else: + raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}") + return sdf + sdf_bias + + def forward_single( + self, + triplane: Float[Tensor, "3 Cp Hp Wp"], + mvp_mtx: Float[Tensor, "Nv 4 4"], + camera_positions: Float[Tensor, "Nv 3"], + height: int, + width: int, + background_color: Optional[Float[Tensor, "3"]], + extra_sdf_query: Any = None, + ) -> Dict[str, Tensor]: + Nv = mvp_mtx.shape[0] + + out = {} + + query_vertices = [] + query_sizes = [] + + grid_vertices = scale_tensor( + self.isosurface_helper.grid_vertices, + self.isosurface_helper.points_range, + self.bbox, + ) + + query_vertices.append(grid_vertices) + query_sizes.append(len(grid_vertices)) + + if extra_sdf_query is not None: + query_vertices.append(extra_sdf_query) + query_sizes.append(len(extra_sdf_query)) + + query_vertices = torch.cat(query_vertices, dim=0) + triplane_out = self.query_triplane(query_vertices, triplane) + + all_sdf = triplane_out["sdf"] + if extra_sdf_query is not None: + sdf, sdf_ex_query = torch.split(all_sdf, query_sizes) + else: + sdf, sdf_ex_query = all_sdf, None + + out.update({"sdf_ex_query": sdf_ex_query}) + + if self.cfg.enable_isosurface_grid_deformation: + all_deformation = triplane_out["deformation"] + if extra_sdf_query is not None: + deformation, _ = torch.split(all_deformation, query_sizes) + else: + deformation, _ = all_deformation, None + else: + deformation = None + + # Fix some sdf if we observe empty shape (full positive or full negative) + pos_shape = torch.sum((sdf.squeeze(dim=-1) > 0).int(), dim=-1) + neg_shape = torch.sum((sdf.squeeze(dim=-1) < 0).int(), dim=-1) + zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0) + if torch.sum(zero_surface).item() > 0: + lrm.warn("Empty mesh! Fixing by adding fake faces.") + sdf = torch.nan_to_num(sdf, nan=0.0, posinf=1.0, neginf=-1.0) + update_sdf = torch.zeros_like(sdf) + max_sdf = sdf.max() + min_sdf = sdf.min() + update_sdf[self.isosurface_helper.center_indices] += ( + -1.0 - max_sdf + ) # greater than zero + update_sdf[self.isosurface_helper.boundary_indices] += ( + 1.0 - min_sdf + ) # smaller than zero + new_sdf = sdf.clone().detach() + if zero_surface: + new_sdf += update_sdf + update_mask = (update_sdf == 0).float() + # Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative) + sdf_reg_loss = sdf.abs().mean() + sdf_reg_loss = sdf_reg_loss * zero_surface.float() + sdf = sdf * update_mask + new_sdf * (1 - update_mask) + lrm.debug( + "max sdf: {}, min sdf: {}".format(sdf.max().item(), sdf.min().item()) + ) + out.update({"sdf_reg": sdf_reg_loss}) + + # Here we remove the gradient for the bad sdf (full positive or full negative) + if zero_surface: + sdf = sdf.detach() + + mesh: Mesh = self.isosurface_helper(sdf, deformation=deformation) + + mesh.v_pos = scale_tensor( + mesh.v_pos, self.isosurface_helper.points_range, self.bbox + ) # scale to bbox as the grid vertices are in [0, 1] + # import ipdb + # ipdb.set_trace() + v_pos_clip: Float[Tensor, "Nv V 4"] = self.ctx.vertex_transform( + mesh.v_pos, mvp_mtx + ) + rast, _ = self.ctx.rasterize(v_pos_clip, mesh.t_pos_idx, (height, width)) + mask = rast[..., 3:] > 0 + mask_aa = self.ctx.antialias(mask.float(), rast, v_pos_clip, mesh.t_pos_idx) + + out.update({"opacity": mask_aa, "mesh": mesh}) + + gb_normal, _ = self.ctx.interpolate_one(mesh.v_nrm, rast, mesh.t_pos_idx) + gb_normal = F.normalize(gb_normal, dim=-1) + gb_normal_aa = torch.lerp( + torch.zeros_like(gb_normal), (gb_normal + 1.0) / 2.0, mask.float() + ) + gb_normal_aa = self.ctx.antialias( + gb_normal_aa, rast, v_pos_clip, mesh.t_pos_idx + ) + out.update({"comp_normal": gb_normal_aa}) # in [0, 1] + + gb_pos, _ = self.ctx.interpolate_one(mesh.v_pos, rast, mesh.t_pos_idx) + + # FIXME: this depth corresponds to the one provided in the dataset, but assumes camera looking at scene center + gb_depth = dot( + gb_pos - camera_positions[:, None, None, :], + F.normalize(-camera_positions[:, None, None, :], dim=-1), + ) + + gb_depth = torch.lerp(torch.zeros_like(gb_depth), gb_depth, mask.float()) + out.update({"depth": gb_depth}) + + gb_viewdirs = F.normalize(gb_pos - camera_positions[:, None, None, :], dim=-1) + gb_rgb_fg = torch.zeros( + (Nv, height, width, 3), device=self.device, dtype=torch.float32 + ) + gb_rgb_bg = self.background(dirs=gb_viewdirs, color_spec=background_color) + + selector = mask[..., 0] + if selector.sum() > 0: + positions = gb_pos[selector] + geo_out = self.query_triplane(positions, triplane) + + extra_geo_info = {} + if self.material.requires_normal: + extra_geo_info["shading_normal"] = gb_normal[selector] + if self.material.requires_tangent: + gb_tangent, _ = self.ctx.interpolate_one( + mesh.v_tng, rast, mesh.t_pos_idx + ) + gb_tangent = F.normalize(gb_tangent, dim=-1) + extra_geo_info["tangent"] = gb_tangent[selector] + + rgb_fg = self.material( + viewdirs=gb_viewdirs[selector], + positions=positions, + **extra_geo_info, + **geo_out, + ) + + gb_rgb_fg[selector] = rgb_fg.to( + gb_rgb_fg.dtype + ) # TODO: don't know if this is correct + + gb_rgb = torch.lerp(gb_rgb_bg, gb_rgb_fg, mask.float()) + gb_rgb_aa = self.ctx.antialias(gb_rgb, rast, v_pos_clip, mesh.t_pos_idx) + + out.update( + {"comp_rgb": gb_rgb_aa, "comp_rgb_fg": gb_rgb_fg, "comp_rgb_bg": gb_rgb_bg} + ) + + return out + + def forward( + self, + triplanes: Float[Tensor, "B 3 Cp Hp Wp"], + mvp_mtx: Float[Tensor, "B Nv 4 4"], + camera_positions: Float[Tensor, "B Nv 3"], + height: int, + width: int, + background_color: Optional[Float[Tensor, "B 3"]] = None, + extra_sdf_query: Optional[List[Float[Tensor, "N 3"]]] = None, + **kwargs, + ) -> Dict[str, Tensor]: + batch_size = triplanes.shape[0] + out_list = [] + for b in range(batch_size): + out_list.append( + self.forward_single( + triplanes[b], + mvp_mtx[b], + camera_positions[b], + height, + width, + background_color=background_color[b] + if background_color is not None + else None, + extra_sdf_query=extra_sdf_query[b] + if extra_sdf_query is not None + else None, + ) + ) + + out = defaultdict(list) + for out_ in out_list: + for k, v in out_.items(): + out[k].append(v) + + for k, v in out.items(): + # some properties cannot be batched + if isinstance(v[0], torch.Tensor) and ( + all([vv.ndim == 0 for vv in v]) + or all([vv.shape[0] == v[0].shape[0] for vv in v]) + ): + out[k] = torch.stack(v, dim=0) + else: + out[k] = v + + return out + + def isosurface(self, triplane: Float[Tensor, "3 Cp Hp Wp"]): + grid_vertices = scale_tensor( + self.isosurface_helper.grid_vertices, + self.isosurface_helper.points_range, + self.bbox, + ) + triplane_out = chunk_batch( + partial(self.query_triplane, triplanes=triplane), self.cfg.eval_chunk_size, grid_vertices, + ) + + sdf = triplane_out["sdf"] + + if self.cfg.inside_out: + sdf = -sdf + + if self.cfg.enable_isosurface_grid_deformation: + deformation = triplane_out["deformation"] + else: + deformation = None + + mesh: Mesh = self.isosurface_helper(sdf, deformation=deformation) + + mesh.v_pos = scale_tensor( + mesh.v_pos, self.isosurface_helper.points_range, self.bbox + ) + + return mesh + + def query( + self, triplane: Float[Tensor, "3 Cp Hp Wp"], points: Float[Tensor, "*N 3"] + ): + input_shape = points.shape[:-1] + triplane_out = chunk_batch( + partial(self.query_triplane, triplanes=triplane), self.cfg.eval_chunk_size, points.view(-1, 3) + ) + triplane_out = { + k: v.view(*input_shape, v.shape[-1]) for k, v in triplane_out.items() + } + return triplane_out diff --git a/3D_Stage/lrm/models/tokenizers/__init__.py b/3D_Stage/lrm/models/tokenizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3D_Stage/lrm/models/tokenizers/dinov2.py b/3D_Stage/lrm/models/tokenizers/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..d4d7eae6382313cb06ed423c09f0537c5bd446fc --- /dev/null +++ b/3D_Stage/lrm/models/tokenizers/dinov2.py @@ -0,0 +1,1223 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DINOv2 model.""" + + +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union +from dataclasses import dataclass + +import torch +import torch.utils.checkpoint +from torch import nn +import torch.nn.functional as F +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BackboneOutput, + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.backbone_utils import BackboneMixin +from transformers.models.dinov2.configuration_dinov2 import Dinov2Config +import xformers + +from ..transformers.attention import MemoryEfficientAttentionMixin +from ...utils.typing import * + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Dinov2Config" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/dinov2-base" +_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base" + + +DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/dinov2-base", + # See all DINOv2 models at https://huggingface.co/models?filter=dinov2 +] + + +class Dinov2Embeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + # register as mask token as it's not used in optimization + # to avoid the use of find_unused_parameters_true + # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.register_buffer("mask_token", torch.zeros(1, config.hidden_size)) + self.patch_embeddings = Dinov2PatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter( + torch.randn(1, num_patches + 1, config.hidden_size) + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.config.patch_size + width = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + height / math.sqrt(num_positions), + width / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + if ( + int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1] + ): + raise ValueError( + "Width or height does not match with the interpolated position embeddings" + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + patch_embeddings = self.patch_embeddings(pixel_values) + embeddings = patch_embeddings + + if bool_masked_pos is not None: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), + self.mask_token.to(embeddings.dtype).unsqueeze(0), + embeddings, + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width + ) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Dinov2PatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = ( + image_size + if isinstance(image_size, collections.abc.Iterable) + else (image_size, image_size) + ) + patch_size = ( + patch_size + if isinstance(patch_size, collections.abc.Iterable) + else (patch_size, patch_size) + ) + num_patches = (image_size[1] // patch_size[1]) * ( + image_size[0] // patch_size[0] + ) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + """ + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2 +class Dinov2SelfAttention(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + self.query = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + self.key = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + self.value = nn.Linear( + config.hidden_size, self.all_head_size, bias=config.qkv_bias + ) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.use_memory_efficient_attention_xformers: bool = False + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + if self.use_memory_efficient_attention_xformers: + assert head_mask is None and not output_attentions + new_size = hidden_states.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + key_layer = self.key(hidden_states).view(new_size) + value_layer = self.value(hidden_states).view(new_size) + query_layer = mixed_query_layer.view(new_size) + context_layer = xformers.ops.memory_efficient_attention( + query_layer, key_layer, value_layer, p=self.attention_probs_dropout_prob + ) + context_layer = context_layer.view(*hidden_states.size()[:-1], -1) + elif hasattr(F, "scaled_dot_product_attention"): + assert head_mask is None and not output_attentions + new_size = hidden_states.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2) + value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2) + query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2) + context_layer = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + dropout_p=self.attention_probs_dropout_prob, + is_causal=False, + ) + context_layer = context_layer.transpose(1, 2).reshape( + *hidden_states.size()[:-1], -1 + ) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + return outputs + + def set_use_memory_efficient_attention_xformers( + self, valid: bool, attention_op: Optional[Callable] = None + ): + self.use_memory_efficient_attention_xformers = valid + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2 +class Dinov2SelfOutput(nn.Module): + """ + The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2 +class Dinov2Attention(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.attention = Dinov2SelfAttention(config) + self.output = Dinov2SelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.attention.num_attention_heads, + self.attention.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len( + heads + ) + self.attention.all_head_size = ( + self.attention.attention_head_size * self.attention.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class Dinov2LayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter( + config.layerscale_value * torch.ones(config.hidden_size) + ) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path( + input: torch.Tensor, drop_prob: float = 0.0, training: bool = False +) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * ( + input.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=input.dtype, device=input.device + ) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath +class Dinov2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Dinov2MLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class Dinov2SwiGLUFFN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +class Dinov2Layer(nn.Module, MemoryEfficientAttentionMixin): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm1_modulation = None + self.attention = Dinov2Attention(config) + self.layer_scale1 = Dinov2LayerScale(config) + self.drop_path1 = ( + Dinov2DropPath(config.drop_path_rate) + if config.drop_path_rate > 0.0 + else nn.Identity() + ) + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm2_modulation = None + + if config.use_swiglu_ffn: + self.mlp = Dinov2SwiGLUFFN(config) + else: + self.mlp = Dinov2MLP(config) + self.layer_scale2 = Dinov2LayerScale(config) + self.drop_path2 = ( + Dinov2DropPath(config.drop_path_rate) + if config.drop_path_rate > 0.0 + else nn.Identity() + ) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + modulation_cond: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + hidden_states_norm = self.norm1(hidden_states) + if self.norm1_modulation is not None: + assert modulation_cond is not None + hidden_states_norm = self.norm1_modulation( + hidden_states_norm, modulation_cond + ) + self_attention_outputs = self.attention( + hidden_states_norm, # in Dinov2, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + outputs = self_attention_outputs[ + 1: + ] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in Dinov2, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + if self.norm2_modulation is not None: + assert modulation_cond is not None + layer_output = self.norm2_modulation(layer_output, modulation_cond) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = layer_output + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module): + self.norm1_modulation = norm1_mod + self.norm2_modulation = norm2_mod + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2 +class Dinov2Encoder(nn.Module, MemoryEfficientAttentionMixin): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [Dinov2Layer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + modulation_cond: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + modulation_cond, + use_reentrant=False, + ) + else: + layer_outputs = layer_module( + hidden_states, layer_head_mask, modulation_cond, output_attentions + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Dinov2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Dinov2Config + base_model_prefix = "dinov2" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Dinov2Embeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + def _set_gradient_checkpointing( + self, module: Dinov2Encoder, value: bool = False + ) -> None: + if isinstance(module, Dinov2Encoder): + module.gradient_checkpointing = value + + +DINOV2_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Dinov2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DINOV2_BASE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for + pre-training. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +DINOV2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@dataclass +class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling): + patch_embeddings: Optional[torch.FloatTensor] = None + + +@add_start_docstrings( + "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.", + DINOV2_START_DOCSTRING, +) +class Dinov2Model(Dinov2PreTrainedModel, MemoryEfficientAttentionMixin): + def __init__(self, config: Dinov2Config): + super().__init__(config) + self.config = config + + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + def expand_input_channels(self, extra_input_channels: int) -> None: + if extra_input_channels == 0: + return + conv_old = self.embeddings.patch_embeddings.projection + conv_new = nn.Conv2d( + self.config.num_channels + extra_input_channels, + self.config.hidden_size, + kernel_size=self.config.patch_size, + stride=self.config.patch_size, + ).to(self.device) + with torch.no_grad(): + conv_new.weight[:, :3] = conv_old.weight + conv_new.bias = conv_old.bias + self.embeddings.patch_embeddings.projection = conv_new + del conv_old + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + modulation_cond: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + modulation_cond=modulation_cond, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return CustomBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + patch_embeddings=embedding_output, + ) + + def set_gradient_checkpointing(self, value: bool = False) -> None: + self._set_gradient_checkpointing(self.encoder, value) + + +@add_start_docstrings( + """ + Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """, + DINOV2_START_DOCSTRING, +) +class Dinov2ForImageClassification(Dinov2PreTrainedModel): + def __init__(self, config: Dinov2Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.dinov2 = Dinov2Model(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_size * 2, config.num_labels) + if config.num_labels > 0 + else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.dinov2( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] # batch_size, sequence_length, hidden_size + + cls_token = sequence_output[:, 0] + patch_tokens = sequence_output[:, 1:] + + linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) + + logits = self.classifier(linear_input) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2 backbone, to be used with frameworks like DETR and MaskFormer. + """, + DINOV2_START_DOCSTRING, +) +class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [ + config.hidden_size for _ in range(config.num_hidden_layers + 1) + ] + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state[:, 1:, :].reshape( + batch_size, width // patch_size, height // patch_size, -1 + ) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions if output_attentions else None, + ) + + +class CustomPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__( + self, image_size: int, patch_size: int, num_channels: int, hidden_size: int + ): + super().__init__() + + image_size = ( + image_size + if isinstance(image_size, collections.abc.Iterable) + else (image_size, image_size) + ) + patch_size = ( + patch_size + if isinstance(patch_size, collections.abc.Iterable) + else (patch_size, patch_size) + ) + num_patches = (image_size[1] // patch_size[1]) * ( + image_size[0] // patch_size[0] + ) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +class CustomEmbeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__( + self, image_size: int, patch_size: int, num_channels: int, hidden_size: int + ) -> None: + super().__init__() + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.hidden_size = hidden_size + + self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size)) + + self.patch_embeddings = CustomPatchEmbeddings( + image_size, patch_size, num_channels, hidden_size + ) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter( + torch.randn(1, num_patches + 1, self.hidden_size) + ) + + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + height = height // self.patch_size + width = width // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + patch_pos_embed = patch_pos_embed.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + height / math.sqrt(num_positions), + width / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + if ( + int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1] + ): + raise ValueError( + "Width or height does not match with the interpolated position embeddings" + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + patch_embeddings = self.patch_embeddings(pixel_values) + embeddings = patch_embeddings + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width + ) + + return embeddings diff --git a/3D_Stage/lrm/models/tokenizers/image.py b/3D_Stage/lrm/models/tokenizers/image.py new file mode 100644 index 0000000000000000000000000000000000000000..1663470778c514654eaa63e104bc511e4783ad47 --- /dev/null +++ b/3D_Stage/lrm/models/tokenizers/image.py @@ -0,0 +1,268 @@ +from dataclasses import dataclass +import random + +import torch +import torch.nn as nn +from einops import rearrange + +from ...utils.base import BaseModule +from .dinov2 import Dinov2Model, CustomEmbeddings +from ..transformers.attention import Modulation +from ...utils.typing import * + + +class NaiveImageTokenizer(BaseModule): + @dataclass + class Config(BaseModule.Config): + num_tokens: int = 1024 + num_channels: int = 768 + + cfg: Config + + def configure(self) -> None: + super().configure() + + def forward(self, images: Float[Tensor, "B N C H W"]) -> Float[Tensor, "B Ct Nt"]: + return torch.rand( + ( + images.shape[0], + self.cfg.num_channels, + self.cfg.num_tokens, + ), + device=images.device, + dtype=images.dtype, + ) + + def detokenize(self, *args, **kwargs): + raise NotImplementedError + + +class DINOV2SingleImageTokenizer(BaseModule): + @dataclass + class Config(BaseModule.Config): + pretrained_model_name_or_path: str = "facebook/dinov2-base" + width: int = 224 + height: int = 224 + modulation: bool = False + modulation_zero_init: bool = False + modulation_single_layer: bool = False + modulation_cond_dim: int = 16 + freeze_backbone_params: bool = True + enable_memory_efficient_attention: bool = False + enable_gradient_checkpointing: bool = False + use_patch_embeddings: bool = False + patch_embeddings_aggr_method: str = "concat" + append_plucker_rays: bool = False + drop_rate: float = 0.0 + drop_type: str = "all_but_first" + + cfg: Config + + def configure(self) -> None: + super().configure() + model: Dinov2Model + + if self.cfg.freeze_backbone_params: + # freeze dino backbone parameters + self.register_non_module( + "model", + Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path).to( + self.device + ), + ) + + model = self.non_module("model") + for p in model.parameters(): + p.requires_grad_(False) + model.eval() + else: + self.model = Dinov2Model.from_pretrained( + self.cfg.pretrained_model_name_or_path + ).to(self.device) + model = self.model + + if self.cfg.append_plucker_rays: + model.expand_input_channels(6) + + model.set_use_memory_efficient_attention_xformers( + self.cfg.enable_memory_efficient_attention + ) + model.set_gradient_checkpointing(self.cfg.enable_gradient_checkpointing) + + # add modulation + if self.cfg.modulation: + modulations = [] + for layer in model.encoder.layer: + norm1_modulation = Modulation( + model.config.hidden_size, + self.cfg.modulation_cond_dim, + zero_init=self.cfg.modulation_zero_init, + single_layer=self.cfg.modulation_single_layer, + ) + norm2_modulation = Modulation( + model.config.hidden_size, + self.cfg.modulation_cond_dim, + zero_init=self.cfg.modulation_zero_init, + single_layer=self.cfg.modulation_single_layer, + ) + layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation) + modulations += [norm1_modulation, norm2_modulation] + self.modulations = nn.ModuleList(modulations) + + self.register_buffer( + "image_mean", + torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1), + persistent=False, + ) + self.register_buffer( + "image_std", + torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1), + persistent=False, + ) + + def forward( + self, + images: Float[Tensor, "B *N C H W"], + modulation_cond: Optional[Float[Tensor, "B *N Cc"]], + plucker_rays: Optional[Float[Tensor, "B *N 6 H W"]], + **kwargs + ) -> Float[Tensor, "B *N Ct Nt"]: + model: Dinov2Model + if self.cfg.freeze_backbone_params: + model = self.non_module("model") + else: + model = self.model + + packed = False + if images.ndim == 4: + packed = True + images = images.unsqueeze(1) + if modulation_cond is not None: + assert modulation_cond.ndim == 2 + modulation_cond = modulation_cond.unsqueeze(1) + if plucker_rays is not None: + assert plucker_rays.ndim == 4 + plucker_rays = plucker_rays.unsqueeze(1) + + if ( + self.training + and self.cfg.drop_rate > 0 + and random.random() < self.cfg.drop_rate + ): + if self.cfg.drop_type == "all_but_first": + drop_func = lambda x: x if x is None else x[:, 0:1] + images = drop_func(images) + modulation_cond = drop_func(modulation_cond) + plucker_rays = drop_func(plucker_rays) + else: + raise NotImplementedError + + batch_size, n_input_views = images.shape[:2] + images = (images - self.image_mean) / self.image_std + if self.cfg.append_plucker_rays and plucker_rays is not None: + images = torch.cat([images, plucker_rays], dim=2) + out = model( + rearrange(images, "B N C H W -> (B N) C H W"), + modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc") + if modulation_cond is not None + else None, + ) + local_features, global_features = out.last_hidden_state, out.pooler_output + if self.cfg.use_patch_embeddings: + patch_embeddings = out.patch_embeddings + if self.cfg.patch_embeddings_aggr_method == "concat": + local_features = torch.cat([local_features, patch_embeddings], dim=1) + elif self.cfg.patch_embeddings_aggr_method == "add": + local_features = local_features + patch_embeddings + else: + raise NotImplementedError + local_features = local_features.permute(0, 2, 1) + local_features = rearrange( + local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size + ) + if packed: + local_features = local_features.squeeze(1) + + return local_features + + def detokenize(self, *args, **kwargs): + raise NotImplementedError + + +class DINOV2CustomLowLevelSingleImageTokenizer(DINOV2SingleImageTokenizer): + @dataclass + class Config(DINOV2SingleImageTokenizer.Config): + custom_embeddings_aggr_method: str = "concat" + custom_embeddings_scale: float = 1.0 + + cfg: Config + + def configure(self) -> None: + super().configure() + self.custom_embeddings = CustomEmbeddings( + self.model.config.image_size, + self.model.config.patch_size, + self.model.config.num_channels, + self.model.config.hidden_size, + ) + self.custom_embeddings.load_state_dict( + self.model.embeddings.state_dict(), strict=False + ) + + def forward( + self, + images: Float[Tensor, "B *N C H W"], + modulation_cond: Optional[Float[Tensor, "B *N Cc"]], + **kwargs + ) -> Float[Tensor, "B *N Ct Nt"]: + model: Dinov2Model + if self.cfg.freeze_backbone_params: + model = self.non_module("model") + else: + model = self.model + + packed = False + if images.ndim == 4: + packed = True + images = images.unsqueeze(1) + if modulation_cond is not None: + assert modulation_cond.ndim == 2 + modulation_cond = modulation_cond.unsqueeze(1) + + batch_size, n_input_views = images.shape[:2] + images = (images - self.image_mean) / self.image_std + images = rearrange(images, "B N C H W -> (B N) C H W") + out = model( + images, + modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc") + if modulation_cond is not None + else None, + ) + local_features, global_features = out.last_hidden_state, out.pooler_output + if self.cfg.use_patch_embeddings: + patch_embeddings = out.patch_embeddings + if self.cfg.patch_embeddings_aggr_method == "concat": + local_features = torch.cat([local_features, patch_embeddings], dim=1) + elif self.cfg.patch_embeddings_aggr_method == "add": + local_features = local_features + patch_embeddings + else: + raise NotImplementedError + + custom_embeddings = ( + self.custom_embeddings(images) * self.cfg.custom_embeddings_scale + ) + if self.cfg.custom_embeddings_aggr_method == "concat": + local_features = torch.cat([local_features, custom_embeddings], dim=1) + elif self.cfg.custom_embeddings_aggr_method == "add": + local_features = local_features + custom_embeddings + else: + raise NotImplementedError + + local_features = local_features.permute(0, 2, 1) + local_features = rearrange( + local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size + ) + if packed: + local_features = local_features.squeeze(1) + + return local_features diff --git a/3D_Stage/lrm/models/tokenizers/triplane.py b/3D_Stage/lrm/models/tokenizers/triplane.py new file mode 100644 index 0000000000000000000000000000000000000000..cfbbba59683fb30998c204eaf89536e1cbe06e01 --- /dev/null +++ b/3D_Stage/lrm/models/tokenizers/triplane.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass +import math + +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from ...utils.base import BaseModule +from ...utils.typing import * + + +class TriplaneLearnablePositionalEmbedding(BaseModule): + @dataclass + class Config(BaseModule.Config): + plane_size: int = 32 + num_channels: int = 1024 + + cfg: Config + + def configure(self) -> None: + super().configure() + self.embeddings = nn.Parameter( + torch.randn( + (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size), + dtype=torch.float32, + ) + * 1 + / math.sqrt(self.cfg.num_channels) + ) + + def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]: + return rearrange( + repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size), + "B Np Ct Hp Wp -> B Ct (Np Hp Wp)", + ) + + def detokenize( + self, tokens: Float[Tensor, "B Ct Nt"] + ) -> Float[Tensor, "B 3 Ct Hp Wp"]: + batch_size, Ct, Nt = tokens.shape + assert Nt == self.cfg.plane_size**2 * 3 + assert Ct == self.cfg.num_channels + return rearrange( + tokens, + "B Ct (Np Hp Wp) -> B Np Ct Hp Wp", + Np=3, + Hp=self.cfg.plane_size, + Wp=self.cfg.plane_size, + ) diff --git a/3D_Stage/lrm/models/transformers/__init__.py b/3D_Stage/lrm/models/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3D_Stage/lrm/models/transformers/attention.py b/3D_Stage/lrm/models/transformers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5891fcf694d94aa11385e057300841f2f4f326 --- /dev/null +++ b/3D_Stage/lrm/models/transformers/attention.py @@ -0,0 +1,669 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings + +from ...utils.typing import * + + +class MemoryEfficientAttentionMixin: + def enable_xformers_memory_efficient_attention( + self, attention_op: Optional[Callable] = None + ): + r""" + Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this + option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed + up during training is not guaranteed. + + + + ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes + precedent. + + + + Parameters: + attention_op (`Callable`, *optional*): + Override the default `None` operator for use as `op` argument to the + [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) + function of xFormers. + + Examples: + + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + >>> # Workaround for not accepting attention shape using VAE for Flash Attention + >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None) + ``` + """ + self.set_use_memory_efficient_attention_xformers(True, attention_op) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). + """ + self.set_use_memory_efficient_attention_xformers(False) + + def set_use_memory_efficient_attention_xformers( + self, valid: bool, attention_op: Optional[Callable] = None + ) -> None: + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid, attention_op) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = ( + x + + self.alpha_attn.tanh() + * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + ) + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module, MemoryEfficientAttentionMixin): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + cond_dim_ada_norm_continuous: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + attention_type: str = "default", + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm" + self.use_ada_layer_norm_continuous = ( + cond_dim_ada_norm_continuous is not None + ) and norm_type == "ada_norm_continuous" + + assert ( + int(self.use_ada_layer_norm) + + int(self.use_ada_layer_norm_continuous) + + int(self.use_ada_layer_norm_zero) + <= 1 + ) + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm1 = AdaLayerNormContinuous(dim, cond_dim_ada_norm_continuous) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if self.use_ada_layer_norm: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm2 = AdaLayerNormContinuous(dim, cond_dim_ada_norm_continuous) + else: + self.norm2 = nn.LayerNorm( + dim, elementwise_affine=norm_elementwise_affine + ) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim + if not double_self_attention + else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if self.use_ada_layer_norm_continuous: + self.norm3 = AdaLayerNormContinuous(dim, cond_dim_ada_norm_continuous) + else: + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense( + dim, cross_attention_dim, num_attention_heads, attention_head_dim + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + modulation_cond: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm1(hidden_states, modulation_cond) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Retrieve lora scale. + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = ( + cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + ) + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states + if self.only_cross_attention + else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + # 2.5 ends + + # 3. Cross-Attention + if self.attn2 is not None: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm2(hidden_states, modulation_cond) + else: + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + if self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm3(hidden_states, modulation_cond) + else: + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [ + self.ff(hid_slice, scale=lora_scale) + for hid_slice in norm_hidden_states.chunk( + num_chunks, dim=self._chunk_dim + ) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to( + dtype=gate.dtype + ) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + linear_cls = nn.Linear + + self.proj = linear_cls(dim_in, dim_out * 2) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states, scale: float = 1.0): + args = () + hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2: + https://arxiv.org/abs/1606.08415. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +class AdaLayerNorm(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the dictionary of embeddings. + """ + + def __init__(self, embedding_dim: int, num_embeddings: int): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: + emb = self.linear(self.silu(self.emb(timestep))) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return x + + +class AdaLayerNormContinuous(nn.Module): + r""" + Norm layer modified to incorporate arbitrary continuous embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__(self, embedding_dim: int, condition_dim: int): + super().__init__() + self.silu = nn.SiLU() + self.linear1 = nn.Linear(condition_dim, condition_dim) + self.linear2 = nn.Linear(condition_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor: + emb = self.linear2(self.silu(self.linear1(condition))) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return x + + +class Modulation(nn.Module): + def __init__(self, embedding_dim: int, condition_dim: int, zero_init: bool = False, single_layer: bool = False): + super().__init__() + self.silu = nn.SiLU() + if single_layer: + self.linear1 = nn.Identity() + else: + self.linear1 = nn.Linear(condition_dim, condition_dim) + + self.linear2 = nn.Linear(condition_dim, embedding_dim * 2) + + # Only zero init the last linear layer + if zero_init: + nn.init.zeros_(self.linear2.weight) + nn.init.zeros_(self.linear2.bias) + + def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor: + emb = self.linear2(self.silu(self.linear1(condition))) + scale, shift = torch.chunk(emb, 2, dim=1) + x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return x + + +class AdaLayerNormZero(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the dictionary of embeddings. + """ + + def __init__(self, embedding_dim: int, num_embeddings: int): + super().__init__() + + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + + def forward( + self, + x: torch.Tensor, + timestep: torch.Tensor, + class_labels: torch.LongTensor, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear( + self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)) + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk( + 6, dim=1 + ) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaGroupNorm(nn.Module): + r""" + GroupNorm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the dictionary of embeddings. + num_groups (`int`): The number of groups to separate the channels into. + act_fn (`str`, *optional*, defaults to `None`): The activation function to use. + eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability. + """ + + def __init__( + self, + embedding_dim: int, + out_dim: int, + num_groups: int, + act_fn: Optional[str] = None, + eps: float = 1e-5, + ): + super().__init__() + self.num_groups = num_groups + self.eps = eps + + if act_fn is None: + self.act = None + else: + self.act = get_activation(act_fn) + + self.linear = nn.Linear(embedding_dim, out_dim * 2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + if self.act: + emb = self.act(emb) + emb = self.linear(emb) + emb = emb[:, :, None, None] + scale, shift = emb.chunk(2, dim=1) + + x = F.group_norm(x, self.num_groups, eps=self.eps) + x = x * (1 + scale) + shift + return x diff --git a/3D_Stage/lrm/models/transformers/transformer_1d.py b/3D_Stage/lrm/models/transformers/transformer_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..7d03a488b6ffb6e5a5c87c3dd9c0130baf743472 --- /dev/null +++ b/3D_Stage/lrm/models/transformers/transformer_1d.py @@ -0,0 +1,252 @@ +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F +from torch import nn +from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed + +import lrm +from ...utils.base import BaseModule +from .attention import ( + BasicTransformerBlock, + MemoryEfficientAttentionMixin, +) +from ...utils.typing import * + + +class Transformer1D(BaseModule, MemoryEfficientAttentionMixin): + """ + A 1D Transformer model for sequence data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @dataclass + class Config(BaseModule.Config): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + in_channels: Optional[int] = None + out_channels: Optional[int] = None + num_layers: int = 1 + dropout: float = 0.0 + norm_num_groups: int = 32 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + cond_dim_ada_norm_continuous: Optional[int] = None + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_type: str = "layer_norm" + norm_elementwise_affine: bool = True + attention_type: str = "default" + enable_memory_efficient_attention: bool = False + gradient_checkpointing: bool = False + + cfg: Config + + def configure(self) -> None: + super().configure() + + self.num_attention_heads = self.cfg.num_attention_heads + self.attention_head_dim = self.cfg.attention_head_dim + inner_dim = self.num_attention_heads * self.attention_head_dim + + linear_cls = nn.Linear + + if self.cfg.norm_type == "layer_norm" and ( + self.cfg.num_embeds_ada_norm is not None + or self.cfg.cond_dim_ada_norm_continuous is not None + ): + raise ValueError("Incorrect norm_type.") + + # 2. Define input layers + self.in_channels = self.cfg.in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=self.cfg.norm_num_groups, + num_channels=self.cfg.in_channels, + eps=1e-6, + affine=True, + ) + self.proj_in = linear_cls(self.cfg.in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + self.num_attention_heads, + self.attention_head_dim, + dropout=self.cfg.dropout, + cross_attention_dim=self.cfg.cross_attention_dim, + activation_fn=self.cfg.activation_fn, + num_embeds_ada_norm=self.cfg.num_embeds_ada_norm, + cond_dim_ada_norm_continuous=self.cfg.cond_dim_ada_norm_continuous, + attention_bias=self.cfg.attention_bias, + only_cross_attention=self.cfg.only_cross_attention, + double_self_attention=self.cfg.double_self_attention, + upcast_attention=self.cfg.upcast_attention, + norm_type=self.cfg.norm_type, + norm_elementwise_affine=self.cfg.norm_elementwise_affine, + attention_type=self.cfg.attention_type, + ) + for d in range(self.cfg.num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = ( + self.cfg.in_channels + if self.cfg.out_channels is None + else self.cfg.out_channels + ) + + self.proj_out = linear_cls(inner_dim, self.cfg.in_channels) + + self.gradient_checkpointing = self.cfg.gradient_checkpointing + + self.set_use_memory_efficient_attention_xformers( + self.cfg.enable_memory_efficient_attention + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + modulation_cond: Optional[torch.FloatTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ): + """ + The [`Transformer1DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch, _, seq_len = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 1).reshape( + batch, seq_len, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + modulation_cond, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + modulation_cond=modulation_cond, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, seq_len, inner_dim) + .permute(0, 2, 1) + .contiguous() + ) + + output = hidden_states + residual + + return output diff --git a/3D_Stage/lrm/models/transformers/transformer_2d.py b/3D_Stage/lrm/models/transformers/transformer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..2c31de6efebf0f3d1565d96e076870eb6979396a --- /dev/null +++ b/3D_Stage/lrm/models/transformers/transformer_2d.py @@ -0,0 +1,414 @@ +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F +from torch import nn +from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed + +import lrm +from ...utils.base import BaseModule +from .attention import ( + BasicTransformerBlock, + MemoryEfficientAttentionMixin, +) +from ...utils.typing import * + + +class Transformer2D(BaseModule, MemoryEfficientAttentionMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @dataclass + class Config(BaseModule.Config): + num_attention_heads: int = 16 + attention_head_dim: int = 88 + in_channels: Optional[int] = None + out_channels: Optional[int] = None + num_layers: int = 1 + dropout: float = 0.0 + norm_num_groups: int = 32 + cross_attention_dim: Optional[int] = None + attention_bias: bool = False + sample_size: Optional[int] = None + num_vector_embeds: Optional[int] = None + patch_size: Optional[int] = None + activation_fn: str = "geglu" + num_embeds_ada_norm: Optional[int] = None + use_linear_projection: bool = False + only_cross_attention: bool = False + double_self_attention: bool = False + upcast_attention: bool = False + norm_type: str = "layer_norm" + norm_elementwise_affine: bool = True + attention_type: str = "default" + enable_memory_efficient_attention: bool = False + gradient_checkpointing: bool = False + + cfg: Config + + def configure(self) -> None: + super().configure() + + self.use_linear_projection = self.cfg.use_linear_projection + self.num_attention_heads = self.cfg.num_attention_heads + self.attention_head_dim = self.cfg.attention_head_dim + inner_dim = self.num_attention_heads * self.attention_head_dim + + conv_cls = nn.Conv2d + linear_cls = nn.Linear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (self.cfg.in_channels is not None) and ( + self.cfg.patch_size is None + ) + self.is_input_vectorized = self.cfg.num_vector_embeds is not None + self.is_input_patches = ( + self.cfg.in_channels is not None and self.cfg.patch_size is not None + ) + + if ( + self.cfg.norm_type == "layer_norm" + and self.cfg.num_embeds_ada_norm is not None + ): + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + lrm.warn(deprecation_message) + self.cfg.norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {self.cfg.in_channels} and `num_vector_embeds`: {self.cfg.num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {self.cfg.num_vector_embeds} and `patch_size`: {self.cfg.patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif ( + not self.is_input_continuous + and not self.is_input_vectorized + and not self.is_input_patches + ): + raise ValueError( + f"Has to define `in_channels`: {self.cfg.in_channels}, `num_vector_embeds`: {self.cfg.num_vector_embeds}, or patch_size:" + f" {self.cfg.patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = self.cfg.in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=self.cfg.norm_num_groups, + num_channels=self.cfg.in_channels, + eps=1e-6, + affine=True, + ) + if self.cfg.use_linear_projection: + self.proj_in = linear_cls(self.cfg.in_channels, inner_dim) + else: + self.proj_in = conv_cls( + self.cfg.in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + elif self.is_input_vectorized: + assert ( + self.cfg.sample_size is not None + ), "Transformer2DModel over discrete input must provide sample_size" + assert ( + self.cfg.num_vector_embeds is not None + ), "Transformer2DModel over discrete input must provide num_embed" + + self.height = self.cfg.sample_size + self.width = self.cfg.sample_size + self.num_vector_embeds = self.cfg.num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=self.cfg.num_vector_embeds, + embed_dim=inner_dim, + height=self.height, + width=self.width, + ) + elif self.is_input_patches: + assert ( + self.cfg.sample_size is not None + ), "Transformer2DModel over patched input must provide sample_size" + + self.height = self.cfg.sample_size + self.width = self.cfg.sample_size + + self.patch_size = self.cfg.patch_size + self.pos_embed = PatchEmbed( + height=self.cfg.sample_size, + width=self.cfg.sample_size, + patch_size=self.cfg.patch_size, + in_channels=self.cfg.in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + self.num_attention_heads, + self.attention_head_dim, + dropout=self.cfg.dropout, + cross_attention_dim=self.cfg.cross_attention_dim, + activation_fn=self.cfg.activation_fn, + num_embeds_ada_norm=self.cfg.num_embeds_ada_norm, + attention_bias=self.cfg.attention_bias, + only_cross_attention=self.cfg.only_cross_attention, + double_self_attention=self.cfg.double_self_attention, + upcast_attention=self.cfg.upcast_attention, + norm_type=self.cfg.norm_type, + norm_elementwise_affine=self.cfg.norm_elementwise_affine, + attention_type=self.cfg.attention_type, + ) + for d in range(self.cfg.num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = ( + self.cfg.in_channels + if self.cfg.out_channels is None + else self.cfg.out_channels + ) + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if self.cfg.use_linear_projection: + self.proj_out = linear_cls(inner_dim, self.cfg.in_channels) + else: + self.proj_out = conv_cls( + inner_dim, self.cfg.in_channels, kernel_size=1, stride=1, padding=0 + ) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear( + inner_dim, self.cfg.patch_size * self.cfg.patch_size * self.out_channels + ) + + self.gradient_checkpointing = self.cfg.gradient_checkpointing + + self.set_use_memory_efficient_attention_xformers( + self.cfg.enable_memory_efficient_attention + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = ( + self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + ) + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=( + -1, + height, + width, + self.patch_size, + self.patch_size, + self.out_channels, + ) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=( + -1, + self.out_channels, + height * self.patch_size, + width * self.patch_size, + ) + ) + + return output diff --git a/3D_Stage/lrm/systems/__init__.py b/3D_Stage/lrm/systems/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3D_Stage/lrm/systems/base.py b/3D_Stage/lrm/systems/base.py new file mode 100644 index 0000000000000000000000000000000000000000..885bef8328064bd3151f1883d7246bcd330ac6c0 --- /dev/null +++ b/3D_Stage/lrm/systems/base.py @@ -0,0 +1,283 @@ +import os +from dataclasses import dataclass, field + +import pytorch_lightning as pl +import torch.nn.functional as F + +import lrm +from .utils import parse_optimizer, parse_scheduler +from ..utils.base import ( + Updateable, + update_end_if_possible, + update_if_possible, +) +from ..models.exporters.base import Exporter, ExporterOutput +from ..utils.config import parse_structured +from ..utils.misc import C, cleanup, get_device, load_module_weights +from ..utils.saving import SaverMixin +from ..utils.typing import * + + +@dataclass +class BaseLossConfig: + pass + + +class BaseSystem(pl.LightningModule, Updateable, SaverMixin): + @dataclass + class Config: + loss: BaseLossConfig = BaseLossConfig() + optimizer: dict = field(default_factory=dict) + scheduler: Optional[dict] = None + weights: Optional[str] = None + weights_ignore_modules: Optional[List[str]] = None + weights_mapping: Optional[List[Dict[str, str]]] = None + check_train_every_n_steps: int = 0 + check_val_limit_rank: int = 8 + cleanup_after_validation_step: bool = False + cleanup_after_test_step: bool = False + + exporter_cls: str = "lrm.models.exporters.mesh_exporter.MeshExporter" + exporter: dict = field(default_factory=lambda: {"fmt": "obj", "save_uv": False}) + + cfg: Config + + def __init__(self, cfg, resumed=False) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self._save_dir: Optional[str] = None + self._resumed: bool = resumed + self._resumed_eval: bool = False + self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0} + + self.configure() + print(self.cfg.weights) + if self.cfg.weights is not None: + self.load_weights( + self.cfg.weights, + self.cfg.weights_ignore_modules, + self.cfg.weights_mapping, + ) + print("finish loading!!") + self.post_configure() + + def load_weights( + self, + weights: str, + ignore_modules: Optional[List[str]] = None, + mapping: Optional[List[Dict[str, str]]] = None, + ): + state_dict, epoch, global_step = load_module_weights( + weights, + ignore_modules=ignore_modules, + mapping=mapping, + map_location="cpu", + ) + self.load_state_dict(state_dict, strict=False) + # restore step-dependent states + self.do_update_step(epoch, global_step, on_load_weights=True) + + def set_resume_status(self, current_epoch: int, global_step: int): + # restore correct epoch and global step in eval + self._resumed_eval = True + self._resumed_eval_status["current_epoch"] = current_epoch + self._resumed_eval_status["global_step"] = global_step + + @property + def resumed(self): + # whether from resumed checkpoint + return self._resumed + + @property + def true_global_step(self): + if self._resumed_eval: + return self._resumed_eval_status["global_step"] + else: + return self.global_step + + @property + def true_current_epoch(self): + if self._resumed_eval: + return self._resumed_eval_status["current_epoch"] + else: + return self.current_epoch + + def configure(self) -> None: + pass + + def post_configure(self) -> None: + """ + executed after weights are loaded + """ + pass + + def C(self, value: Any) -> float: + return C(value, self.true_current_epoch, self.true_global_step) + + def configure_optimizers(self): + optim = parse_optimizer(self.cfg.optimizer, self) + ret = { + "optimizer": optim, + } + if self.cfg.scheduler is not None: + ret.update( + { + "lr_scheduler": parse_scheduler(self.cfg.scheduler, optim), + } + ) + return ret + + def on_fit_start(self) -> None: + if self._save_dir is not None: + lrm.info(f"Validation results will be saved to {self._save_dir}") + else: + lrm.warn( + f"Saving directory not set for the system, visualization results will not be saved" + ) + + def training_step(self, batch, batch_idx): + raise NotImplementedError + + def check_train(self, batch, outputs, **kwargs): + if ( + self.global_rank == 0 + and self.cfg.check_train_every_n_steps > 0 + and self.true_global_step % self.cfg.check_train_every_n_steps == 0 + ): + self.on_check_train(batch, outputs, **kwargs) + + def on_check_train(self, batch, outputs, **kwargs): + pass + + def validation_step(self, batch, batch_idx): + raise NotImplementedError + + def on_validation_epoch_end(self): + pass + + def test_step(self, batch, batch_idx): + raise NotImplementedError + + def on_test_epoch_end(self): + pass + + def on_test_end(self) -> None: + if self._save_dir is not None: + lrm.info(f"Test results saved to {self._save_dir}") + + def on_predict_start(self) -> None: + pass + + def predict_step(self, batch, batch_idx): + batch_size = batch["index"].shape[0] + scene_codes = self(batch) + for b in range(batch_size): + if batch["view_index"][b, 0] == 0: + exporter_output: List[ExporterOutput] = self.exporter( + batch["index"][b][None], scene_codes[b][None] + ) + for out in exporter_output: + save_func_name = f"save_{out.save_type}" + if not hasattr(self, save_func_name): + raise ValueError( + f"{save_func_name} not supported by the SaverMixin" + ) + save_func = getattr(self, save_func_name) + save_func( + f"it{self.true_global_step}-export/{out.save_name}", + **out.params, + ) + if self.exporter.cfg.save_video: + self.test_step(batch, batch_idx) + + def on_predict_epoch_end(self) -> None: + if self.exporter.cfg.save_video: + self.on_test_epoch_end() + + def on_predict_end(self) -> None: + if self._save_dir is not None: + lrm.info(f"Export assets saved to {self._save_dir}") + + def preprocess_data(self, batch, stage): + pass + + """ + Implementing on_after_batch_transfer of DataModule does the same. + But on_after_batch_transfer does not support DP. + """ + + def on_train_batch_start(self, batch, batch_idx, unused=0): + self.preprocess_data(batch, "train") + self.dataset = self.trainer.train_dataloader.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): + self.preprocess_data(batch, "validation") + self.dataset = self.trainer.val_dataloaders.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0): + self.preprocess_data(batch, "test") + self.dataset = self.trainer.test_dataloaders.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0): + self.preprocess_data(batch, "predict") + self.dataset = self.trainer.predict_dataloaders.dataset + update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) + self.do_update_step(self.true_current_epoch, self.true_global_step) + + def on_train_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.train_dataloader.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + + def on_validation_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.val_dataloaders.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + if self.cfg.cleanup_after_validation_step: + # cleanup to save vram + cleanup() + + def on_test_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.test_dataloaders.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + if self.cfg.cleanup_after_test_step: + # cleanup to save vram + cleanup() + + def on_predict_batch_end(self, outputs, batch, batch_idx): + self.dataset = self.trainer.predict_dataloaders.dataset + update_end_if_possible( + self.dataset, self.true_current_epoch, self.true_global_step + ) + self.do_update_step_end(self.true_current_epoch, self.true_global_step) + if self.cfg.cleanup_after_test_step: + # cleanup to save vram + cleanup() + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + pass + + def on_before_optimizer_step(self, optimizer): + """ + # some gradient-related debugging goes here, example: + from lightning.pytorch.utilities import grad_norm + norms = grad_norm(self.geometry, norm_type=2) + print(norms) + for name, p in self.named_parameters(): + if p.grad is None: + lrm.info(f"{name} does not receive gradients!") + """ + pass diff --git a/3D_Stage/lrm/systems/multiview_lrm.py b/3D_Stage/lrm/systems/multiview_lrm.py new file mode 100644 index 0000000000000000000000000000000000000000..5bec99b677f9402bc453e866cf7d517132f2d80e --- /dev/null +++ b/3D_Stage/lrm/systems/multiview_lrm.py @@ -0,0 +1,335 @@ +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F +from einops import rearrange + +import lrm +from lrm.systems.base import BaseLossConfig, BaseSystem +from lrm.utils.ops import binary_cross_entropy, get_plucker_rays +from lrm.utils.typing import * +from lrm.models.lpips import LPIPS +from lrm.utils.misc import time_recorder as tr + + +@dataclass +class MultiviewLRMLossConfig(BaseLossConfig): + lambda_mse: Any = 0.0 + lambda_mse_coarse: Any = 0.0 + lambda_smooth_l1: Any = 0.0 + lambda_smooth_l1_coarse: Any = 0.0 + lambda_lpips: Any = 0.0 + lambda_lpips_coarse: Any = 0.0 + lambda_mask: Any = 0.0 + lambda_mask_coarse: Any = 0.0 + + +class MultiviewLRM(BaseSystem): + @dataclass + class Config(BaseSystem.Config): + loss: MultiviewLRMLossConfig = MultiviewLRMLossConfig() + + camera_embedder_cls: str = "" + camera_embedder: dict = field(default_factory=dict) + + image_tokenizer_cls: str = "" + image_tokenizer: dict = field(default_factory=dict) + + tokenizer_cls: str = "" + tokenizer: dict = field(default_factory=dict) + + backbone_cls: str = "" + backbone: dict = field(default_factory=dict) + + post_processor_cls: str = "" + post_processor: dict = field(default_factory=dict) + + decoder_cls: str = "" + decoder: dict = field(default_factory=dict) + + material_cls: str = "" + material: dict = field(default_factory=dict) + + background_cls: str = "" + background: dict = field(default_factory=dict) + + renderer_cls: str = "" + renderer: dict = field(default_factory=dict) + + resume_ckpt_path: str = "" + + cfg: Config + + def configure(self): + super().configure() + self.image_tokenizer = lrm.find(self.cfg.image_tokenizer_cls)( + self.cfg.image_tokenizer + ) + if self.cfg.image_tokenizer.modulation: + self.camera_embedder = lrm.find(self.cfg.camera_embedder_cls)( + self.cfg.camera_embedder + ) + self.tokenizer = lrm.find(self.cfg.tokenizer_cls)(self.cfg.tokenizer) + self.backbone = lrm.find(self.cfg.backbone_cls)(self.cfg.backbone) + self.post_processor = lrm.find(self.cfg.post_processor_cls)( + self.cfg.post_processor + ) + self.decoder = lrm.find(self.cfg.decoder_cls)(self.cfg.decoder) + self.material = lrm.find(self.cfg.material_cls)(self.cfg.material) + self.background = lrm.find(self.cfg.background_cls)(self.cfg.background) + self.renderer = lrm.find(self.cfg.renderer_cls)( + self.cfg.renderer, self.decoder, self.material, self.background + ) + + self.exporter = lrm.find(self.cfg.exporter_cls)( + self.cfg.exporter, self.renderer + ) + + def on_fit_start(self): + super().on_fit_start() + self.lpips_loss_fn = LPIPS() + + def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: + # batch["rgb_cond"]: B, N_cond, H, W, 3 + # batch["rgb"]: B, N_render, H, W, 3 + # batch["c2w_cond"]: B, N_cond, 4, 4 + # for single image input (like LRM), N_cond = 1 + + batch_size, n_input_views = batch["rgb_cond"].shape[:2] + + # Camera modulation + camera_embeds: Optional[Float[Tensor, "B Nv Cc"]] + if self.cfg.image_tokenizer.modulation: + camera_embeds = self.camera_embedder(**batch) + else: + camera_embeds = None + + tr.start("image tokenizer") + input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer( + rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"), + modulation_cond=camera_embeds, + plucker_rays=rearrange( + get_plucker_rays(batch["rays_o_cond"], batch["rays_d_cond"]), + "B Nv H W C -> B Nv C H W", + ) + if "rays_o_cond" in batch + else None, + ) + tr.end("image tokenizer") + + input_image_tokens = rearrange( + input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views + ) + + tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size) + + tr.start("backbone") + tokens = self.backbone( + tokens, + encoder_hidden_states=input_image_tokens, + modulation_cond=None, + ) + tr.end("backbone") + + scene_codes = self.post_processor(self.tokenizer.detokenize(tokens)) + return scene_codes + + def forward_renderer_nerf( + self, batch: Dict[str, Any], scene_codes + ) -> Dict[str, Any]: + tr.start("render") + render_out = self.renderer(scene_codes, **batch) + tr.end("render") + return render_out + + def training_step(self, batch, batch_idx): + scene_codes = self(batch) + out = self.forward_renderer_nerf(batch, scene_codes) + + loss = 0.0 + + for suffix in ["", "_coarse"]: + if not f"comp_rgb{suffix}" in out: + continue + + comp_rgb: Float[Tensor, "B Nv H W 3"] = out["comp_rgb{}".format(suffix)] + gt_rgb: Float[Tensor, "B Nv H W 3"] = batch["rgb"] + + self.log(f"train/comp_rgb_min{suffix}", comp_rgb.min()) + + loss_mse = F.mse_loss(comp_rgb, gt_rgb, reduction="mean") + self.log(f"train/loss_mse{suffix}", loss_mse) + loss += loss_mse * self.C(self.cfg.loss[f"lambda_mse{suffix}"]) + + loss_smooth_l1 = F.smooth_l1_loss( + comp_rgb, gt_rgb, beta=0.1, reduction="mean" + ) + self.log(f"train/loss_smooth_l1{suffix}", loss_smooth_l1) + loss += loss_smooth_l1 * self.C(self.cfg.loss[f"lambda_smooth_l1{suffix}"]) + + if self.C(self.cfg.loss[f"lambda_lpips{suffix}"]) > 0: + loss_lpips = self.lpips_loss_fn( + rearrange(comp_rgb, "B Nv H W C -> (B Nv) C H W"), + rearrange(gt_rgb, "B Nv H W C -> (B Nv) C H W"), + input_range=(0, 1), + ).mean() + self.log(f"train/loss_lpips{suffix}", loss_lpips) + loss += loss_lpips * self.C(self.cfg.loss[f"lambda_lpips{suffix}"]) + + loss_mask = binary_cross_entropy( + out[f"opacity{suffix}"].clamp(1e-5, 1 - 1e-5), batch["mask"] + ) + self.log(f"train/loss_mask{suffix}", loss_mask) + loss += loss_mask * self.C(self.cfg.loss[f"lambda_mask{suffix}"]) + + for name, value in self.cfg.loss.items(): + self.log(f"train_params/{name}", self.C(value)) + + # will execute self.on_check_train every self.cfg.check_train_every_n_steps steps + self.check_train( + batch, + out, + extra=f"m{loss_mse:.2f}_l{loss_smooth_l1:.2f}_p{loss_lpips:.2f}_ma{loss_mask:.2f}", + ) + + return {"loss": loss} + + def get_input_visualizations(self, batch): + return [ + { + "type": "rgb", + "img": rearrange(batch["rgb_cond"], "B N H W C -> (B H) (N W) C"), + "kwargs": {"data_format": "HWC"}, + } + ] + + def get_output_visualizations(self, batch, outputs): + out = outputs + images = [] + if "rgb" in batch: + images += [ + { + "type": "rgb", + "img": rearrange(batch["rgb"], "B N H W C -> (B H) (N W) C"), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "grayscale", + "img": rearrange(batch["mask"], "B N H W C -> (B H) (N W) C")[ + ..., 0 + ], + "kwargs": {"cmap": None, "data_range": None}, + }, + ] + for suffix in ["", "_coarse"]: + if not f"comp_rgb{suffix}" in out: + continue + images += [ + { + "type": "rgb", + "img": rearrange( + out[f"comp_rgb{suffix}"], "B N H W C -> (B H) (N W) C" + ), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "grayscale", + "img": rearrange( + out[f"opacity{suffix}"], "B N H W C -> (B H) (N W) C" + )[..., 0], + "kwargs": {"cmap": None, "data_range": None}, + }, + { + "type": "grayscale", + "img": rearrange( + out[f"depth{suffix}"], "B N H W C -> (B H) (N W) C" + )[..., 0], + "kwargs": {"cmap": None, "data_range": None}, + }, + ] + return images + + # def check_train(self, batch, outputs, **kwargs): + # self.on_check_train(batch, outputs, **kwargs) + + def on_check_train(self, batch, outputs, extra=""): + self.save_image_grid( + f"it{self.true_global_step}-train.jpg", + self.get_output_visualizations(batch, outputs), + name="train_step_output", + step=self.true_global_step, + ) + # self.save_image_grid( + # f"debug/it{self.true_global_step}-{self.global_rank}-{extra}.jpg", + # self.get_output_visualizations(batch, outputs), + # name="train_step_output", + # step=self.true_global_step, + # ) + # self.save_json( + # f"debug_list/it{self.true_global_step}-{self.global_rank}-ids.json", + # batch["scene_id"], + # ) + + def validation_step(self, batch, batch_idx): + scene_codes = self(batch) + out = self.forward_renderer_nerf(batch, scene_codes) + if ( + self.cfg.check_val_limit_rank > 0 + and self.global_rank < self.cfg.check_val_limit_rank + ): + self.save_image_grid( + f"it{self.true_global_step}-validation-{self.global_rank}_{batch_idx}-input.jpg", + self.get_input_visualizations(batch), + name=f"validation_step_input_{self.global_rank}_{batch_idx}", + step=self.true_global_step, + ) + self.save_image_grid( + f"it{self.true_global_step}-validation-{self.global_rank}_{batch_idx}.jpg", + self.get_output_visualizations(batch, out), + name=f"validation_step_output_{self.global_rank}_{batch_idx}", + step=self.true_global_step, + ) + + def test_step(self, batch, batch_idx): + # not saved to wandb + scene_codes = self(batch) + out = self.forward_renderer_nerf(batch, scene_codes) + batch_size = batch["index"].shape[0] + for b in range(batch_size): + if batch["view_index"][b, 0] == 0: + self.save_image_grid( + f"it{self.true_global_step}-test/{batch['index'][b]}-input.jpg", + [ + { + "type": "rgb", + "img": rearrange( + batch["rgb_cond"][b], "N H W C -> H (N W) C" + ), + "kwargs": {"data_format": "HWC"}, + }, + ], + ) + self.save_image_grid( + f"it{self.true_global_step}-test/{batch['index'][b]}/{batch['view_index'][b,0]}.png", + [ + { + "type": "rgb", + "img": out["comp_rgb"][b][0], + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "grayscale", + "img": out["depth"][b][0, ..., 0], + "kwargs": {"cmap": None, "data_range": None}, + }, + ], + ) + + def on_test_end(self): + if self.global_rank == 0: + self.save_img_sequences( + f"it{self.true_global_step}-test", + "(\d+)\.png", + save_format="mp4", + fps=30, + ) diff --git a/3D_Stage/lrm/systems/utils.py b/3D_Stage/lrm/systems/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..91e534849bcb6e627df383c84696264808382c66 --- /dev/null +++ b/3D_Stage/lrm/systems/utils.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn +from torch.optim import lr_scheduler + +import lrm + + +def get_scheduler(name): + if hasattr(lr_scheduler, name): + return getattr(lr_scheduler, name) + else: + raise NotImplementedError + + +def getattr_recursive(m, attr): + for name in attr.split("."): + m = getattr(m, name) + return m + + +def get_parameters(model, name): + module = getattr_recursive(model, name) + if isinstance(module, nn.Module): + return module.parameters() + elif isinstance(module, nn.Parameter): + return module + return [] + + +def parse_optimizer(config, model): + if hasattr(config, "params"): + params = [ + {"params": get_parameters(model, name), "name": name, **args} + for name, args in config.params.items() + ] + lrm.debug(f"Specify optimizer params: {config.params}") + else: + params = model.parameters() + if config.name in ["FusedAdam"]: + import apex + + optim = getattr(apex.optimizers, config.name)(params, **config.args) + elif config.name in ["Adam8bit", "AdamW8bit"]: + import bitsandbytes as bnb + + optim = bnb.optim.Adam8bit(params, **config.args) + else: + optim = getattr(torch.optim, config.name)(params, **config.args) + return optim + + +def parse_scheduler_to_instance(config, optimizer): + if config.name == "ChainedScheduler": + schedulers = [ + parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers + ] + scheduler = lr_scheduler.ChainedScheduler(schedulers) + elif config.name == "Sequential": + schedulers = [ + parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers + ] + scheduler = lr_scheduler.SequentialLR( + optimizer, schedulers, milestones=config.milestones + ) + else: + scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args) + return scheduler + + +def parse_scheduler(config, optimizer): + interval = config.get("interval", "epoch") + assert interval in ["epoch", "step"] + if config.name == "SequentialLR": + scheduler = { + "scheduler": lr_scheduler.SequentialLR( + optimizer, + [ + parse_scheduler(conf, optimizer)["scheduler"] + for conf in config.schedulers + ], + milestones=config.milestones, + ), + "interval": interval, + } + elif config.name == "ChainedScheduler": + scheduler = { + "scheduler": lr_scheduler.ChainedScheduler( + [ + parse_scheduler(conf, optimizer)["scheduler"] + for conf in config.schedulers + ] + ), + "interval": interval, + } + else: + scheduler = { + "scheduler": get_scheduler(config.name)(optimizer, **config.args), + "interval": interval, + } + return scheduler diff --git a/3D_Stage/lrm/utils/__init__.py b/3D_Stage/lrm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3D_Stage/lrm/utils/base.py b/3D_Stage/lrm/utils/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f78adda249e57664ee2ee8b38b41caf5e80a780f --- /dev/null +++ b/3D_Stage/lrm/utils/base.py @@ -0,0 +1,123 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from .config import parse_structured +from .misc import get_device, load_module_weights +from .typing import * + + +class Configurable: + @dataclass + class Config: + pass + + def __init__(self, cfg: Optional[dict] = None) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + + +class Updateable: + def do_update_step( + self, epoch: int, global_step: int, on_load_weights: bool = False + ): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step( + epoch, global_step, on_load_weights=on_load_weights + ) + self.update_step(epoch, global_step, on_load_weights=on_load_weights) + + def do_update_step_end(self, epoch: int, global_step: int): + for attr in self.__dir__(): + if attr.startswith("_"): + continue + try: + module = getattr(self, attr) + except: + continue # ignore attributes like property, which can't be retrived using getattr? + if isinstance(module, Updateable): + module.do_update_step_end(epoch, global_step) + self.update_step_end(epoch, global_step) + + def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): + # override this method to implement custom update logic + # if on_load_weights is True, you should be careful doing things related to model evaluations, + # as the models and tensors are not guarenteed to be on the same device + pass + + def update_step_end(self, epoch: int, global_step: int): + pass + + +def update_if_possible(module: Any, epoch: int, global_step: int) -> None: + if isinstance(module, Updateable): + module.do_update_step(epoch, global_step) + + +def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: + if isinstance(module, Updateable): + module.do_update_step_end(epoch, global_step) + + +class BaseObject(Updateable): + @dataclass + class Config: + pass + + cfg: Config # add this to every subclass of BaseObject to enable static type checking + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self.device = get_device() + self.configure(*args, **kwargs) + + def configure(self, *args, **kwargs) -> None: + pass + + +class BaseModule(nn.Module, Updateable): + @dataclass + class Config: + weights: Optional[str] = None + + cfg: Config # add this to every subclass of BaseModule to enable static type checking + + def __init__( + self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs + ) -> None: + super().__init__() + self.cfg = parse_structured(self.Config, cfg) + self.device = get_device() + self._non_modules = {} + self.configure(*args, **kwargs) + if self.cfg.weights is not None: + # format: path/to/weights:module_name + weights_path, module_name = self.cfg.weights.split(":") + state_dict, epoch, global_step = load_module_weights( + weights_path, module_name=module_name, map_location="cpu" + ) + self.load_state_dict(state_dict) + self.do_update_step( + epoch, global_step, on_load_weights=True + ) # restore states + + def configure(self, *args, **kwargs) -> None: + pass + + def register_non_module(self, name: str, module: nn.Module) -> None: + # non-modules won't be treated as model parameters + self._non_modules[name] = module + + def non_module(self, name: str): + return self._non_modules.get(name, None) diff --git a/3D_Stage/lrm/utils/callbacks.py b/3D_Stage/lrm/utils/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..1f4cbd68e513e2ce6c905b1bb95842cdad8a7551 --- /dev/null +++ b/3D_Stage/lrm/utils/callbacks.py @@ -0,0 +1,156 @@ +import os +import shutil +import subprocess + +import pytorch_lightning + +from .config import dump_config +from .misc import parse_version + +if parse_version(pytorch_lightning.__version__) > parse_version("1.8"): + from pytorch_lightning.callbacks import Callback +else: + from pytorch_lightning.callbacks.base import Callback + +from pytorch_lightning.callbacks.progress import TQDMProgressBar +from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn + + +class VersionedCallback(Callback): + def __init__(self, save_root, version=None, use_version=True): + self.save_root = save_root + self._version = version + self.use_version = use_version + + @property + def version(self) -> int: + """Get the experiment version. + + Returns: + The experiment version if specified else the next version. + """ + if self._version is None: + self._version = self._get_next_version() + return self._version + + def _get_next_version(self): + existing_versions = [] + if os.path.isdir(self.save_root): + for f in os.listdir(self.save_root): + bn = os.path.basename(f) + if bn.startswith("version_"): + dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") + existing_versions.append(int(dir_ver)) + if len(existing_versions) == 0: + return 0 + return max(existing_versions) + 1 + + @property + def savedir(self): + if not self.use_version: + return self.save_root + return os.path.join( + self.save_root, + self.version + if isinstance(self.version, str) + else f"version_{self.version}", + ) + + +class CodeSnapshotCallback(VersionedCallback): + def __init__(self, save_root, version=None, use_version=True): + super().__init__(save_root, version, use_version) + + def get_file_list(self): + return [ + b.decode() + for b in set( + subprocess.check_output( + 'git ls-files -- ":!:load/*"', shell=True + ).splitlines() + ) + | set( # hard code, TODO: use config to exclude folders or files + subprocess.check_output( + "git ls-files --others --exclude-standard", shell=True + ).splitlines() + ) + ] + + @rank_zero_only + def save_code_snapshot(self): + os.makedirs(self.savedir, exist_ok=True) + for f in self.get_file_list(): + if not os.path.exists(f) or os.path.isdir(f): + continue + os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) + shutil.copyfile(f, os.path.join(self.savedir, f)) + + def on_fit_start(self, trainer, pl_module): + try: + self.save_code_snapshot() + except: + rank_zero_warn( + "Code snapshot is not saved. Please make sure you have git installed and are in a git repository." + ) + + +class ConfigSnapshotCallback(VersionedCallback): + def __init__(self, config_path, config, save_root, version=None, use_version=True): + super().__init__(save_root, version, use_version) + self.config_path = config_path + self.config = config + + @rank_zero_only + def save_config_snapshot(self): + os.makedirs(self.savedir, exist_ok=True) + dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config) + shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml")) + + def on_fit_start(self, trainer, pl_module): + self.save_config_snapshot() + + +class CustomProgressBar(TQDMProgressBar): + def get_metrics(self, *args, **kwargs): + # don't show the version number + items = super().get_metrics(*args, **kwargs) + items.pop("v_num", None) + return items + + +class ProgressCallback(Callback): + def __init__(self, save_path): + super().__init__() + self.save_path = save_path + self._file_handle = None + + @property + def file_handle(self): + if self._file_handle is None: + self._file_handle = open(self.save_path, "w") + return self._file_handle + + @rank_zero_only + def write(self, msg: str) -> None: + self.file_handle.seek(0) + self.file_handle.truncate() + self.file_handle.write(msg) + self.file_handle.flush() + + @rank_zero_only + def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): + self.write( + f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%" + ) + + @rank_zero_only + def on_validation_start(self, trainer, pl_module): + self.write(f"Rendering validation image ...") + + @rank_zero_only + def on_test_start(self, trainer, pl_module): + self.write(f"Rendering video ...") + + @rank_zero_only + def on_predict_start(self, trainer, pl_module): + self.write(f"Exporting mesh assets ...") diff --git a/3D_Stage/lrm/utils/config.py b/3D_Stage/lrm/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..6c490d30efe9fdc0a104d8e0854eba3474518e4e --- /dev/null +++ b/3D_Stage/lrm/utils/config.py @@ -0,0 +1,140 @@ +import os +from dataclasses import dataclass, field +from datetime import datetime + +from omegaconf import OmegaConf + +import lrm +from .typing import * + +# ============ Register OmegaConf Resolvers ============= # +OmegaConf.register_new_resolver( + "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) +) +OmegaConf.register_new_resolver("add", lambda a, b: a + b) +OmegaConf.register_new_resolver("sub", lambda a, b: a - b) +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) +OmegaConf.register_new_resolver("div", lambda a, b: a / b) +OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) +OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) +OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub)) +OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) +OmegaConf.register_new_resolver("gt0", lambda s: s > 0) +OmegaConf.register_new_resolver("not", lambda s: not s) + + +def calc_num_train_steps(num_data, batch_size, max_epochs, num_nodes, num_cards=8): + return int(num_data / (num_nodes * num_cards * batch_size)) * max_epochs + + +OmegaConf.register_new_resolver("calc_num_train_steps", calc_num_train_steps) + +# ======================================================= # + + +# ============== Automatic Name Resolvers =============== # +def get_naming_convention(cfg): + # TODO + name = f"lrm_{cfg.system.backbone.num_layers}" + return name + + +# ======================================================= # + + +@dataclass +class ExperimentConfig: + name: str = "default" + description: str = "" + tag: str = "" + seed: int = 0 + use_timestamp: bool = True + timestamp: Optional[str] = None + exp_root_dir: str = "outputs" + + ### these shouldn't be set manually + exp_dir: str = "outputs/default" + trial_name: str = "exp" + trial_dir: str = "outputs/default/exp" + n_gpus: int = 1 + ### + + resume: Optional[str] = None + + data_cls: str = "" + data: dict = field(default_factory=dict) + + system_cls: str = "" + system: dict = field(default_factory=dict) + + # accept pytorch-lightning trainer parameters + # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api + trainer: dict = field(default_factory=dict) + + # accept pytorch-lightning checkpoint callback parameters + # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint + checkpoint: dict = field(default_factory=dict) + + +def load_config( + *yamls: str, cli_args: list = [], from_string=False, makedirs=True, **kwargs +) -> Any: + if from_string: + parse_func = OmegaConf.create + else: + parse_func = OmegaConf.load + yaml_confs = [] + for y in yamls: + conf = parse_func(y) + extends = conf.pop("extends", None) + if extends: + assert os.path.exists(extends), f"File {extends} does not exist." + yaml_confs.append(OmegaConf.load(extends)) + yaml_confs.append(conf) + cli_conf = OmegaConf.from_cli(cli_args) + cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) + OmegaConf.resolve(cfg) + assert isinstance(cfg, DictConfig) + scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg) + + # post processing + # auto naming + if scfg.name == "auto": + scfg.name = get_naming_convention(scfg) + # add timestamp + if not scfg.tag and not scfg.use_timestamp: + raise ValueError("Either tag is specified or use_timestamp is True.") + scfg.trial_name = scfg.tag + # if resume from an existing config, scfg.timestamp should not be None + if scfg.timestamp is None: + scfg.timestamp = "" + if scfg.use_timestamp: + if scfg.n_gpus > 1: + lrm.warn( + "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag." + ) + else: + scfg.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S") + # make directories + scfg.trial_name += scfg.timestamp + scfg.exp_dir = os.path.join(scfg.exp_root_dir, scfg.name) + scfg.trial_dir = os.path.join(scfg.exp_dir, scfg.trial_name) + + if makedirs: + os.makedirs(scfg.trial_dir, exist_ok=True) + + return scfg + + +def config_to_primitive(config, resolve: bool = True) -> Any: + return OmegaConf.to_container(config, resolve=resolve) + + +def dump_config(path: str, config) -> None: + with open(path, "w") as fp: + OmegaConf.save(config=config, f=fp) + + +def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: + scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg) + return scfg diff --git a/3D_Stage/lrm/utils/misc.py b/3D_Stage/lrm/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..80bfa1e9ba4c1a46e3e268d691be6e8a3562cccc --- /dev/null +++ b/3D_Stage/lrm/utils/misc.py @@ -0,0 +1,211 @@ +import gc +import os +import re +import time +from collections import defaultdict +from contextlib import contextmanager + +import torch +from packaging import version + +import lrm +from .config import config_to_primitive +from .typing import * + + +def parse_version(ver: str): + return version.parse(ver) + + +def get_rank(): + # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, + # therefore LOCAL_RANK needs to be checked first + rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") + for key in rank_keys: + rank = os.environ.get(key) + if rank is not None: + return int(rank) + return 0 + + +def get_device(): + return torch.device(f"cuda:{get_rank()}") + + +def load_module_weights( + path, module_name=None, ignore_modules=None, mapping=None, map_location=None +) -> Tuple[dict, int, int]: + if module_name is not None and ignore_modules is not None: + raise ValueError("module_name and ignore_modules cannot be both set") + if map_location is None: + map_location = get_device() + + ckpt = torch.load(path, map_location=map_location) + state_dict = ckpt["state_dict"] + + if mapping is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + if any([k.startswith(m["to"]) for m in mapping]): + pass + else: + state_dict_to_load[k] = v + for k, v in state_dict.items(): + for m in mapping: + if k.startswith(m["from"]): + k_dest = k.replace(m["from"], m["to"]) + lrm.info(f"Mapping {k} => {k_dest}") + state_dict_to_load[k_dest] = v.clone() + state_dict = state_dict_to_load + + state_dict_to_load = state_dict + + if ignore_modules is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + ignore = any( + [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] + ) + if ignore: + continue + state_dict_to_load[k] = v + + if module_name is not None: + state_dict_to_load = {} + for k, v in state_dict.items(): + m = re.match(rf"^{module_name}\.(.*)$", k) + if m is None: + continue + state_dict_to_load[m.group(1)] = v + + return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] + + +def C(value: Any, epoch: int, global_step: int) -> float: + if isinstance(value, int) or isinstance(value, float): + pass + else: + value = config_to_primitive(value) + if not isinstance(value, list): + raise TypeError("Scalar specification only supports list, got", type(value)) + if len(value) == 3: + value = [0] + value + assert len(value) == 4 + start_step, start_value, end_value, end_step = value + if isinstance(end_step, int): + current_step = global_step + value = start_value + (end_value - start_value) * max( + min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 + ) + elif isinstance(end_step, float): + current_step = epoch + value = start_value + (end_value - start_value) * max( + min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 + ) + return value + + +def cleanup(): + gc.collect() + torch.cuda.empty_cache() + try: + import tinycudann as tcnn + + tcnn.free_temporary_memory() + except: + pass + + +def finish_with_cleanup(func: Callable): + def wrapper(*args, **kwargs): + out = func(*args, **kwargs) + cleanup() + return out + + return wrapper + + +def _distributed_available(): + return torch.distributed.is_available() and torch.distributed.is_initialized() + + +def barrier(): + if not _distributed_available(): + return + else: + torch.distributed.barrier() + + +def broadcast(tensor, src=0): + if not _distributed_available(): + return tensor + else: + torch.distributed.broadcast(tensor, src=src) + return tensor + + +def enable_gradient(model, enabled: bool = True) -> None: + for param in model.parameters(): + param.requires_grad_(enabled) + + +class TimeRecorder: + _instance = None + + def __init__(self): + self.items = {} + self.accumulations = defaultdict(list) + self.time_scale = 1000.0 # ms + self.time_unit = "ms" + self.enabled = False + + def __new__(cls): + # singleton + if cls._instance is None: + cls._instance = super(TimeRecorder, cls).__new__(cls) + return cls._instance + + def enable(self, enabled: bool) -> None: + self.enabled = enabled + + def start(self, name: str) -> None: + if not self.enabled: + return + torch.cuda.synchronize() + self.items[name] = time.time() + + def end(self, name: str, accumulate: bool = False) -> float: + if not self.enabled or name not in self.items: + return + torch.cuda.synchronize() + start_time = self.items.pop(name) + delta = time.time() - start_time + if accumulate: + self.accumulations[name].append(delta) + t = delta * self.time_scale + lrm.info(f"{name}: {t:.2f}{self.time_unit}") + + def get_accumulation(self, name: str, average: bool = False) -> float: + if not self.enabled or name not in self.accumulations: + return + acc = self.accumulations.pop(name) + total = sum(acc) + if average: + t = total / len(acc) * self.time_scale + else: + t = total * self.time_scale + lrm.info(f"{name} for {len(acc)} times: {t:.2f}{self.time_unit}") + + +### global time recorder +time_recorder = TimeRecorder() + + +@contextmanager +def time_recorder_enabled(): + enabled = time_recorder.enabled + time_recorder.enable(enabled=True) + try: + yield + finally: + time_recorder.enable(enabled=enabled) diff --git a/3D_Stage/lrm/utils/ops.py b/3D_Stage/lrm/utils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..3b1e06d3d7136575ed96dad1764b72374d5b5529 --- /dev/null +++ b/3D_Stage/lrm/utils/ops.py @@ -0,0 +1,435 @@ +from collections import defaultdict +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +import lrm +from .typing import * + + +def dot(x, y): + return torch.sum(x * y, -1, keepdim=True) + + +def reflect(x, n): + return 2 * dot(x, n) * n - x + + +ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] + + +def scale_tensor( + dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale +): + if inp_scale is None: + inp_scale = (0, 1) + if tgt_scale is None: + tgt_scale = (0, 1) + if isinstance(tgt_scale, Tensor): + assert dat.shape[-1] == tgt_scale.shape[-1] + dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) + dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] + return dat + + +class _TruncExp(Function): # pylint: disable=abstract-method + # Implementation from torch-ngp: + # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, x): # pylint: disable=arguments-differ + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd + def backward(ctx, g): # pylint: disable=arguments-differ + x = ctx.saved_tensors[0] + return g * torch.exp(torch.clamp(x, max=15)) + + +trunc_exp = _TruncExp.apply + + +def get_activation(name) -> Callable: + if name is None: + return lambda x: x + name = name.lower() + if name == "none": + return lambda x: x + elif name == "lin2srgb": + return lambda x: torch.where( + x > 0.0031308, + torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, + 12.92 * x, + ).clamp(0.0, 1.0) + elif name == "exp": + return lambda x: torch.exp(x) + elif name == "shifted_exp": + return lambda x: torch.exp(x - 1.0) + elif name == "trunc_exp": + return trunc_exp + elif name == "shifted_trunc_exp": + return lambda x: trunc_exp(x - 1.0) + elif name == "sigmoid": + return lambda x: torch.sigmoid(x) + elif name == "tanh": + return lambda x: torch.tanh(x) + elif name == "shifted_softplus": + return lambda x: F.softplus(x - 1.0) + elif name == "scale_-11_01": + return lambda x: x * 0.5 + 0.5 + elif name == "negative": + return lambda x: -x + else: + try: + return getattr(F, name) + except AttributeError: + raise ValueError(f"Unknown activation function: {name}") + + +def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any: + if chunk_size <= 0: + return func(*args, **kwargs) + B = None + for arg in list(args) + list(kwargs.values()): + if isinstance(arg, torch.Tensor): + B = arg.shape[0] + break + assert ( + B is not None + ), "No tensor found in args or kwargs, cannot determine batch size." + out = defaultdict(list) + out_type = None + # max(1, B) to support B == 0 + for i in range(0, max(1, B), chunk_size): + out_chunk = func( + *[ + arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for arg in args + ], + **{ + k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg + for k, arg in kwargs.items() + }, + ) + if out_chunk is None: + continue + out_type = type(out_chunk) + if isinstance(out_chunk, torch.Tensor): + out_chunk = {0: out_chunk} + elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): + chunk_length = len(out_chunk) + out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} + elif isinstance(out_chunk, dict): + pass + else: + print( + f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}." + ) + exit(1) + for k, v in out_chunk.items(): + v = v if torch.is_grad_enabled() else v.detach() + out[k].append(v) + + if out_type is None: + return None + + out_merged: Dict[Any, Optional[torch.Tensor]] = {} + for k, v in out.items(): + if all([vv is None for vv in v]): + # allow None in return value + out_merged[k] = None + elif all([isinstance(vv, torch.Tensor) for vv in v]): + out_merged[k] = torch.cat(v, dim=0) + else: + raise TypeError( + f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}" + ) + + if out_type is torch.Tensor: + return out_merged[0] + elif out_type in [tuple, list]: + return out_type([out_merged[i] for i in range(chunk_length)]) + elif out_type is dict: + return out_merged + + +def get_ray_directions( + H: int, + W: int, + focal: Union[float, Tuple[float, float]], + principal: Optional[Tuple[float, float]] = None, + use_pixel_centers: bool = True, + normalize: bool = True, +) -> Float[Tensor, "H W 3"]: + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + + Inputs: + H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + pixel_center = 0.5 if use_pixel_centers else 0 + + if isinstance(focal, float): + fx, fy = focal, focal + cx, cy = W / 2, H / 2 + else: + fx, fy = focal + assert principal is not None + cx, cy = principal + + i, j = torch.meshgrid( + torch.arange(W, dtype=torch.float32) + pixel_center, + torch.arange(H, dtype=torch.float32) + pixel_center, + indexing="xy", + ) + + directions: Float[Tensor, "H W 3"] = torch.stack( + [(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1 + ) + + if normalize: + directions = F.normalize(directions, dim=-1) + + return directions + + +def get_rays( + directions: Float[Tensor, "... 3"], + c2w: Float[Tensor, "... 4 4"], + keepdim=False, + noise_scale=0.0, + normalize=False, +) -> Tuple[Float[Tensor, "... 3"], Float[Tensor, "... 3"]]: + # Rotate ray directions from camera coordinate to the world coordinate + assert directions.shape[-1] == 3 + + if directions.ndim == 2: # (N_rays, 3) + if c2w.ndim == 2: # (4, 4) + c2w = c2w[None, :, :] + assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4) + rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3) + rays_o = c2w[:, :3, 3].expand(rays_d.shape) + elif directions.ndim == 3: # (H, W, 3) + assert c2w.ndim in [2, 3] + if c2w.ndim == 2: # (4, 4) + rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum( + -1 + ) # (H, W, 3) + rays_o = c2w[None, None, :3, 3].expand(rays_d.shape) + elif c2w.ndim == 3: # (B, 4, 4) + rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( + -1 + ) # (B, H, W, 3) + rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) + elif directions.ndim == 4: # (B, H, W, 3) + assert c2w.ndim == 3 # (B, 4, 4) + rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum( + -1 + ) # (B, H, W, 3) + rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape) + + # add camera noise to avoid grid-like artifect + # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373 + if noise_scale > 0: + rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale + rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale + + if normalize: + rays_d = F.normalize(rays_d, dim=-1) + if not keepdim: + rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) + + return rays_o, rays_d + + +def get_projection_matrix( + fovy: Union[float, Float[Tensor, "B"]], aspect_wh: float, near: float, far: float +) -> Float[Tensor, "*B 4 4"]: + if isinstance(fovy, float): + proj_mtx = torch.zeros(4, 4, dtype=torch.float32) + proj_mtx[0, 0] = 1.0 / (math.tan(fovy / 2.0) * aspect_wh) + proj_mtx[1, 1] = -1.0 / math.tan( + fovy / 2.0 + ) # add a negative sign here as the y axis is flipped in nvdiffrast output + proj_mtx[2, 2] = -(far + near) / (far - near) + proj_mtx[2, 3] = -2.0 * far * near / (far - near) + proj_mtx[3, 2] = -1.0 + else: + batch_size = fovy.shape[0] + proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32) + proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh) + proj_mtx[:, 1, 1] = -1.0 / torch.tan( + fovy / 2.0 + ) # add a negative sign here as the y axis is flipped in nvdiffrast output + proj_mtx[:, 2, 2] = -(far + near) / (far - near) + proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near) + proj_mtx[:, 3, 2] = -1.0 + return proj_mtx + + +def get_mvp_matrix( + c2w: Float[Tensor, "*B 4 4"], proj_mtx: Float[Tensor, "*B 4 4"] +) -> Float[Tensor, "*B 4 4"]: + # calculate w2c from c2w: R' = Rt, t' = -Rt * t + # mathematically equivalent to (c2w)^-1 + if c2w.ndim == 2: + assert proj_mtx.ndim == 2 + w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w) + w2c[:3, :3] = c2w[:3, :3].permute(1, 0) + w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:] + w2c[3, 3] = 1.0 + else: + w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w) + w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1) + w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:] + w2c[:, 3, 3] = 1.0 + # calculate mvp matrix by proj_mtx @ w2c (mv_mtx) + mvp_mtx = proj_mtx @ w2c + return mvp_mtx + + +def get_intrinsic_from_fov(fov, H, W, bs=-1): + focal_length = 0.5 * H / np.tan(0.5 * fov) + intrinsic = np.identity(3, dtype=np.float32) + intrinsic[0, 0] = focal_length + intrinsic[1, 1] = focal_length + intrinsic[0, 2] = W / 2.0 + intrinsic[1, 2] = H / 2.0 + + if bs > 0: + intrinsic = intrinsic[None].repeat(bs, axis=0) + + return torch.from_numpy(intrinsic) + + +def binary_cross_entropy(input, target): + """ + F.binary_cross_entropy is not numerically stable in mixed-precision training. + """ + return -(target * torch.log(input) + (1 - target) * torch.log(1 - input)).mean() + + +def tet_sdf_diff( + vert_sdf: Float[Tensor, "Nv 1"], tet_edges: Integer[Tensor, "Ne 2"] +) -> Float[Tensor, ""]: + sdf_f1x6x2 = vert_sdf[:, 0][tet_edges.reshape(-1)].reshape(-1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = F.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float() + ) + F.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float() + ) + return sdf_diff + + +def validate_empty_rays(ray_indices, t_start, t_end): + if ray_indices.nelement() == 0: + lrm.warn("Empty rays_indices!") + ray_indices = torch.LongTensor([0]).to(ray_indices) + t_start = torch.Tensor([0]).to(ray_indices) + t_end = torch.Tensor([0]).to(ray_indices) + return ray_indices, t_start, t_end + + +def rays_intersect_bbox( + rays_o: Float[Tensor, "N 3"], + rays_d: Float[Tensor, "N 3"], + radius: Float, + near: Float = 0.0, + valid_thresh: Float = 0.01, + background: bool = False, +): + input_shape = rays_o.shape[:-1] + rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3) + rays_d_valid = torch.where( + rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d + ) + if type(radius) in [int, float]: + radius = torch.FloatTensor( + [[-radius, radius], [-radius, radius], [-radius, radius]] + ).to(rays_o.device) + radius = ( + 1.0 - 1.0e-3 + ) * radius # tighten the radius to make sure the intersection point lies in the bounding box + interx0 = (radius[..., 1] - rays_o) / rays_d_valid + interx1 = (radius[..., 0] - rays_o) / rays_d_valid + t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near) + t_far = torch.maximum(interx0, interx1).amin(dim=-1) + + # check wheter a ray intersects the bbox or not + rays_valid = t_far - t_near > valid_thresh + + t_near_valid, t_far_valid = t_near[rays_valid], t_far[rays_valid] + global_near = t_near_valid.min().item() + global_far = t_far_valid.max().item() + + t_near[torch.where(~rays_valid)] = 0.0 + t_far[torch.where(~rays_valid)] = 0.0 + + t_near = t_near.view(*input_shape, 1) + t_far = t_far.view(*input_shape, 1) + rays_valid = rays_valid.view(*input_shape) + + return t_near, t_far, rays_valid + + +def get_plucker_rays(rays_o: Float[Tensor, "*N 3"], rays_d: Float[Tensor, "*N 3"]) -> Float[Tensor, "*N 6"]: + rays_o = F.normalize(rays_o, dim=-1) + rays_d = F.normalize(rays_d, dim=-1) + return torch.cat([ + rays_o.cross(rays_d), + rays_d + ], dim=-1) + + +def c2w_to_polar(c2w: Float[Tensor, "4 4"]) -> Tuple[float, float, float]: + cam_pos = c2w[:3, 3] + x, y, z = cam_pos.tolist() + distance = cam_pos.norm().item() + elevation = math.asin(z / distance) + if abs(x) < 1.0e-5 and abs(y) < 1.0e-5: + azimuth = 0 + else: + azimuth = math.atan2(y, x) + if azimuth < 0: + azimuth += 2 * math.pi + + return elevation, azimuth, distance + + +def polar_to_c2w(elevation: float, azimuth: float, distance: float) -> Float[Tensor, "4 4"]: + """ + Compute L = p - C. + Normalize L. + Compute s = L x u. (cross product) + Normalize s. + Compute u' = s x L. + rotation = [s, u, -l] + """ + z = distance * math.sin(elevation) + x = distance * math.cos(elevation) * math.cos(azimuth) + y = distance * math.cos(elevation) * math.sin(azimuth) + l = -torch.as_tensor([x, y, z]).float() + l = F.normalize(l, dim=0) + u = torch.as_tensor([0.0, 0.0, 1.0]).float() + s = l.cross(u) + s = F.normalize(s, dim=0) + u = s.cross(l) + rot = torch.stack([s, u, -l], dim=0).T + c2w = torch.zeros((4, 4), dtype=torch.float32) + c2w[:3, :3] = rot + c2w[:3, 3] = torch.as_tensor([x, y, z]) + c2w[3, 3] = 1 + return c2w diff --git a/3D_Stage/lrm/utils/rasterize.py b/3D_Stage/lrm/utils/rasterize.py new file mode 100644 index 0000000000000000000000000000000000000000..448f37ea4432b9d351dee39f77e82665a4a78cbb --- /dev/null +++ b/3D_Stage/lrm/utils/rasterize.py @@ -0,0 +1,81 @@ +import nvdiffrast.torch as dr +import torch + +from .typing import * + + +class NVDiffRasterizerContext: + def __init__(self, context_type: str, device: torch.device) -> None: + self.device = device + self.ctx = self.initialize_context(context_type, device) + + def initialize_context( + self, context_type: str, device: torch.device + ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]: + context_type = "cuda" + if context_type == "gl": + return dr.RasterizeGLContext(device=device) + elif context_type == "cuda": + return dr.RasterizeCudaContext(device=device) + else: + raise ValueError(f"Unknown rasterizer context type: {context_type}") + + def vertex_transform( + self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"] + ) -> Float[Tensor, "B Nv 4"]: + with torch.cuda.amp.autocast(enabled=False): + verts_homo = torch.cat( + [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1 + ) + verts_clip = torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1)) + return verts_clip + + def rasterize( + self, + pos: Float[Tensor, "B Nv 4"], + tri: Integer[Tensor, "Nf 3"], + resolution: Union[int, Tuple[int, int]], + ): + # rasterize in instance mode (single topology) + return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True) + + def rasterize_one( + self, + pos: Float[Tensor, "Nv 4"], + tri: Integer[Tensor, "Nf 3"], + resolution: Union[int, Tuple[int, int]], + ): + # rasterize one single mesh under a single viewpoint + rast, rast_db = self.rasterize(pos[None, ...], tri, resolution) + return rast[0], rast_db[0] + + def antialias( + self, + color: Float[Tensor, "B H W C"], + rast: Float[Tensor, "B H W 4"], + pos: Float[Tensor, "B Nv 4"], + tri: Integer[Tensor, "Nf 3"], + ) -> Float[Tensor, "B H W C"]: + return dr.antialias(color.float(), rast, pos.float(), tri.int()) + + def interpolate( + self, + attr: Float[Tensor, "B Nv C"], + rast: Float[Tensor, "B H W 4"], + tri: Integer[Tensor, "Nf 3"], + rast_db=None, + diff_attrs=None, + ) -> Float[Tensor, "B H W C"]: + return dr.interpolate( + attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs + ) + + def interpolate_one( + self, + attr: Float[Tensor, "Nv C"], + rast: Float[Tensor, "B H W 4"], + tri: Integer[Tensor, "Nf 3"], + rast_db=None, + diff_attrs=None, + ) -> Float[Tensor, "B H W C"]: + return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs) diff --git a/3D_Stage/lrm/utils/saving.py b/3D_Stage/lrm/utils/saving.py new file mode 100644 index 0000000000000000000000000000000000000000..2f8fa010ad1dc816f9d46d7c131df5e2ca7e856c --- /dev/null +++ b/3D_Stage/lrm/utils/saving.py @@ -0,0 +1,725 @@ +import json +import os +import re +import shutil + +import cv2 +import imageio +import matplotlib.pyplot as plt +import numpy as np +import torch +import wandb +from matplotlib import cm +from matplotlib.colors import LinearSegmentedColormap +from PIL import Image, ImageDraw +from pytorch_lightning.loggers import WandbLogger + +import lrm +from ..models.mesh import Mesh +from ..utils.typing import * + + +class SaverMixin: + _save_dir: Optional[str] = None + _wandb_logger: Optional[WandbLogger] = None + + def set_save_dir(self, save_dir: str): + self._save_dir = save_dir + + def get_save_dir(self): + if self._save_dir is None: + raise ValueError("Save dir is not set") + return self._save_dir + + def convert_data(self, data): + if data is None: + return None + elif isinstance(data, np.ndarray): + return data + elif isinstance(data, torch.Tensor): + if data.dtype in [torch.float16, torch.bfloat16]: + data = data.float() + return data.detach().cpu().numpy() + elif isinstance(data, list): + return [self.convert_data(d) for d in data] + elif isinstance(data, dict): + return {k: self.convert_data(v) for k, v in data.items()} + else: + raise TypeError( + "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", + type(data), + ) + + def get_save_path(self, filename): + save_path = os.path.join(self.get_save_dir(), filename) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + return save_path + + DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)} + DEFAULT_UV_KWARGS = { + "data_format": "HWC", + "data_range": (0, 1), + "cmap": "checkerboard", + } + DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"} + DEFAULT_GRID_KWARGS = {"align": "max"} + + def get_rgb_image_(self, img, data_format, data_range, rgba=False): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + if img.dtype != np.uint8: + img = img.clip(min=data_range[0], max=data_range[1]) + img = ( + (img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0 + ).astype(np.uint8) + nc = 4 if rgba else 3 + imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)] + imgs = [ + img_ + if img_.shape[-1] == nc + else np.concatenate( + [ + img_, + np.zeros( + (img_.shape[0], img_.shape[1], nc - img_.shape[2]), + dtype=img_.dtype, + ), + ], + axis=-1, + ) + for img_ in imgs + ] + img = np.concatenate(imgs, axis=1) + if rgba: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + else: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_rgb_image( + self, + filename, + img, + data_format, + data_range, + name: Optional[str] = None, + step: Optional[int] = None, + ): + img = self.get_rgb_image_(img, data_format, data_range) + cv2.imwrite(filename, img) + if name and self._wandb_logger: + self._wandb_logger.log_image( + key=name, images=[self.get_save_path(filename)], step=step + ) + + def save_rgb_image( + self, + filename, + img, + data_format=DEFAULT_RGB_KWARGS["data_format"], + data_range=DEFAULT_RGB_KWARGS["data_range"], + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + save_path = self.get_save_path(filename) + self._save_rgb_image(save_path, img, data_format, data_range, name, step) + return save_path + + def get_uv_image_(self, img, data_format, data_range, cmap): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + img = img.clip(min=data_range[0], max=data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in ["checkerboard", "color"] + if cmap == "checkerboard": + n_grid = 64 + mask = (img * n_grid).astype(int) + mask = (mask[..., 0] + mask[..., 1]) % 2 == 0 + img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255 + img[mask] = np.array([255, 0, 255], dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif cmap == "color": + img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) + img_[..., 0] = (img[..., 0] * 255).astype(np.uint8) + img_[..., 1] = (img[..., 1] * 255).astype(np.uint8) + img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR) + img = img_ + return img + + def save_uv_image( + self, + filename, + img, + data_format=DEFAULT_UV_KWARGS["data_format"], + data_range=DEFAULT_UV_KWARGS["data_range"], + cmap=DEFAULT_UV_KWARGS["cmap"], + ) -> str: + save_path = self.get_save_path(filename) + img = self.get_uv_image_(img, data_format, data_range, cmap) + cv2.imwrite(save_path, img) + return save_path + + def get_grayscale_image_(self, img, data_range, cmap): + img = self.convert_data(img) + img = np.nan_to_num(img) + if data_range is None: + img = (img - img.min()) / (img.max() - img.min()) + else: + img = img.clip(data_range[0], data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in [None, "jet", "magma", "spectral"] + if cmap == None: + img = (img * 255.0).astype(np.uint8) + img = np.repeat(img[..., None], 3, axis=2) + elif cmap == "jet": + img = (img * 255.0).astype(np.uint8) + img = cv2.applyColorMap(img, cv2.COLORMAP_JET) + elif cmap == "magma": + img = 1.0 - img + base = cm.get_cmap("magma") + num_bins = 256 + colormap = LinearSegmentedColormap.from_list( + f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins + )(np.linspace(0, 1, num_bins))[:, :3] + a = np.floor(img * 255.0) + b = (a + 1).clip(max=255.0) + f = img * 255.0 - a + a = a.astype(np.uint16).clip(0, 255) + b = b.astype(np.uint16).clip(0, 255) + img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None] + img = (img * 255.0).astype(np.uint8) + elif cmap == "spectral": + colormap = plt.get_cmap("Spectral") + + def blend_rgba(image): + image = image[..., :3] * image[..., -1:] + ( + 1.0 - image[..., -1:] + ) # blend A to RGB + return image + + img = colormap(img) + img = blend_rgba(img) + img = (img * 255).astype(np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def _save_grayscale_image( + self, + filename, + img, + data_range, + cmap, + name: Optional[str] = None, + step: Optional[int] = None, + ): + img = self.get_grayscale_image_(img, data_range, cmap) + cv2.imwrite(filename, img) + if name and self._wandb_logger: + self._wandb_logger.log_image( + key=name, images=[self.get_save_path(filename)], step=step + ) + + def save_grayscale_image( + self, + filename, + img, + data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"], + cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"], + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + save_path = self.get_save_path(filename) + self._save_grayscale_image(save_path, img, data_range, cmap, name, step) + return save_path + + def get_image_grid_(self, imgs, align): + if isinstance(imgs[0], list): + return np.concatenate( + [self.get_image_grid_(row, align) for row in imgs], axis=0 + ) + cols = [] + for col in imgs: + assert col["type"] in ["rgb", "uv", "grayscale"] + if col["type"] == "rgb": + rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() + rgb_kwargs.update(col["kwargs"]) + cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs)) + elif col["type"] == "uv": + uv_kwargs = self.DEFAULT_UV_KWARGS.copy() + uv_kwargs.update(col["kwargs"]) + cols.append(self.get_uv_image_(col["img"], **uv_kwargs)) + elif col["type"] == "grayscale": + grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() + grayscale_kwargs.update(col["kwargs"]) + cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs)) + + if align == "max": + h = max([col.shape[0] for col in cols]) + elif align == "min": + h = min([col.shape[0] for col in cols]) + elif isinstance(align, int): + h = align + else: + raise ValueError( + f"Unsupported image grid align: {align}, should be min, max, or int" + ) + + for i in range(len(cols)): + if cols[i].shape[0] != h: + w = int(cols[i].shape[1] * h / cols[i].shape[0]) + cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_CUBIC) + return np.concatenate(cols, axis=1) + + def save_image_grid( + self, + filename, + imgs, + align=DEFAULT_GRID_KWARGS["align"], + name: Optional[str] = None, + step: Optional[int] = None, + texts: Optional[List[float]] = None, + ): + save_path = self.get_save_path(filename) + img = self.get_image_grid_(imgs, align=align) + + if texts is not None: + img = Image.fromarray(img) + draw = ImageDraw.Draw(img) + black, white = (0, 0, 0), (255, 255, 255) + for i, text in enumerate(texts): + draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white) + draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white) + draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black) + img = np.asarray(img) + + cv2.imwrite(save_path, img) + if name and self._wandb_logger: + self._wandb_logger.log_image(key=name, images=[save_path], step=step) + return save_path + + def save_image(self, filename, img) -> str: + save_path = self.get_save_path(filename) + img = self.convert_data(img) + assert img.dtype == np.uint8 or img.dtype == np.uint16 + if img.ndim == 3 and img.shape[-1] == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif img.ndim == 3 and img.shape[-1] == 4: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + cv2.imwrite(save_path, img) + return save_path + + def save_cubemap(self, filename, img, data_range=(0, 1), rgba=False) -> str: + save_path = self.get_save_path(filename) + img = self.convert_data(img) + assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2] + + imgs_full = [] + for start in range(0, img.shape[-1], 3): + img_ = img[..., start : start + 3] + img_ = np.stack( + [ + self.get_rgb_image_(img_[i], "HWC", data_range, rgba=rgba) + for i in range(img_.shape[0]) + ], + axis=0, + ) + size = img_.shape[1] + placeholder = np.zeros((size, size, 3), dtype=np.float32) + img_full = np.concatenate( + [ + np.concatenate( + [placeholder, img_[2], placeholder, placeholder], axis=1 + ), + np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1), + np.concatenate( + [placeholder, img_[3], placeholder, placeholder], axis=1 + ), + ], + axis=0, + ) + imgs_full.append(img_full) + + imgs_full = np.concatenate(imgs_full, axis=1) + cv2.imwrite(save_path, imgs_full) + return save_path + + def save_data(self, filename, data) -> str: + data = self.convert_data(data) + if isinstance(data, dict): + if not filename.endswith(".npz"): + filename += ".npz" + save_path = self.get_save_path(filename) + np.savez(save_path, **data) + else: + if not filename.endswith(".npy"): + filename += ".npy" + save_path = self.get_save_path(filename) + np.save(save_path, data) + return save_path + + def save_state_dict(self, filename, data) -> str: + save_path = self.get_save_path(filename) + torch.save(data, save_path) + return save_path + + def save_img_sequence( + self, + filename, + img_dir, + matcher, + save_format="mp4", + fps=30, + name: Optional[str] = None, + step: Optional[int] = None, + ) -> str: + assert save_format in ["gif", "mp4"] + if not filename.endswith(save_format): + filename += f".{save_format}" + save_path = self.get_save_path(filename) + matcher = re.compile(matcher) + img_dir = os.path.join(self.get_save_dir(), img_dir) + imgs = [] + for f in os.listdir(img_dir): + if matcher.search(f): + imgs.append(f) + imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) + imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] + + if save_format == "gif": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps, palettesize=256) + elif save_format == "mp4": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(save_path, imgs, fps=fps) + if name and self._wandb_logger: + lrm.warn("Wandb logger does not support video logging yet!") + return save_path + + def save_img_sequences( + self, + seq_dir, + matcher, + save_format="mp4", + fps=30, + delete=True, + name: Optional[str] = None, + step: Optional[int] = None, + ): + seq_dir_ = os.path.join(self.get_save_dir(), seq_dir) + for f in os.listdir(seq_dir_): + img_dir_ = os.path.join(seq_dir_, f) + if not os.path.isdir(img_dir_): + continue + try: + self.save_img_sequence( + os.path.join(seq_dir, f), + os.path.join(seq_dir, f), + matcher, + save_format=save_format, + fps=fps, + name=f"{name}_{f}", + step=step, + ) + if delete: + shutil.rmtree(img_dir_) + except: + lrm.warn(f"Video saving for directory {seq_dir_} failed!") + + def save_mesh(self, filename, v_pos, t_pos_idx, v_tex=None, t_tex_idx=None) -> str: + import trimesh + + save_path = self.get_save_path(filename) + v_pos = self.convert_data(v_pos) + t_pos_idx = self.convert_data(t_pos_idx) + mesh = trimesh.Trimesh(vertices=v_pos, faces=t_pos_idx) + mesh.export(save_path) + return save_path + + def save_obj( + self, + filename: str, + mesh: Mesh, + save_mat: bool = False, + save_normal: bool = False, + save_uv: bool = False, + save_vertex_color: bool = True, + map_Kd: Optional[Float[Tensor, "H W 3"]] = None, + map_Ks: Optional[Float[Tensor, "H W 3"]] = None, + map_Bump: Optional[Float[Tensor, "H W 3"]] = None, + map_Pm: Optional[Float[Tensor, "H W 1"]] = None, + map_Pr: Optional[Float[Tensor, "H W 1"]] = None, + map_format: str = "jpg", + ) -> List[str]: + save_paths: List[str] = [] + if not filename.endswith(".obj"): + filename += ".obj" + v_pos, t_pos_idx = self.convert_data(mesh.v_pos), self.convert_data( + mesh.t_pos_idx + ) + v_nrm, v_tex, t_tex_idx, v_rgb = None, None, None, None + if save_normal: + v_nrm = self.convert_data(mesh.v_nrm) + if save_uv: + v_tex, t_tex_idx = self.convert_data(mesh.v_tex), self.convert_data( + mesh.t_tex_idx + ) + if save_vertex_color: + v_rgb = self.convert_data(mesh.v_rgb) + matname, mtllib = None, None + if save_mat: + matname = "default" + mtl_filename = filename.replace(".obj", ".mtl") + mtllib = os.path.basename(mtl_filename) + mtl_save_paths = self._save_mtl( + mtl_filename, + matname, + map_Kd=self.convert_data(map_Kd), + map_Ks=self.convert_data(map_Ks), + map_Bump=self.convert_data(map_Bump), + map_Pm=self.convert_data(map_Pm), + map_Pr=self.convert_data(map_Pr), + map_format=map_format, + ) + save_paths += mtl_save_paths + obj_save_path = self._save_obj( + filename, + v_pos, + t_pos_idx, + v_nrm=v_nrm, + v_tex=v_tex, + t_tex_idx=t_tex_idx, + v_rgb=v_rgb, + matname=matname, + mtllib=mtllib, + ) + save_paths.append(obj_save_path) + return save_paths + + def _save_obj( + self, + filename, + v_pos, + t_pos_idx, + v_nrm=None, + v_tex=None, + t_tex_idx=None, + v_rgb=None, + matname=None, + mtllib=None, + ) -> str: + obj_str = "" + if matname is not None: + obj_str += f"mtllib {mtllib}\n" + obj_str += f"g object\n" + obj_str += f"usemtl {matname}\n" + for i in range(len(v_pos)): + obj_str += f"v {v_pos[i][0]} {v_pos[i][1]} {v_pos[i][2]}" + if v_rgb is not None: + obj_str += f" {v_rgb[i][0]} {v_rgb[i][1]} {v_rgb[i][2]}" + obj_str += "\n" + if v_nrm is not None: + for v in v_nrm: + obj_str += f"vn {v[0]} {v[1]} {v[2]}\n" + if v_tex is not None: + for v in v_tex: + obj_str += f"vt {v[0]} {1.0 - v[1]}\n" + + for i in range(len(t_pos_idx)): + obj_str += "f" + for j in range(3): + obj_str += f" {t_pos_idx[i][j] + 1}/" + if v_tex is not None: + obj_str += f"{t_tex_idx[i][j] + 1}" + obj_str += "/" + if v_nrm is not None: + obj_str += f"{t_pos_idx[i][j] + 1}" + obj_str += "\n" + + save_path = self.get_save_path(filename) + with open(save_path, "w") as f: + f.write(obj_str) + return save_path + + def _save_mtl( + self, + filename, + matname, + Ka=(0.0, 0.0, 0.0), + Kd=(1.0, 1.0, 1.0), + Ks=(0.0, 0.0, 0.0), + map_Kd=None, + map_Ks=None, + map_Bump=None, + map_Pm=None, + map_Pr=None, + map_format="jpg", + step: Optional[int] = None, + ) -> List[str]: + mtl_save_path = self.get_save_path(filename) + save_paths = [mtl_save_path] + mtl_str = f"newmtl {matname}\n" + mtl_str += f"Ka {Ka[0]} {Ka[1]} {Ka[2]}\n" + if map_Kd is not None: + map_Kd_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_kd.{map_format}" + ) + mtl_str += f"map_Kd texture_kd.{map_format}\n" + self._save_rgb_image( + map_Kd_save_path, + map_Kd, + data_format="HWC", + data_range=(0, 1), + name=f"{matname}_Kd", + step=step, + ) + save_paths.append(map_Kd_save_path) + else: + mtl_str += f"Kd {Kd[0]} {Kd[1]} {Kd[2]}\n" + if map_Ks is not None: + map_Ks_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_ks.{map_format}" + ) + mtl_str += f"map_Ks texture_ks.{map_format}\n" + self._save_rgb_image( + map_Ks_save_path, + map_Ks, + data_format="HWC", + data_range=(0, 1), + name=f"{matname}_Ks", + step=step, + ) + save_paths.append(map_Ks_save_path) + else: + mtl_str += f"Ks {Ks[0]} {Ks[1]} {Ks[2]}\n" + if map_Bump is not None: + map_Bump_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_nrm.{map_format}" + ) + mtl_str += f"map_Bump texture_nrm.{map_format}\n" + self._save_rgb_image( + map_Bump_save_path, + map_Bump, + data_format="HWC", + data_range=(0, 1), + name=f"{matname}_Bump", + step=step, + ) + save_paths.append(map_Bump_save_path) + if map_Pm is not None: + map_Pm_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_metallic.{map_format}" + ) + mtl_str += f"map_Pm texture_metallic.{map_format}\n" + self._save_grayscale_image( + map_Pm_save_path, + map_Pm, + data_range=(0, 1), + cmap=None, + name=f"{matname}_refl", + step=step, + ) + save_paths.append(map_Pm_save_path) + if map_Pr is not None: + map_Pr_save_path = os.path.join( + os.path.dirname(mtl_save_path), f"texture_roughness.{map_format}" + ) + mtl_str += f"map_Pr texture_roughness.{map_format}\n" + self._save_grayscale_image( + map_Pr_save_path, + map_Pr, + data_range=(0, 1), + cmap=None, + name=f"{matname}_Ns", + step=step, + ) + save_paths.append(map_Pr_save_path) + with open(self.get_save_path(filename), "w") as f: + f.write(mtl_str) + return save_paths + + def save_glb( + self, + filename: str, + mesh: Mesh, + save_mat: bool = False, + save_normal: bool = True, + save_uv: bool = True, + save_vertex_color: bool = True, + map_Kd: Optional[Float[Tensor, "H W 3"]] = None, + map_Ks: Optional[Float[Tensor, "H W 3"]] = None, + map_Bump: Optional[Float[Tensor, "H W 3"]] = None, + map_Pm: Optional[Float[Tensor, "H W 1"]] = None, + map_Pr: Optional[Float[Tensor, "H W 1"]] = None, + map_format: str = "jpg", + ) -> List[str]: + save_paths: List[str] = [] + if not filename.endswith(".glb"): + filename += ".glb" + v_pos, t_pos_idx = self.convert_data(mesh.v_pos), self.convert_data( + mesh.t_pos_idx + ) + v_nrm, v_tex, t_tex_idx, v_rgb = None, None, None, None + if save_normal: + v_nrm = self.convert_data(mesh.v_nrm) + if save_uv: + v_tex, t_tex_idx = self.convert_data(mesh.v_tex), self.convert_data( + mesh.t_tex_idx + ) + if save_vertex_color: + v_rgb = self.convert_data(mesh.v_rgb) + + obj_save_path = self._save_glb( + filename, + v_pos, + t_pos_idx, + v_nrm=v_nrm, + v_tex=v_tex, + t_tex_idx=t_tex_idx, + v_rgb=v_rgb, + ) + save_paths.append(obj_save_path) + return save_paths + + def _save_glb( + self, + filename, + v_pos, + t_pos_idx, + v_nrm=None, + v_tex=None, + t_tex_idx=None, + v_rgb=None, + matname=None, + mtllib=None, + ) -> str: + import trimesh + + mesh = trimesh.Trimesh( + vertices=v_pos, faces=t_pos_idx, vertex_normals=v_nrm, vertex_colors=v_rgb + ) + # not tested + if v_tex is not None: + mesh.visual = trimesh.visual.TextureVisuals(uv=v_tex) + + save_path = self.get_save_path(filename) + mesh.export(save_path) + return save_path + + def save_file(self, filename, src_path, delete=False) -> str: + save_path = self.get_save_path(filename) + shutil.copyfile(src_path, save_path) + if delete: + os.remove(src_path) + return save_path + + def save_json(self, filename, payload) -> str: + save_path = self.get_save_path(filename) + with open(save_path, "w") as f: + f.write(json.dumps(payload)) + return save_path diff --git a/3D_Stage/lrm/utils/sdf.py b/3D_Stage/lrm/utils/sdf.py new file mode 100644 index 0000000000000000000000000000000000000000..b463a0c9f7482eb73b42205b92fbaadbc8ffde5f --- /dev/null +++ b/3D_Stage/lrm/utils/sdf.py @@ -0,0 +1,15 @@ +import pySDF as SDF +import cv2 +import numpy as np +import torch +from lrm.models.isosurface import MarchingTetrahedraHelper + +def get_tetra_for_mesh(mesh_path, resolution=128): + isosurface_helper = MarchingTetrahedraHelper(resolution, f"load/{resolution}_tets.npz") + isosurface_helper.points_range = (-1, 1) + mesh = trimesh.load(mesh_path) + dmtet = np.load(f"") + sdf = SDF(mesh.vertices, mesh.faces) + sdf_gt = sdf(isosurface_helper.grid_vertices.numpy()) + return sdf_gt + diff --git a/3D_Stage/lrm/utils/typing.py b/3D_Stage/lrm/utils/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..dee9f967c21f94db1ad939d7dead156d86748752 --- /dev/null +++ b/3D_Stage/lrm/utils/typing.py @@ -0,0 +1,40 @@ +""" +This module contains type annotations for the project, using +1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects +2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors + +Two types of typing checking can be used: +1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) +2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) +""" + +# Basic types +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + NamedTuple, + NewType, + Optional, + Sized, + Tuple, + Type, + TypeVar, + Union, +) + +# Tensor dtype +# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md +from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt + +# Config type +from omegaconf import DictConfig + +# PyTorch Tensor type +from torch import Tensor + +# Runtime type checking decorator +from typeguard import typechecked as typechecker diff --git a/3D_Stage/material/examples/1/1.png b/3D_Stage/material/examples/1/1.png new file mode 100644 index 0000000000000000000000000000000000000000..25719db8520ec47127497a5e1f5b826b5caed95c --- /dev/null +++ b/3D_Stage/material/examples/1/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24eab85189be0a9f4977b561807e5da1458c70f7f72b156e6af77d0ed9b25334 +size 186833 diff --git a/3D_Stage/material/examples/1/2.png b/3D_Stage/material/examples/1/2.png new file mode 100644 index 0000000000000000000000000000000000000000..75ddc041ba44a7ea3e6c50a5741575a3b7160a44 --- /dev/null +++ b/3D_Stage/material/examples/1/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dc956a3a62dbd23b3eb4a8efcec71e30ea3d612ff1fa8e1cd01c3b59e97d1bfc +size 203314 diff --git a/3D_Stage/material/examples/1/3.png b/3D_Stage/material/examples/1/3.png new file mode 100644 index 0000000000000000000000000000000000000000..3a2aeb4899988e9bb53c503e608e7254a10f1e51 --- /dev/null +++ b/3D_Stage/material/examples/1/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:636e479fd74580b1cfaeb9c1f40b097afca4c4678b084dcda5cd648e2e415d9c +size 133351 diff --git a/3D_Stage/material/examples/1/4.png b/3D_Stage/material/examples/1/4.png new file mode 100644 index 0000000000000000000000000000000000000000..9ab5a320857692af16b09c502795cf32ff2c847f --- /dev/null +++ b/3D_Stage/material/examples/1/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f9fbaf91b5d15b95d6b14e0f2331b81b6c66388ab9e38a1ecc5b07094e69272 +size 131478 diff --git a/3D_Stage/material/meta.json b/3D_Stage/material/meta.json new file mode 100644 index 0000000000000000000000000000000000000000..9ca2e428f7cea46525623179a35168bbe4ff15c9 --- /dev/null +++ b/3D_Stage/material/meta.json @@ -0,0 +1,36 @@ +{ + "locations": [ + { + "transform_matrix": [ + [ 1.0, 0.0, 0.0, 0.0 ], + [ 0.0, 0.0, 1.0, 1.5 ], + [ 0.0, 1.0, 0.0, 0.0 ], + [ 0.0, 0.0, 0.0, 1.0 ] + ] + }, + { + "transform_matrix": [ + [ -1.0, 0.0, 0.0, 0.0 ], + [ 0.0, 0.0, -1.0, -1.5 ], + [ 0.0, 1.0, 0.0, 0.0 ], + [ 0.0, 0.0, 0.0, 1.0 ] + ] + }, + { + "transform_matrix": [ + [ 0.0, 0.0, 1.0, 1.5 ], + [ -1.0, 0.0, 0.0, 0.0 ], + [ 0.0, 1.0, 0.0, 0.0 ], + [ 0.0, 0.0, 0.0, 1.0 ] + ] + }, + { + "transform_matrix": [ + [ 0.0, 0.0, -1.0, -1.5 ], + [ 1.0, 0.0, 0.0, 0.0 ], + [ 0.0, 1.0, 0.0, 0.0 ], + [ 0.0, 0.0, 0.0, 1.0 ] + ] + } + ] +} \ No newline at end of file diff --git a/3D_Stage/refine.py b/3D_Stage/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..7d87c44f2c9ba8052aeb534b28df67ae26733898 --- /dev/null +++ b/3D_Stage/refine.py @@ -0,0 +1,185 @@ +import os +import numpy as np +import torch +import nvdiffrast.torch as dr +import json +import torch.nn.functional as F +from PIL import Image +import pymeshlab +import cv2 + +def back_to_texture(glctx, look_at, pos, tri, tex, uv, uv_idx, idx, vn): + rast_out, rast_out_db = dr.rasterize(glctx, pos, tri, resolution=[tex.shape[0],tex.shape[1]]) + gb_normal, _ = dr.interpolate(vn[None], rast_out, tri) + gb_normal = F.normalize(gb_normal, dim=-1) + if idx == 2 or idx == 0: + filter_camera = [torch.tensor([[[[1,0.,0.]]]]).cuda(), torch.tensor([[[[-1,0.,0.]]]]).cuda()] + else: + filter_camera = [torch.tensor([[[[0,-1.,0.]]]]).cuda(), torch.tensor([[[[0,1.,0.]]]]).cuda()] + nmasks = [] + for fc in filter_camera: + nmasks.append(((gb_normal * fc) > 0.75).int().sum(keepdim=True, dim=-1)) + gb_normal_mask = 1 - (nmasks[0] | nmasks[1]) + #Image.fromarray(np.clip(gb_normal_mask[0,...,0].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)).save(f"mask_normal_{idx}.png") + gb_mask = rast_out[...,3:4] > 0 + tri_list = torch.unique(rast_out[...,3:4].reshape(-1)) + tri_list = (tri_list[1:] - 1).to(torch.int32) + pos = pos[0] + + depth_map = rast_out[...,3:4].clone() + depth_map[depth_map > 0] = 1 + depth_map = depth_map.to(torch.float32) + dmax = (rast_out[...,2:3] * gb_mask).max() + uv = torch.cat([uv * 2 - 1, torch.zeros(uv.shape[0], 1).cuda(), torch.ones(uv.shape[0], 1).cuda()], dim=1).unsqueeze(0) + uv_idx = uv_idx[tri_list.to(torch.long)] + rast_uv, rast_uv_db = dr.rasterize(glctx, uv, uv_idx, resolution=(1024, 1024)) + pos_clip = torch.cat([pos[...,:2], pos[...,3:]], -1) + pos_2d, _ = dr.interpolate(pos_clip, rast_uv, tri[tri_list.to(torch.long)]) # pos (x, y, z, w) + pos_coord = (pos_2d[...,:2] / (pos_2d[...,2:3] + 1e-6) + 1) / 2. + texture_mask = (rast_uv[...,3:4] > 0).int() + color = dr.texture(tex[None, ...] * gb_normal_mask, pos_coord, filter_mode='linear') + color_mask = dr.texture(gb_normal_mask.to(torch.float32), pos_coord, filter_mode='linear') + color_mask[color_mask > 0.82] = 1 + color_mask[color_mask <= 0.82] = 0 + color_mask = color_mask.to(torch.int32) + #Image.fromarray(np.clip(color_mask[0].repeat(1,1,3).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)).save(f"depth_{idx}.png") + texture_mask = texture_mask * color_mask + #Image.fromarray(np.clip(color[0].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)).save(f"{idx}.png") + #Image.fromarray(np.clip(texture_mask[0].repeat(1,1,3).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)).convert("RGB").save(f"mask-{idx}.png") + return color, texture_mask, rast_uv + +def perspective(fovy=0.6913, aspect=1.0, n=0.1, f=1000.0, device=None): + y = np.tan(fovy / 2) + return torch.tensor([[1/(y*aspect), 0, 0, 0], + [ 0, 1/-y, 0, 0], + [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], + [ 0, 0, -1, 0]]).to(torch.float32).cuda() + +def rec_mvp(trans, h, w): + mv = trans + fov = 40. / 180. * np.pi + proj = perspective(fov, h / w, n=0.1, f=1000) + mvp = proj @ mv + return mvp + +def aggregate_texture(kd_map, textures, texture_masks, rast_uvs): + texture = torch.zeros_like(textures[0]) + texture_mask = torch.zeros_like(texture_masks[0]) + ctex = [] + for idx in range(len(textures)): + ctex.append(textures[idx] * texture_masks[idx] + 10 * (1 - texture_masks[idx])) + cat_textures = torch.stack(ctex, dim=-2) + dis_measure = (cat_textures - kd_map.unsqueeze(-2)).abs().sum(-1) + _, choose_idx = dis_measure.min(-1) + + choose_idx = choose_idx.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 1, 3) + final_texture_map = torch.gather(cat_textures, 3, choose_idx).squeeze(-2) + #cv2.imwrite("final_texture_map.png", cv2.cvtColor((final_texture_map[0].cpu().numpy() * 255).astype(np.uint8), cv2.COLOR_BGR2RGB)) + #cv2.imwrite("final_texture_mask.png", (texture_mask[0].cpu().numpy() * 255).astype(np.uint8)) + zero_mask = (final_texture_map.max(dim=-1, keepdim=True)[0] > 0.1) + close_mask = ((final_texture_map[0] - kd_map).abs().sum(dim=-1, keepdim=True) < 1.0).int() + for idx in range(len(textures)): + texture += textures[idx] * texture_masks[idx] + texture_mask |= texture_masks[idx] + texture_mask = texture_mask * zero_mask * close_mask[None] + optimize_mask = (texture_mask == 0).int() + + #import pdb; pdb.set_trace() + #mask = (texture_mask[0].cpu().numpy() * 255).astype(np.uint8) + #cv2.imwrite("mask.png", mask) + #kernel = np.ones((5,5), np.uint8) + #dilated = cv2.dilate(mask, kernel, iterations=1) + #cv2.imwrite("di_mask.png", dilated) + #texture_mask[0] = torch.from_numpy(dilated).unsqueeze(-1).to(torch.float32) / 255. + + final_texture_map = final_texture_map[0] * texture_mask[0] + Image.fromarray(np.rint(final_texture_map.cpu().numpy() * 255).astype(np.uint8)).save(f"final_texture.png") + + #cv2.imwrite("kd_map.png", cv2.cvtColor((kd_map.cpu().numpy() * 255).astype(np.uint8), cv2.COLOR_BGR2RGB)) + #cv2.imwrite("texture_map.png", cv2.cvtColor((final_texture_map.cpu().numpy() * 255).astype(np.uint8), cv2.COLOR_BGR2RGB)) + #result = cv2.seamlessClone((final_texture_map.cpu().numpy() * 255).astype(np.uint8), (kd_map.cpu().numpy() * 255).astype(np.uint8), mask, (mask.shape[1]//2, mask.shape[0]//2), cv2.NORMAL_CLONE) + #cv2.imwrite("result.png", cv2.cvtColor(result * 255, cv2.COLOR_BGR2RGB)) + + kd_map = kd_map * (1 - texture_mask[0]) + final_texture_map + return kd_map, optimize_mask + +def refine(save_path, front_image, back_image, left_image, right_image): + ms = pymeshlab.MeshSet() + mesh_path = f"{save_path}/model-00.obj" + ms.load_new_mesh(mesh_path) + ms.apply_coord_laplacian_smoothing(stepsmoothnum=10) + tl = open(mesh_path, "r").readlines() + tex_uv = [] + uv_idx = [] + for line in tl: + if line.startswith("vt"): + uvs = line.split(" ")[1:3] + tex_uv += [float(uvs[0]), 1.0-float(uvs[1])] + tex_uv = torch.from_numpy(np.array(tex_uv)).to(torch.float32).cuda().reshape(-1, 2) + m = ms.current_mesh() + v_matrix = m.vertex_matrix() + f_matrix = m.face_matrix() + vn = m.vertex_normal_matrix() + uv_idx = torch.arange(f_matrix.shape[0] * 3).reshape(-1, 3).to(torch.int32).cuda() + vn = torch.tensor(vn).contiguous().cuda().to(torch.float32) + + frames = [] + front_camera = torch.tensor([[ + 1,0,0,0, + 0,0,1,0, + 0,-1,0,-1.5, + 0,0,0,1, + ]]).to(torch.float32).reshape(4,4).cuda() + back_camera = torch.tensor([[ + 1,0,0,0, + 0,0,1,0, + 0,1,0,-1.5, + 0,0,0,1, + ]]).to(torch.float32).reshape(4,4).cuda() + right_camera = torch.tensor([[ + 0,-1,0,0, + 0,0,1,0, + 1,0,0,-1.5, + 0,0,0,1, + ]]).to(torch.float32).reshape(4,4).cuda() + left_camera = torch.tensor([[ + 0,1,0,0, + 0,0,1,0, + -1,0,0,-1.5, + 0,0,0,1, + ]]).to(torch.float32).reshape(4,4).cuda() + frames = [front_camera, left_camera, back_camera, right_camera] + + target_images = [] + for target_image in [front_image, left_image, back_image, right_image]: + target_images.append(torch.from_numpy(np.asarray(target_image.convert("RGB"))).to(torch.float32).cuda() / 255.) + + pos = torch.tensor(v_matrix, dtype=torch.float32).contiguous().cuda() + tri = torch.tensor(f_matrix, dtype=torch.int32).contiguous().cuda() + + kd_map = (torch.tensor(np.asarray(Image.open(f"{save_path}/texture_kd.jpg"))) / 255.).cuda() + translate_tensor = torch.zeros((1,1,3)).cuda() + pos = torch.cat([pos, torch.ones([pos.shape[0], 1]).cuda()],-1).unsqueeze(0) + glctx = dr.RasterizeCudaContext() + target_texture = [] + target_mask = [] + rast_uvs = [] + with torch.no_grad(): + for idx, trans in enumerate(frames): + target_image = target_images[idx] + look_at = -torch.linalg.inv(trans)[:3,2] + mvp = rec_mvp(trans, h=target_images[0].shape[0], w=target_images[0].shape[1]) + trans_pos = pos.clone() + trans_pos[...,:3] += translate_tensor + view_pos = torch.matmul(mvp, trans_pos.unsqueeze(-1)).squeeze(-1) + texture, mask, rast_uv = back_to_texture(glctx, look_at, view_pos, tri, target_image, tex_uv, uv_idx, idx, vn) + target_texture.append(texture) + target_mask.append(mask) + rast_uvs.append(rast_uv) + kd_map, opt_mask = aggregate_texture(kd_map, target_texture, target_mask, rast_uvs) + opt_mask = opt_mask[0] + Image.fromarray((np.clip(kd_map.detach().cpu().numpy() * 255, 0, 255)).astype(np.uint8)).save(f"{save_path}/refined_texture_kd.jpg") + + #ms.save_current_mesh(f"{save_path}/model-00.obj") + with open(f"{save_path}/model-00.mtl", "w") as f: + f.write(f"newmtl default\nKa 0.0 0.0 0.0\nmap_Kd refined_texture_kd.jpg\nKs 0.0 0.0 0.0") \ No newline at end of file diff --git a/3D_Stage/webui.py b/3D_Stage/webui.py new file mode 100644 index 0000000000000000000000000000000000000000..ebfaee7db54c1f1e769cec2779c789be9c75bc34 --- /dev/null +++ b/3D_Stage/webui.py @@ -0,0 +1,162 @@ +import os +import json +import tqdm +import cv2 +import numpy as np +import torch, lrm +import torch.nn.functional as F +from lrm.utils.config import load_config +from datetime import datetime +import gradio as gr +from pygltflib import GLTF2 +from PIL import Image +from huggingface_hub import hf_hub_download + +from refine import refine + +device = "cuda" + +import trimesh +import pymeshlab +import numpy as np + +from huggingface_hub import hf_hub_download, list_repo_files + +repo_id = "zjpshadow/CharacterGen" +all_files = list_repo_files(repo_id, revision="main") + +for file in all_files: + if os.path.exists("../" + file): + continue + if file.startswith("3D_Stage"): + hf_hub_download(repo_id, file, local_dir="../") + +def traverse(path, back_proj): + mesh = trimesh.load(f"{path}/model-00.obj") + mesh.apply_transform(trimesh.transformations.rotation_matrix(np.radians(90.0), [-1, 0, 0])) + mesh.apply_transform(trimesh.transformations.rotation_matrix(np.radians(180.0), [0, 1, 0])) + + cmesh = pymeshlab.Mesh(mesh.vertices, mesh.faces) + ms = pymeshlab.MeshSet() + ms.add_mesh(cmesh) + ms.apply_coord_laplacian_smoothing(stepsmoothnum=4) + mesh.vertices = ms.current_mesh().vertex_matrix() + + mesh.export(f'{path}/output.glb', file_type='glb') + + image = Image.open(f"{path}/{'refined_texture_kd.jpg' if back_proj else 'texture_kd.jpg'}") + texture = np.array(image) + vertex_colors = np.zeros((mesh.vertices.shape[0], 4), dtype=np.uint8) + + for vertex_index in range(len(mesh.visual.uv)): + uv = mesh.visual.uv[vertex_index] + x = int(uv[0] * (texture.shape[1] - 1)) + y = int((1 - uv[1]) * (texture.shape[0] - 1)) + + color = texture[y, x, :3] + vertex_colors[vertex_index] = [color[0], color[1], color[2], 255] + return trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, vertex_colors=vertex_colors) + +class Inference_API: + + def __init__(self): + # Load config + self.cfg = load_config("configs/infer.yaml", makedirs=False) + # Load system + print("Loading system") + self.system = lrm.find(self.cfg.system_cls)(self.cfg.system).to(device) + self.system.eval() + + def process_images(self, img_input0, img_input1, img_input2, img_input3, back_proj): + meta = json.load(open("material/meta.json")) + c2w_cond = [np.array(loc["transform_matrix"]) for loc in meta["locations"]] + c2w_cond = torch.from_numpy(np.stack(c2w_cond, axis=0)).float()[None].to(device) + + # Prepare input data + rgb_cond = [] + files = [img_input0, img_input1, img_input2, img_input3] + new_image = [] + for file in files: + image = np.array(file) + image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) + new_image.append(Image.fromarray(image.astype(np.uint8)).convert("RGB")) + rgb = cv2.resize(image, (self.cfg.data.cond_width, + self.cfg.data.cond_height)).astype(np.float32) / 255.0 + rgb_cond.append(rgb) + assert len(rgb_cond) == 4, "Please provide 4 images" + + rgb_cond = torch.from_numpy(np.stack(rgb_cond, axis=0)).float()[None].to(device) + + # Run inference + with torch.no_grad(): + scene_codes = self.system({"rgb_cond": rgb_cond, "c2w_cond": c2w_cond}) + exporter_output = self.system.exporter([f"{i:02d}" for i in range(rgb_cond.shape[0])], scene_codes) + + # Save output + save_dir = os.path.join("./outputs", datetime.now().strftime("@%Y%m%d-%H%M%S")) + os.makedirs(save_dir, exist_ok=True) + self.system.set_save_dir(save_dir) + + for out in exporter_output: + save_func_name = f"save_{out.save_type}" + save_func = getattr(self.system, save_func_name) + save_func(f"{out.save_name}", **out.params) + + if back_proj: + refine(save_dir, new_image[1], new_image[0], new_image[3], new_image[2]) + + new_obj = traverse(save_dir, back_proj) + new_obj.export(f'{save_dir}/output.obj', file_type='obj') + + gltf = GLTF2().load(f'{save_dir}/output.glb') + for material in gltf.materials: + if material.pbrMetallicRoughness: + material.pbrMetallicRoughness.baseColorFactor = [1.0, 1.0, 1.0, 100.0] + material.pbrMetallicRoughness.metallicFactor = 0.0 + material.pbrMetallicRoughness.roughnessFactor = 1.0 + gltf.save(f'{save_dir}/output.glb') + + return save_dir, f"{save_dir}/output.obj", f"{save_dir}/output.glb" + +inferapi = Inference_API() + +# Define the interface +with gr.Blocks() as demo: + gr.Markdown("# [SIGGRAPH'24] CharacterGen: Efficient 3D Character Generation from Single Images with Multi-View Pose Calibration") + gr.Markdown("# 3D Stage: Four View Images to 3D Mesh") + with gr.Row(variant="panel"): + with gr.Column(): + with gr.Row(): + img_input0 = gr.Image(type="pil", label="Back Image", image_mode="RGBA", width=256, height=384) + img_input1 = gr.Image(type="pil", label="Front Image", image_mode="RGBA", width=256, height=384) + with gr.Row(): + img_input2 = gr.Image(type="pil", label="Right Image", image_mode="RGBA", width=256, height=384) + img_input3 = gr.Image(type="pil", label="Left Image", image_mode="RGBA", width=256, height=384) + with gr.Row(): + gr.Examples( + examples= + [["material/examples/1/1.png", + "material/examples/1/2.png", + "material/examples/1/3.png", + "material/examples/1/4.png"]], + label="Example Images", + inputs=[img_input0, img_input1, img_input2, img_input3] + ) + with gr.Column(): + with gr.Row(): + back_proj = gr.Checkbox(label="Back Projection") + submit_button = gr.Button("Process") + output_dir = gr.Textbox(label="Output Directory") + with gr.Column(): + with gr.Tab("GLB"): + output_model_glb = gr.Model3D( label="Output Model (GLB Format)", height = 768) + gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.") + with gr.Tab("OBJ"): + output_model_obj = gr.Model3D( label="Output Model (OBJ Format)", height = 768) + gr.Markdown("Note: The model shown here is flipped. Download to get correct results.") + + submit_button.click(inferapi.process_images, inputs=[img_input0, img_input1, img_input2, img_input3, back_proj], + outputs=[output_dir, output_model_obj, output_model_glb]) + +# Run the interface +demo.launch() \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0ad25db4bd1d86c452db3f9602ccdbe172438f52 --- /dev/null +++ b/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/README.md b/README.md index 9241e43cf99951683b8885bf28be80f805c1c0ee..c6440c88ba3ba1c915f8dcf18fef946927ed3f7c 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,137 @@ --- +license: apache-2.0 title: CharacterGen -emoji: 👁 -colorFrom: red -colorTo: blue sdk: gradio -sdk_version: 6.6.0 -python_version: '3.12' -app_file: app.py +sdk_version: 5.14.0 +python_version: "3.10" +emoji: 🏃 +colorFrom: gray +colorTo: red pinned: false -license: apache-2.0 +short_description: Gradio demo of CharacterGen (SIGGRAPH 2024) +hardware: zero-gpu --- +# CharacterGen: Efficient 3D Character Generation from Single Images with Multi-View Pose Calibration + +This is the official codebase of SIGGRAPH'24 (TOG) [CharacterGen](https://charactergen.github.io/). + +![teaser](./materials/teaser.png) + +- [x] Rendering Script of VRM model, including blender and three-js. +- [x] Inference code for 2D generation stage. +- [x] Inference code for 3D generation stage. + +## Quick Start + +### 1. Prepare environment + +`pip install -r requirements.txt` + +### 2. Download the weight + +Install `huggingface-cli` first. + +```bash +huggingface-cli download --resume-download zjpshadow/CharacterGen --include 2D_Stage/* --local-dir . +huggingface-cli download --resume-download zjpshadow/CharacterGen --include 3D_Stage/* --local-dir . +``` + +If you find mistakes on download, you can download all the reporitory and move to the right folder. + +### 3. Run the script + +#### Run the whole pipeline +```bash +python webui.py +``` + +#### Only Run 2D Stage + +```bash +cd 2D_Stage +python webui.py +``` + +#### Only Run 3D Stage + +```bash +cd 3D_Stage +python webui.py +``` + +## Get the Anime3D Dataset + +Due to the policy, we cannot redistribute the raw data of VRM format 3D character. +You can download the vroid dataset follow [PAniC-3D](https://github.com/ShuhongChen/panic3d-anime-reconstruction) instruction. +And the you can render the script with blender or three-js with our released rendering script. + +### Blender + +First, you should install [Blender](https://www.blender.org/) and [the VRM addon for Blender](https://github.com/saturday06/VRM-Addon-for-Blender). + +The you can render the VRM and export the obj of VRM under some fbx animation. + +```bash +blender -b --python render_script/blender/render.py importVrmPath importFbxPath outputFolder [is_apose] +``` + +The last input argument represents whether you use apose; if used, output apose; otherwise, output the action of any frame in the fbx. + +### [three-vrm](https://github.com/pixiv/three-vrm) + +**Much quicker than blender VRM add-on.** + +Install [Node.js](https://nodejs.org/) first to use the npm environment. + +```bash +cd render_script/three-js +npm install three @pixiv/three-vrm +``` + +If you want to render depth-map images of VRM, you should replace three-vrm with [my version](/home/zjp/CharacterGen/render_script/three-js/src/three-vrm.js). + +Fisrt, run the backend to catch the data from the frontend (default port is `17070`), remember to change the folder path. + +```bash +pip install fastapi uvicorn aiofiles pillow numpy +python up_backend.py +``` + +Second, run the frontend to render the images. + +```bash +npm run dev +``` + +The open the website http://localhost:5173/, it use 2 threads to render the image, which costs about 1 day. + +## Our Result + +| Single Input Image | 2D Multi-View Images | 3D Character | +|-------|-------|-------| +| ![](./materials/input/1.png) | ![](./materials/ours_multiview/1.png) | threestudio | +| ![](./materials/input/2.png) | ![](./materials/ours_multiview/2.png) | threestudio | +| ![](./materials/input/3.png) | ![](./materials/ours_multiview/3.png) | threestudio | + +# Acknowledgements + +This project is built upon **[Tune-A-Video](https://github.com/showlab/Tune-A-Video)** and **[TripoSR](https://github.com/VAST-AI-Research/TripoSR)**. +And the rendering scripts is build upon **[three-vrm](https://github.com/pixiv/three-vrm)** and **[VRM-Addon-for-Blender](https://github.com/saturday06/VRM-Addon-for-Blender)**. +Thanks very much to many friends for their unselfish help with our work. We're extremely grateful to **[Yuanchen](https://github.com/bennyguo)**, **[Yangguang](https://scholar.google.com/citations?user=a7AMvgkAAAAJ)**, and **Yuan Liang** for their guidance on code details and ideas. +We thank all the authors for their great repos and help. + +# Citation + +If you find our code or paper helps, please consider citing: -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +```bibtex +@article{peng2024charactergen, + title ={CharacterGen: Efficient 3D Character Generation from Single Images with Multi-View Pose Canonicalization}, + author ={Hao-Yang Peng and Jia-Peng Zhang and Meng-Hao Guo and Yan-Pei Cao and Shi-Min Hu}, + journal ={ACM Transactions on Graphics (TOG)}, + year ={2024}, + volume ={43}, + number ={4}, + doi ={10.1145/3658217} +} +``` \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..5681c451056531b66ebcf070a569d5ae03711874 --- /dev/null +++ b/app.py @@ -0,0 +1,483 @@ +import gradio as gr +from PIL import Image +import glob + +import spaces + +import io +import argparse +import os +import random +from typing import Dict, Optional, Tuple +from omegaconf import OmegaConf +import numpy as np + +import torch +import torch.utils.checkpoint + +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.utils import check_min_version +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection +from torchvision import transforms + +import sys + +sys.path.append("2D_Stage") +sys.path.append("3D_Stage") +from tuneavideo.models.unet_mv2d_condition import UNetMV2DConditionModel +from tuneavideo.models.unet_mv2d_ref import UNetMV2DRefModel +from tuneavideo.models.PoseGuider import PoseGuider +from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline +from tuneavideo.util import shifted_noise +from einops import rearrange +import PIL +from PIL import Image +from torchvision.utils import save_image +import json +import cv2 + +import lrm +import trimesh +from lrm.utils.config import load_config +from refine import refine +from datetime import datetime +import gradio as gr +from pygltflib import GLTF2 + +import onnxruntime as rt +from huggingface_hub import hf_hub_download, list_repo_files +from rm_anime_bg.cli import get_mask, SCALE +import pymeshlab + + +def download_model_files(): + """モデルファイルをHuggingFace Hubからダウンロードする(初期化時に実行)""" + repo_id = "zjpshadow/CharacterGen" + all_files = list_repo_files(repo_id, revision="main") + for file in all_files: + if os.path.exists(file): + continue + if file.startswith("2D_Stage") or file.startswith("3D_Stage"): + hf_hub_download(repo_id, file, local_dir=".") + + +check_min_version("0.24.0") + +logger = get_logger(__name__, log_level="INFO") + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def get_bg_color(bg_color): + if bg_color == 'white': + bg_color = np.array([1., 1., 1.], dtype=np.float32) + elif bg_color == 'black': + bg_color = np.array([0., 0., 0.], dtype=np.float32) + elif bg_color == 'gray': + bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) + elif bg_color == 'random': + bg_color = np.random.rand(3) + elif isinstance(bg_color, float): + bg_color = np.array([bg_color] * 3, dtype=np.float32) + else: + raise NotImplementedError + return bg_color + +def process_image(image, totensor): + if not image.mode == "RGBA": + image = image.convert("RGBA") + + # Find non-transparent pixels + non_transparent = np.nonzero(np.array(image)[..., 3]) + min_x, max_x = non_transparent[1].min(), non_transparent[1].max() + min_y, max_y = non_transparent[0].min(), non_transparent[0].max() + image = image.crop((min_x, min_y, max_x, max_y)) + + # paste to center + max_dim = max(image.width, image.height) + max_height = max_dim + max_width = int(max_dim / 3 * 2) + new_image = Image.new("RGBA", (max_width, max_height)) + left = (max_width - image.width) // 2 + top = (max_height - image.height) // 2 + new_image.paste(image, (left, top)) + + image = new_image.resize((512, 768), resample=PIL.Image.BICUBIC) + image = np.array(image) + image = image.astype(np.float32) / 255. + assert image.shape[-1] == 4 # RGBA + alpha = image[..., 3:4] + bg_color = get_bg_color("gray") + image = image[..., :3] * alpha + bg_color * (1 - alpha) + # save image + new_image = Image.fromarray((image * 255).astype(np.uint8)) + new_image.save("input.png") + return totensor(image) + +class rm_bg_api: + + def __init__(self, force_cpu: Optional[bool] = True): + session_infer_path = hf_hub_download( + repo_id="skytnt/anime-seg", filename="isnetis.onnx", + ) + providers: list[str] = ["CPUExecutionProvider"] + if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers(): + providers = ["CUDAExecutionProvider"] + + self.session_infer = rt.InferenceSession( + session_infer_path, providers=providers, + ) + + def _remove_background_impl( + self, + imgs: list[np.ndarray], + alpha_min: float, + alpha_max: float, + ) -> list: + process_imgs = [] + for img in imgs: + img = np.array(img) + # CHANGE to RGB + if img.shape[-1] == 4: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) + mask = get_mask(self.session_infer, img) + + mask[mask < alpha_min] = 0.0 # type: ignore + mask[mask > alpha_max] = 1.0 # type: ignore + + img_after = (mask * img).astype(np.uint8) # type: ignore + mask = (mask * SCALE).astype(np.uint8) # type: ignore + img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8) + mask = mask.repeat(3, axis=2) + process_imgs.append(Image.fromarray(img_after)) + return process_imgs + + +class Inference2D_API: + + def __init__(self, + pretrained_model_path: str, + image_encoder_path: str, + ckpt_dir: str, + validation: Dict, + local_crossattn: bool = True, + unet_from_pretrained_kwargs=None, + unet_condition_type=None, + use_pose_guider=False, + use_shifted_noise=False, + use_noise=True, + device="cuda" + ): + self.validation = validation + self.use_noise = use_noise + self.use_shifted_noise = use_shifted_noise + self.unet_condition_type = unet_condition_type + image_encoder_path = image_encoder_path.replace("./", "./2D_Stage/") + ckpt_dir = ckpt_dir.replace("./", "./2D_Stage/") + + self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path) + feature_extractor = CLIPImageProcessor() + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs) + ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs) + if use_pose_guider: + pose_guider = PoseGuider(noise_latent_channels=4).to("cuda") + else: + pose_guider = None + + unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model.bin"), map_location="cpu") + if use_pose_guider: + pose_guider_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_1.bin"), map_location="cpu") + ref_unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_2.bin"), map_location="cpu") + pose_guider.load_state_dict(pose_guider_params) + else: + ref_unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_1.bin"), map_location="cpu") + unet.load_state_dict(unet_params) + ref_unet.load_state_dict(ref_unet_params) + + weight_dtype = torch.float16 + + text_encoder.to(device, dtype=weight_dtype) + image_encoder.to(device, dtype=weight_dtype) + vae.to(device, dtype=weight_dtype) + ref_unet.to(device, dtype=weight_dtype) + unet.to(device, dtype=weight_dtype) + + vae.requires_grad_(False) + unet.requires_grad_(False) + ref_unet.requires_grad_(False) + + noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") + self.validation_pipeline = TuneAVideoPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=self.tokenizer, unet=unet, ref_unet=ref_unet,feature_extractor=feature_extractor,image_encoder=image_encoder, + scheduler=noise_scheduler + ) + self.validation_pipeline.enable_vae_slicing() + self.validation_pipeline.set_progress_bar_config(disable=True) + self.generator = torch.Generator(device=device) + + @torch.no_grad() + def _inference_impl(self, input_image, val_width, val_height, + use_shifted_noise=False, crop=False, seed=100, timestep=20): + set_seed(seed) + totensor = transforms.ToTensor() + + metas = json.load(open("./2D_Stage/material/pose.json", "r")) + cameras = [] + pose_images = [] + input_path = "./2D_Stage/material" + for lm in metas: + cameras.append(torch.tensor(np.array(lm[0]).reshape(4, 4).transpose(1,0)[:3, :4]).reshape(-1)) + if not crop: + pose_images.append(totensor(np.asarray(Image.open(os.path.join(input_path, lm[1])).resize( + (val_height, val_width), resample=PIL.Image.BICUBIC)).astype(np.float32) / 255.)) + else: + pose_image = Image.open(os.path.join(input_path, lm[1])) + crop_area = (128, 0, 640, 768) + pose_images.append(totensor(np.array(pose_image.crop(crop_area)).astype(np.float32)) / 255.) + camera_matrixs = torch.stack(cameras).unsqueeze(0).to("cuda") + pose_imgs_in = torch.stack(pose_images).to("cuda") + prompts = "high quality, best quality" + prompt_ids = self.tokenizer( + prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ).input_ids[0] + + # (B*Nv, 3, H, W) + B = 1 + weight_dtype = torch.bfloat16 + imgs_in = process_image(input_image, totensor) + imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W") + + with torch.autocast("cuda", dtype=weight_dtype): + imgs_in = imgs_in.to("cuda") + # B*Nv images + out = self.validation_pipeline(prompt=prompts, image=imgs_in.to(weight_dtype), generator=self.generator, + num_inference_steps=timestep, + camera_matrixs=camera_matrixs.to(weight_dtype), prompt_ids=prompt_ids, + height=val_height, width=val_width, unet_condition_type=self.unet_condition_type, + pose_guider=None, pose_image=pose_imgs_in, use_noise=self.use_noise, + use_shifted_noise=use_shifted_noise, **self.validation).videos + out = rearrange(out, "B C f H W -> (B f) C H W", f=self.validation.video_length) + + image_outputs = [] + for bs in range(4): + img_buf = io.BytesIO() + save_image(out[bs], img_buf, format='PNG') + img_buf.seek(0) + img = Image.open(img_buf) + image_outputs.append(img) + torch.cuda.empty_cache() + return image_outputs + + +def traverse(path, back_proj, smooth_iter): + ms = pymeshlab.MeshSet() + ms.load_new_mesh(f"{path}/model-00.obj") + image = Image.open(f"{path}/{'refined_texture_kd.jpg' if back_proj else 'texture_kd.jpg'}") + out_image_path = f"{path}/{'refined_texture_kd.png' if back_proj else 'texture_kd.png'}" + image.save(out_image_path, 'PNG') + ms.set_texture_per_mesh(textname=f"{path}/{'refined_texture_kd.png' if back_proj else 'texture_kd.png'}") + ms.meshing_merge_close_vertices() + ms.apply_coord_laplacian_smoothing(stepsmoothnum=smooth_iter) + ms.save_current_mesh(f"{path}/temp-00.obj", save_vertex_normal=False, save_wedge_normal=False, save_vertex_color=False) + + mesh = trimesh.load(f"{path}/temp-00.obj", process=False) + mesh.apply_transform(trimesh.transformations.rotation_matrix(np.radians(90.0), [-1, 0, 0])) + mesh.apply_transform(trimesh.transformations.rotation_matrix(np.radians(180.0), [0, 1, 0])) + + mesh.export(f'{path}/output.glb', file_type='glb') + + image = Image.open(f"{path}/{'refined_texture_kd.png' if back_proj else 'texture_kd.png'}") + texture = np.array(image) + vertex_colors = np.zeros((mesh.vertices.shape[0], 4), dtype=np.uint8) + + for vertex_index in range(len(mesh.visual.uv)): + uv = mesh.visual.uv[vertex_index] + x = int(uv[0] * (texture.shape[1] - 1)) + y = int((1 - uv[1]) * (texture.shape[0] - 1)) + + color = texture[y, x, :3] + vertex_colors[vertex_index] = [color[0], color[1], color[2], 255] + return trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, vertex_colors=vertex_colors, process=False) + +class Inference3D_API: + + def __init__(self, device="cuda"): + self.cfg = load_config("3D_Stage/configs/infer.yaml", makedirs=False) + print("Loading system") + self.device = device + self.cfg.system.weights = self.cfg.system.weights.replace("./", "./3D_Stage/") + self.cfg.system.image_tokenizer.pretrained_model_name_or_path = \ + self.cfg.system.image_tokenizer.pretrained_model_name_or_path.replace("./", "./3D_Stage/") + self.cfg.system.renderer.tet_dir = self.cfg.system.renderer.tet_dir.replace("./", "./3D_Stage/") + self.cfg.system.exporter.output_path = self.cfg.system.exporter.output_path.replace("./", "./3D_Stage/") + self.system = lrm.find(self.cfg.system_cls)(self.cfg.system).to(self.device) + self.system.eval() + + def _process_images_impl(self, img_input0, img_input1, img_input2, img_input3, back_proj, smooth_iter): + meta = json.load(open("./3D_Stage/material/meta.json")) + c2w_cond = [np.array(loc["transform_matrix"]) for loc in meta["locations"]] + c2w_cond = torch.from_numpy(np.stack(c2w_cond, axis=0)).float()[None].to(self.device) + + rgb_cond = [] + files = [img_input0, img_input1, img_input2, img_input3] + new_images = [] + for file in files: + image = np.array(file) + image = Image.fromarray(image) + if image.width != image.height: + max_dim = max(image.width, image.height) + new_image = Image.new("RGBA", (max_dim, max_dim)) + left = (max_dim - image.width) // 2 + top = (max_dim - image.height) // 2 + new_image.paste(image, (left, top)) + image = new_image + image.save("input_3D.png") + + image = cv2.cvtColor(np.array(image), cv2.COLOR_RGBA2RGB) + rgb = cv2.resize(image, (self.cfg.data.cond_width, + self.cfg.data.cond_height)).astype(np.float32) / 255.0 + new_images.append(Image.fromarray(image.astype(np.uint8)).convert("RGB")) + rgb_cond.append(rgb) + assert len(rgb_cond) == 4, "Please provide 4 images" + + rgb_cond = torch.from_numpy(np.stack(rgb_cond, axis=0)).float()[None].to(self.device) + + with torch.no_grad(): + scene_codes = self.system({"rgb_cond": rgb_cond, "c2w_cond": c2w_cond}) + exporter_output = self.system.exporter([f"{i:02d}" for i in range(rgb_cond.shape[0])], scene_codes) + + save_dir = os.path.join("./3D_Stage/outputs", datetime.now().strftime("@%Y%m%d-%H%M%S")) + os.makedirs(save_dir, exist_ok=True) + self.system.set_save_dir(save_dir) + + for out in exporter_output: + save_func_name = f"save_{out.save_type}" + save_func = getattr(self.system, save_func_name) + save_func(f"{out.save_name}", **out.params) + if back_proj: + refine(save_dir, new_images[1], new_images[0], new_images[3], new_images[2]) + + new_obj = traverse(save_dir, back_proj, smooth_iter) + new_obj.export(f'{save_dir}/output.obj', file_type='obj') + + gltf = GLTF2().load(f'{save_dir}/output.glb') + for material in gltf.materials: + if material.pbrMetallicRoughness: + material.pbrMetallicRoughness.baseColorFactor = [1.0, 1.0, 1.0, 100.0] + material.pbrMetallicRoughness.metallicFactor = 0.0 + material.pbrMetallicRoughness.roughnessFactor = 1.0 + gltf.save(f'{save_dir}/output.glb') + + return save_dir, f"{save_dir}/output.obj", f"{save_dir}/output.glb" + + +# モジュールレベルのシングルトンインスタンス(遅延初期化) +_remove_api = None +_infer2dapi = None +_infer3dapi = None + + +@spaces.GPU +def run_remove_background(imgs, alpha_min, alpha_max): + global _remove_api + if _remove_api is None: + _remove_api = rm_bg_api() + return _remove_api._remove_background_impl(imgs, alpha_min, alpha_max) + + +@spaces.GPU +def run_inference2d(input_image, val_width, val_height, + use_shifted_noise=False, crop=False, seed=100, timestep=20): + global _infer2dapi + if _infer2dapi is None: + download_model_files() + _infer2dapi = Inference2D_API(**OmegaConf.load("./2D_Stage/configs/infer.yaml")) + return _infer2dapi._inference_impl(input_image, val_width, val_height, + use_shifted_noise, crop, seed, timestep) + + +@spaces.GPU +def run_inference3d(img_input0, img_input1, img_input2, img_input3, back_proj, smooth_iter): + global _infer3dapi + if _infer3dapi is None: + download_model_files() + _infer3dapi = Inference3D_API() + return _infer3dapi._process_images_impl(img_input0, img_input1, img_input2, img_input3, + back_proj, smooth_iter) + + +@torch.no_grad() +def main(): + def gen4views(image, width, height, seed, timestep, remove_bg): + if remove_bg: + image = run_remove_background( + imgs=[np.array(image)], + alpha_min=0.1, + alpha_max=0.9, + )[0] + return run_remove_background( + imgs=run_inference2d( + image, width, height, crop=True, seed=seed, timestep=timestep + ), alpha_min=0.2, alpha_max=0.9) + + with gr.Blocks() as demo: + gr.Markdown("# [SIGGRAPH'24] CharacterGen: Efficient 3D Character Generation from Single Images with Multi-View Pose Calibration") + with gr.Row(): + with gr.Column(variant="panel"): + img_input = gr.Image(type="pil", label="Upload Image(without background)", image_mode="RGBA", width=768, height=512) + gr.Examples( + label="Example Images", + examples=glob.glob("./2D_Stage/material/examples/*.png"), + inputs=[img_input] + ) + with gr.Row(): + width_input = gr.Number(label="Width", value=512) + height_input = gr.Number(label="Height", value=768) + seed_input = gr.Number(label="Seed", value=2333) + remove_bg = gr.Checkbox(label="Remove Background (with algorithm)", value=True) + with gr.Column(variant="panel"): + timestep = gr.Slider(minimum=10, maximum=70, step=1, value=40, label="Timesteps") + button1 = gr.Button(value="Generate 4 Views") + with gr.Row(): + img_input0 = gr.Image(type="pil", label="Back Image", image_mode="RGBA", width=256, height=384) + img_input1 = gr.Image(type="pil", label="Front Image", image_mode="RGBA", width=256, height=384) + with gr.Row(): + img_input2 = gr.Image(type="pil", label="Right Image", image_mode="RGBA", width=256, height=384) + img_input3 = gr.Image(type="pil", label="Left Image", image_mode="RGBA", width=256, height=384) + with gr.Column(variant="panel"): + smooth_iter = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Laplacian Smoothing Iterations") + with gr.Row(): + back_proj = gr.Checkbox(label="Back Projection") + button2 = gr.Button(value="Generate 3D Mesh") + output_dir = gr.Textbox(label="Output Directory") + with gr.Row(): + with gr.Tab("GLB"): + output_model_glb = gr.Model3D( label="Output Model (GLB Format)", height=512) + gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.") + with gr.Tab("OBJ"): + output_model_obj = gr.Model3D( label="Output Model (OBJ Format)") + gr.Markdown("Note: The model shown here's texture is mapped to vertex. Download to get correct results.") + button1.click( + fn=gen4views, + inputs=[img_input, width_input, height_input, seed_input, timestep, remove_bg], + outputs=[img_input2, img_input0, img_input3, img_input1] + ) + button2.click( + run_inference3d, + inputs=[img_input0, img_input1, img_input2, img_input3, back_proj, smooth_iter], + outputs=[output_dir, output_model_obj, output_model_glb] + ) + demo.queue() + demo.launch() + +if __name__ == "__main__": + main() diff --git a/final_texture.png b/final_texture.png new file mode 100644 index 0000000000000000000000000000000000000000..4b3c4225e2d1fd0cc7d34f8950f05b6a8578e56a --- /dev/null +++ b/final_texture.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd3ef257ea3be83344ab7ddd0100c5ec4568a28b4aff05952a7b4fbc394fb240 +size 835726 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5a3de257693a4ff57e1f1a655f9c45ecff2bf4b6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,31 @@ +gradio==5.14.0 +rm_anime_bg +# for 2D_stage +accelerate +transformers==4.32.1 +diffusers==0.24.0 +huggingface_hub==0.25.2 +ipdb +einops +imageio + +--extra-index-url https://download.pytorch.org/whl/cu121 +torch==2.4.0 +torchvision==0.19.0 +xformers==0.0.27.post2 + +onnxruntime +omegaconf +# for 3D_stage +pytorch_lightning +jaxtyping +wandb +lpips +https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true +ninja +open3d +trimesh +pymeshlab +pygltflib +omegaconf +typeguard==4.1.5 \ No newline at end of file