diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..d763c293f353e053bf107a3fd0a1f479a623bb55 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+2D_Stage/material/examples/1.png filter=lfs diff=lfs merge=lfs -text
+2D_Stage/material/examples/2.png filter=lfs diff=lfs merge=lfs -text
+2D_Stage/material/examples/3.png filter=lfs diff=lfs merge=lfs -text
+2D_Stage/material/examples/4.png filter=lfs diff=lfs merge=lfs -text
+2D_Stage/material/examples/5.png filter=lfs diff=lfs merge=lfs -text
+2D_Stage/material/examples/7.png filter=lfs diff=lfs merge=lfs -text
+2D_Stage/material/examples/8.png filter=lfs diff=lfs merge=lfs -text
+3D_Stage/material/examples/1/1.png filter=lfs diff=lfs merge=lfs -text
+3D_Stage/material/examples/1/2.png filter=lfs diff=lfs merge=lfs -text
+3D_Stage/material/examples/1/3.png filter=lfs diff=lfs merge=lfs -text
+3D_Stage/material/examples/1/4.png filter=lfs diff=lfs merge=lfs -text
+final_texture.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..01f427e77212f86512a5f6474878c0f4d27cf74b
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,14 @@
+# Ignore Python bytecode files
+*.pyc
+*.pyo
+__pycache__/
+
+# Ignore virtual environment directory
+venv/
+
+/3D_stage/outputs/
+input_3D.png
+input.png
+
+# LFS pointer files (large model files downloaded at runtime)
+3D_Stage/load/tets/*.npz
\ No newline at end of file
diff --git a/2D_Stage/configs/infer.yaml b/2D_Stage/configs/infer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ee53fe601bb4599bbdea936acc7f4d462a8c4e11
--- /dev/null
+++ b/2D_Stage/configs/infer.yaml
@@ -0,0 +1,24 @@
+pretrained_model_path: "sd2-community/stable-diffusion-2-1"
+image_encoder_path: "./models/image_encoder"
+ckpt_dir: "./models/checkpoint"
+
+validation:
+ guidance_scale: 5.0
+ use_inv_latent: False
+ video_length: 4
+
+use_pose_guider: True
+use_noise: False
+use_shifted_noise: False
+unet_condition_type: image
+
+unet_from_pretrained_kwargs:
+ camera_embedding_type: 'e_de_da_sincos'
+ projection_class_embeddings_input_dim: 10 # modify
+ joint_attention: false # modify
+ num_views: 4
+ sample_size: 96
+ zero_init_conv_in: false
+ zero_init_camera_projection: false
+ in_channels: 4
+ use_safetensors: true
\ No newline at end of file
diff --git a/2D_Stage/material/examples/1.png b/2D_Stage/material/examples/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..64e5183ffec3b03a4aaa82e828d03be5f5d2b434
--- /dev/null
+++ b/2D_Stage/material/examples/1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8fd677efa043cc71fbe0d78e30b93c1f49fd88fd8d2a00ae6946f6f0a06d3b2
+size 620509
diff --git a/2D_Stage/material/examples/2.png b/2D_Stage/material/examples/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..d29a82ce617bd746cfbfbaf31b0689c056750335
--- /dev/null
+++ b/2D_Stage/material/examples/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b342e025d7170708fdc19c1fa816835d70de61a315d6bde3e2afb390977882d0
+size 304610
diff --git a/2D_Stage/material/examples/3.png b/2D_Stage/material/examples/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..d31da2adddca3f6a1a3cfcf7a3a1f2c6852e1480
--- /dev/null
+++ b/2D_Stage/material/examples/3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf15c6d6321c8ae8177aefd6ae1d3c41b3cbd51f1f0df13f1be53633273fa457
+size 187843
diff --git a/2D_Stage/material/examples/4.png b/2D_Stage/material/examples/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..930b0842d94851302b5407e282c6df08632c847b
--- /dev/null
+++ b/2D_Stage/material/examples/4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7291130686044453c3c1a7fc7c9b66bb2424a825cf2406149d6a193cdcf7638c
+size 179319
diff --git a/2D_Stage/material/examples/5.png b/2D_Stage/material/examples/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..9bf77f6bc52e74fac152ef963acc9a3e7b4d0849
--- /dev/null
+++ b/2D_Stage/material/examples/5.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:385007de9cb47cafd6f63971afaf72e7d8ab6f7c146dc10cf0ad4788072d592c
+size 141777
diff --git a/2D_Stage/material/examples/6.png b/2D_Stage/material/examples/6.png
new file mode 100644
index 0000000000000000000000000000000000000000..ba5dc79b7d3283af18f57e3686836762e2aa03da
Binary files /dev/null and b/2D_Stage/material/examples/6.png differ
diff --git a/2D_Stage/material/examples/7.png b/2D_Stage/material/examples/7.png
new file mode 100644
index 0000000000000000000000000000000000000000..78d984a3f88770f95c8a5ba7063423431f9d3df8
--- /dev/null
+++ b/2D_Stage/material/examples/7.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b8e002752dbae219db8c7f6ecab6b79d62c31c06a5a36520effe82eb53418a4
+size 166635
diff --git a/2D_Stage/material/examples/8.png b/2D_Stage/material/examples/8.png
new file mode 100644
index 0000000000000000000000000000000000000000..2facade339f89502e74a7a9ebf6296e1fe8ea74d
--- /dev/null
+++ b/2D_Stage/material/examples/8.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd3def91bb231c82fb39a770a8b645078ab47d4a8f74e2068ccb8db0a0438837
+size 198307
diff --git a/2D_Stage/material/pose.json b/2D_Stage/material/pose.json
new file mode 100644
index 0000000000000000000000000000000000000000..40ed1ce28aa4cd809afd2d1f637887ed88b6c5fd
--- /dev/null
+++ b/2D_Stage/material/pose.json
@@ -0,0 +1,38 @@
+[
+ [
+ [
+ 0, 0, -1, 0,
+ 0, 1, 0, 0,
+ 1, 0, 0, 0,
+ 1.5, 0, 0, 1
+ ],
+ "pose0.png"
+ ],
+ [
+ [
+ 0, 0, 1, 0,
+ 0, 1, 0, 0,
+ -1, 0, 0, 0,
+ -1.5, 0, 0, 1
+ ],
+ "pose1.png"
+ ],
+ [
+ [
+ 0, 0, 1, 0,
+ 0, 1, 0, 0,
+ -1, 0, 0, 0,
+ -1.5, 0, 0, 1
+ ],
+ "pose2.png"
+ ],
+ [
+ [
+ -1, 0, 0, 0,
+ 0, 1, 0, 0,
+ 0, 0, -1, 0,
+ 0, 0, -1.5, 1
+ ],
+ "pose3.png"
+ ]
+]
\ No newline at end of file
diff --git a/2D_Stage/material/pose0.png b/2D_Stage/material/pose0.png
new file mode 100644
index 0000000000000000000000000000000000000000..7e044b3aaf9fce856e4a8fa50dc86fdc90fa41ac
Binary files /dev/null and b/2D_Stage/material/pose0.png differ
diff --git a/2D_Stage/material/pose1.png b/2D_Stage/material/pose1.png
new file mode 100644
index 0000000000000000000000000000000000000000..fd4093ce11b6f950853bb77ed95d6605db476784
Binary files /dev/null and b/2D_Stage/material/pose1.png differ
diff --git a/2D_Stage/material/pose2.png b/2D_Stage/material/pose2.png
new file mode 100644
index 0000000000000000000000000000000000000000..537d77bf217ac5ba9522062d3348e4cb0161883e
Binary files /dev/null and b/2D_Stage/material/pose2.png differ
diff --git a/2D_Stage/material/pose3.png b/2D_Stage/material/pose3.png
new file mode 100644
index 0000000000000000000000000000000000000000..c562d5f9f5db7d26818a3fc4512e2d168157cded
Binary files /dev/null and b/2D_Stage/material/pose3.png differ
diff --git a/2D_Stage/tuneavideo/models/PoseGuider.py b/2D_Stage/tuneavideo/models/PoseGuider.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d0e1e2b3843fa89272788c8ef9ac45ae7ac3fc9
--- /dev/null
+++ b/2D_Stage/tuneavideo/models/PoseGuider.py
@@ -0,0 +1,59 @@
+import os
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+from einops import rearrange
+
+class PoseGuider(nn.Module):
+ def __init__(self, noise_latent_channels=4):
+ super(PoseGuider, self).__init__()
+
+ self.conv_layers = nn.Sequential(
+ nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
+ nn.ReLU(),
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
+ nn.ReLU()
+ )
+
+ # Final projection layer
+ self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)
+
+ # Initialize layers
+ self._initialize_weights()
+
+ def _initialize_weights(self):
+ # Initialize weights with Gaussian distribution and zero out the final layer
+ for m in self.conv_layers:
+ if isinstance(m, nn.Conv2d):
+ init.normal_(m.weight, mean=0.0, std=0.02)
+ if m.bias is not None:
+ init.zeros_(m.bias)
+
+ init.zeros_(self.final_proj.weight)
+ if self.final_proj.bias is not None:
+ init.zeros_(self.final_proj.bias)
+
+ def forward(self, pose_image):
+ x = self.conv_layers(pose_image)
+ x = self.final_proj(x)
+
+ return x
+
+ @classmethod
+ def from_pretrained(pretrained_model_path):
+ if not os.path.exists(pretrained_model_path):
+ print(f"There is no model file in {pretrained_model_path}")
+ print(f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ...")
+
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
+ model = PoseGuider(noise_latent_channels=4)
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
+ print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M")
+
+ return model
diff --git a/2D_Stage/tuneavideo/models/attention.py b/2D_Stage/tuneavideo/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4be16c609d63b2d2f9541e124c03cb572dfc736
--- /dev/null
+++ b/2D_Stage/tuneavideo/models/attention.py
@@ -0,0 +1,344 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
+
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers import ModelMixin
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
+
+from einops import rearrange, repeat
+
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+class Transformer3DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ use_attn_temp: bool = False,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # Define input layers
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ # Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_attn_temp = use_attn_temp,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ if use_linear_projection:
+ self.proj_out = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
+ # Input
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
+
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ # Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ video_length=video_length
+ )
+
+ # Output
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+
+ output = hidden_states + residual
+
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ use_attn_temp: bool = False
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
+ self.use_attn_temp = use_attn_temp
+ # SC-Attn
+ self.attn1 = SparseCausalAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ # Cross-Attn
+ if cross_attention_dim is not None:
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ else:
+ self.attn2 = None
+
+ if cross_attention_dim is not None:
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ else:
+ self.norm2 = None
+
+ # Feed-forward
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.norm3 = nn.LayerNorm(dim)
+
+ # Temp-Attn
+ if self.use_attn_temp:
+ self.attn_temp = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
+ if not is_xformers_available():
+ print("Here is how to install it")
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ if self.attn2 is not None:
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ #self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
+ # SparseCausal-Attention
+ norm_hidden_states = (
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
+ )
+
+ if self.only_cross_attention:
+ hidden_states = (
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
+ )
+ else:
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
+
+ if self.attn2 is not None:
+ # Cross-Attention
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ hidden_states = (
+ self.attn2(
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ + hidden_states
+ )
+
+ # Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ # Temporal-Attention
+ if self.use_attn_temp:
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
+ norm_hidden_states = (
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
+ )
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
+
+ return hidden_states
+
+
+class SparseCausalAttention(CrossAttention):
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, use_full_attn=True):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ # query = rearrange(query, "(b f) d c -> b (f d) c", f=video_length)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ former_frame_index = torch.arange(video_length) - 1
+ former_frame_index[0] = 0
+
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
+ if not use_full_attn:
+ key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
+ else:
+ # key = torch.cat([key[:, [0] * video_length], key[:, [1] * video_length], key[:, [2] * video_length], key[:, [3] * video_length]], dim=2)
+ key_video_length = [key[:, [i] * video_length] for i in range(video_length)]
+ key = torch.cat(key_video_length, dim=2)
+ key = rearrange(key, "b f d c -> (b f) d c")
+
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
+ if not use_full_attn:
+ value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
+ else:
+ # value = torch.cat([value[:, [0] * video_length], value[:, [1] * video_length], value[:, [2] * video_length], value[:, [3] * video_length]], dim=2)
+ value_video_length = [value[:, [i] * video_length] for i in range(video_length)]
+ value = torch.cat(value_video_length, dim=2)
+ value = rearrange(value, "b f d c -> (b f) d c")
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
diff --git a/2D_Stage/tuneavideo/models/imageproj.py b/2D_Stage/tuneavideo/models/imageproj.py
new file mode 100644
index 0000000000000000000000000000000000000000..63e20527154594ef7a207b81c6520af2b07b8e50
--- /dev/null
+++ b/2D_Stage/tuneavideo/models/imageproj.py
@@ -0,0 +1,118 @@
+# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
+import math
+
+import torch
+import torch.nn as nn
+
+# FFN
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+def reshape_tensor(x, heads):
+ bs, length, width = x.shape
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, length, heads, -1)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs, heads, length, -1)
+ return x
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+class Resampler(nn.Module):
+ def __init__(
+ self,
+ dim=1024,
+ depth=8,
+ dim_head=64,
+ heads=16,
+ num_queries=8,
+ embedding_dim=768,
+ output_dim=1024,
+ ff_mult=4,
+ ):
+ super().__init__()
+
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+
+ self.proj_in = nn.Linear(embedding_dim, dim)
+
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ def forward(self, x):
+
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ x = self.proj_in(x)
+
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
\ No newline at end of file
diff --git a/2D_Stage/tuneavideo/models/refunet.py b/2D_Stage/tuneavideo/models/refunet.py
new file mode 100644
index 0000000000000000000000000000000000000000..8808e3243eddd47ddaab63f35b56edff4ebd75ee
--- /dev/null
+++ b/2D_Stage/tuneavideo/models/refunet.py
@@ -0,0 +1,125 @@
+import torch
+from einops import rearrange
+from typing import Any, Dict, Optional
+from diffusers.utils.import_utils import is_xformers_available
+from tuneavideo.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor
+class ReferenceOnlyAttnProc(torch.nn.Module):
+ def __init__(
+ self,
+ chained_proc,
+ enabled=False,
+ name=None
+ ) -> None:
+ super().__init__()
+ self.enabled = enabled
+ self.chained_proc = chained_proc
+ self.name = name
+
+ def __call__(
+ self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None,
+ mode="w", ref_dict: dict = None, is_cfg_guidance = False,num_views=4,
+ multiview_attention=True,
+ cross_domain_attention=False,
+ ) -> Any:
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ # print(self.enabled)
+ if self.enabled:
+ if mode == 'w':
+ ref_dict[self.name] = encoder_hidden_states
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=1,
+ multiview_attention=False,
+ cross_domain_attention=False,)
+ elif mode == 'r':
+ encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views)
+ if self.name in ref_dict:
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views,
+ multiview_attention=False,
+ cross_domain_attention=False,)
+ elif mode == 'm':
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)
+ elif mode == 'n':
+ encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views)
+ encoder_hidden_states = torch.cat([encoder_hidden_states], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views,
+ multiview_attention=False,
+ cross_domain_attention=False,)
+ else:
+ assert False, mode
+ else:
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
+ return res
+
+class RefOnlyNoisedUNet(torch.nn.Module):
+ def __init__(self, unet, train_sched, val_sched) -> None:
+ super().__init__()
+ self.unet = unet
+ self.train_sched = train_sched
+ self.val_sched = val_sched
+
+ unet_lora_attn_procs = dict()
+ for name, _ in unet.attn_processors.items():
+ if is_xformers_available():
+ default_attn_proc = XFormersMVAttnProcessor()
+ else:
+ default_attn_proc = MVAttnProcessor()
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name)
+
+ self.unet.set_attn_processor(unet_lora_attn_procs)
+
+ def __getattr__(self, name: str):
+ try:
+ return super().__getattr__(name)
+ except AttributeError:
+ return getattr(self.unet, name)
+
+ def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
+ if is_cfg_guidance:
+ encoder_hidden_states = encoder_hidden_states[1:]
+ class_labels = class_labels[1:]
+ self.unet(
+ noisy_cond_lat, timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ class_labels=class_labels,
+ cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
+ **kwargs
+ )
+
+ def forward(
+ self, sample, timestep, encoder_hidden_states, class_labels=None,
+ *args, cross_attention_kwargs,
+ down_block_res_samples=None, mid_block_res_sample=None,
+ **kwargs
+ ):
+ cond_lat = cross_attention_kwargs['cond_lat']
+ is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
+ noise = torch.randn_like(cond_lat)
+ if self.training:
+ noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
+ noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
+ else:
+ noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
+ noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
+ ref_dict = {}
+ self.forward_cond(
+ noisy_cond_lat, timestep,
+ encoder_hidden_states, class_labels,
+ ref_dict, is_cfg_guidance, **kwargs
+ )
+ weight_dtype = self.unet.dtype
+ return self.unet(
+ sample, timestep,
+ encoder_hidden_states, *args,
+ class_labels=class_labels,
+ cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
+ down_block_additional_residuals=[
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
+ ] if down_block_res_samples is not None else None,
+ mid_block_additional_residual=(
+ mid_block_res_sample.to(dtype=weight_dtype)
+ if mid_block_res_sample is not None else None
+ ),
+ **kwargs
+ )
\ No newline at end of file
diff --git a/2D_Stage/tuneavideo/models/resnet.py b/2D_Stage/tuneavideo/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b8b0fc2737f3335ea4e313568c1890326900f01
--- /dev/null
+++ b/2D_Stage/tuneavideo/models/resnet.py
@@ -0,0 +1,210 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange
+
+
+class InflatedConv3d(nn.Conv2d):
+ def forward(self, x):
+ video_length = x.shape[2]
+
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+ x = super().forward(x)
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
+
+ return x
+
+
+class Upsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ raise NotImplementedError
+ elif use_conv:
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
+
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, hidden_states, output_size=None):
+ assert hidden_states.shape[1] == self.channels
+
+ if self.use_conv_transpose:
+ raise NotImplementedError
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if output_size is None:
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
+ else:
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ if self.use_conv:
+ if self.name == "conv":
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = self.Conv2d_0(hidden_states)
+
+ return hidden_states
+
+
+class Downsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ raise NotImplementedError
+
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, hidden_states):
+ assert hidden_states.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ raise NotImplementedError
+
+ assert hidden_states.shape[1] == self.channels
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class ResnetBlock3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ time_emb_proj_out_channels = out_channels
+ elif self.time_embedding_norm == "scale_shift":
+ time_emb_proj_out_channels = out_channels * 2
+ else:
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
+
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ # temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, :, None, None].permute(0,2,1,3,4)
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
+class Mish(torch.nn.Module):
+ def forward(self, hidden_states):
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
\ No newline at end of file
diff --git a/2D_Stage/tuneavideo/models/transformer_mv2d.py b/2D_Stage/tuneavideo/models/transformer_mv2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf9ca74289ce6f991ac151131d8b01c85c2fa510
--- /dev/null
+++ b/2D_Stage/tuneavideo/models/transformer_mv2d.py
@@ -0,0 +1,1010 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.embeddings import ImagePositionalEmbeddings
+from diffusers.utils import BaseOutput, deprecate
+try:
+ from diffusers.utils import maybe_allow_in_graph
+except:
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
+from diffusers.models.embeddings import PatchEmbed
+from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils.import_utils import is_xformers_available
+
+from einops import rearrange
+import pdb
+import random
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+@dataclass
+class TransformerMV2DModelOutput(BaseOutput):
+ """
+ The output of [`Transformer2DModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
+ distributions for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+class TransformerMV2DModel(ModelMixin, ConfigMixin):
+ """
+ A 2D Transformer model for image-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ num_vector_embeds (`int`, *optional*):
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm",
+ norm_elementwise_affine: bool = True,
+ num_views: int = 1,
+ joint_attention: bool=False,
+ joint_attention_twice: bool=False,
+ multiview_attention: bool=True,
+ cross_domain_attention: bool=False
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
+ self.is_input_vectorized = num_vector_embeds is not None
+ self.is_input_patches = in_channels is not None and patch_size is not None
+
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
+ deprecation_message = (
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
+ )
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
+ norm_type = "ada_norm"
+
+ if self.is_input_continuous and self.is_input_vectorized:
+ raise ValueError(
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is None."
+ )
+ elif self.is_input_vectorized and self.is_input_patches:
+ raise ValueError(
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
+ " sure that either `num_vector_embeds` or `num_patches` is None."
+ )
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
+ raise ValueError(
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
+ )
+
+ # 2. Define input layers
+ if self.is_input_continuous:
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
+ else:
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
+
+ self.height = sample_size
+ self.width = sample_size
+ self.num_vector_embeds = num_vector_embeds
+ self.num_latent_pixels = self.height * self.width
+
+ self.latent_image_embedding = ImagePositionalEmbeddings(
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
+ )
+ elif self.is_input_patches:
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
+
+ self.height = sample_size
+ self.width = sample_size
+
+ self.patch_size = patch_size
+ self.pos_embed = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ )
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicMVTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ num_views=num_views,
+ joint_attention=joint_attention,
+ joint_attention_twice=joint_attention_twice,
+ multiview_attention=multiview_attention,
+ cross_domain_attention=cross_domain_attention
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ self.out_channels = in_channels if out_channels is None else out_channels
+ if self.is_input_continuous:
+ # TODO: should use out_channels for continuous projections
+ if use_linear_projection:
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
+ else:
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+ elif self.is_input_patches:
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ):
+ """
+ The [`Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ if self.is_input_continuous:
+ batch, _, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+ elif self.is_input_vectorized:
+ hidden_states = self.latent_image_embedding(hidden_states)
+ elif self.is_input_patches:
+ hidden_states = self.pos_embed(hidden_states)
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+
+ # 3. Output
+ if self.is_input_continuous:
+ if not self.use_linear_projection:
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+ output = hidden_states + residual
+ elif self.is_input_vectorized:
+ hidden_states = self.norm_out(hidden_states)
+ logits = self.out(hidden_states)
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
+ logits = logits.permute(0, 2, 1)
+
+ # log(p(x_0))
+ output = F.log_softmax(logits.double(), dim=1).float()
+ elif self.is_input_patches:
+ # TODO: cleanup!
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
+
+ # unpatchify
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+ )
+
+ if not return_dict:
+ return (output,)
+
+ return TransformerMV2DModelOutput(sample=output)
+
+
+@maybe_allow_in_graph
+class BasicMVTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm",
+ final_dropout: bool = False,
+ num_views: int = 1,
+ joint_attention: bool = False,
+ joint_attention_twice: bool = False,
+ multiview_attention: bool = True,
+ cross_domain_attention: bool = False
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+
+ self.multiview_attention = multiview_attention
+ self.cross_domain_attention = cross_domain_attention
+ # import pdb;pdb.set_trace()
+ self.attn1 = CustomAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ processor=MVAttnProcessor()
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = (
+ AdaLayerNorm(dim, num_embeds_ada_norm)
+ if self.use_ada_layer_norm
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ )
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ self.num_views = num_views
+
+ self.joint_attention = joint_attention
+
+ if self.joint_attention:
+ # Joint task -Attn
+ self.attn_joint = CustomJointAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ processor=JointAttnProcessor()
+ )
+ nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
+ self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+
+ self.joint_attention_twice = joint_attention_twice
+
+ if self.joint_attention_twice:
+ print("joint twice")
+ # Joint task -Attn
+ self.attn_joint_twice = CustomJointAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ processor=JointAttnProcessor()
+ )
+ nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data)
+ self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ):
+ assert attention_mask is None # not supported yet
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 1. Self-Attention
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ num_views=self.num_views,
+ multiview_attention=self.multiview_attention,
+ cross_domain_attention=self.cross_domain_attention,
+ **cross_attention_kwargs,
+ )
+
+
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ # joint attention twice
+ if self.joint_attention_twice:
+ norm_hidden_states = (
+ self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states)
+ )
+ hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states
+
+ # 2. Cross-Attention
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ if self.joint_attention:
+ norm_hidden_states = (
+ self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states)
+ )
+ hidden_states = self.attn_joint(norm_hidden_states) + hidden_states
+
+ return hidden_states
+
+
+class CustomAttention(Attention):
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
+ ):
+ processor = XFormersMVAttnProcessor()
+ self.set_processor(processor)
+ # print("using xformers attention processor")
+
+
+class CustomJointAttention(Attention):
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
+ ):
+ processor = XFormersJointAttnProcessor()
+ self.set_processor(processor)
+ # print("using xformers attention processor")
+
+class MVAttnProcessor:
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ num_views=1,
+ multiview_attention=True
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
+ # pdb.set_trace()
+ # multi-view self-attention
+ if multiview_attention:
+ if num_views <= 6:
+ # after use xformer; possible to train with 6 views
+ # key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
+ # value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
+ key = rearrange(key, '(b t) d c-> b (t d) c', t=num_views).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
+ value = rearrange(value, '(b t) d c-> b (t d) c', t=num_views).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
+
+ else:# apply sparse attention
+ pass
+ # print("use sparse attention")
+ # # seems that the sparse random sampling cause problems
+ # # don't use random sampling, just fix the indexes
+ # onekey = rearrange(key, "(b t) d c -> b t d c", t=num_views)
+ # onevalue = rearrange(value, "(b t) d c -> b t d c", t=num_views)
+ # allkeys = []
+ # allvalues = []
+ # all_indexes = {
+ # 0 : [0, 2, 3, 4],
+ # 1: [0, 1, 3, 5],
+ # 2: [0, 2, 3, 4],
+ # 3: [0, 2, 3, 4],
+ # 4: [0, 2, 3, 4],
+ # 5: [0, 1, 3, 5]
+ # }
+ # for jj in range(num_views):
+ # # valid_index = [x for x in range(0, num_views) if x!= jj]
+ # # indexes = random.sample(valid_index, 3) + [jj] + [0]
+ # indexes = all_indexes[jj]
+
+ # indexes = torch.tensor(indexes).long().to(key.device)
+ # allkeys.append(onekey[:, indexes])
+ # allvalues.append(onevalue[:, indexes])
+ # keys = torch.stack(allkeys, dim=1) # checked, should be dim=1
+ # values = torch.stack(allvalues, dim=1)
+ # key = rearrange(keys, 'b t f d c -> (b t) (f d) c')
+ # value = rearrange(values, 'b t f d c -> (b t) (f d) c')
+
+
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class XFormersMVAttnProcessor:
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ num_views=1.,
+ multiview_attention=True,
+ cross_domain_attention=False,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ # from yuancheng; here attention_mask is None
+ if attention_mask is not None:
+ # expand our mask's singleton query_tokens dimension:
+ # [batch*heads, 1, key_tokens] ->
+ # [batch*heads, query_tokens, key_tokens]
+ # so that it can be added as a bias onto the attention scores that xformers computes:
+ # [batch*heads, query_tokens, key_tokens]
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+ _, query_tokens, _ = hidden_states.shape
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key_raw = attn.to_k(encoder_hidden_states)
+ value_raw = attn.to_v(encoder_hidden_states)
+
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
+ # pdb.set_trace()
+ # multi-view self-attention
+ if multiview_attention:
+ key = rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
+ value = rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
+
+ if cross_domain_attention:
+ # memory efficient, cross domain attention
+ key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
+ value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
+ key_cross = torch.concat([key_1, key_0], dim=0)
+ value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
+ key = torch.cat([key, key_cross], dim=1)
+ value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
+ else:
+ # print("don't use multiview attention.")
+ key = key_raw
+ value = value_raw
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+
+class XFormersJointAttnProcessor:
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ num_tasks=2
+ ):
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ # from yuancheng; here attention_mask is None
+ if attention_mask is not None:
+ # expand our mask's singleton query_tokens dimension:
+ # [batch*heads, 1, key_tokens] ->
+ # [batch*heads, query_tokens, key_tokens]
+ # so that it can be added as a bias onto the attention scores that xformers computes:
+ # [batch*heads, query_tokens, key_tokens]
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+ _, query_tokens, _ = hidden_states.shape
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ assert num_tasks == 2 # only support two tasks now
+
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
+
+
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class JointAttnProcessor:
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ num_tasks=2
+ ):
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ assert num_tasks == 2 # only support two tasks now
+
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
+
+
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
\ No newline at end of file
diff --git a/2D_Stage/tuneavideo/models/unet.py b/2D_Stage/tuneavideo/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa156afe8f0f179cd5d4b2aef569073eafeb690c
--- /dev/null
+++ b/2D_Stage/tuneavideo/models/unet.py
@@ -0,0 +1,497 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import os
+import json
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers import ModelMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from .unet_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+)
+from .resnet import InflatedConv3d
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet3DConditionOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+class UNet3DConditionModel(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D",
+ ),
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
+ up_block_types: Tuple[str] = (
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ use_attn_temp: bool = False,
+ camera_input_dim: int = 12,
+ camera_hidden_dim: int = 320,
+ camera_output_dim: int = 1280,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ # camera metrix
+ # def init_linear(l, stddev):
+ # nn.init.normal_(l.weight, std=stddev)
+ # if l.bias is not None:
+ # nn.init.constant_(l.bias, 0.0)
+ # self.camera_embedding_1 = nn.Linear(camera_input_dim, camera_hidden_dim)
+ # self.camera_embedding_2 = nn.Linear(camera_hidden_dim, camera_output_dim)
+ # init_linear(self.camera_embedding_1, 0.25)
+ # init_linear(self.camera_embedding_2, 0.25)
+
+ self.camera_embedding = nn.Sequential(
+ nn.Linear(camera_input_dim, time_embed_dim),
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_attn_temp=use_attn_temp
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
+ self.mid_block = UNetMidBlock3DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the videos
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
+ only_cross_attention = list(reversed(only_cross_attention))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=reversed_attention_head_dim[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_attn_temp=use_attn_temp,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_slicable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_slicable_dims(module)
+
+ num_slicable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_slicable_layers * [1]
+
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ camera_matrixs: Optional[torch.Tensor] = None,
+ class_labels: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet3DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+ # time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb) #torch.Size([32, 1280])
+ emb = torch.unsqueeze(emb, 1)
+ if camera_matrixs is not None:
+ # came emb
+ cam_emb = self.camera_embedding(camera_matrixs)
+ # cam_emb = self.camera_embedding_2(cam_emb)
+ emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280])
+ emb = emb + cam_emb
+
+ # import pdb;pdb.set_trace()
+ if self.class_embedding is not None:
+ # if class_labels is None:
+ # raise ValueError("class_labels should be provided when num_class_embeds > 0")
+ if class_labels is not None:
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels)
+ emb = emb + class_emb
+
+ # pre-process
+ sample = self.conv_in(sample)
+
+ # down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # mid
+ sample = self.mid_block(
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+
+ # up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet3DConditionOutput(sample=sample)
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+ config["_class_name"] = cls.__name__
+ config["down_block_types"] = [
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D"
+ ]
+ config["up_block_types"] = [
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ]
+
+ from diffusers.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
+ # model = cls.from_config(config)
+ # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ # if not os.path.isfile(model_file):
+ # raise RuntimeError(f"{model_file} does not exist")
+ # state_dict = torch.load(model_file, map_location="cpu")
+
+ import safetensors
+ model = cls.from_config(config)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ if not os.path.isfile(model_file):
+ model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME)
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ else:
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ else:
+ state_dict = torch.load(model_file, map_location="cpu")
+
+ for k, v in model.state_dict().items():
+ if '_temp.' in k or 'camera_embedding' in k or 'class_embedding' in k:
+ state_dict.update({k: v})
+ for k in list(state_dict.keys()):
+ if 'camera_embedding_' in k:
+ v = state_dict.pop(k)
+ model.load_state_dict(state_dict)
+
+ return model
\ No newline at end of file
diff --git a/2D_Stage/tuneavideo/models/unet_blocks.py b/2D_Stage/tuneavideo/models/unet_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac9b78d42fbd35261c23775e54cfa88f6507925d
--- /dev/null
+++ b/2D_Stage/tuneavideo/models/unet_blocks.py
@@ -0,0 +1,596 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
+
+import torch
+from torch import nn
+
+# from .attention import Transformer3DModel
+from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ use_attn_temp=False,
+):
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock3D":
+ return DownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "CrossAttnDownBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
+ return CrossAttnDownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_attn_temp=use_attn_temp,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ use_attn_temp=False,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock3D":
+ return UpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "CrossAttnUpBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
+ return CrossAttnUpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_attn_temp=use_attn_temp,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock3DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ )
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class CrossAttnDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ use_attn_temp=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_attn_temp=use_attn_temp,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ use_attn_temp=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_attn_temp=use_attn_temp,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ attention_mask=None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
diff --git a/2D_Stage/tuneavideo/models/unet_mv2d_blocks.py b/2D_Stage/tuneavideo/models/unet_mv2d_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..009c774337feb5f5d94a7958174400a150e351f3
--- /dev/null
+++ b/2D_Stage/tuneavideo/models/unet_mv2d_blocks.py
@@ -0,0 +1,926 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import is_torch_version, logging
+# from diffusers.models.attention import AdaGroupNorm
+from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
+from diffusers.models.dual_transformer_2d import DualTransformer2DModel
+from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
+from tuneavideo.models.transformer_mv2d import TransformerMV2DModel
+
+from diffusers.models.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
+from diffusers.models.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ transformer_layers_per_block=1,
+ num_attention_heads=None,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ resnet_skip_time_act=False,
+ resnet_out_scale_factor=1.0,
+ cross_attention_norm=None,
+ attention_head_dim=None,
+ downsample_type=None,
+ num_views=1,
+ joint_attention: bool = False,
+ joint_attention_twice: bool = False,
+ multiview_attention: bool = True,
+ cross_domain_attention: bool=False
+):
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock2D":
+ return DownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "ResnetDownsampleBlock2D":
+ return ResnetDownsampleBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ )
+ elif down_block_type == "AttnDownBlock2D":
+ if add_downsample is False:
+ downsample_type = None
+ else:
+ downsample_type = downsample_type or "conv" # default to 'conv'
+ return AttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ downsample_type=downsample_type,
+ )
+ elif down_block_type == "CrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
+ return CrossAttnDownBlock2D(
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ # custom MV2D attention block
+ elif down_block_type == "CrossAttnDownBlockMV2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
+ return CrossAttnDownBlockMV2D(
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ num_views=num_views,
+ joint_attention=joint_attention,
+ joint_attention_twice=joint_attention_twice,
+ multiview_attention=multiview_attention,
+ cross_domain_attention=cross_domain_attention
+ )
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
+ return SimpleCrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif down_block_type == "SkipDownBlock2D":
+ return SkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "AttnSkipDownBlock2D":
+ return AttnSkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "DownEncoderBlock2D":
+ return DownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "AttnDownEncoderBlock2D":
+ return AttnDownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "KDownBlock2D":
+ return KDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif down_block_type == "KCrossAttnDownBlock2D":
+ return KCrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ add_self_attention=True if not add_downsample else False,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ transformer_layers_per_block=1,
+ num_attention_heads=None,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ resnet_skip_time_act=False,
+ resnet_out_scale_factor=1.0,
+ cross_attention_norm=None,
+ attention_head_dim=None,
+ upsample_type=None,
+ num_views=1,
+ joint_attention: bool = False,
+ joint_attention_twice: bool = False,
+ multiview_attention: bool = True,
+ cross_domain_attention: bool=False
+):
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock2D":
+ return UpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "ResnetUpsampleBlock2D":
+ return ResnetUpsampleBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ )
+ elif up_block_type == "CrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
+ return CrossAttnUpBlock2D(
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ # custom MV2D attention block
+ elif up_block_type == "CrossAttnUpBlockMV2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
+ return CrossAttnUpBlockMV2D(
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ num_views=num_views,
+ joint_attention=joint_attention,
+ joint_attention_twice=joint_attention_twice,
+ multiview_attention=multiview_attention,
+ cross_domain_attention=cross_domain_attention
+ )
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
+ return SimpleCrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif up_block_type == "AttnUpBlock2D":
+ if add_upsample is False:
+ upsample_type = None
+ else:
+ upsample_type = upsample_type or "conv" # default to 'conv'
+
+ return AttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ upsample_type=upsample_type,
+ )
+ elif up_block_type == "SkipUpBlock2D":
+ return SkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "AttnSkipUpBlock2D":
+ return AttnSkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "UpDecoderBlock2D":
+ return UpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temb_channels=temb_channels,
+ )
+ elif up_block_type == "AttnUpDecoderBlock2D":
+ return AttnUpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temb_channels=temb_channels,
+ )
+ elif up_block_type == "KUpBlock2D":
+ return KUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "KCrossAttnUpBlock2D":
+ return KCrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ )
+
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlockMV2DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ num_views: int = 1,
+ joint_attention: bool = False,
+ joint_attention_twice: bool = False,
+ multiview_attention: bool = True,
+ cross_domain_attention: bool=False
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if not dual_cross_attention:
+ attentions.append(
+ TransformerMV2DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ num_views=num_views,
+ joint_attention=joint_attention,
+ joint_attention_twice=joint_attention_twice,
+ multiview_attention=multiview_attention,
+ cross_domain_attention=cross_domain_attention
+ )
+ )
+ else:
+ raise NotImplementedError
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class CrossAttnUpBlockMV2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ num_views: int = 1,
+ joint_attention: bool = False,
+ joint_attention_twice: bool = False,
+ multiview_attention: bool = True,
+ cross_domain_attention: bool=False
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ TransformerMV2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ num_views=num_views,
+ joint_attention=joint_attention,
+ joint_attention_twice=joint_attention_twice,
+ multiview_attention=multiview_attention,
+ cross_domain_attention=cross_domain_attention
+ )
+ )
+ else:
+ raise NotImplementedError
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+ if num_views == 4:
+ self.gradient_checkpointing = False
+ else:
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ None, # timestep
+ None, # class_labels
+ cross_attention_kwargs,
+ attention_mask,
+ encoder_attention_mask,
+ **ckpt_kwargs,
+ )[0]
+ # hidden_states = attn(
+ # hidden_states,
+ # encoder_hidden_states=encoder_hidden_states,
+ # cross_attention_kwargs=cross_attention_kwargs,
+ # attention_mask=attention_mask,
+ # encoder_attention_mask=encoder_attention_mask,
+ # return_dict=False,
+ # )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class CrossAttnDownBlockMV2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ num_views: int = 1,
+ joint_attention: bool = False,
+ joint_attention_twice: bool = False,
+ multiview_attention: bool = True,
+ cross_domain_attention: bool=False
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ TransformerMV2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ num_views=num_views,
+ joint_attention=joint_attention,
+ joint_attention_twice=joint_attention_twice,
+ multiview_attention=multiview_attention,
+ cross_domain_attention=cross_domain_attention
+ )
+ )
+ else:
+ raise NotImplementedError
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+ if num_views == 4:
+ self.gradient_checkpointing = False
+ else:
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ additional_residuals=None,
+ ):
+ output_states = ()
+
+ blocks = list(zip(self.resnets, self.attentions))
+
+ for i, (resnet, attn) in enumerate(blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ None, # timestep
+ None, # class_labels
+ cross_attention_kwargs,
+ attention_mask,
+ encoder_attention_mask,
+ **ckpt_kwargs,
+ )[0]
+ else:
+ # import ipdb
+ # ipdb.set_trace()
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
+ if i == len(blocks) - 1 and additional_residuals is not None:
+ hidden_states = hidden_states + additional_residuals
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
diff --git a/2D_Stage/tuneavideo/models/unet_mv2d_condition.py b/2D_Stage/tuneavideo/models/unet_mv2d_condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..c719b84c6b8cc1cdbcf81d449511e05eede20d92
--- /dev/null
+++ b/2D_Stage/tuneavideo/models/unet_mv2d_condition.py
@@ -0,0 +1,1509 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+import os
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from einops import rearrange
+
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import UNet2DConditionLoadersMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
+from diffusers.models.embeddings import (
+ GaussianFourierProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
+from diffusers.models.unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ CrossAttnUpBlock2D,
+ DownBlock2D,
+ UNetMidBlock2DCrossAttn,
+ UNetMidBlock2DSimpleCrossAttn,
+ UpBlock2D,
+)
+from diffusers.utils import (
+ CONFIG_NAME,
+ DIFFUSERS_CACHE,
+ FLAX_WEIGHTS_NAME,
+ HF_HUB_OFFLINE,
+ SAFETENSORS_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ _add_variant,
+ _get_model_file,
+ deprecate,
+ is_accelerate_available,
+ is_torch_version,
+ logging,
+)
+from diffusers import __version__
+from tuneavideo.models.unet_mv2d_blocks import (
+ CrossAttnDownBlockMV2D,
+ CrossAttnUpBlockMV2D,
+ UNetMidBlockMV2DCrossAttn,
+ get_down_block,
+ get_up_block,
+)
+from diffusers.models.attention_processor import Attention, AttnProcessor
+from diffusers.utils.import_utils import is_xformers_available
+from tuneavideo.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor
+from tuneavideo.models.refunet import ReferenceOnlyAttnProc
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNetMV2DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockMV2D",
+ "CrossAttnDownBlockMV2D",
+ "CrossAttnDownBlockMV2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: int = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads=64,
+ num_views: int = 1,
+ joint_attention: bool = False,
+ joint_attention_twice: bool = False,
+ multiview_attention: bool = True,
+ cross_domain_attention: bool = False,
+ camera_input_dim: int = 12,
+ camera_hidden_dim: int = 320,
+ camera_output_dim: int = 1280,
+
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.camera_embedding = nn.Sequential(
+ nn.Linear(camera_input_dim, time_embed_dim),
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ num_views=num_views,
+ joint_attention=joint_attention,
+ joint_attention_twice=joint_attention_twice,
+ multiview_attention=multiview_attention,
+ cross_domain_attention=cross_domain_attention
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ # custom MV2D attention block
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ num_views=num_views,
+ joint_attention=joint_attention,
+ joint_attention_twice=joint_attention_twice,
+ multiview_attention=multiview_attention,
+ cross_domain_attention=cross_domain_attention
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim[-1],
+ attention_head_dim=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif mid_block_type is None:
+ self.mid_block = None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ num_views=num_views,
+ joint_attention=joint_attention,
+ joint_attention_twice=joint_attention_twice,
+ multiview_attention=multiview_attention,
+ cross_domain_attention=cross_domain_attention
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+
+ self.conv_act = get_activation(act_fn)
+
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(AttnProcessor())
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ # def _set_gradient_checkpointing(self, module, value=False):
+ # if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
+ # module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ camera_matrixs: Optional[torch.Tensor] = None,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNetMV2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # import pdb; pdb.set_trace()
+ if camera_matrixs is not None:
+ emb = torch.unsqueeze(emb, 1)
+ # came emb
+ cam_emb = self.camera_embedding(camera_matrixs)
+ # cam_emb = self.camera_embedding_2(cam_emb)
+ # import ipdb
+ # ipdb.set_trace()
+ emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280])
+ emb = emb + cam_emb
+ emb = rearrange(emb, "b f c -> (b f) c", f=emb.shape[1])
+
+ aug_emb = None
+
+ if self.class_embedding is not None and class_labels is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb, hint = self.add_embedding(image_embs, hint)
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ # 2. pre-process
+ sample = rearrange(sample, "b c f h w -> (b f) c h w", f=sample.shape[2])
+ sample = self.conv_in(sample)
+ # 3. down
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample += down_block_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+ # print("after down: ", sample.mean(), emb.mean())
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNetMV2DConditionOutput(sample=sample)
+
+ @classmethod
+ def from_pretrained_2d(
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ camera_embedding_type: str, num_views: int, sample_size: int,
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
+ projection_class_embeddings_input_dim: int=6, joint_attention: bool = False,
+ joint_attention_twice: bool = False, multiview_attention: bool = True,
+ cross_domain_attention: bool = False,
+ in_channels: int = 8, out_channels: int = 4, local_crossattn=False,
+ **kwargs
+ ):
+ r"""
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
+
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
+ train the model, set it back in training mode with `model.train()`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`~ModelMixin.save_pretrained`].
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info (`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ from_flax (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a Flax checkpoint save file.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if `device_map` contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ variant (`str`, *optional*):
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
+ loading `from_flax`.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
+ weights. If set to `False`, `safetensors` weights are not loaded.
+
+
+
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
+ `huggingface-cli login`. You can also activate the special
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ firewalled environment.
+
+
+
+ Example:
+
+ ```py
+ from diffusers import UNet2DConditionModel
+
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
+ ```
+
+ If you get the error message below, you need to finetune the weights for your downstream task:
+
+ ```bash
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
+ ```
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
+ force_download = kwargs.pop("force_download", False)
+ from_flax = kwargs.pop("from_flax", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ output_loading_info = kwargs.pop("output_loading_info", False)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ subfolder = kwargs.pop("subfolder", None)
+ device_map = kwargs.pop("device_map", None)
+ max_memory = kwargs.pop("max_memory", None)
+ offload_folder = kwargs.pop("offload_folder", None)
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
+ variant = kwargs.pop("variant", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ # if use_safetensors and not is_safetensors_available():
+ # raise ValueError(
+ # "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
+ # )
+
+ allow_pickle = False
+ if use_safetensors is None:
+ # use_safetensors = is_safetensors_available()
+ use_safetensors = False
+ allow_pickle = True
+
+ if device_map is not None and not is_accelerate_available():
+ raise NotImplementedError(
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
+ )
+
+ # Check if we can handle device_map and dispatching the weights
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `device_map=None`."
+ )
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
+
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "model",
+ "framework": "pytorch",
+ }
+
+ # load config
+ config, unused_kwargs, commit_hash = cls.load_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ return_commit_hash=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ device_map=device_map,
+ max_memory=max_memory,
+ offload_folder=offload_folder,
+ offload_state_dict=offload_state_dict,
+ user_agent=user_agent,
+ **kwargs,
+ )
+
+ # modify config
+ config["_class_name"] = cls.__name__
+ config['in_channels'] = in_channels
+ config['out_channels'] = out_channels
+ config['sample_size'] = sample_size # training resolution
+ config['num_views'] = num_views
+ config['joint_attention'] = joint_attention
+ config['joint_attention_twice'] = joint_attention_twice
+ config['multiview_attention'] = multiview_attention
+ config['cross_domain_attention'] = cross_domain_attention
+ config["down_block_types"] = [
+ "CrossAttnDownBlockMV2D",
+ "CrossAttnDownBlockMV2D",
+ "CrossAttnDownBlockMV2D",
+ "DownBlock2D"
+ ]
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
+ config["up_block_types"] = [
+ "UpBlock2D",
+ "CrossAttnUpBlockMV2D",
+ "CrossAttnUpBlockMV2D",
+ "CrossAttnUpBlockMV2D"
+ ]
+ config['class_embed_type'] = 'projection'
+ if camera_embedding_type == 'e_de_da_sincos':
+ config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6
+ else:
+ raise NotImplementedError
+
+ # load model
+ model_file = None
+ if from_flax:
+ raise NotImplementedError
+ else:
+ if use_safetensors:
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ commit_hash=commit_hash,
+ )
+ except IOError as e:
+ if not allow_pickle:
+ raise e
+ pass
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ commit_hash=commit_hash,
+ )
+
+ model = cls.from_config(config, **unused_kwargs)
+ if local_crossattn:
+ unet_lora_attn_procs = dict()
+ for name, _ in model.attn_processors.items():
+ if not name.endswith("attn1.processor"):
+ default_attn_proc = AttnProcessor()
+ elif is_xformers_available():
+ default_attn_proc = XFormersMVAttnProcessor()
+ else:
+ default_attn_proc = MVAttnProcessor()
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
+ )
+ model.set_attn_processor(unet_lora_attn_procs)
+ state_dict = load_state_dict(model_file, variant=variant)
+ model._convert_deprecated_attention_blocks(state_dict)
+
+ conv_in_weight = state_dict['conv_in.weight']
+ conv_out_weight = state_dict['conv_out.weight']
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
+ model,
+ state_dict,
+ model_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=True,
+ )
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
+ # initialize from the original SD structure
+ model.conv_in.weight.data[:,:4] = conv_in_weight
+
+ # whether to place all zero to new layers?
+ if zero_init_conv_in:
+ model.conv_in.weight.data[:,4:] = 0.
+
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
+ # initialize from the original SD structure
+ model.conv_out.weight.data[:,:4] = conv_out_weight
+ if out_channels == 8: # copy for the last 4 channels
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
+
+ if zero_init_camera_projection:
+ for p in model.class_embedding.parameters():
+ torch.nn.init.zeros_(p)
+
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ "error_msgs": error_msgs,
+ }
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ raise ValueError(
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+ )
+ elif torch_dtype is not None:
+ model = model.to(torch_dtype)
+
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.eval()
+ if output_loading_info:
+ return model, loading_info
+
+ return model
+
+ @classmethod
+ def _load_pretrained_model_2d(
+ cls,
+ model,
+ state_dict,
+ resolved_archive_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=False,
+ ):
+ # Retrieve missing & unexpected_keys
+ model_state_dict = model.state_dict()
+ loaded_keys = list(state_dict.keys())
+
+ expected_keys = list(model_state_dict.keys())
+
+ original_loaded_keys = loaded_keys
+
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ model_to_load = model
+
+ def _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+ ):
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+
+ if (
+ model_key in model_state_dict
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
+ ):
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+ if state_dict is not None:
+ # Whole checkpoint
+ mismatched_keys = _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ original_loaded_keys,
+ ignore_mismatched_sizes,
+ )
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
+
+ if len(error_msgs) > 0:
+ error_msg = "\n\t".join(error_msgs)
+ if "size mismatch" in error_msg:
+ error_msg += (
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+ )
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
+ " identical (initializing a BertForSequenceClassification model from a"
+ " BertForSequenceClassification model)."
+ )
+ else:
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ elif len(mismatched_keys) == 0:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
+ " without further training."
+ )
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join(
+ [
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+ for key, shape1, shape2 in mismatched_keys
+ ]
+ )
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
+ " able to use it for predictions and inference."
+ )
+
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
+
diff --git a/2D_Stage/tuneavideo/models/unet_mv2d_ref.py b/2D_Stage/tuneavideo/models/unet_mv2d_ref.py
new file mode 100644
index 0000000000000000000000000000000000000000..add370667c096523d47d77c3313e86ab2ab730e7
--- /dev/null
+++ b/2D_Stage/tuneavideo/models/unet_mv2d_ref.py
@@ -0,0 +1,1570 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+import os
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from einops import rearrange
+
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import UNet2DConditionLoadersMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
+from diffusers.models.embeddings import (
+ GaussianFourierProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from diffusers.models.lora import LoRALinearLayer
+
+from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
+from diffusers.models.unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ CrossAttnUpBlock2D,
+ DownBlock2D,
+ UNetMidBlock2DCrossAttn,
+ UNetMidBlock2DSimpleCrossAttn,
+ UpBlock2D,
+)
+from diffusers.utils import (
+ CONFIG_NAME,
+ DIFFUSERS_CACHE,
+ FLAX_WEIGHTS_NAME,
+ HF_HUB_OFFLINE,
+ SAFETENSORS_WEIGHTS_NAME,
+ WEIGHTS_NAME,
+ _add_variant,
+ _get_model_file,
+ deprecate,
+ is_accelerate_available,
+ is_torch_version,
+ logging,
+)
+from diffusers import __version__
+from tuneavideo.models.unet_mv2d_blocks import (
+ CrossAttnDownBlockMV2D,
+ CrossAttnUpBlockMV2D,
+ UNetMidBlockMV2DCrossAttn,
+ get_down_block,
+ get_up_block,
+)
+from diffusers.models.attention_processor import Attention, AttnProcessor
+from diffusers.utils.import_utils import is_xformers_available
+from tuneavideo.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor
+from tuneavideo.models.refunet import ReferenceOnlyAttnProc
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNetMV2DRefOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+class Identity(torch.nn.Module):
+ r"""A placeholder identity operator that is argument-insensitive.
+
+ Args:
+ args: any argument (unused)
+ kwargs: any keyword argument (unused)
+
+ Shape:
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
+ - Output: :math:`(*)`, same shape as the input.
+
+ Examples::
+
+ >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
+ >>> input = torch.randn(128, 20)
+ >>> output = m(input)
+ >>> print(output.size())
+ torch.Size([128, 20])
+
+ """
+ def __init__(self, scale=None, *args, **kwargs) -> None:
+ super(Identity, self).__init__()
+
+ def forward(self, input, *args, **kwargs):
+ return input
+
+
+
+class _LoRACompatibleLinear(nn.Module):
+ """
+ A Linear layer that can be used with LoRA.
+ """
+
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.lora_layer = lora_layer
+
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
+ self.lora_layer = lora_layer
+
+ def _fuse_lora(self):
+ pass
+
+ def _unfuse_lora(self):
+ pass
+
+ def forward(self, hidden_states, scale=None, lora_scale: int = 1):
+ return hidden_states
+
+class UNetMV2DRefModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockMV2D",
+ "CrossAttnDownBlockMV2D",
+ "CrossAttnDownBlockMV2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: int = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads=64,
+ num_views: int = 1,
+ joint_attention: bool = False,
+ joint_attention_twice: bool = False,
+ multiview_attention: bool = True,
+ cross_domain_attention: bool = False,
+ camera_input_dim: int = 12,
+ camera_hidden_dim: int = 320,
+ camera_output_dim: int = 1280,
+
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.camera_embedding = nn.Sequential(
+ nn.Linear(camera_input_dim, time_embed_dim),
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ num_views=num_views,
+ joint_attention=joint_attention,
+ joint_attention_twice=joint_attention_twice,
+ multiview_attention=multiview_attention,
+ cross_domain_attention=cross_domain_attention
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ # custom MV2D attention block
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ num_views=num_views,
+ joint_attention=joint_attention,
+ joint_attention_twice=joint_attention_twice,
+ multiview_attention=multiview_attention,
+ cross_domain_attention=cross_domain_attention
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim[-1],
+ attention_head_dim=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif mid_block_type is None:
+ self.mid_block = None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ num_views=num_views,
+ joint_attention=joint_attention,
+ joint_attention_twice=joint_attention_twice,
+ multiview_attention=multiview_attention,
+ cross_domain_attention=cross_domain_attention
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ # if norm_num_groups is not None:
+ # self.conv_norm_out = nn.GroupNorm(
+ # num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ # )
+
+ # self.conv_act = get_activation(act_fn)
+
+ # else:
+ # self.conv_norm_out = None
+ # self.conv_act = None
+
+ # conv_out_padding = (conv_out_kernel - 1) // 2
+ # self.conv_out = nn.Conv2d(
+ # block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ # )
+
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()
+ self.up_blocks[3].attentions[2].proj_out = Identity()
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "set_processor"):
+ processors[f"{name}.processor"] = module.processor
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ self.set_attn_processor(AttnProcessor())
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ camera_matrixs: Optional[torch.Tensor] = None,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNetMV2DRefOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # import pdb; pdb.set_trace()
+ if camera_matrixs is not None:
+ emb = torch.unsqueeze(emb, 1)
+ # came emb
+ cam_emb = self.camera_embedding(camera_matrixs)
+ # cam_emb = self.camera_embedding_2(cam_emb)
+ emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280])
+ emb = emb + cam_emb
+ emb = rearrange(emb, "b f c -> (b f) c", f=emb.shape[1])
+
+ aug_emb = None
+
+ if self.class_embedding is not None and class_labels is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb, hint = self.add_embedding(image_embs, hint)
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ # 2. pre-process
+ sample = rearrange(sample, "b c f h w -> (b f) c h w", f=sample.shape[2])
+ sample = self.conv_in(sample)
+ # 3. down
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample += down_block_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+ # print("after down: ", sample.mean(), emb.mean())
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # print("after mid: ", sample.mean())
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+
+ # 6. post-process
+ # if self.conv_norm_out:
+ # sample = self.conv_norm_out(sample)
+ # sample = self.conv_act(sample)
+ # sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNetMV2DRefOutput(sample=sample)
+
+ @classmethod
+ def from_pretrained_2d(
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ camera_embedding_type: str, num_views: int, sample_size: int,
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
+ projection_class_embeddings_input_dim: int=6, joint_attention: bool = False,
+ joint_attention_twice: bool = False, multiview_attention: bool = True,
+ cross_domain_attention: bool = False,
+ in_channels: int = 8, out_channels: int = 4, local_crossattn=False,
+ **kwargs
+ ):
+ r"""
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
+
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
+ train the model, set it back in training mode with `model.train()`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`~ModelMixin.save_pretrained`].
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info (`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ from_flax (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a Flax checkpoint save file.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if `device_map` contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ variant (`str`, *optional*):
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
+ loading `from_flax`.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
+ weights. If set to `False`, `safetensors` weights are not loaded.
+
+
+
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
+ `huggingface-cli login`. You can also activate the special
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ firewalled environment.
+
+
+
+ Example:
+
+ ```py
+ from diffusers import UNet2DConditionModel
+
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
+ ```
+
+ If you get the error message below, you need to finetune the weights for your downstream task:
+
+ ```bash
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
+ ```
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
+ force_download = kwargs.pop("force_download", False)
+ from_flax = kwargs.pop("from_flax", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ output_loading_info = kwargs.pop("output_loading_info", False)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ subfolder = kwargs.pop("subfolder", None)
+ device_map = kwargs.pop("device_map", None)
+ max_memory = kwargs.pop("max_memory", None)
+ offload_folder = kwargs.pop("offload_folder", None)
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
+ variant = kwargs.pop("variant", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ # if use_safetensors and not is_safetensors_available():
+ # raise ValueError(
+ # "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
+ # )
+
+ allow_pickle = False
+ if use_safetensors is None:
+ # use_safetensors = is_safetensors_available()
+ use_safetensors = False
+ allow_pickle = True
+
+ if device_map is not None and not is_accelerate_available():
+ raise NotImplementedError(
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
+ )
+
+ # Check if we can handle device_map and dispatching the weights
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `device_map=None`."
+ )
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
+
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "model",
+ "framework": "pytorch",
+ }
+
+ # load config
+ config, unused_kwargs, commit_hash = cls.load_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ return_commit_hash=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ device_map=device_map,
+ max_memory=max_memory,
+ offload_folder=offload_folder,
+ offload_state_dict=offload_state_dict,
+ user_agent=user_agent,
+ **kwargs,
+ )
+
+ # modify config
+ config["_class_name"] = cls.__name__
+ config['in_channels'] = in_channels
+ config['out_channels'] = out_channels
+ config['sample_size'] = sample_size # training resolution
+ config['num_views'] = num_views
+ config['joint_attention'] = joint_attention
+ config['joint_attention_twice'] = joint_attention_twice
+ config['multiview_attention'] = multiview_attention
+ config['cross_domain_attention'] = cross_domain_attention
+ config["down_block_types"] = [
+ "CrossAttnDownBlockMV2D",
+ "CrossAttnDownBlockMV2D",
+ "CrossAttnDownBlockMV2D",
+ "DownBlock2D"
+ ]
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
+ config["up_block_types"] = [
+ "UpBlock2D",
+ "CrossAttnUpBlockMV2D",
+ "CrossAttnUpBlockMV2D",
+ "CrossAttnUpBlockMV2D"
+ ]
+ config['class_embed_type'] = 'projection'
+ if camera_embedding_type == 'e_de_da_sincos':
+ config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6
+ else:
+ raise NotImplementedError
+
+ # load model
+ model_file = None
+ if from_flax:
+ raise NotImplementedError
+ else:
+ if use_safetensors:
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ commit_hash=commit_hash,
+ )
+ except IOError as e:
+ if not allow_pickle:
+ raise e
+ pass
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ commit_hash=commit_hash,
+ )
+
+ model = cls.from_config(config, **unused_kwargs)
+ if local_crossattn:
+ unet_lora_attn_procs = dict()
+ for name, _ in model.attn_processors.items():
+ if not name.endswith("attn1.processor"):
+ default_attn_proc = AttnProcessor()
+ elif is_xformers_available():
+ default_attn_proc = XFormersMVAttnProcessor()
+ else:
+ default_attn_proc = MVAttnProcessor()
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
+ )
+ model.set_attn_processor(unet_lora_attn_procs)
+ state_dict = load_state_dict(model_file, variant=variant)
+ model._convert_deprecated_attention_blocks(state_dict)
+
+ conv_in_weight = state_dict['conv_in.weight']
+ conv_out_weight = state_dict['conv_out.weight']
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
+ model,
+ state_dict,
+ model_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=True,
+ )
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
+ # initialize from the original SD structure
+ model.conv_in.weight.data[:,:4] = conv_in_weight
+
+ # whether to place all zero to new layers?
+ if zero_init_conv_in:
+ model.conv_in.weight.data[:,4:] = 0.
+
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
+ # initialize from the original SD structure
+ model.conv_out.weight.data[:,:4] = conv_out_weight
+ if out_channels == 8: # copy for the last 4 channels
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
+
+ if zero_init_camera_projection:
+ for p in model.class_embedding.parameters():
+ torch.nn.init.zeros_(p)
+
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ "error_msgs": error_msgs,
+ }
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ raise ValueError(
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+ )
+ elif torch_dtype is not None:
+ model = model.to(torch_dtype)
+
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.eval()
+ if output_loading_info:
+ return model, loading_info
+
+ return model
+
+ @classmethod
+ def _load_pretrained_model_2d(
+ cls,
+ model,
+ state_dict,
+ resolved_archive_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=False,
+ ):
+ # Retrieve missing & unexpected_keys
+ model_state_dict = model.state_dict()
+ loaded_keys = list(state_dict.keys())
+
+ expected_keys = list(model_state_dict.keys())
+
+ original_loaded_keys = loaded_keys
+
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ model_to_load = model
+
+ def _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+ ):
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+
+ if (
+ model_key in model_state_dict
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
+ ):
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+ if state_dict is not None:
+ # Whole checkpoint
+ mismatched_keys = _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ original_loaded_keys,
+ ignore_mismatched_sizes,
+ )
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
+
+ if len(error_msgs) > 0:
+ error_msg = "\n\t".join(error_msgs)
+ if "size mismatch" in error_msg:
+ error_msg += (
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+ )
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
+ " identical (initializing a BertForSequenceClassification model from a"
+ " BertForSequenceClassification model)."
+ )
+ else:
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ elif len(mismatched_keys) == 0:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
+ " without further training."
+ )
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join(
+ [
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+ for key, shape1, shape2 in mismatched_keys
+ ]
+ )
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
+ " able to use it for predictions and inference."
+ )
+
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
+
diff --git a/2D_Stage/tuneavideo/pipelines/pipeline_tuneavideo.py b/2D_Stage/tuneavideo/pipelines/pipeline_tuneavideo.py
new file mode 100644
index 0000000000000000000000000000000000000000..59687723d4190ba283616e68c6e799b4d2bd6225
--- /dev/null
+++ b/2D_Stage/tuneavideo/pipelines/pipeline_tuneavideo.py
@@ -0,0 +1,585 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+
+import tqdm
+
+import inspect
+from typing import Callable, List, Optional, Union
+from dataclasses import dataclass
+
+import numpy as np
+import torch
+
+from diffusers.utils import is_accelerate_available
+from packaging import version
+from transformers import CLIPTextModel, CLIPTokenizer
+import torchvision.transforms.functional as TF
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL
+from diffusers import DiffusionPipeline
+from diffusers.schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from diffusers.utils import deprecate, logging, BaseOutput
+
+from einops import rearrange
+
+from ..models.unet import UNet3DConditionModel
+from torchvision.transforms import InterpolationMode
+
+import ipdb
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class TuneAVideoPipelineOutput(BaseOutput):
+ videos: Union[torch.Tensor, np.ndarray]
+
+
+class TuneAVideoPipeline(DiffusionPipeline):
+ _optional_components = []
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet3DConditionModel,
+
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ ref_unet = None,
+ feature_extractor=None,
+ image_encoder=None
+ ):
+ super().__init__()
+ self.ref_unet = ref_unet
+ self.feature_extractor = feature_extractor
+ self.image_encoder = image_encoder
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ def enable_vae_slicing(self):
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ self.vae.disable_slicing()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+
+ @property
+ def _execution_device(self):
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_image(self, image_pil, device, num_images_per_prompt, do_classifier_free_guidance, img_proj=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ # image_pt = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values
+ # image_pt = image_pt.to(device=device, dtype=dtype)
+ # image_embeddings = self.image_encoder(image_pt).image_embeds
+ # image_embeddings = image_embeddings.unsqueeze(1)
+
+ # # image encoding
+ clip_image_mean = torch.as_tensor(self.feature_extractor.image_mean)[:,None,None].to(device, dtype=torch.float32)
+ clip_image_std = torch.as_tensor(self.feature_extractor.image_std)[:,None,None].to(device, dtype=torch.float32)
+ imgs_in_proc = TF.resize(image_pil, (self.feature_extractor.crop_size['height'], self.feature_extractor.crop_size['width']), interpolation=InterpolationMode.BICUBIC)
+ # do the normalization in float32 to preserve precision
+ imgs_in_proc = ((imgs_in_proc.float() - clip_image_mean) / clip_image_std).to(dtype)
+ if img_proj is None:
+ # (B*Nv, 1, 768)
+ image_embeddings = self.image_encoder(imgs_in_proc).image_embeds.unsqueeze(1)
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ # Note: repeat differently from official pipelines
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(num_images_per_prompt, 1, 1)
+ if do_classifier_free_guidance:
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
+ else:
+ if do_classifier_free_guidance:
+ negative_image_proc = torch.zeros_like(imgs_in_proc)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ imgs_in_proc = torch.cat([negative_image_proc, imgs_in_proc])
+
+ image_embeds = image_encoder(imgs_in_proc, output_hidden_states=True).hidden_states[-2]
+ image_embeddings = img_proj(image_embeds)
+
+ # image_embeddings_unet = rearrange(image_embeddings_unet, 'B Nv d c -> (B Nv) d c')
+
+ # image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device)
+ # image_pil = image_pil * 2.0 - 1.0
+ image_latents = self.vae.encode(image_pil* 2.0 - 1.0).latent_dist.mode() * self.vae.config.scaling_factor
+
+ # Note: repeat differently from official pipelines
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
+ image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1)
+
+ # if do_classifier_free_guidance:
+ # image_latents = torch.cat([torch.zeros_like(image_latents), image_latents])
+
+ return image_embeddings, image_latents
+
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def decode_latents(self, latents):
+ video_length = latents.shape[2]
+ latents = 1 / 0.18215 * latents
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
+ video = self.vae.decode(latents).sample
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
+ video = (video / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ video = video.cpu().float().numpy()
+ return video
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ rand_device = "cpu" if device.type == "mps" else device
+
+ if isinstance(generator, list):
+ shape = (1,) + shape[1:]
+ latents = [
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
+ for i in range(batch_size)
+ ]
+ latents = torch.cat(latents, dim=0).to(device)
+ else:
+ latents = torch.randn(shape, dtype=dtype).to(device)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ image: Union[str, List[str]],
+ video_length: Optional[int],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "tensor",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ camera_matrixs = None,
+ class_labels = None,
+ prompt_ids = None,
+ unet_condition_type = None,
+ pose_guider = None,
+ pose_image = None,
+ img_proj=None,
+ use_noise=True,
+ use_shifted_noise=False,
+ rescale = 0.7,
+ **kwargs,
+ ):
+ # Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+ if isinstance(image, list):
+ batch_size = len(image)
+ else:
+ batch_size = image.shape[0]
+ # assert batch_size >= video_length and batch_size % video_length == 0
+ # Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input image
+ # if isinstance(image, list):
+ # image_pil = image
+ # elif isinstance(image, torch.Tensor):
+ # image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
+ # encode input reference image
+ image_embeddings, image_latents = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance, img_proj=img_proj) #torch.Size([64, 1, 768]) torch.Size([64, 4, 32, 32])
+ image_latents = rearrange(image_latents, "(b f) c h w -> b c f h w", f=1) #torch.Size([64, 4, 1, 32, 32])
+
+ # Encode input prompt_id
+ # encoder_hidden_states = self.text_encoder(prompt_ids)[0] #torch.Size([32, 77, 768])
+
+ # Encode input prompt
+ text_embeddings = self._encode_prompt( #torch.Size([64, 77, 768])
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # Prepare latent variables
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents( #torch.Size([32, 4, 4, 32, 32])
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ video_length,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+ latents_dtype = latents.dtype
+ # import ipdb
+ # ipdb.set_trace()
+ # Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ # prepare camera_matrix
+ if camera_matrixs is not None:
+ camera_matrixs = torch.cat([camera_matrixs] * 2) if do_classifier_free_guidance else camera_matrixs #(64, 4, 12)
+ # Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ if pose_guider is not None:
+ if len(pose_image.shape) == 5:
+ pose_embeds = pose_guider(rearrange(pose_image, "b f c h w -> (b f) c h w"))
+ pose_embeds = rearrange(pose_embeds, "(b f) c h w-> b c f h w ", f=video_length)
+ else:
+ pose_embeds = pose_guider(pose_image).unsqueeze(0)
+ pose_embeds = torch.cat([pose_embeds]*2, dim=0)
+ # import ipdb
+ # ipdb.set_trace()
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(tqdm.tqdm(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ if pose_guider is not None:
+ latent_model_input = latent_model_input + pose_embeds
+
+ noise_cond = torch.randn_like(image_latents)
+ if use_noise:
+ cond_latents = self.scheduler.add_noise(image_latents, noise_cond, t)
+ else:
+ cond_latents = image_latents
+ cond_latent_model_input = torch.cat([cond_latents] * 2) if do_classifier_free_guidance else cond_latents
+ cond_latent_model_input = self.scheduler.scale_model_input(cond_latent_model_input, t)
+
+ # predict the noise residual
+ # ref text condition
+ ref_dict = {}
+ if self.ref_unet is not None:
+ noise_pred_cond = self.ref_unet(
+ cond_latent_model_input, #torch.Size([64, 4, 1, 32, 32])
+ t, #torch.Size([32])
+ encoder_hidden_states=text_embeddings.to(torch.float32), #torch.Size([64, 77, 768])
+ cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict)
+ ).sample.to(dtype=latents_dtype)
+
+ # if torch.isnan(noise_pred_cond).any():
+ # ipdb.set_trace()
+ # Predict the noise residual and compute loss
+ # model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, camera_matrixs).sample
+ # unet
+ #text condition for unet
+ text_embeddings_unet = text_embeddings.unsqueeze(1).repeat(1,latents.shape[2],1,1)
+ text_embeddings_unet = rearrange(text_embeddings_unet, 'B Nv d c -> (B Nv) d c')
+ #image condition for unet
+ image_embeddings_unet = image_embeddings.unsqueeze(1).repeat(1,latents.shape[2],1, 1)
+ image_embeddings_unet = rearrange(image_embeddings_unet, 'B Nv d c -> (B Nv) d c')
+
+ if unet_condition_type == 'text':
+ encoder_hidden_states_unet_cond = text_embeddings_unet
+ elif unet_condition_type == 'image':
+ encoder_hidden_states_unet_cond = image_embeddings_unet
+ else:
+ raise('need unet_condition_type')
+
+ if self.ref_unet is not None:
+ noise_pred = self.unet(
+ latent_model_input.to(torch.float32), #torch.Size([64, 4, 4, 32, 32])
+ t,
+ encoder_hidden_states=encoder_hidden_states_unet_cond.to(torch.float32),
+ camera_matrixs=camera_matrixs.to(torch.float32), #torch.Size([64, 4, 12])
+ cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
+ # cross_attention_kwargs=dict(mode="n", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
+ ).sample.to(dtype=latents_dtype)
+ else:
+ noise_pred = self.unet(
+ latent_model_input.to(torch.float32), #torch.Size([64, 4, 4, 32, 32])
+ t,
+ encoder_hidden_states=encoder_hidden_states_unet_cond.to(torch.float32),
+ camera_matrixs=camera_matrixs.to(torch.float32), #torch.Size([64, 4, 12])
+ # cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
+ cross_attention_kwargs=dict(mode="n", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
+ ).sample.to(dtype=latents_dtype)
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ if use_shifted_noise:
+ # Apply regular classifier-free guidance.
+ cfg = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ # Calculate standard deviations.
+ std_pos = noise_pred_text.std([1,2,3], keepdim=True)
+ std_cfg = cfg.std([1,2,3], keepdim=True)
+ # Apply guidance rescale with fused operations.
+ factor = std_pos / std_cfg
+ factor = rescale * factor + (1 - rescale)
+ noise_pred = cfg * factor
+ else:
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ # noise_pred_uncond_, noise_pred_text_ = noise_pred_cond.chunk(2)
+ # noise_pred_cond = noise_pred_uncond_ + guidance_scale * (noise_pred_text_ - noise_pred_uncond_)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ noise_pred = rearrange(noise_pred, "(b f) c h w -> b c f h w", f=video_length)
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+ # noise_pred_cond = rearrange(noise_pred_cond, "(b f) c h w -> b c f h w", f=1)
+ # cond_latents = self.scheduler.step(noise_pred_cond, t, cond_latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # Post-processing
+ video = self.decode_latents(latents)
+
+ # Convert to tensor
+ if output_type == "tensor":
+ video = torch.from_numpy(video)
+
+ if not return_dict:
+ return video
+
+ return TuneAVideoPipelineOutput(videos=video)
diff --git a/2D_Stage/tuneavideo/util.py b/2D_Stage/tuneavideo/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c4daa393169f319084632b4a1f172d8ba981bfc
--- /dev/null
+++ b/2D_Stage/tuneavideo/util.py
@@ -0,0 +1,128 @@
+import os
+import imageio
+import numpy as np
+from typing import Union
+import cv2
+import torch
+import torchvision
+
+from tqdm import tqdm
+from einops import rearrange
+
+def shifted_noise(betas, image_d=512, noise_d=256, shifted_noise=True):
+ alphas = 1 - betas
+ alphas_bar = torch.cumprod(alphas, dim=0)
+ d = (image_d / noise_d) ** 2
+ if shifted_noise:
+ alphas_bar = alphas_bar / (d - (d - 1) * alphas_bar)
+ alphas_bar_sqrt = torch.sqrt(alphas_bar)
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+ # Shift so last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+ # Scale so first timestep is back to old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
+ alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt ** 2
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+ return betas
+
+def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ outputs = []
+ for x in videos:
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = (x * 255).numpy().astype(np.uint8)
+ outputs.append(x)
+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ imageio.mimsave(path, outputs, duration=1000/fps)
+
+def save_imgs_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ for i, x in enumerate(videos):
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = (x * 255).numpy().astype(np.uint8)
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ cv2.imwrite(os.path.join(path, f'view_{i}.png'), x[:,:,::-1])
+
+def imgs_grid(videos: torch.Tensor, rescale=False, n_rows=4, fps=8):
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ image_list = []
+ for i, x in enumerate(videos):
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = (x * 255).numpy().astype(np.uint8)
+ # image_list.append(x[:,:,::-1])
+ image_list.append(x)
+ return image_list
+
+# DDIM Inversion
+@torch.no_grad()
+def init_prompt(prompt, pipeline):
+ uncond_input = pipeline.tokenizer(
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
+ return_tensors="pt"
+ )
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
+ text_input = pipeline.tokenizer(
+ [prompt],
+ padding="max_length",
+ max_length=pipeline.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
+ context = torch.cat([uncond_embeddings, text_embeddings])
+
+ return context
+
+
+def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
+ timestep, next_timestep = min(
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
+ beta_prod_t = 1 - alpha_prod_t
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
+ return next_sample
+
+
+def get_noise_pred_single(latents, t, context, unet):
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
+ return noise_pred
+
+
+@torch.no_grad()
+def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
+ context = init_prompt(prompt, pipeline)
+ uncond_embeddings, cond_embeddings = context.chunk(2)
+ all_latent = [latent]
+ latent = latent.clone().detach()
+ for i in tqdm(range(num_inv_steps)):
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
+ noise_pred = get_noise_pred_single(latent.to(torch.float32), t, cond_embeddings.to(torch.float32), pipeline.unet)
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
+ all_latent.append(latent)
+ return all_latent
+
+
+@torch.no_grad()
+def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
+ return ddim_latents
diff --git a/2D_Stage/webui.py b/2D_Stage/webui.py
new file mode 100644
index 0000000000000000000000000000000000000000..14e37c9a47a8c5a906407725a3a2f58f309bc3e7
--- /dev/null
+++ b/2D_Stage/webui.py
@@ -0,0 +1,323 @@
+import gradio as gr
+from PIL import Image
+import glob
+
+import io
+import argparse
+import inspect
+import os
+import random
+from typing import Dict, Optional, Tuple
+from omegaconf import OmegaConf
+import numpy as np
+
+import torch
+import torch.utils.checkpoint
+
+from accelerate.logging import get_logger
+from accelerate.utils import set_seed
+from diffusers import AutoencoderKL, DDIMScheduler
+from diffusers.utils import check_min_version
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection
+from torchvision import transforms
+
+from tuneavideo.models.unet_mv2d_condition import UNetMV2DConditionModel
+from tuneavideo.models.unet_mv2d_ref import UNetMV2DRefModel
+from tuneavideo.models.PoseGuider import PoseGuider
+from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
+from tuneavideo.util import shifted_noise
+from einops import rearrange
+import PIL
+from PIL import Image
+from torchvision.utils import save_image
+import json
+import cv2
+
+import onnxruntime as rt
+from huggingface_hub.file_download import hf_hub_download
+from rm_anime_bg.cli import get_mask, SCALE
+
+from huggingface_hub import hf_hub_download, list_repo_files
+
+repo_id = "zjpshadow/CharacterGen"
+all_files = list_repo_files(repo_id, revision="main")
+
+for file in all_files:
+ if os.path.exists("../" + file):
+ continue
+ if file.startswith("2D_Stage"):
+ hf_hub_download(repo_id, file, local_dir="../")
+
+class rm_bg_api:
+
+ def __init__(self, force_cpu: Optional[bool] = True):
+ session_infer_path = hf_hub_download(
+ repo_id="skytnt/anime-seg", filename="isnetis.onnx",
+ )
+ providers: list[str] = ["CPUExecutionProvider"]
+ if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers():
+ providers = ["CUDAExecutionProvider"]
+
+ self.session_infer = rt.InferenceSession(
+ session_infer_path, providers=providers,
+ )
+
+ def remove_background(
+ self,
+ imgs: list[np.ndarray],
+ alpha_min: float,
+ alpha_max: float,
+ ) -> list:
+ process_imgs = []
+ for img in imgs:
+ # CHANGE to RGB
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
+ mask = get_mask(self.session_infer, img)
+
+ mask[mask < alpha_min] = 0.0 # type: ignore
+ mask[mask > alpha_max] = 1.0 # type: ignore
+
+ img_after = (mask * img + SCALE * (1 - mask)).astype(np.uint8) # type: ignore
+ mask = (mask * SCALE).astype(np.uint8) # type: ignore
+ img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8)
+ mask = mask.repeat(3, axis=2)
+ process_imgs.append(Image.fromarray(img_after))
+ return process_imgs
+
+check_min_version("0.24.0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+def get_bg_color(bg_color):
+ if bg_color == 'white':
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
+ elif bg_color == 'black':
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
+ elif bg_color == 'gray':
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
+ elif bg_color == 'random':
+ bg_color = np.random.rand(3)
+ elif isinstance(bg_color, float):
+ bg_color = np.array([bg_color] * 3, dtype=np.float32)
+ else:
+ raise NotImplementedError
+ return bg_color
+
+def process_image(image, totensor):
+ if not image.mode == "RGBA":
+ image = image.convert("RGBA")
+
+ # Find non-transparent pixels
+ non_transparent = np.nonzero(np.array(image)[..., 3])
+ min_x, max_x = non_transparent[1].min(), non_transparent[1].max()
+ min_y, max_y = non_transparent[0].min(), non_transparent[0].max()
+ image = image.crop((min_x, min_y, max_x, max_y))
+
+ # paste to center
+ max_dim = max(image.width, image.height)
+ max_height = max_dim
+ max_width = int(max_dim / 3 * 2)
+ new_image = Image.new("RGBA", (max_width, max_height))
+ left = (max_width - image.width) // 2
+ top = (max_height - image.height) // 2
+ new_image.paste(image, (left, top))
+
+ image = new_image.resize((512, 768), resample=PIL.Image.BICUBIC)
+ image = np.array(image)
+ image = image.astype(np.float32) / 255.
+ assert image.shape[-1] == 4 # RGBA
+ alpha = image[..., 3:4]
+ bg_color = get_bg_color("gray")
+ image = image[..., :3] * alpha + bg_color * (1 - alpha)
+ # save image
+ # new_image = Image.fromarray((image * 255).astype(np.uint8))
+ # new_image.save("input.png")
+ return totensor(image)
+
+class Inference_API:
+
+ def __init__(self):
+ self.validation_pipeline = None
+
+ @torch.no_grad()
+ def inference(self, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type,
+ pose_guider=None, use_noise=True, use_shifted_noise=False, noise_d=256, crop=False, seed=100, timestep=20):
+ set_seed(seed)
+ # Get the validation pipeline
+ if self.validation_pipeline is None:
+ noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
+ if use_shifted_noise:
+ print(f"enable shifted noise for {val_height} to {noise_d}")
+ betas = shifted_noise(noise_scheduler.betas, image_d=val_height, noise_d=noise_d)
+ noise_scheduler.betas = betas
+ noise_scheduler.alphas = 1 - betas
+ noise_scheduler.alphas_cumprod = torch.cumprod(noise_scheduler.alphas, dim=0)
+ self.validation_pipeline = TuneAVideoPipeline(
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, ref_unet=ref_unet,feature_extractor=feature_extractor,image_encoder=image_encoder,
+ scheduler=noise_scheduler
+ )
+ self.validation_pipeline.enable_vae_slicing()
+ self.validation_pipeline.set_progress_bar_config(disable=True)
+
+ totensor = transforms.ToTensor()
+
+ metas = json.load(open("./material/pose.json", "r"))
+ cameras = []
+ pose_images = []
+ input_path = "./material"
+ for lm in metas:
+ cameras.append(torch.tensor(np.array(lm[0]).reshape(4, 4).transpose(1,0)[:3, :4]).reshape(-1))
+ if not crop:
+ pose_images.append(totensor(np.asarray(Image.open(os.path.join(input_path, lm[1])).resize(
+ (val_height, val_width), resample=PIL.Image.BICUBIC)).astype(np.float32) / 255.))
+ else:
+ pose_image = Image.open(os.path.join(input_path, lm[1]))
+ crop_area = (128, 0, 640, 768)
+ pose_images.append(totensor(np.array(pose_image.crop(crop_area)).astype(np.float32)) / 255.)
+ camera_matrixs = torch.stack(cameras).unsqueeze(0).to("cuda")
+ pose_imgs_in = torch.stack(pose_images).to("cuda")
+ prompts = "high quality, best quality"
+ prompt_ids = tokenizer(
+ prompts, max_length=tokenizer.model_max_length, padding="max_length", truncation=True,
+ return_tensors="pt"
+ ).input_ids[0]
+
+ # (B*Nv, 3, H, W)
+ B = 1
+ weight_dtype = torch.bfloat16
+ imgs_in = process_image(input_image, totensor)
+ imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")
+
+ with torch.autocast("cuda", dtype=weight_dtype):
+ imgs_in = imgs_in.to("cuda")
+ # B*Nv images
+ out = self.validation_pipeline(prompt=prompts, image=imgs_in.to(weight_dtype), generator=generator,
+ num_inference_steps=timestep,
+ camera_matrixs=camera_matrixs.to(weight_dtype), prompt_ids=prompt_ids,
+ height=val_height, width=val_width, unet_condition_type=unet_condition_type,
+ pose_guider=None, pose_image=pose_imgs_in, use_noise=use_noise,
+ use_shifted_noise=use_shifted_noise, **validation).videos
+ out = rearrange(out, "B C f H W -> (B f) C H W", f=validation.video_length)
+
+ image_outputs = []
+ for bs in range(4):
+ img_buf = io.BytesIO()
+ save_image(out[bs], img_buf, format='PNG')
+ img_buf.seek(0)
+ img = Image.open(img_buf)
+ image_outputs.append(img)
+ torch.cuda.empty_cache()
+ return image_outputs
+
+@torch.no_grad()
+def main(
+ pretrained_model_path: str,
+ image_encoder_path: str,
+ ckpt_dir: str,
+ validation: Dict,
+ local_crossattn: bool = True,
+ unet_from_pretrained_kwargs=None,
+ unet_condition_type=None,
+ use_pose_guider=False,
+ use_noise=True,
+ use_shifted_noise=False,
+ noise_d=256
+):
+ *_, config = inspect.getargvalues(inspect.currentframe())
+
+ device = "cuda"
+
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path)
+ feature_extractor = CLIPImageProcessor()
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
+ unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
+ ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
+ if use_pose_guider:
+ pose_guider = PoseGuider(noise_latent_channels=4).to("cuda")
+ else:
+ pose_guider = None
+
+ unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model.bin"), map_location="cpu")
+ if use_pose_guider:
+ pose_guider_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_1.bin"), map_location="cpu")
+ ref_unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_2.bin"), map_location="cpu")
+ pose_guider.load_state_dict(pose_guider_params)
+ else:
+ ref_unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_1.bin"), map_location="cpu")
+ unet.load_state_dict(unet_params)
+ ref_unet.load_state_dict(ref_unet_params)
+
+ weight_dtype = torch.float16
+
+ text_encoder.to(device, dtype=weight_dtype)
+ image_encoder.to(device, dtype=weight_dtype)
+ vae.to(device, dtype=weight_dtype)
+ ref_unet.to(device, dtype=weight_dtype)
+ unet.to(device, dtype=weight_dtype)
+
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ ref_unet.requires_grad_(False)
+
+ generator = torch.Generator(device="cuda")
+ inferapi = Inference_API()
+ remove_api = rm_bg_api()
+ def gen4views(image, width, height, seed, timestep, remove_bg):
+ if remove_bg:
+ image = remove_api.remove_background(
+ imgs=[np.array(image)],
+ alpha_min=0.1,
+ alpha_max=0.9,
+ )[0]
+ return inferapi.inference(
+ image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, text_encoder, pretrained_model_path,
+ generator, validation, width, height, unet_condition_type,
+ pose_guider=pose_guider, use_noise=use_noise, use_shifted_noise=use_shifted_noise, noise_d=noise_d,
+ crop=True, seed=seed, timestep=timestep
+ )
+
+ with gr.Blocks() as demo:
+ gr.Markdown("# [SIGGRAPH'24] CharacterGen: Efficient 3D Character Generation from Single Images with Multi-View Pose Calibration")
+ gr.Markdown("# 2D Stage: One Image to Four Views of Character Image")
+ gr.Markdown("**Please Upload the Image without background, and the pictures uploaded should preferably be full-body frontal photos.**")
+ with gr.Row():
+ with gr.Column():
+ img_input = gr.Image(type="pil", label="Upload Image(without background)", image_mode="RGBA", width=768, height=512)
+ gr.Examples(
+ label="Example Images",
+ examples=glob.glob("./material/examples/*.png"),
+ inputs=[img_input]
+ )
+ with gr.Row():
+ width_input = gr.Number(label="Width", value=512)
+ height_input = gr.Number(label="Height", value=768)
+ seed_input = gr.Number(label="Seed", value=2333)
+ remove_bg = gr.Checkbox(label="Remove Background (with algorithm)", value=False)
+ timestep = gr.Slider(minimum=10, maximum=70, step=1, value=40, label="Timesteps")
+ with gr.Column():
+ button = gr.Button(value="Generate")
+ output = gr.Gallery(label="4 views of Character Image")
+
+ button.click(
+ fn=gen4views,
+ inputs=[img_input, width_input, height_input, seed_input, timestep, remove_bg],
+ outputs=[output]
+ )
+
+ demo.launch()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="./configs/infer.yaml")
+ args = parser.parse_args()
+
+ main(**OmegaConf.load(args.config))
\ No newline at end of file
diff --git a/3D_Stage/configs/infer.yaml b/3D_Stage/configs/infer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c349dbdffe7355daef1209311a7efda83be54b79
--- /dev/null
+++ b/3D_Stage/configs/infer.yaml
@@ -0,0 +1,104 @@
+system_cls: lrm.systems.multiview_lrm.MultiviewLRM
+data:
+ cond_width: 504
+ cond_height: 504
+
+system:
+ weights: ./models/lrm.ckpt
+
+ weights_ignore_modules:
+ - decoder.heads.density
+
+ check_train_every_n_steps: 100
+
+ camera_embedder_cls: lrm.models.camera.LinearCameraEmbedder
+ camera_embedder:
+ in_channels: 16
+ out_channels: 768
+ conditions:
+ - c2w_cond
+
+ # image tokenizer transforms input images to tokens
+ image_tokenizer_cls: lrm.models.tokenizers.image.DINOV2SingleImageTokenizer
+ image_tokenizer:
+ pretrained_model_name_or_path: "./models/base"
+ freeze_backbone_params: false
+ enable_memory_efficient_attention: true
+ enable_gradient_checkpointing: true
+ # camera modulation to the DINO transformer layers
+ modulation: true
+ modulation_zero_init: true
+ modulation_single_layer: true
+ modulation_cond_dim: ${system.camera_embedder.out_channels}
+
+ # tokenizer gives a tokenized representation for the 3D scene
+ # triplane tokens in this case
+ tokenizer_cls: lrm.models.tokenizers.triplane.TriplaneLearnablePositionalEmbedding
+ tokenizer:
+ plane_size: 32
+ num_channels: 512
+
+ # backbone network is a transformer that takes scene tokens (potentially with conditional image tokens)
+ # and outputs scene tokens of the same size
+ backbone_cls: lrm.models.transformers.transformer_1d.Transformer1D
+ backbone:
+ in_channels: ${system.tokenizer.num_channels}
+ num_attention_heads: 16
+ attention_head_dim: 64
+ num_layers: 12
+ cross_attention_dim: 768 # hard-code, =DINO feature dim
+ # camera modulation to the transformer layers
+ # if not needed, set norm_type=layer_norm and do not specify cond_dim_ada_norm_continuous
+ norm_type: "layer_norm"
+ enable_memory_efficient_attention: true
+ gradient_checkpointing: true
+
+ # post processor takes scene tokens and outputs the final scene parameters that will be used for rendering
+ # in this case, triplanes are upsampled and the features are condensed
+ post_processor_cls: lrm.models.networks.TriplaneUpsampleNetwork
+ post_processor:
+ in_channels: 512
+ out_channels: 80
+
+ renderer_cls: lrm.models.renderers.triplane_dmtet.TriplaneDMTetRenderer
+ renderer:
+ radius: 0.6 # slightly larger than 0.5
+ feature_reduction: concat
+ sdf_bias: -2.
+ tet_dir: "./load/tets/"
+ isosurface_resolution: 256
+ enable_isosurface_grid_deformation: false
+ sdf_activation: negative
+
+ decoder_cls: lrm.models.networks.MultiHeadMLP
+ decoder:
+ in_channels: 240 # 3 * 80
+ n_neurons: 64
+ n_hidden_layers_share: 8
+ heads:
+ - name: sdf
+ out_channels: 1
+ n_hidden_layers: 1
+ output_activation: null
+ - name: features
+ out_channels: 3
+ n_hidden_layers: 1
+ output_activation: null # activate in material
+ activation: silu
+ chunk_mode: deferred
+ chunk_size: 131072
+
+ exporter:
+ fmt: "obj"
+ #visual: "vertex"
+ visual: "uv"
+ save_uv: True
+ save_texture: True
+ uv_unwrap_method: "open3d"
+ output_path: "./outputs"
+
+ material_cls: lrm.models.materials.no_material.NoMaterial
+
+ background_cls: lrm.models.background.solid_color_background.SolidColorBackground
+ background:
+ color: [0.5, 0.5, 0.5]
\ No newline at end of file
diff --git a/3D_Stage/load/tets/generate_tets.py b/3D_Stage/load/tets/generate_tets.py
new file mode 100644
index 0000000000000000000000000000000000000000..424d852d25105fae2d03d6a0d269bc7c2797c6bd
--- /dev/null
+++ b/3D_Stage/load/tets/generate_tets.py
@@ -0,0 +1,58 @@
+# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import os
+
+import numpy as np
+
+"""
+This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet,
+to generate a tet grid
+1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet`
+2) Run the function below to generate a file `cube_32_tet.tet`
+"""
+
+
+def generate_tetrahedron_grid_file(res=32, root=".."):
+ frac = 1.0 / res
+ command = f"cd {root}; ./quartet meshes/cube.obj {frac} meshes/cube_{res}_tet.tet -s meshes/cube_boundary_{res}.obj"
+ os.system(command)
+
+
+"""
+This code segment shows how to convert from a quartet .tet file to compressed npz file
+"""
+
+
+def convert_from_quartet_to_npz(quartetfile="cube_32_tet.tet", npzfile="32_tets"):
+ file1 = open(quartetfile, "r")
+ header = file1.readline()
+ numvertices = int(header.split(" ")[1])
+ numtets = int(header.split(" ")[2])
+ print(numvertices, numtets)
+
+ # load vertices
+ vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices)
+ print(vertices.shape)
+
+ # load indices
+ indices = np.loadtxt(
+ quartetfile, dtype=int, skiprows=1 + numvertices, max_rows=numtets
+ )
+ print(indices.shape)
+
+ np.savez_compressed(npzfile, vertices=vertices, indices=indices)
+
+
+root = "/home/gyc/quartet"
+for res in [300, 350, 400]:
+ generate_tetrahedron_grid_file(res, root)
+ convert_from_quartet_to_npz(
+ os.path.join(root, f"meshes/cube_{res}_tet.tet"), npzfile=f"{res}_tets"
+ )
diff --git a/3D_Stage/lrm/__init__.py b/3D_Stage/lrm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..69f5951a4991d24895d3eba9c7f285cf7ec2e3e6
--- /dev/null
+++ b/3D_Stage/lrm/__init__.py
@@ -0,0 +1,29 @@
+import importlib
+
+
+def find(cls_string):
+ module_string = ".".join(cls_string.split(".")[:-1])
+ cls_name = cls_string.split(".")[-1]
+ module = importlib.import_module(module_string, package=None)
+ cls = getattr(module, cls_name)
+ return cls
+
+
+### grammar sugar for logging utilities ###
+import logging
+
+logger = logging.getLogger("pytorch_lightning")
+
+from pytorch_lightning.utilities.rank_zero import (
+ rank_zero_debug,
+ rank_zero_info,
+ rank_zero_only,
+)
+
+debug = rank_zero_debug
+info = rank_zero_info
+
+
+@rank_zero_only
+def warn(*args, **kwargs):
+ logger.warn(*args, **kwargs)
diff --git a/3D_Stage/lrm/models/__init__.py b/3D_Stage/lrm/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3D_Stage/lrm/models/background/__init__.py b/3D_Stage/lrm/models/background/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3D_Stage/lrm/models/background/base.py b/3D_Stage/lrm/models/background/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..3900de842573a60b6343f4590229aacdd3e0ecbb
--- /dev/null
+++ b/3D_Stage/lrm/models/background/base.py
@@ -0,0 +1,24 @@
+import random
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import lrm
+from ...utils.base import BaseModule
+from ...utils.typing import *
+
+
+class BaseBackground(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ pass
+
+ cfg: Config
+
+ def configure(self):
+ pass
+
+ def forward(self, dirs: Float[Tensor, "B H W 3"]) -> Float[Tensor, "B H W Nc"]:
+ raise NotImplementedError
diff --git a/3D_Stage/lrm/models/background/solid_color_background.py b/3D_Stage/lrm/models/background/solid_color_background.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fad75512bcce180ae1c0c45830b696460bb098d
--- /dev/null
+++ b/3D_Stage/lrm/models/background/solid_color_background.py
@@ -0,0 +1,58 @@
+import random
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import lrm
+from .base import BaseBackground
+from ...utils.typing import *
+
+
+class SolidColorBackground(BaseBackground):
+ @dataclass
+ class Config(BaseBackground.Config):
+ n_output_dims: int = 3
+ color: Tuple = (1.0, 1.0, 1.0)
+ learned: bool = False
+ random_aug: bool = False
+ random_aug_prob: float = 0.5
+
+ cfg: Config
+
+ def configure(self) -> None:
+ self.env_color: Float[Tensor, "Nc"]
+ if self.cfg.learned:
+ self.env_color = nn.Parameter(
+ torch.as_tensor(self.cfg.color, dtype=torch.float32)
+ )
+ else:
+ self.register_buffer(
+ "env_color", torch.as_tensor(self.cfg.color, dtype=torch.float32)
+ )
+
+ def forward(
+ self,
+ dirs: Float[Tensor, "B H W Nc"],
+ color_spec: Optional[Float[Tensor, "Nc"]] = None,
+ ) -> Float[Tensor, "B H W Nc"]:
+ color = torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(dirs) * (
+ color_spec if color_spec is not None else self.env_color
+ )
+ if (
+ self.training
+ and self.cfg.random_aug
+ and random.random() < self.cfg.random_aug_prob
+ ):
+ # use random background color with probability random_aug_prob
+ # color = color * 0 + (
+ # torch.ones(*dirs.shape[:-1], self.cfg.n_output_dims).to(dirs) *
+ # torch.rand(self.cfg.n_output_dims).to(dirs)
+ # )
+ color = color * 0 + ( # prevent checking for unused parameters in DDP
+ torch.rand(dirs.shape[0], 1, 1, self.cfg.n_output_dims)
+ .to(dirs)
+ .expand(*dirs.shape[:-1], -1)
+ )
+ return color
diff --git a/3D_Stage/lrm/models/camera.py b/3D_Stage/lrm/models/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..62461c3aadf73d1cd3f138eda5cd494ffe48d40a
--- /dev/null
+++ b/3D_Stage/lrm/models/camera.py
@@ -0,0 +1,33 @@
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn as nn
+
+from ..utils.base import BaseModule
+from ..utils.typing import *
+
+
+class LinearCameraEmbedder(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ in_channels: int = 0
+ out_channels: int = 0
+ conditions: List[str] = field(default_factory=list)
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+ self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
+
+ def forward(self, **kwargs):
+ cond_tensors = []
+ for cond_name in self.cfg.conditions:
+ assert cond_name in kwargs
+ cond = kwargs[cond_name]
+ # cond in shape (B, Nv, ...)
+ cond_tensors.append(cond.view(*cond.shape[:2], -1))
+ cond_tensor = torch.cat(cond_tensors, dim=-1)
+ assert cond_tensor.shape[-1] == self.cfg.in_channels
+ embedding = self.linear(cond_tensor)
+ return embedding
diff --git a/3D_Stage/lrm/models/exporters/__init__.py b/3D_Stage/lrm/models/exporters/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3D_Stage/lrm/models/exporters/base.py b/3D_Stage/lrm/models/exporters/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3779ff65a1839ce7cba056d24159991743f6b89
--- /dev/null
+++ b/3D_Stage/lrm/models/exporters/base.py
@@ -0,0 +1,33 @@
+from dataclasses import dataclass
+
+import lrm
+from ..renderers.base import BaseRenderer
+from ...utils.base import BaseObject
+from ...utils.typing import *
+
+
+@dataclass
+class ExporterOutput:
+ save_name: str
+ save_type: str
+ params: Dict[str, Any]
+
+
+class Exporter(BaseObject):
+ @dataclass
+ class Config(BaseObject.Config):
+ save_video: bool = False
+
+ cfg: Config
+
+ def configure(self, renderer: BaseRenderer) -> None:
+ self.renderer = renderer
+
+ def __call__(self, *args, **kwargs) -> List[ExporterOutput]:
+ raise NotImplementedError
+
+
+class DummyExporter(Exporter):
+ def __call__(self, *args, **kwargs) -> List[ExporterOutput]:
+ # DummyExporter does not export anything
+ return []
diff --git a/3D_Stage/lrm/models/exporters/mesh_exporter.py b/3D_Stage/lrm/models/exporters/mesh_exporter.py
new file mode 100644
index 0000000000000000000000000000000000000000..236fa48aa2e067f67b1df77b8b9963312b042a25
--- /dev/null
+++ b/3D_Stage/lrm/models/exporters/mesh_exporter.py
@@ -0,0 +1,263 @@
+from dataclasses import dataclass, field
+import tempfile
+import os
+
+import cv2
+import numpy as np
+import torch
+
+import lrm
+from ..renderers.base import BaseRenderer
+from .base import Exporter, ExporterOutput
+from ..mesh import Mesh
+from ...utils.rasterize import NVDiffRasterizerContext
+from ...utils.typing import *
+from ...utils.misc import time_recorder as tr, time_recorder_enabled
+
+
+def uv_padding_cpu(image, hole_mask, padding):
+ uv_padding_size = padding
+ inpaint_image = (
+ cv2.inpaint(
+ (image.detach().cpu().numpy() * 255).astype(np.uint8),
+ (hole_mask.detach().cpu().numpy() * 255).astype(np.uint8),
+ uv_padding_size,
+ cv2.INPAINT_TELEA,
+ )
+ / 255.0
+ )
+ return torch.from_numpy(inpaint_image).to(image)
+
+
+def uv_padding_cvc(image, hole_mask, padding):
+ import cvcuda
+
+ torch_to_cvc = lambda x, layout: cvcuda.as_tensor(x, layout)
+ cvc_to_torch = lambda x: torch.as_tensor(x.cuda())
+
+ uv_padding_size = padding
+ image_cvc = torch_to_cvc((image.detach() * 255).to(torch.uint8), "HWC")
+ hole_mask_cvc = torch_to_cvc((hole_mask.detach() * 255).to(torch.uint8), "HW")
+ inpaint_image = cvcuda.inpaint(image_cvc, hole_mask_cvc, uv_padding_size)
+ inpaint_image = cvc_to_torch(inpaint_image) / 255.0
+ return inpaint_image.to(image)
+
+
+def uv_padding(image, hole_mask, padding):
+ try:
+ inpaint_image = uv_padding_cvc(image, hole_mask, padding)
+ except:
+ lrm.info(f"CVCUDA not available, fallback to CPU UV padding.")
+ inpaint_image = uv_padding_cpu(image, hole_mask, padding)
+ return inpaint_image
+
+
+class MeshExporter(Exporter):
+ @dataclass
+ class Config(Exporter.Config):
+ fmt: str = "obj" # in ['obj', 'glb']
+ visual: str = "uv" # in ['uv', 'vertex']
+ save_name: str = "model"
+ save_normal: bool = False
+ save_uv: bool = True
+ save_texture: bool = True
+ texture_size: int = 1024
+ texture_format: str = "jpg"
+ uv_unwrap_method: str = "xatlas"
+ xatlas_chart_options: dict = field(default_factory=dict)
+ xatlas_pack_options: dict = field(default_factory=dict)
+ smartuv_options: dict = field(default_factory=dict)
+ uv_padding_size: int = 2
+ subdivide: bool = False
+ post_process: bool = False
+ post_process_options: dict = field(default_factory=dict)
+ context_type: str = "gl"
+ output_path: str = "outputs"
+
+ cfg: Config
+
+ def configure(self, renderer: BaseRenderer) -> None:
+ super().configure(renderer)
+ self.ctx = NVDiffRasterizerContext(self.cfg.context_type, self.device)
+ if self.cfg.fmt == "obj-mtl":
+ lrm.warn(
+ f"fmt=obj-mtl is deprecated, please us fmt=obj and visual=uv instead."
+ )
+ self.cfg.fmt = "obj"
+ self.cfg.visual = "uv"
+
+ if self.cfg.fmt == "glb":
+ assert self.cfg.visual in [
+ "vertex",
+ "uv-blender",
+ ], "GLB format only supports visual=vertex and visual=uv-blender!"
+
+ def get_geometry(self, scene_code: torch.Tensor) -> Mesh:
+ tr.start("Surface extraction")
+ mesh: Mesh = self.renderer.isosurface(scene_code)
+ tr.end("Surface extraction")
+ return mesh
+
+ def get_texture_maps(
+ self, scene_code: torch.Tensor, mesh: Mesh
+ ) -> Dict[str, torch.Tensor]:
+ assert mesh.has_uv
+ # clip space transform
+ uv_clip = mesh.v_tex * 2.0 - 1.0
+ # pad to four component coordinate
+ uv_clip4 = torch.cat(
+ (
+ uv_clip,
+ torch.zeros_like(uv_clip[..., 0:1]),
+ torch.ones_like(uv_clip[..., 0:1]),
+ ),
+ dim=-1,
+ )
+ # rasterize
+ rast, _ = self.ctx.rasterize_one(
+ uv_clip4,
+ mesh.t_tex_idx,
+ (self.cfg.texture_size, self.cfg.texture_size),
+ )
+
+ hole_mask = ~(rast[:, :, 3] > 0)
+
+ # Interpolate world space position
+ gb_pos, _ = self.ctx.interpolate_one(
+ mesh.v_pos, rast[None, ...], mesh.t_pos_idx
+ )
+ gb_pos = gb_pos[0]
+
+ # Sample out textures from MLP
+ tr.start("Query color")
+ geo_out = self.renderer.query(scene_code, points=gb_pos)
+ tr.end("Query color")
+ mat_out = self.renderer.material.export(points=gb_pos, **geo_out)
+
+ textures = {}
+ tr.start("UV padding")
+ if "albedo" in mat_out:
+ textures["map_Kd"] = uv_padding(
+ mat_out["albedo"], hole_mask, self.cfg.uv_padding_size
+ )
+ else:
+ lrm.warn(
+ "save_texture is True but no albedo texture found, using default white texture"
+ )
+ if "metallic" in mat_out:
+ textures["map_Pm"] = uv_padding(
+ mat_out["metallic"], hole_mask, self.cfg.uv_padding_size
+ )
+ if "roughness" in mat_out:
+ textures["map_Pr"] = uv_padding(
+ mat_out["roughness"], hole_mask, self.cfg.uv_padding_size
+ )
+ if "bump" in mat_out:
+ textures["map_Bump"] = uv_padding(
+ mat_out["bump"], hole_mask, self.cfg.uv_padding_size
+ )
+ tr.end("UV padding")
+ return textures
+
+ def __call__(self, names, scene_codes) -> List[ExporterOutput]:
+ outputs = []
+ for name, scene_code in zip(names, scene_codes):
+ mesh = self.get_geometry(scene_code)
+ if self.cfg.post_process:
+ tr.start("Mesh post-processing")
+ mesh = mesh.post_process(self.cfg.post_process_options)
+ tr.end("Mesh post-processing")
+ if self.cfg.visual == "uv":
+ output = self.export_model_with_mtl(
+ name, self.cfg.fmt, scene_code, mesh
+ )
+ elif self.cfg.visual == "vertex":
+ output = self.export_model(name, self.cfg.fmt, scene_code, mesh)
+ elif self.cfg.visual == "uv-blender":
+ output = self.export_model_blender(name, self.cfg.fmt, scene_code, mesh)
+ else:
+ raise ValueError(f"Unsupported visual format: {self.cfg.visual}")
+ outputs.append(output)
+ return outputs
+
+ def export_model_with_mtl(
+ self, name: str, fmt: str, scene_code: torch.Tensor, mesh: Mesh
+ ) -> ExporterOutput:
+ params = {
+ "mesh": mesh,
+ "save_mat": True,
+ "save_normal": self.cfg.save_normal,
+ "save_uv": self.cfg.save_uv,
+ "save_vertex_color": False,
+ "map_Kd": None, # Base Color
+ "map_Ks": None, # Specular
+ "map_Bump": None, # Normal
+ # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
+ "map_Pm": None, # Metallic
+ "map_Pr": None, # Roughness
+ "map_format": self.cfg.texture_format,
+ }
+
+ if self.cfg.save_uv:
+ mesh.unwrap_uv(
+ self.cfg.uv_unwrap_method,
+ self.cfg.xatlas_chart_options,
+ self.cfg.xatlas_pack_options,
+ self.cfg.smartuv_options,
+ )
+
+ if self.cfg.save_texture:
+ lrm.info("Exporting textures ...")
+ assert self.cfg.save_uv, "save_uv must be True when save_texture is True"
+
+ with time_recorder_enabled():
+ textures = self.get_texture_maps(scene_code, mesh)
+ params.update(textures)
+ os.makedirs(self.cfg.output_path, exist_ok=True)
+ np.savez(f"{self.cfg.output_path}/tex_info.npz", v_tex=mesh.v_tex.cpu().numpy(), t_tex_idx=mesh.t_tex_idx.cpu().numpy())
+ return ExporterOutput(
+ save_name=f"{self.cfg.save_name}-{name}.{fmt}", save_type=fmt, params=params
+ )
+
+ def export_model(
+ self, name: str, fmt: str, scene_code, mesh: Mesh
+ ) -> ExporterOutput:
+ params = {
+ "mesh": mesh,
+ "save_mat": False,
+ "save_normal": self.cfg.save_normal,
+ "save_uv": self.cfg.save_uv,
+ "save_vertex_color": False,
+ "map_Kd": None, # Base Color
+ "map_Ks": None, # Specular
+ "map_Bump": None, # Normal
+ # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
+ "map_Pm": None, # Metallic
+ "map_Pr": None, # Roughness
+ "map_format": self.cfg.texture_format,
+ }
+
+ if self.cfg.save_uv:
+ mesh.unwrap_uv(
+ self.cfg.uv_unwrap_method,
+ self.cfg.xatlas_chart_options,
+ self.cfg.xatlas_pack_options,
+ self.cfg.smartuv_options,
+ )
+
+ if self.cfg.save_texture:
+ lrm.info("Exporting textures ...")
+ geo_out = self.renderer.query(scene_code, points=mesh.v_pos)
+ mat_out = self.renderer.material.export(points=mesh.v_pos, **geo_out)
+
+ if "albedo" in mat_out:
+ mesh.set_vertex_color(mat_out["albedo"])
+ params["save_vertex_color"] = True
+ else:
+ lrm.warn(
+ "save_texture is True but no albedo texture found, not saving vertex color"
+ )
+
+ return ExporterOutput(
+ save_name=f"{self.cfg.save_name}-{name}.{fmt}", save_type=fmt, params=params
+ )
diff --git a/3D_Stage/lrm/models/isosurface.py b/3D_Stage/lrm/models/isosurface.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ece3ac42cc573d057460ddcea4a61d0b9281c27
--- /dev/null
+++ b/3D_Stage/lrm/models/isosurface.py
@@ -0,0 +1,272 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import lrm
+from ..models.mesh import Mesh
+from ..utils.typing import *
+from ..utils.ops import scale_tensor
+
+
+class IsosurfaceHelper(nn.Module):
+ points_range: Tuple[float, float] = (0, 1)
+
+ @property
+ def grid_vertices(self) -> Float[Tensor, "N 3"]:
+ raise NotImplementedError
+
+
+class MarchingCubeCPUHelper(IsosurfaceHelper):
+ def __init__(self, resolution: int) -> None:
+ super().__init__()
+ self.resolution = resolution
+ import mcubes
+
+ self.mc_func: Callable = mcubes.marching_cubes
+ self._grid_vertices: Optional[Float[Tensor, "N3 3"]] = None
+ self._dummy: Float[Tensor, "..."]
+ self.register_buffer(
+ "_dummy", torch.zeros(0, dtype=torch.float32), persistent=False
+ )
+
+ @property
+ def grid_vertices(self) -> Float[Tensor, "N3 3"]:
+ if self._grid_vertices is None:
+ # keep the vertices on CPU so that we can support very large resolution
+ x, y, z = (
+ torch.linspace(*self.points_range, self.resolution),
+ torch.linspace(*self.points_range, self.resolution),
+ torch.linspace(*self.points_range, self.resolution),
+ )
+ x, y, z = torch.meshgrid(x, y, z, indexing="ij")
+ verts = torch.cat(
+ [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
+ ).reshape(-1, 3)
+ self._grid_vertices = verts
+ return self._grid_vertices
+
+ def forward(
+ self,
+ level: Float[Tensor, "N3 1"],
+ deformation: Optional[Float[Tensor, "N3 3"]] = None,
+ ) -> Mesh:
+ if deformation is not None:
+ lrm.warn(
+ f"{self.__class__.__name__} does not support deformation. Ignoring."
+ )
+ level = -level.view(self.resolution, self.resolution, self.resolution)
+ print(level.shape, level.min(), level.max())
+ v_pos, t_pos_idx = self.mc_func(
+ level.detach().cpu().numpy(), 0.0
+ ) # transform to numpy
+ # test
+ v_pos, t_pos_idx = (
+ torch.from_numpy(v_pos).float().to(self._dummy.device),
+ torch.from_numpy(t_pos_idx.astype(np.int64)).long().to(self._dummy.device),
+ ) # transform back to torch tensor on CUDA
+ v_pos = v_pos / (self.resolution - 1.0)
+ return Mesh(v_pos=v_pos, t_pos_idx=t_pos_idx)
+
+
+def get_center_boundary_index(verts):
+ # Assuming the verts are in range [-1.0, 1.0]
+ length_ = torch.linalg.norm(verts**2, ord=2, dim=-1, keepdim=False)
+ center_idx = torch.argmin(length_)
+ # center_idx = torch.where(length_ < 0.1)[0]
+ boundary_neg = verts == verts.max()
+ boundary_pos = verts == verts.min()
+ boundary = torch.bitwise_or(boundary_pos, boundary_neg)
+ boundary = torch.sum(boundary.float(), dim=-1)
+ boundary_idx = torch.nonzero(boundary)
+ return center_idx.unsqueeze(0), boundary_idx.squeeze(dim=-1)
+
+
+class MarchingTetrahedraHelper(IsosurfaceHelper):
+ def __init__(self, resolution: int, tets_path: str):
+ super().__init__()
+ self.resolution = resolution
+ self.tets_path = tets_path
+
+ self.triangle_table: Float[Tensor, "..."]
+ self.register_buffer(
+ "triangle_table",
+ torch.as_tensor(
+ [
+ [-1, -1, -1, -1, -1, -1],
+ [1, 0, 2, -1, -1, -1],
+ [4, 0, 3, -1, -1, -1],
+ [1, 4, 2, 1, 3, 4],
+ [3, 1, 5, -1, -1, -1],
+ [2, 3, 0, 2, 5, 3],
+ [1, 4, 0, 1, 5, 4],
+ [4, 2, 5, -1, -1, -1],
+ [4, 5, 2, -1, -1, -1],
+ [4, 1, 0, 4, 5, 1],
+ [3, 2, 0, 3, 5, 2],
+ [1, 3, 5, -1, -1, -1],
+ [4, 1, 2, 4, 3, 1],
+ [3, 0, 4, -1, -1, -1],
+ [2, 0, 1, -1, -1, -1],
+ [-1, -1, -1, -1, -1, -1],
+ ],
+ dtype=torch.long,
+ ),
+ persistent=False,
+ )
+ self.num_triangles_table: Integer[Tensor, "..."]
+ self.register_buffer(
+ "num_triangles_table",
+ torch.as_tensor(
+ [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
+ ),
+ persistent=False,
+ )
+ self.base_tet_edges: Integer[Tensor, "..."]
+ self.register_buffer(
+ "base_tet_edges",
+ torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
+ persistent=False,
+ )
+
+ tets = np.load(self.tets_path)
+ self._grid_vertices: Float[Tensor, "..."]
+ self.register_buffer(
+ "_grid_vertices",
+ torch.from_numpy(tets["vertices"]).float(),
+ persistent=False,
+ )
+ self.indices: Integer[Tensor, "..."]
+ self.register_buffer(
+ "indices", torch.from_numpy(tets["indices"]).long(), persistent=False
+ )
+
+ self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
+ self.center_indices, self.boundary_indices = get_center_boundary_index(
+ scale_tensor(self.grid_vertices, self.points_range, (-1.0, 1.0))
+ )
+
+ def normalize_grid_deformation(
+ self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
+ ) -> Float[Tensor, "Nv 3"]:
+ return (
+ (self.points_range[1] - self.points_range[0])
+ / (self.resolution) # half tet size is approximately 1 / self.resolution
+ * torch.tanh(grid_vertex_offsets)
+ ) # FIXME: hard-coded activation
+
+ @property
+ def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
+ return self._grid_vertices
+
+ @property
+ def all_edges(self) -> Integer[Tensor, "Ne 2"]:
+ if self._all_edges is None:
+ # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
+ edges = torch.tensor(
+ [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
+ dtype=torch.long,
+ device=self.indices.device,
+ )
+ _all_edges = self.indices[:, edges].reshape(-1, 2)
+ _all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
+ _all_edges = torch.unique(_all_edges_sorted, dim=0)
+ self._all_edges = _all_edges
+ return self._all_edges
+
+ def sort_edges(self, edges_ex2):
+ with torch.no_grad():
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
+ order = order.unsqueeze(dim=1)
+
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
+ b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
+
+ return torch.stack([a, b], -1)
+
+ def _forward(self, pos_nx3, sdf_n, tet_fx4):
+ with torch.no_grad():
+ occ_n = sdf_n > 0
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
+ occ_sum = torch.sum(occ_fx4, -1)
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
+ occ_sum = occ_sum[valid_tets]
+
+ # find all vertices
+ all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
+ all_edges = self.sort_edges(all_edges)
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
+
+ unique_edges = unique_edges.long()
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
+ mapping = (
+ torch.ones(
+ (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
+ )
+ * -1
+ )
+ mapping[mask_edges] = torch.arange(
+ mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
+ )
+ idx_map = mapping[idx_map] # map edges to verts
+
+ interp_v = unique_edges[mask_edges]
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
+ edges_to_interp_sdf[:, -1] *= -1
+
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
+
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
+
+ idx_map = idx_map.reshape(-1, 6)
+
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
+ num_triangles = self.num_triangles_table[tetindex]
+
+ # Generate triangle indices
+ faces = torch.cat(
+ (
+ torch.gather(
+ input=idx_map[num_triangles == 1],
+ dim=1,
+ index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
+ ).reshape(-1, 3),
+ torch.gather(
+ input=idx_map[num_triangles == 2],
+ dim=1,
+ index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
+ ).reshape(-1, 3),
+ ),
+ dim=0,
+ )
+
+ return verts, faces
+
+ def forward(
+ self,
+ level: Float[Tensor, "N3 1"],
+ deformation: Optional[Float[Tensor, "N3 3"]] = None,
+ ) -> Mesh:
+ if deformation is not None:
+ grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
+ deformation
+ )
+ else:
+ grid_vertices = self.grid_vertices
+
+ v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
+
+ mesh = Mesh(
+ v_pos=v_pos,
+ t_pos_idx=t_pos_idx,
+ # extras
+ grid_vertices=grid_vertices,
+ tet_edges=self.all_edges,
+ grid_level=level,
+ grid_deformation=deformation,
+ )
+
+ return mesh
diff --git a/3D_Stage/lrm/models/lpips.py b/3D_Stage/lrm/models/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..40c70e278b1c157b3cf4105b707c6679d0f7e1e6
--- /dev/null
+++ b/3D_Stage/lrm/models/lpips.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+from typing import Any
+import lpips
+
+from ..utils.ops import scale_tensor
+from ..utils.misc import get_device
+
+
+class LPIPS:
+ def __init__(self):
+ self.model = lpips.LPIPS(net="vgg").to(get_device())
+ self.model.eval()
+ for params in self.model.parameters():
+ params.requires_grad = False
+ self.model_input_range = (-1, 1)
+
+ def __call__(self, x1, x2, return_layers=False, input_range=(0, 1)):
+ x1 = scale_tensor(x1, input_range, self.model_input_range)
+ x2 = scale_tensor(x2, input_range, self.model_input_range)
+ return self.model(x1, x2, retPerLayer=return_layers, normalize=False)
diff --git a/3D_Stage/lrm/models/materials/__init__.py b/3D_Stage/lrm/models/materials/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3D_Stage/lrm/models/materials/base.py b/3D_Stage/lrm/models/materials/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..968e4e2ed771227f0db347f6858965a6e4e4f5c0
--- /dev/null
+++ b/3D_Stage/lrm/models/materials/base.py
@@ -0,0 +1,29 @@
+import random
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import lrm
+from ...utils.base import BaseModule
+from ...utils.typing import *
+
+
+class BaseMaterial(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ pass
+
+ cfg: Config
+ requires_normal: bool = False
+ requires_tangent: bool = False
+
+ def configure(self):
+ pass
+
+ def forward(self, *args, **kwargs) -> Float[Tensor, "*B 3"]:
+ raise NotImplementedError
+
+ def export(self, *args, **kwargs) -> Dict[str, Any]:
+ return {}
diff --git a/3D_Stage/lrm/models/materials/no_material.py b/3D_Stage/lrm/models/materials/no_material.py
new file mode 100644
index 0000000000000000000000000000000000000000..754fec75a42ff1835e9df5e2c9fcdefcc5099814
--- /dev/null
+++ b/3D_Stage/lrm/models/materials/no_material.py
@@ -0,0 +1,60 @@
+import random
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import lrm
+from .base import BaseMaterial
+from ..networks import get_encoding, get_mlp
+from ...utils.ops import dot, get_activation
+from ...utils.typing import *
+
+
+class NoMaterial(BaseMaterial):
+ @dataclass
+ class Config(BaseMaterial.Config):
+ n_output_dims: int = 3
+ color_activation: str = "sigmoid"
+ input_feature_dims: Optional[int] = None
+ mlp_network_config: Optional[dict] = None
+ requires_normal: bool = False
+
+ cfg: Config
+
+ def configure(self) -> None:
+ self.use_network = False
+ if (
+ self.cfg.input_feature_dims is not None
+ and self.cfg.mlp_network_config is not None
+ ):
+ self.network = get_mlp(
+ self.cfg.input_feature_dims,
+ self.cfg.n_output_dims,
+ self.cfg.mlp_network_config,
+ )
+ self.use_network = True
+ self.requires_normal = self.cfg.requires_normal
+
+ def forward(
+ self, features: Float[Tensor, "B ... Nf"], **kwargs
+ ) -> Float[Tensor, "B ... Nc"]:
+ if not self.use_network:
+ assert (
+ features.shape[-1] == self.cfg.n_output_dims
+ ), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input."
+ color = get_activation(self.cfg.color_activation)(features)
+ else:
+ color = self.network(features.view(-1, features.shape[-1])).view(
+ *features.shape[:-1], self.cfg.n_output_dims
+ )
+ color = get_activation(self.cfg.color_activation)(color)
+ return color
+
+ def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]:
+ color = self(features, **kwargs).clamp(0, 1)
+ assert color.shape[-1] >= 3, "Output color must have at least 3 channels"
+ if color.shape[-1] > 3:
+ lrm.warn("Output color has >3 channels, treating the first 3 as RGB")
+ return {"albedo": color[..., :3]}
diff --git a/3D_Stage/lrm/models/mesh.py b/3D_Stage/lrm/models/mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..9765a8131347b9abe316608aa7a5bbd509db7566
--- /dev/null
+++ b/3D_Stage/lrm/models/mesh.py
@@ -0,0 +1,471 @@
+from __future__ import annotations
+
+import functools
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+import lrm
+from ..utils.ops import dot
+from ..utils.typing import *
+from ..utils.misc import time_recorder as tr, time_recorder_enabled
+
+
+class Mesh:
+ def __init__(
+ self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
+ ) -> None:
+ self.v_pos: Float[Tensor, "Nv 3"] = v_pos
+ self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
+ self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
+ self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
+ self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
+ self._t_tex_idx: Optional[Float[Tensor, "Nf 3"]] = None
+ self._v_rgb: Optional[Float[Tensor, "Nv 3"]] = None
+ self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
+ self.extras: Dict[str, Any] = {}
+ for k, v in kwargs.items():
+ self.add_extra(k, v)
+
+ def add_extra(self, k, v) -> None:
+ self.extras[k] = v
+
+ def remove_outlier(self, outlier_n_faces_threshold: Union[int, float]) -> Mesh:
+ if self.requires_grad:
+ lrm.debug("Mesh is differentiable, not removing outliers")
+ return self
+
+ # use trimesh to first split the mesh into connected components
+ # then remove the components with less than n_face_threshold faces
+ import trimesh
+
+ # construct a trimesh object
+ mesh = trimesh.Trimesh(
+ vertices=self.v_pos.detach().cpu().numpy(),
+ faces=self.t_pos_idx.detach().cpu().numpy(),
+ )
+
+ # split the mesh into connected components
+ components = mesh.split(only_watertight=False)
+ # log the number of faces in each component
+ lrm.debug(
+ "Mesh has {} components, with faces: {}".format(
+ len(components), [c.faces.shape[0] for c in components]
+ )
+ )
+
+ n_faces_threshold: int
+ if isinstance(outlier_n_faces_threshold, float):
+ # set the threshold to the number of faces in the largest component multiplied by outlier_n_faces_threshold
+ n_faces_threshold = int(
+ max([c.faces.shape[0] for c in components]) * outlier_n_faces_threshold
+ )
+ else:
+ # set the threshold directly to outlier_n_faces_threshold
+ n_faces_threshold = outlier_n_faces_threshold
+
+ # log the threshold
+ lrm.debug(
+ "Removing components with less than {} faces".format(n_faces_threshold)
+ )
+
+ # remove the components with less than n_face_threshold faces
+ components = [c for c in components if c.faces.shape[0] >= n_faces_threshold]
+
+ # log the number of faces in each component after removing outliers
+ lrm.debug(
+ "Mesh has {} components after removing outliers, with faces: {}".format(
+ len(components), [c.faces.shape[0] for c in components]
+ )
+ )
+ # merge the components
+ mesh = trimesh.util.concatenate(components)
+
+ # convert back to our mesh format
+ v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos)
+ t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx)
+
+ clean_mesh = Mesh(v_pos, t_pos_idx)
+ # keep the extras unchanged
+
+ if len(self.extras) > 0:
+ clean_mesh.extras = self.extras
+ lrm.debug(
+ f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}"
+ )
+ return clean_mesh
+
+ def subdivide(self):
+ if self.requires_grad:
+ lrm.debug("Mesh is differentiable, not performing subdivision")
+ return self
+
+ import trimesh
+
+ mesh = trimesh.Trimesh(
+ vertices=self.v_pos.detach().cpu().numpy(),
+ faces=self.t_pos_idx.detach().cpu().numpy(),
+ )
+
+ mesh.subdivide_loop()
+
+ v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos)
+ t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx)
+
+ subdivided_mesh = Mesh(v_pos, t_pos_idx)
+
+ if len(self.extras) > 0:
+ subdivided_mesh.extras = self.extras
+ lrm.debug(
+ f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}"
+ )
+
+ return subdivided_mesh
+
+ def post_process(self, options):
+ if self.requires_grad:
+ lrm.debug("Mesh is differentiable, not performing post processing")
+ return self
+
+ from extern.mesh_process.MeshProcess import process_mesh
+
+ v_pos, t_pos_idx = process_mesh(
+ vertices=self.v_pos.detach().cpu().numpy(),
+ faces=self.t_pos_idx.detach().cpu().numpy(),
+ **options,
+ )
+
+ v_pos = torch.from_numpy(v_pos).to(self.v_pos).contiguous()
+ t_pos_idx = torch.from_numpy(t_pos_idx).to(self.t_pos_idx).contiguous()
+
+ processed_mesh = Mesh(v_pos, t_pos_idx)
+
+ if len(self.extras) > 0:
+ processed_mesh.extras = self.extras
+ lrm.debug(
+ f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}"
+ )
+
+ return processed_mesh
+
+ @property
+ def requires_grad(self):
+ return self.v_pos.requires_grad
+
+ @property
+ def v_nrm(self):
+ if self._v_nrm is None:
+ self._v_nrm = self._compute_vertex_normal()
+ return self._v_nrm
+
+ @property
+ def v_tng(self):
+ if self._v_tng is None:
+ self._v_tng = self._compute_vertex_tangent()
+ return self._v_tng
+
+ @property
+ def v_tex(self):
+ if self._v_tex is None:
+ self._v_tex, self._t_tex_idx = self._unwrap_uv()
+ return self._v_tex
+
+ @property
+ def t_tex_idx(self):
+ if self._t_tex_idx is None:
+ self._v_tex, self._t_tex_idx = self._unwrap_uv()
+ return self._t_tex_idx
+
+ @property
+ def v_rgb(self):
+ return self._v_rgb
+
+ @property
+ def edges(self):
+ if self._edges is None:
+ self._edges = self._compute_edges()
+ return self._edges
+
+ def _compute_vertex_normal(self):
+ i0 = self.t_pos_idx[:, 0]
+ i1 = self.t_pos_idx[:, 1]
+ i2 = self.t_pos_idx[:, 2]
+
+ v0 = self.v_pos[i0, :]
+ v1 = self.v_pos[i1, :]
+ v2 = self.v_pos[i2, :]
+
+ face_normals = torch.cross(v1 - v0, v2 - v0)
+
+ # Splat face normals to vertices
+ v_nrm = torch.zeros_like(self.v_pos)
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
+
+ # Normalize, replace zero (degenerated) normals with some default value
+ v_nrm = torch.where(
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
+ )
+ v_nrm = F.normalize(v_nrm, dim=1)
+
+ if torch.is_anomaly_enabled():
+ assert torch.all(torch.isfinite(v_nrm))
+
+ return v_nrm
+
+ def _compute_vertex_tangent(self):
+ vn_idx = [None] * 3
+ pos = [None] * 3
+ tex = [None] * 3
+ for i in range(0, 3):
+ pos[i] = self.v_pos[self.t_pos_idx[:, i]]
+ tex[i] = self.v_tex[self.t_tex_idx[:, i]]
+ # t_nrm_idx is always the same as t_pos_idx
+ vn_idx[i] = self.t_pos_idx[:, i]
+
+ tangents = torch.zeros_like(self.v_nrm)
+ tansum = torch.zeros_like(self.v_nrm)
+
+ # Compute tangent space for each triangle
+ uve1 = tex[1] - tex[0]
+ uve2 = tex[2] - tex[0]
+ pe1 = pos[1] - pos[0]
+ pe2 = pos[2] - pos[0]
+
+ nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]
+ denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]
+
+ # Avoid division by zero for degenerated texture coordinates
+ tang = nom / torch.where(
+ denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)
+ )
+
+ # Update all 3 vertices
+ for i in range(0, 3):
+ idx = vn_idx[i][:, None].repeat(1, 3)
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
+ tansum.scatter_add_(
+ 0, idx, torch.ones_like(tang)
+ ) # tansum[n_i] = tansum[n_i] + 1
+ tangents = tangents / tansum
+
+ # Normalize and make sure tangent is perpendicular to normal
+ tangents = F.normalize(tangents, dim=1)
+ tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
+
+ if torch.is_anomaly_enabled():
+ assert torch.all(torch.isfinite(tangents))
+
+ return tangents
+
+ def _unwrap_uv_open3d(
+ self
+ ):
+ import open3d as o3d
+ mesh = o3d.t.geometry.TriangleMesh()
+ mesh.vertex.positions = o3d.core.Tensor(self.v_pos.detach().cpu().numpy())
+ mesh.triangle.indices = o3d.core.Tensor(self.t_pos_idx.cpu().numpy())
+ mesh.compute_uvatlas(size=1024)
+ texture_uvs = torch.from_numpy(mesh.triangle.texture_uvs.numpy()).reshape(-1, 2).cuda()
+ indices = torch.arange(self.t_pos_idx.shape[0] * 3).reshape(-1, 3).to(torch.int64).cuda()
+ # Add a wood texture and visualize
+ return texture_uvs, indices
+
+ def _unwrap_uv_xatlas(
+ self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}
+ ):
+ lrm.info("Using xatlas to perform UV unwrapping, may take a while ...")
+
+ import xatlas
+
+ atlas = xatlas.Atlas()
+ atlas.add_mesh(
+ self.v_pos.detach().cpu().numpy(),
+ self.t_pos_idx.cpu().numpy(),
+ )
+ co = xatlas.ChartOptions()
+ po = xatlas.PackOptions()
+ for k, v in xatlas_chart_options.items():
+ setattr(co, k, v)
+ for k, v in xatlas_pack_options.items():
+ setattr(po, k, v)
+ atlas.generate(co, po)
+ vmapping, indices, uvs = atlas.get_mesh(0)
+ vmapping = (
+ torch.from_numpy(
+ vmapping.astype(np.uint64, casting="same_kind").view(np.int64)
+ )
+ .to(self.v_pos.device)
+ .long()
+ )
+ uvs = torch.from_numpy(uvs).to(self.v_pos.device).float()
+ indices = (
+ torch.from_numpy(
+ indices.astype(np.uint64, casting="same_kind").view(np.int64)
+ )
+ .to(self.v_pos.device)
+ .long()
+ )
+ return uvs, indices
+
+ def _unwrap_uv_smartuv(self, options: dict = {}):
+ from extern.mesh_process.MeshProcess import (
+ mesh_to_bpy,
+ get_uv_from_bpy,
+ bpy_context,
+ bpy_export,
+ )
+ from lrm.utils.misc import time_recorder as tr
+
+ v_pos, t_pos_idx = self.v_pos.cpu().numpy(), self.t_pos_idx.cpu().numpy()
+ with bpy_context():
+ mesh_bpy = mesh_to_bpy("_", v_pos, t_pos_idx)
+ v_tex = get_uv_from_bpy(mesh_bpy, **options).astype(np.float32)
+
+ assert v_tex.shape[0] == self.t_pos_idx.shape[0] * 3
+
+ t_tex_idx = torch.arange(
+ self.t_pos_idx.shape[0] * 3, device=self.t_pos_idx.device, dtype=torch.long
+ ).reshape(-1, 3)
+
+ """
+ # super efficient de-duplication
+ v_tex_u_uint32 = v_tex[..., 0].view(np.uint32)
+ v_tex_v_uint32 = v_tex[..., 1].view(np.uint32)
+ v_hashed = (v_tex_u_uint32.astype(np.uint64) << 32) | v_tex_v_uint32
+ v_hashed = torch.from_numpy(v_hashed.view(np.int64)).to(self.v_pos.device)
+
+ v_tex = torch.from_numpy(v_tex).to(
+ device=self.v_pos.device, dtype=torch.float32
+ )
+ t_pos_idx_f3 = torch.arange(
+ self.t_pos_idx.shape[0] * 3, device=self.t_pos_idx.device, dtype=torch.long
+ ).reshape(-1, 3)
+ v_pos_f3 = self.v_pos[self.t_pos_idx].reshape(-1, 3)
+
+ # super efficient de-duplication
+ v_hashed_dedup, inverse_indices = torch.unique(v_hashed, return_inverse=True)
+ dedup_size, full_size = v_hashed_dedup.shape[0], inverse_indices.shape[0]
+ indices = torch.scatter_reduce(
+ torch.full(
+ [dedup_size],
+ fill_value=full_size,
+ device=inverse_indices.device,
+ dtype=torch.long,
+ ),
+ index=inverse_indices,
+ src=torch.arange(
+ full_size, device=inverse_indices.device, dtype=torch.int64
+ ),
+ dim=0,
+ reduce="amin",
+ )
+ v_tex = v_tex[indices]
+ t_tex_idx = inverse_indices.reshape(-1, 3)
+
+ v_pos = v_pos_f3[indices]
+ t_pos_idx = inverse_indices[t_pos_idx_f3]
+ """
+
+ return self.v_pos, self.t_pos_idx, v_tex, t_tex_idx
+
+ def unwrap_uv(
+ self,
+ method: str,
+ xatlas_chart_options: dict = {},
+ xatlas_pack_options: dict = {},
+ smartuv_options: dict = {},
+ ):
+ if method == "xatlas":
+ with time_recorder_enabled():
+ tr.start("UV unwrapping xatlas")
+ self._v_tex, self._t_tex_idx = self._unwrap_uv_xatlas(
+ xatlas_chart_options, xatlas_pack_options
+ )
+ tr.end("UV unwrapping xatlas")
+ elif method == "open3d":
+ with time_recorder_enabled():
+ tr.start("UV unwrapping o3d")
+ self._v_tex, self._t_tex_idx = self._unwrap_uv_open3d()
+ tr.end("UV unwrapping o3d")
+ elif method == "smartuv":
+ with time_recorder_enabled():
+ tr.start("UV unwrapping smartuv")
+ (
+ self.v_pos,
+ self.t_pos_idx,
+ self._v_tex,
+ self._t_tex_idx,
+ ) = self._unwrap_uv_smartuv(smartuv_options)
+ tr.end("UV unwrapping smartuv")
+ else:
+ raise NotImplementedError
+
+ def set_vertex_color(self, v_rgb):
+ assert v_rgb.shape[0] == self.v_pos.shape[0]
+ self._v_rgb = v_rgb
+
+ def set_uv(self, v_tex, t_tex_idx):
+ self._v_tex = v_tex
+ self._t_tex_idx = t_tex_idx
+
+ @property
+ def has_uv(self):
+ return self._v_tex is not None and self._t_tex_idx is not None
+
+ def _compute_edges(self):
+ # Compute edges
+ edges = torch.cat(
+ [
+ self.t_pos_idx[:, [0, 1]],
+ self.t_pos_idx[:, [1, 2]],
+ self.t_pos_idx[:, [2, 0]],
+ ],
+ dim=0,
+ )
+ edges = edges.sort()[0]
+ edges = torch.unique(edges, dim=0)
+ return edges
+
+ def normal_consistency(self) -> Float[Tensor, ""]:
+ edge_nrm: Float[Tensor, "Ne 2 3"] = self.v_nrm[self.edges]
+ nc = (
+ 1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1)
+ ).mean()
+ return nc
+
+ def _laplacian_uniform(self):
+ # from stable-dreamfusion
+ # https://github.com/ashawkey/stable-dreamfusion/blob/8fb3613e9e4cd1ded1066b46e80ca801dfb9fd06/nerf/renderer.py#L224
+ verts, faces = self.v_pos, self.t_pos_idx
+
+ V = verts.shape[0]
+ F = faces.shape[0]
+
+ # Neighbor indices
+ ii = faces[:, [1, 2, 0]].flatten()
+ jj = faces[:, [2, 0, 1]].flatten()
+ adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(
+ dim=1
+ )
+ adj_values = torch.ones(adj.shape[1]).to(verts)
+
+ # Diagonal indices
+ diag_idx = adj[0]
+
+ # Build the sparse matrix
+ idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
+ values = torch.cat((-adj_values, adj_values))
+
+ # The coalesce operation sums the duplicate indices, resulting in the
+ # correct diagonal
+ return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
+
+ def laplacian(self) -> Float[Tensor, ""]:
+ with torch.no_grad():
+ L = self._laplacian_uniform()
+ loss = L.mm(self.v_pos)
+ loss = loss.norm(dim=1)
+ loss = loss.mean()
+ return loss
diff --git a/3D_Stage/lrm/models/networks.py b/3D_Stage/lrm/models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..44ffc4052bd089dc9190ac9b9b98b8ec58fba85d
--- /dev/null
+++ b/3D_Stage/lrm/models/networks.py
@@ -0,0 +1,390 @@
+from dataclasses import dataclass, field
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from ..utils.base import BaseModule
+from ..utils.ops import get_activation
+from ..utils.typing import *
+
+
+class TriplaneUpsampleNetwork(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ in_channels: int = 1024
+ out_channels: int = 80
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+ self.upsample = nn.ConvTranspose2d(
+ self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
+ )
+
+ def forward(
+ self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
+ ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
+ triplanes_up = rearrange(
+ self.upsample(
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
+ ),
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
+ Np=3,
+ )
+ return triplanes_up
+
+
+class MLP(nn.Module):
+ def __init__(
+ self,
+ dim_in: int,
+ dim_out: int,
+ n_neurons: int,
+ n_hidden_layers: int,
+ activation: str = "relu",
+ output_activation: Optional[str] = None,
+ bias: bool = True,
+ weight_init: Optional[str] = "kaiming_uniform",
+ bias_init: Optional[str] = None,
+ ):
+ super().__init__()
+ layers = [
+ self.make_linear(
+ dim_in,
+ n_neurons,
+ is_first=True,
+ is_last=False,
+ bias=bias,
+ weight_init=weight_init,
+ bias_init=bias_init,
+ ),
+ self.make_activation(activation),
+ ]
+ for i in range(n_hidden_layers - 1):
+ layers += [
+ self.make_linear(
+ n_neurons,
+ n_neurons,
+ is_first=False,
+ is_last=False,
+ bias=bias,
+ weight_init=weight_init,
+ bias_init=bias_init,
+ ),
+ self.make_activation(activation),
+ ]
+ layers += [
+ self.make_linear(
+ n_neurons,
+ dim_out,
+ is_first=False,
+ is_last=True,
+ bias=bias,
+ weight_init=weight_init,
+ bias_init=bias_init,
+ )
+ ]
+ self.layers = nn.Sequential(*layers)
+ self.output_activation = get_activation(output_activation)
+
+ def forward(self, x):
+ x = self.layers(x)
+ x = self.output_activation(x)
+ return x
+
+ def make_linear(
+ self,
+ dim_in,
+ dim_out,
+ is_first,
+ is_last,
+ bias=True,
+ weight_init=None,
+ bias_init=None,
+ ):
+ layer = nn.Linear(dim_in, dim_out, bias=bias)
+
+ if weight_init is None:
+ pass
+ elif weight_init == "kaiming_uniform":
+ torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
+ else:
+ raise NotImplementedError
+
+ if bias:
+ if bias_init is None:
+ pass
+ elif bias_init == "zero":
+ torch.nn.init.zeros_(layer.bias)
+ else:
+ raise NotImplementedError
+
+ return layer
+
+ def make_activation(self, activation):
+ if activation == "relu":
+ return nn.ReLU(inplace=True)
+ elif activation == "silu":
+ return nn.SiLU(inplace=True)
+ else:
+ raise NotImplementedError
+
+
+@dataclass
+class HeadSpec:
+ name: str
+ out_channels: int
+ n_hidden_layers: int
+ output_activation: Optional[str] = None
+
+
+class MultiHeadMLP(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ in_channels: int = 0
+ n_neurons: int = 0
+ n_hidden_layers_share: int = 0
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
+ activation: str = "relu"
+ bias: bool = True
+ weight_init: Optional[str] = "kaiming_uniform"
+ bias_init: Optional[str] = None
+ chunk_mode: Optional[str] = None
+ chunk_size: int = -1
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+ shared_layers = [
+ self.make_linear(
+ self.cfg.in_channels,
+ self.cfg.n_neurons,
+ bias=self.cfg.bias,
+ weight_init=self.cfg.weight_init,
+ bias_init=self.cfg.bias_init,
+ ),
+ self.make_activation(self.cfg.activation),
+ ]
+ for i in range(self.cfg.n_hidden_layers_share - 1):
+ shared_layers += [
+ self.make_linear(
+ self.cfg.n_neurons,
+ self.cfg.n_neurons,
+ bias=self.cfg.bias,
+ weight_init=self.cfg.weight_init,
+ bias_init=self.cfg.bias_init,
+ ),
+ self.make_activation(self.cfg.activation),
+ ]
+ self.shared_layers = nn.Sequential(*shared_layers)
+
+ assert len(self.cfg.heads) > 0
+ heads = {}
+ for head in self.cfg.heads:
+ head_layers = []
+ for i in range(head.n_hidden_layers):
+ head_layers += [
+ self.make_linear(
+ self.cfg.n_neurons,
+ self.cfg.n_neurons,
+ bias=self.cfg.bias,
+ weight_init=self.cfg.weight_init,
+ bias_init=self.cfg.bias_init,
+ ),
+ self.make_activation(self.cfg.activation),
+ ]
+ head_layers += [
+ self.make_linear(
+ self.cfg.n_neurons,
+ head.out_channels,
+ bias=self.cfg.bias,
+ weight_init=self.cfg.weight_init,
+ bias_init=self.cfg.bias_init,
+ ),
+ ]
+ heads[head.name] = nn.Sequential(*head_layers)
+ self.heads = nn.ModuleDict(heads)
+
+ if self.cfg.chunk_mode is not None:
+ assert self.cfg.chunk_size > 0
+
+ def make_linear(
+ self,
+ dim_in,
+ dim_out,
+ bias=True,
+ weight_init=None,
+ bias_init=None,
+ ):
+ layer = nn.Linear(dim_in, dim_out, bias=bias)
+
+ if weight_init is None:
+ pass
+ elif weight_init == "kaiming_uniform":
+ torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
+ else:
+ raise NotImplementedError
+
+ if bias:
+ if bias_init is None:
+ pass
+ elif bias_init == "zero":
+ torch.nn.init.zeros_(layer.bias)
+ else:
+ raise NotImplementedError
+
+ return layer
+
+ def make_activation(self, activation):
+ if activation == "relu":
+ return nn.ReLU(inplace=True)
+ elif activation == "silu":
+ return nn.SiLU(inplace=True)
+ else:
+ raise NotImplementedError
+
+ def forward(
+ self, x, include: Optional[List] = None, exclude: Optional[List] = None
+ ):
+ inp_shape = x.shape[:-1]
+ x = x.reshape(-1, x.shape[-1])
+
+ if self.cfg.chunk_mode is None:
+ shared_features = self.shared_layers(x)
+ elif self.cfg.chunk_mode == "deferred":
+ shared_features = DeferredFunc.apply(
+ self.shared_layers, x, self.cfg.chunk_size
+ )
+ elif self.cfg.chunk_mode == "checkpointing":
+ shared_features = apply_batch_checkpointing(
+ self.shared_layers, x, self.cfg.chunk_size
+ )
+ else:
+ raise NotImplementedError
+
+ shared_features = shared_features.reshape(*inp_shape, -1)
+
+ if include is not None and exclude is not None:
+ raise ValueError("Cannot specify both include and exclude.")
+ if include is not None:
+ heads = [h for h in self.cfg.heads if h.name in include]
+ elif exclude is not None:
+ heads = [h for h in self.cfg.heads if h.name not in exclude]
+ else:
+ heads = self.cfg.heads
+
+ out = {
+ head.name: get_activation(head.output_activation)(
+ self.heads[head.name](shared_features)
+ )
+ for head in heads
+ }
+ """
+ # TypeError
+ if self.cfg.chunk_mode is None:
+ out = {
+ head.name: get_activation(head.output_activation)(
+ self.heads[head.name](shared_features)
+ )
+ for head in heads
+ }
+ elif self.cfg.chunk_mode == "deferred":
+ out = {
+ head.name: get_activation(head.output_activation)(
+ DeferredFunc.apply(self.heads[head.name], shared_features, self.cfg.chunk_size)
+ )
+ for head in heads
+ }
+ else:
+ raise NotImplementedError
+ """
+ return out
+
+
+class DeferredFunc(torch.autograd.Function):
+ # Note that forward, setup_context, and backward are @staticmethods
+ @staticmethod
+ def forward(ctx, model, x, chunk_size):
+ model_copy = deepcopy(model)
+ model_copy.requires_grad_(False)
+
+ ret = []
+ x_split = torch.split(x, chunk_size, dim=0)
+
+ with torch.no_grad():
+ for cur_x in x_split:
+ ret.append(model_copy(cur_x))
+
+ ctx.model = model
+ ctx.save_for_backward(x.detach(), torch.as_tensor(chunk_size))
+
+ ret = torch.cat(ret, dim=0)
+
+ return ret
+
+ # This function has only a single output, so it gets only one gradient
+ @staticmethod
+ def backward(ctx, grad_output):
+ model = ctx.model
+ x, chunk_size = ctx.saved_tensors
+ chunk_size = chunk_size.item()
+
+ model_copy = deepcopy(model)
+
+ x_split = torch.split(x, chunk_size, dim=0)
+ grad_output_split = torch.split(grad_output, chunk_size, 0)
+ grad_input_split = []
+
+ with torch.set_grad_enabled(True):
+ model_copy.requires_grad_(True)
+ model_copy.zero_grad()
+ for cur_x, cur_grad_output in zip(x_split, grad_output_split):
+ cur_x.requires_grad_(True)
+ cur_y = model_copy(cur_x)
+ cur_y.backward(cur_grad_output)
+
+ grad_input_split.append(cur_x.grad.clone())
+
+ grad_input = torch.cat(grad_input_split, dim=0)
+
+ model_copy_params = list(model_copy.parameters())
+ model_params = list(model.parameters())
+
+ for param, param_copy in zip(model_params, model_copy_params):
+ if param.grad is None:
+ param.grad = param_copy.grad.clone()
+ else:
+ param.grad.add_(param_copy.grad)
+
+ return None, grad_input, None
+
+
+def apply_batch_checkpointing(func, x, chunk_size):
+ if chunk_size >= len(x):
+ # return func(x)
+ return torch.utils.checkpoint.checkpoint(func, x, use_reentrant=False)
+
+ x_split = torch.split(x, chunk_size, dim=0)
+
+ def cat_and_query(y_all, x):
+ return torch.cat([y_all, func(x)])
+
+ y_all = func(x_split[0])
+ for cur_x in x_split[1:]:
+ y_all = torch.utils.checkpoint.checkpoint(
+ cat_and_query, y_all, cur_x, use_reentrant=False
+ )
+
+ return y_all
+
+
+def get_encoding(n_input_dims: int, config) -> nn.Module:
+ raise NotImplementedError
+
+
+def get_mlp(n_input_dims, n_output_dims, config) -> nn.Module:
+ raise NotImplementedError
diff --git a/3D_Stage/lrm/models/renderers/__init__.py b/3D_Stage/lrm/models/renderers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3D_Stage/lrm/models/renderers/base.py b/3D_Stage/lrm/models/renderers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a733d00b8a8726cf23006a37f81df8781b7728e
--- /dev/null
+++ b/3D_Stage/lrm/models/renderers/base.py
@@ -0,0 +1,68 @@
+from dataclasses import dataclass
+
+import torch
+import torch.nn.functional as F
+
+import lrm
+from ..networks import MultiHeadMLP
+from ..background.base import BaseBackground
+from ..materials.base import BaseMaterial
+from ...utils.base import BaseModule
+from ...utils.typing import *
+
+
+class BaseRenderer(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ radius: float = 1.0
+
+ cfg: Config
+
+ def configure(
+ self,
+ decoder: MultiHeadMLP,
+ material: BaseMaterial,
+ background: BaseBackground,
+ ) -> None:
+ super().configure()
+
+ self.set_decoder(decoder)
+ self.set_material(material)
+ self.set_background(background)
+
+ # set up bounding box
+ self.bbox: Float[Tensor, "2 3"]
+ self.register_buffer(
+ "bbox",
+ torch.as_tensor(
+ [
+ [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
+ [self.cfg.radius, self.cfg.radius, self.cfg.radius],
+ ],
+ dtype=torch.float32,
+ ),
+ )
+
+ def forward(self, *args, **kwargs) -> Dict[str, Any]:
+ raise NotImplementedError
+
+ @property
+ def decoder(self) -> MultiHeadMLP:
+ return self.non_module("decoder")
+
+ @property
+ def material(self) -> BaseMaterial:
+ return self.non_module("material")
+
+ @property
+ def background(self) -> BaseBackground:
+ return self.non_module("background")
+
+ def set_decoder(self, decoder: MultiHeadMLP) -> None:
+ self.register_non_module("decoder", decoder)
+
+ def set_material(self, material: BaseMaterial) -> None:
+ self.register_non_module("material", material)
+
+ def set_background(self, background: BaseBackground) -> None:
+ self.register_non_module("background", background)
diff --git a/3D_Stage/lrm/models/renderers/triplane_dmtet.py b/3D_Stage/lrm/models/renderers/triplane_dmtet.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b898e3d584ef18307fb9ca2df0b015bc01f0430
--- /dev/null
+++ b/3D_Stage/lrm/models/renderers/triplane_dmtet.py
@@ -0,0 +1,369 @@
+import os
+from dataclasses import dataclass, field
+from collections import defaultdict
+from functools import partial
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, reduce
+
+import lrm
+from ..renderers.base import BaseRenderer
+from ..isosurface import MarchingTetrahedraHelper
+from ...utils.ops import (
+ get_activation,
+ scale_tensor,
+ dot,
+ chunk_batch
+)
+from ...utils.rasterize import NVDiffRasterizerContext
+from ..mesh import Mesh
+from ...utils.typing import *
+
+
+class TriplaneDMTetRenderer(BaseRenderer):
+ @dataclass
+ class Config(BaseRenderer.Config):
+ feature_reduction: str = "concat"
+ sdf_activation: Optional[str] = None
+ sdf_bias: Union[str, float] = 0.0
+ sdf_bias_params: Any = None
+ inside_out: bool = False
+
+ isosurface_resolution: int = 128
+ tet_dir: str = "load/tets/"
+ enable_isosurface_grid_deformation: bool = False
+ eval_chunk_size: int = 262144
+ context_type: str = "gl"
+
+ cfg: Config
+
+ def configure(self, *args, **kwargs) -> None:
+ super().configure(*args, **kwargs)
+
+ assert self.cfg.feature_reduction in ["concat", "mean"]
+
+ self.ctx = NVDiffRasterizerContext(self.cfg.context_type, self.device)
+ self.isosurface_helper = MarchingTetrahedraHelper(
+ self.cfg.isosurface_resolution,
+ os.path.join(self.cfg.tet_dir, f"{self.cfg.isosurface_resolution}_tets.npz"),
+ ).to(self.device)
+
+ def query_triplane(
+ self,
+ positions: Float[Tensor, "*B N 3"],
+ triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
+ ) -> Dict[str, Tensor]:
+ batched = positions.ndim == 3
+ if not batched:
+ # no batch dimension
+ triplanes = triplanes[None, ...]
+ positions = positions[None, ...]
+ # import ipdb
+ # ipdb.set_trace()
+ assert triplanes.ndim == 5 and positions.ndim == 3
+
+ # assume positions in [-1, 1]
+ # normalized to (-1, 1) for grid sample
+ positions = scale_tensor(
+ positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
+ )
+
+ indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
+ (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
+ dim=-3,
+ )
+ out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
+ rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3),
+ rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3),
+ align_corners=False,
+ mode="bilinear",
+ )
+ if self.cfg.feature_reduction == "concat":
+ out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
+ elif self.cfg.feature_reduction == "mean":
+ out = reduce(out, "(B Np) Cp () N -> B N Cp", Np=3, reduction="mean")
+ else:
+ raise NotImplementedError
+
+ net_out: Dict[str, Float[Tensor, "B N ..."]] = self.decoder(out)
+ assert "sdf" in net_out
+ net_out["sdf"] = get_activation(self.cfg.sdf_activation)(
+ self.get_shifted_sdf(positions, net_out["sdf"])
+ )
+
+ if not batched:
+ net_out = {k: v.squeeze(0) for k, v in net_out.items()}
+
+ return net_out
+
+ def get_shifted_sdf(
+ self, points: Float[Tensor, "*N Di"], sdf: Float[Tensor, "*N 1"]
+ ) -> Float[Tensor, "*N 1"]:
+ sdf_bias: Union[float, Float[Tensor, "*N 1"]]
+ if self.cfg.sdf_bias == "ellipsoid":
+ assert (
+ isinstance(self.cfg.sdf_bias_params, Sized)
+ and len(self.cfg.sdf_bias_params) == 3
+ )
+ size = torch.as_tensor(self.cfg.sdf_bias_params).to(points)
+ sdf_bias = ((points / size) ** 2).sum(
+ dim=-1, keepdim=True
+ ).sqrt() - 1.0 # pseudo signed distance of an ellipsoid
+ elif self.cfg.sdf_bias == "sphere":
+ assert isinstance(self.cfg.sdf_bias_params, float)
+ radius = self.cfg.sdf_bias_params
+ sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius
+ elif isinstance(self.cfg.sdf_bias, float):
+ sdf_bias = self.cfg.sdf_bias
+ else:
+ raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}")
+ return sdf + sdf_bias
+
+ def forward_single(
+ self,
+ triplane: Float[Tensor, "3 Cp Hp Wp"],
+ mvp_mtx: Float[Tensor, "Nv 4 4"],
+ camera_positions: Float[Tensor, "Nv 3"],
+ height: int,
+ width: int,
+ background_color: Optional[Float[Tensor, "3"]],
+ extra_sdf_query: Any = None,
+ ) -> Dict[str, Tensor]:
+ Nv = mvp_mtx.shape[0]
+
+ out = {}
+
+ query_vertices = []
+ query_sizes = []
+
+ grid_vertices = scale_tensor(
+ self.isosurface_helper.grid_vertices,
+ self.isosurface_helper.points_range,
+ self.bbox,
+ )
+
+ query_vertices.append(grid_vertices)
+ query_sizes.append(len(grid_vertices))
+
+ if extra_sdf_query is not None:
+ query_vertices.append(extra_sdf_query)
+ query_sizes.append(len(extra_sdf_query))
+
+ query_vertices = torch.cat(query_vertices, dim=0)
+ triplane_out = self.query_triplane(query_vertices, triplane)
+
+ all_sdf = triplane_out["sdf"]
+ if extra_sdf_query is not None:
+ sdf, sdf_ex_query = torch.split(all_sdf, query_sizes)
+ else:
+ sdf, sdf_ex_query = all_sdf, None
+
+ out.update({"sdf_ex_query": sdf_ex_query})
+
+ if self.cfg.enable_isosurface_grid_deformation:
+ all_deformation = triplane_out["deformation"]
+ if extra_sdf_query is not None:
+ deformation, _ = torch.split(all_deformation, query_sizes)
+ else:
+ deformation, _ = all_deformation, None
+ else:
+ deformation = None
+
+ # Fix some sdf if we observe empty shape (full positive or full negative)
+ pos_shape = torch.sum((sdf.squeeze(dim=-1) > 0).int(), dim=-1)
+ neg_shape = torch.sum((sdf.squeeze(dim=-1) < 0).int(), dim=-1)
+ zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0)
+ if torch.sum(zero_surface).item() > 0:
+ lrm.warn("Empty mesh! Fixing by adding fake faces.")
+ sdf = torch.nan_to_num(sdf, nan=0.0, posinf=1.0, neginf=-1.0)
+ update_sdf = torch.zeros_like(sdf)
+ max_sdf = sdf.max()
+ min_sdf = sdf.min()
+ update_sdf[self.isosurface_helper.center_indices] += (
+ -1.0 - max_sdf
+ ) # greater than zero
+ update_sdf[self.isosurface_helper.boundary_indices] += (
+ 1.0 - min_sdf
+ ) # smaller than zero
+ new_sdf = sdf.clone().detach()
+ if zero_surface:
+ new_sdf += update_sdf
+ update_mask = (update_sdf == 0).float()
+ # Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative)
+ sdf_reg_loss = sdf.abs().mean()
+ sdf_reg_loss = sdf_reg_loss * zero_surface.float()
+ sdf = sdf * update_mask + new_sdf * (1 - update_mask)
+ lrm.debug(
+ "max sdf: {}, min sdf: {}".format(sdf.max().item(), sdf.min().item())
+ )
+ out.update({"sdf_reg": sdf_reg_loss})
+
+ # Here we remove the gradient for the bad sdf (full positive or full negative)
+ if zero_surface:
+ sdf = sdf.detach()
+
+ mesh: Mesh = self.isosurface_helper(sdf, deformation=deformation)
+
+ mesh.v_pos = scale_tensor(
+ mesh.v_pos, self.isosurface_helper.points_range, self.bbox
+ ) # scale to bbox as the grid vertices are in [0, 1]
+ # import ipdb
+ # ipdb.set_trace()
+ v_pos_clip: Float[Tensor, "Nv V 4"] = self.ctx.vertex_transform(
+ mesh.v_pos, mvp_mtx
+ )
+ rast, _ = self.ctx.rasterize(v_pos_clip, mesh.t_pos_idx, (height, width))
+ mask = rast[..., 3:] > 0
+ mask_aa = self.ctx.antialias(mask.float(), rast, v_pos_clip, mesh.t_pos_idx)
+
+ out.update({"opacity": mask_aa, "mesh": mesh})
+
+ gb_normal, _ = self.ctx.interpolate_one(mesh.v_nrm, rast, mesh.t_pos_idx)
+ gb_normal = F.normalize(gb_normal, dim=-1)
+ gb_normal_aa = torch.lerp(
+ torch.zeros_like(gb_normal), (gb_normal + 1.0) / 2.0, mask.float()
+ )
+ gb_normal_aa = self.ctx.antialias(
+ gb_normal_aa, rast, v_pos_clip, mesh.t_pos_idx
+ )
+ out.update({"comp_normal": gb_normal_aa}) # in [0, 1]
+
+ gb_pos, _ = self.ctx.interpolate_one(mesh.v_pos, rast, mesh.t_pos_idx)
+
+ # FIXME: this depth corresponds to the one provided in the dataset, but assumes camera looking at scene center
+ gb_depth = dot(
+ gb_pos - camera_positions[:, None, None, :],
+ F.normalize(-camera_positions[:, None, None, :], dim=-1),
+ )
+
+ gb_depth = torch.lerp(torch.zeros_like(gb_depth), gb_depth, mask.float())
+ out.update({"depth": gb_depth})
+
+ gb_viewdirs = F.normalize(gb_pos - camera_positions[:, None, None, :], dim=-1)
+ gb_rgb_fg = torch.zeros(
+ (Nv, height, width, 3), device=self.device, dtype=torch.float32
+ )
+ gb_rgb_bg = self.background(dirs=gb_viewdirs, color_spec=background_color)
+
+ selector = mask[..., 0]
+ if selector.sum() > 0:
+ positions = gb_pos[selector]
+ geo_out = self.query_triplane(positions, triplane)
+
+ extra_geo_info = {}
+ if self.material.requires_normal:
+ extra_geo_info["shading_normal"] = gb_normal[selector]
+ if self.material.requires_tangent:
+ gb_tangent, _ = self.ctx.interpolate_one(
+ mesh.v_tng, rast, mesh.t_pos_idx
+ )
+ gb_tangent = F.normalize(gb_tangent, dim=-1)
+ extra_geo_info["tangent"] = gb_tangent[selector]
+
+ rgb_fg = self.material(
+ viewdirs=gb_viewdirs[selector],
+ positions=positions,
+ **extra_geo_info,
+ **geo_out,
+ )
+
+ gb_rgb_fg[selector] = rgb_fg.to(
+ gb_rgb_fg.dtype
+ ) # TODO: don't know if this is correct
+
+ gb_rgb = torch.lerp(gb_rgb_bg, gb_rgb_fg, mask.float())
+ gb_rgb_aa = self.ctx.antialias(gb_rgb, rast, v_pos_clip, mesh.t_pos_idx)
+
+ out.update(
+ {"comp_rgb": gb_rgb_aa, "comp_rgb_fg": gb_rgb_fg, "comp_rgb_bg": gb_rgb_bg}
+ )
+
+ return out
+
+ def forward(
+ self,
+ triplanes: Float[Tensor, "B 3 Cp Hp Wp"],
+ mvp_mtx: Float[Tensor, "B Nv 4 4"],
+ camera_positions: Float[Tensor, "B Nv 3"],
+ height: int,
+ width: int,
+ background_color: Optional[Float[Tensor, "B 3"]] = None,
+ extra_sdf_query: Optional[List[Float[Tensor, "N 3"]]] = None,
+ **kwargs,
+ ) -> Dict[str, Tensor]:
+ batch_size = triplanes.shape[0]
+ out_list = []
+ for b in range(batch_size):
+ out_list.append(
+ self.forward_single(
+ triplanes[b],
+ mvp_mtx[b],
+ camera_positions[b],
+ height,
+ width,
+ background_color=background_color[b]
+ if background_color is not None
+ else None,
+ extra_sdf_query=extra_sdf_query[b]
+ if extra_sdf_query is not None
+ else None,
+ )
+ )
+
+ out = defaultdict(list)
+ for out_ in out_list:
+ for k, v in out_.items():
+ out[k].append(v)
+
+ for k, v in out.items():
+ # some properties cannot be batched
+ if isinstance(v[0], torch.Tensor) and (
+ all([vv.ndim == 0 for vv in v])
+ or all([vv.shape[0] == v[0].shape[0] for vv in v])
+ ):
+ out[k] = torch.stack(v, dim=0)
+ else:
+ out[k] = v
+
+ return out
+
+ def isosurface(self, triplane: Float[Tensor, "3 Cp Hp Wp"]):
+ grid_vertices = scale_tensor(
+ self.isosurface_helper.grid_vertices,
+ self.isosurface_helper.points_range,
+ self.bbox,
+ )
+ triplane_out = chunk_batch(
+ partial(self.query_triplane, triplanes=triplane), self.cfg.eval_chunk_size, grid_vertices,
+ )
+
+ sdf = triplane_out["sdf"]
+
+ if self.cfg.inside_out:
+ sdf = -sdf
+
+ if self.cfg.enable_isosurface_grid_deformation:
+ deformation = triplane_out["deformation"]
+ else:
+ deformation = None
+
+ mesh: Mesh = self.isosurface_helper(sdf, deformation=deformation)
+
+ mesh.v_pos = scale_tensor(
+ mesh.v_pos, self.isosurface_helper.points_range, self.bbox
+ )
+
+ return mesh
+
+ def query(
+ self, triplane: Float[Tensor, "3 Cp Hp Wp"], points: Float[Tensor, "*N 3"]
+ ):
+ input_shape = points.shape[:-1]
+ triplane_out = chunk_batch(
+ partial(self.query_triplane, triplanes=triplane), self.cfg.eval_chunk_size, points.view(-1, 3)
+ )
+ triplane_out = {
+ k: v.view(*input_shape, v.shape[-1]) for k, v in triplane_out.items()
+ }
+ return triplane_out
diff --git a/3D_Stage/lrm/models/tokenizers/__init__.py b/3D_Stage/lrm/models/tokenizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3D_Stage/lrm/models/tokenizers/dinov2.py b/3D_Stage/lrm/models/tokenizers/dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4d7eae6382313cb06ed423c09f0537c5bd446fc
--- /dev/null
+++ b/3D_Stage/lrm/models/tokenizers/dinov2.py
@@ -0,0 +1,1223 @@
+# coding=utf-8
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch DINOv2 model."""
+
+
+import collections.abc
+import math
+from typing import Dict, List, Optional, Set, Tuple, Union
+from dataclasses import dataclass
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+import torch.nn.functional as F
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (
+ BackboneOutput,
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ ImageClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import (
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from transformers.utils.backbone_utils import BackboneMixin
+from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
+import xformers
+
+from ..transformers.attention import MemoryEfficientAttentionMixin
+from ...utils.typing import *
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "Dinov2Config"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
+_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
+
+
+DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/dinov2-base",
+ # See all DINOv2 models at https://huggingface.co/models?filter=dinov2
+]
+
+
+class Dinov2Embeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, position and patch embeddings.
+ """
+
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ # register as mask token as it's not used in optimization
+ # to avoid the use of find_unused_parameters_true
+ # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
+ self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
+ self.patch_embeddings = Dinov2PatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(
+ torch.randn(1, num_patches + 1, config.hidden_size)
+ )
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def interpolate_pos_encoding(
+ self, embeddings: torch.Tensor, height: int, width: int
+ ) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+ if num_patches == num_positions and height == width:
+ return self.position_embeddings
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+ height = height // self.config.patch_size
+ width = width // self.config.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ height, width = height + 0.1, width + 0.1
+ patch_pos_embed = patch_pos_embed.reshape(
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ scale_factor=(
+ height / math.sqrt(num_positions),
+ width / math.sqrt(num_positions),
+ ),
+ mode="bicubic",
+ align_corners=False,
+ )
+ if (
+ int(height) != patch_pos_embed.shape[-2]
+ or int(width) != patch_pos_embed.shape[-1]
+ ):
+ raise ValueError(
+ "Width or height does not match with the interpolated position embeddings"
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ patch_embeddings = self.patch_embeddings(pixel_values)
+ embeddings = patch_embeddings
+
+ if bool_masked_pos is not None:
+ embeddings = torch.where(
+ bool_masked_pos.unsqueeze(-1),
+ self.mask_token.to(embeddings.dtype).unsqueeze(0),
+ embeddings,
+ )
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(
+ embeddings, height, width
+ )
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class Dinov2PatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = (
+ image_size
+ if isinstance(image_size, collections.abc.Iterable)
+ else (image_size, image_size)
+ )
+ patch_size = (
+ patch_size
+ if isinstance(patch_size, collections.abc.Iterable)
+ else (patch_size, patch_size)
+ )
+ num_patches = (image_size[1] // patch_size[1]) * (
+ image_size[0] // patch_size[0]
+ )
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
+ )
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ """
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ f" Expected {self.num_channels} but got {num_channels}."
+ )
+ """
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
+class Dinov2SelfAttention(nn.Module):
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
+ config, "embedding_size"
+ ):
+ raise ValueError(
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
+
+ self.query = nn.Linear(
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
+ )
+ self.key = nn.Linear(
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
+ )
+ self.value = nn.Linear(
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
+ )
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ self.use_memory_efficient_attention_xformers: bool = False
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ if self.use_memory_efficient_attention_xformers:
+ assert head_mask is None and not output_attentions
+ new_size = hidden_states.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ key_layer = self.key(hidden_states).view(new_size)
+ value_layer = self.value(hidden_states).view(new_size)
+ query_layer = mixed_query_layer.view(new_size)
+ context_layer = xformers.ops.memory_efficient_attention(
+ query_layer, key_layer, value_layer, p=self.attention_probs_dropout_prob
+ )
+ context_layer = context_layer.view(*hidden_states.size()[:-1], -1)
+ elif hasattr(F, "scaled_dot_product_attention"):
+ assert head_mask is None and not output_attentions
+ new_size = hidden_states.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2)
+ value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2)
+ query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2)
+ context_layer = F.scaled_dot_product_attention(
+ query_layer,
+ key_layer,
+ value_layer,
+ dropout_p=self.attention_probs_dropout_prob,
+ is_causal=False,
+ )
+ context_layer = context_layer.transpose(1, 2).reshape(
+ *hidden_states.size()[:-1], -1
+ )
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
+ )
+
+ return outputs
+
+ def set_use_memory_efficient_attention_xformers(
+ self, valid: bool, attention_op: Optional[Callable] = None
+ ):
+ self.use_memory_efficient_attention_xformers = valid
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
+class Dinov2SelfOutput(nn.Module):
+ """
+ The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
+ ) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
+class Dinov2Attention(nn.Module):
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+ self.attention = Dinov2SelfAttention(config)
+ self.output = Dinov2SelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads,
+ self.attention.num_attention_heads,
+ self.attention.attention_head_size,
+ self.pruned_heads,
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(
+ heads
+ )
+ self.attention.all_head_size = (
+ self.attention.attention_head_size * self.attention.num_attention_heads
+ )
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[
+ 1:
+ ] # add attentions if we output them
+ return outputs
+
+
+class Dinov2LayerScale(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.lambda1 = nn.Parameter(
+ config.layerscale_value * torch.ones(config.hidden_size)
+ )
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ return hidden_state * self.lambda1
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
+) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (
+ input.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(
+ shape, dtype=input.dtype, device=input.device
+ )
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath
+class Dinov2DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class Dinov2MLP(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
+ if isinstance(config.hidden_act, str):
+ self.activation = ACT2FN[config.hidden_act]
+ else:
+ self.activation = config.hidden_act
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.fc2(hidden_state)
+ return hidden_state
+
+
+class Dinov2SwiGLUFFN(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.weights_in(hidden_state)
+ x1, x2 = hidden_state.chunk(2, dim=-1)
+ hidden = nn.functional.silu(x1) * x2
+ return self.weights_out(hidden)
+
+
+class Dinov2Layer(nn.Module, MemoryEfficientAttentionMixin):
+ """This corresponds to the Block class in the original implementation."""
+
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.norm1_modulation = None
+ self.attention = Dinov2Attention(config)
+ self.layer_scale1 = Dinov2LayerScale(config)
+ self.drop_path1 = (
+ Dinov2DropPath(config.drop_path_rate)
+ if config.drop_path_rate > 0.0
+ else nn.Identity()
+ )
+
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.norm2_modulation = None
+
+ if config.use_swiglu_ffn:
+ self.mlp = Dinov2SwiGLUFFN(config)
+ else:
+ self.mlp = Dinov2MLP(config)
+ self.layer_scale2 = Dinov2LayerScale(config)
+ self.drop_path2 = (
+ Dinov2DropPath(config.drop_path_rate)
+ if config.drop_path_rate > 0.0
+ else nn.Identity()
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ modulation_cond: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ hidden_states_norm = self.norm1(hidden_states)
+ if self.norm1_modulation is not None:
+ assert modulation_cond is not None
+ hidden_states_norm = self.norm1_modulation(
+ hidden_states_norm, modulation_cond
+ )
+ self_attention_outputs = self.attention(
+ hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+
+ attention_output = self.layer_scale1(attention_output)
+ outputs = self_attention_outputs[
+ 1:
+ ] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in Dinov2, layernorm is also applied after self-attention
+ layer_output = self.norm2(hidden_states)
+ if self.norm2_modulation is not None:
+ assert modulation_cond is not None
+ layer_output = self.norm2_modulation(layer_output, modulation_cond)
+ layer_output = self.mlp(layer_output)
+ layer_output = self.layer_scale2(layer_output)
+
+ # second residual connection
+ layer_output = layer_output + hidden_states
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
+ self.norm1_modulation = norm1_mod
+ self.norm2_modulation = norm2_mod
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
+class Dinov2Encoder(nn.Module, MemoryEfficientAttentionMixin):
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList(
+ [Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ modulation_cond: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ layer_head_mask,
+ modulation_cond,
+ use_reentrant=False,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states, layer_head_mask, modulation_cond, output_attentions
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
+ if v is not None
+ )
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class Dinov2PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Dinov2Config
+ base_model_prefix = "dinov2"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Dinov2Embeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+ module.cls_token.data = nn.init.trunc_normal_(
+ module.cls_token.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.cls_token.dtype)
+
+ def _set_gradient_checkpointing(
+ self, module: Dinov2Encoder, value: bool = False
+ ) -> None:
+ if isinstance(module, Dinov2Encoder):
+ module.gradient_checkpointing = value
+
+
+DINOV2_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DINOV2_BASE_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`BitImageProcessor.preprocess`] for details.
+
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
+ pre-training.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+DINOV2_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`BitImageProcessor.preprocess`] for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@dataclass
+class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
+ patch_embeddings: Optional[torch.FloatTensor] = None
+
+
+@add_start_docstrings(
+ "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
+ DINOV2_START_DOCSTRING,
+)
+class Dinov2Model(Dinov2PreTrainedModel, MemoryEfficientAttentionMixin):
+ def __init__(self, config: Dinov2Config):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = Dinov2Embeddings(config)
+ self.encoder = Dinov2Encoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def expand_input_channels(self, extra_input_channels: int) -> None:
+ if extra_input_channels == 0:
+ return
+ conv_old = self.embeddings.patch_embeddings.projection
+ conv_new = nn.Conv2d(
+ self.config.num_channels + extra_input_channels,
+ self.config.hidden_size,
+ kernel_size=self.config.patch_size,
+ stride=self.config.patch_size,
+ ).to(self.device)
+ with torch.no_grad():
+ conv_new.weight[:, :3] = conv_old.weight
+ conv_new.bias = conv_old.bias
+ self.embeddings.patch_embeddings.projection = conv_new
+ del conv_old
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ modulation_cond: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ modulation_cond=modulation_cond,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = sequence_output[:, 0, :]
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output)
+ return head_outputs + encoder_outputs[1:]
+
+ return CustomBaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ patch_embeddings=embedding_output,
+ )
+
+ def set_gradient_checkpointing(self, value: bool = False) -> None:
+ self._set_gradient_checkpointing(self.encoder, value)
+
+
+@add_start_docstrings(
+ """
+ Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
+ of the [CLS] token) e.g. for ImageNet.
+ """,
+ DINOV2_START_DOCSTRING,
+)
+class Dinov2ForImageClassification(Dinov2PreTrainedModel):
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.dinov2 = Dinov2Model(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.hidden_size * 2, config.num_labels)
+ if config.num_labels > 0
+ else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ outputs = self.dinov2(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
+
+ cls_token = sequence_output[:, 0]
+ patch_tokens = sequence_output[:, 1:]
+
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
+
+ logits = self.classifier(linear_input)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (
+ labels.dtype == torch.long or labels.dtype == torch.int
+ ):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
+ """,
+ DINOV2_START_DOCSTRING,
+)
+class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.num_features = [
+ config.hidden_size for _ in range(config.num_hidden_layers + 1)
+ ]
+ self.embeddings = Dinov2Embeddings(config)
+ self.encoder = Dinov2Encoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BackboneOutput:
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 768, 16, 16]
+ ```"""
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+
+ embedding_output = self.embeddings(pixel_values)
+
+ outputs = self.encoder(
+ embedding_output,
+ output_hidden_states=True,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ if self.config.apply_layernorm:
+ hidden_state = self.layernorm(hidden_state)
+ if self.config.reshape_hidden_states:
+ batch_size, _, height, width = pixel_values.shape
+ patch_size = self.config.patch_size
+ hidden_state = hidden_state[:, 1:, :].reshape(
+ batch_size, width // patch_size, height // patch_size, -1
+ )
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps += (hidden_state,)
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (feature_maps,) + outputs[1:]
+ else:
+ output = (feature_maps,) + outputs[2:]
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions if output_attentions else None,
+ )
+
+
+class CustomPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
+ ):
+ super().__init__()
+
+ image_size = (
+ image_size
+ if isinstance(image_size, collections.abc.Iterable)
+ else (image_size, image_size)
+ )
+ patch_size = (
+ patch_size
+ if isinstance(patch_size, collections.abc.Iterable)
+ else (patch_size, patch_size)
+ )
+ num_patches = (image_size[1] // patch_size[1]) * (
+ image_size[0] // patch_size[0]
+ )
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
+ )
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ f" Expected {self.num_channels} but got {num_channels}."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+class CustomEmbeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, position and patch embeddings.
+ """
+
+ def __init__(
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
+ ) -> None:
+ super().__init__()
+
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.hidden_size = hidden_size
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
+
+ self.patch_embeddings = CustomPatchEmbeddings(
+ image_size, patch_size, num_channels, hidden_size
+ )
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(
+ torch.randn(1, num_patches + 1, self.hidden_size)
+ )
+
+ def interpolate_pos_encoding(
+ self, embeddings: torch.Tensor, height: int, width: int
+ ) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images.
+
+ Source:
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+ if num_patches == num_positions and height == width:
+ return self.position_embeddings
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+ height = height // self.patch_size
+ width = width // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ height, width = height + 0.1, width + 0.1
+ patch_pos_embed = patch_pos_embed.reshape(
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ scale_factor=(
+ height / math.sqrt(num_positions),
+ width / math.sqrt(num_positions),
+ ),
+ mode="bicubic",
+ align_corners=False,
+ )
+ if (
+ int(height) != patch_pos_embed.shape[-2]
+ or int(width) != patch_pos_embed.shape[-1]
+ ):
+ raise ValueError(
+ "Width or height does not match with the interpolated position embeddings"
+ )
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ ) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ patch_embeddings = self.patch_embeddings(pixel_values)
+ embeddings = patch_embeddings
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(
+ embeddings, height, width
+ )
+
+ return embeddings
diff --git a/3D_Stage/lrm/models/tokenizers/image.py b/3D_Stage/lrm/models/tokenizers/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..1663470778c514654eaa63e104bc511e4783ad47
--- /dev/null
+++ b/3D_Stage/lrm/models/tokenizers/image.py
@@ -0,0 +1,268 @@
+from dataclasses import dataclass
+import random
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from ...utils.base import BaseModule
+from .dinov2 import Dinov2Model, CustomEmbeddings
+from ..transformers.attention import Modulation
+from ...utils.typing import *
+
+
+class NaiveImageTokenizer(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ num_tokens: int = 1024
+ num_channels: int = 768
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+
+ def forward(self, images: Float[Tensor, "B N C H W"]) -> Float[Tensor, "B Ct Nt"]:
+ return torch.rand(
+ (
+ images.shape[0],
+ self.cfg.num_channels,
+ self.cfg.num_tokens,
+ ),
+ device=images.device,
+ dtype=images.dtype,
+ )
+
+ def detokenize(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class DINOV2SingleImageTokenizer(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ pretrained_model_name_or_path: str = "facebook/dinov2-base"
+ width: int = 224
+ height: int = 224
+ modulation: bool = False
+ modulation_zero_init: bool = False
+ modulation_single_layer: bool = False
+ modulation_cond_dim: int = 16
+ freeze_backbone_params: bool = True
+ enable_memory_efficient_attention: bool = False
+ enable_gradient_checkpointing: bool = False
+ use_patch_embeddings: bool = False
+ patch_embeddings_aggr_method: str = "concat"
+ append_plucker_rays: bool = False
+ drop_rate: float = 0.0
+ drop_type: str = "all_but_first"
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+ model: Dinov2Model
+
+ if self.cfg.freeze_backbone_params:
+ # freeze dino backbone parameters
+ self.register_non_module(
+ "model",
+ Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path).to(
+ self.device
+ ),
+ )
+
+ model = self.non_module("model")
+ for p in model.parameters():
+ p.requires_grad_(False)
+ model.eval()
+ else:
+ self.model = Dinov2Model.from_pretrained(
+ self.cfg.pretrained_model_name_or_path
+ ).to(self.device)
+ model = self.model
+
+ if self.cfg.append_plucker_rays:
+ model.expand_input_channels(6)
+
+ model.set_use_memory_efficient_attention_xformers(
+ self.cfg.enable_memory_efficient_attention
+ )
+ model.set_gradient_checkpointing(self.cfg.enable_gradient_checkpointing)
+
+ # add modulation
+ if self.cfg.modulation:
+ modulations = []
+ for layer in model.encoder.layer:
+ norm1_modulation = Modulation(
+ model.config.hidden_size,
+ self.cfg.modulation_cond_dim,
+ zero_init=self.cfg.modulation_zero_init,
+ single_layer=self.cfg.modulation_single_layer,
+ )
+ norm2_modulation = Modulation(
+ model.config.hidden_size,
+ self.cfg.modulation_cond_dim,
+ zero_init=self.cfg.modulation_zero_init,
+ single_layer=self.cfg.modulation_single_layer,
+ )
+ layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
+ modulations += [norm1_modulation, norm2_modulation]
+ self.modulations = nn.ModuleList(modulations)
+
+ self.register_buffer(
+ "image_mean",
+ torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
+ persistent=False,
+ )
+ self.register_buffer(
+ "image_std",
+ torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
+ persistent=False,
+ )
+
+ def forward(
+ self,
+ images: Float[Tensor, "B *N C H W"],
+ modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
+ plucker_rays: Optional[Float[Tensor, "B *N 6 H W"]],
+ **kwargs
+ ) -> Float[Tensor, "B *N Ct Nt"]:
+ model: Dinov2Model
+ if self.cfg.freeze_backbone_params:
+ model = self.non_module("model")
+ else:
+ model = self.model
+
+ packed = False
+ if images.ndim == 4:
+ packed = True
+ images = images.unsqueeze(1)
+ if modulation_cond is not None:
+ assert modulation_cond.ndim == 2
+ modulation_cond = modulation_cond.unsqueeze(1)
+ if plucker_rays is not None:
+ assert plucker_rays.ndim == 4
+ plucker_rays = plucker_rays.unsqueeze(1)
+
+ if (
+ self.training
+ and self.cfg.drop_rate > 0
+ and random.random() < self.cfg.drop_rate
+ ):
+ if self.cfg.drop_type == "all_but_first":
+ drop_func = lambda x: x if x is None else x[:, 0:1]
+ images = drop_func(images)
+ modulation_cond = drop_func(modulation_cond)
+ plucker_rays = drop_func(plucker_rays)
+ else:
+ raise NotImplementedError
+
+ batch_size, n_input_views = images.shape[:2]
+ images = (images - self.image_mean) / self.image_std
+ if self.cfg.append_plucker_rays and plucker_rays is not None:
+ images = torch.cat([images, plucker_rays], dim=2)
+ out = model(
+ rearrange(images, "B N C H W -> (B N) C H W"),
+ modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
+ if modulation_cond is not None
+ else None,
+ )
+ local_features, global_features = out.last_hidden_state, out.pooler_output
+ if self.cfg.use_patch_embeddings:
+ patch_embeddings = out.patch_embeddings
+ if self.cfg.patch_embeddings_aggr_method == "concat":
+ local_features = torch.cat([local_features, patch_embeddings], dim=1)
+ elif self.cfg.patch_embeddings_aggr_method == "add":
+ local_features = local_features + patch_embeddings
+ else:
+ raise NotImplementedError
+ local_features = local_features.permute(0, 2, 1)
+ local_features = rearrange(
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
+ )
+ if packed:
+ local_features = local_features.squeeze(1)
+
+ return local_features
+
+ def detokenize(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class DINOV2CustomLowLevelSingleImageTokenizer(DINOV2SingleImageTokenizer):
+ @dataclass
+ class Config(DINOV2SingleImageTokenizer.Config):
+ custom_embeddings_aggr_method: str = "concat"
+ custom_embeddings_scale: float = 1.0
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+ self.custom_embeddings = CustomEmbeddings(
+ self.model.config.image_size,
+ self.model.config.patch_size,
+ self.model.config.num_channels,
+ self.model.config.hidden_size,
+ )
+ self.custom_embeddings.load_state_dict(
+ self.model.embeddings.state_dict(), strict=False
+ )
+
+ def forward(
+ self,
+ images: Float[Tensor, "B *N C H W"],
+ modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
+ **kwargs
+ ) -> Float[Tensor, "B *N Ct Nt"]:
+ model: Dinov2Model
+ if self.cfg.freeze_backbone_params:
+ model = self.non_module("model")
+ else:
+ model = self.model
+
+ packed = False
+ if images.ndim == 4:
+ packed = True
+ images = images.unsqueeze(1)
+ if modulation_cond is not None:
+ assert modulation_cond.ndim == 2
+ modulation_cond = modulation_cond.unsqueeze(1)
+
+ batch_size, n_input_views = images.shape[:2]
+ images = (images - self.image_mean) / self.image_std
+ images = rearrange(images, "B N C H W -> (B N) C H W")
+ out = model(
+ images,
+ modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
+ if modulation_cond is not None
+ else None,
+ )
+ local_features, global_features = out.last_hidden_state, out.pooler_output
+ if self.cfg.use_patch_embeddings:
+ patch_embeddings = out.patch_embeddings
+ if self.cfg.patch_embeddings_aggr_method == "concat":
+ local_features = torch.cat([local_features, patch_embeddings], dim=1)
+ elif self.cfg.patch_embeddings_aggr_method == "add":
+ local_features = local_features + patch_embeddings
+ else:
+ raise NotImplementedError
+
+ custom_embeddings = (
+ self.custom_embeddings(images) * self.cfg.custom_embeddings_scale
+ )
+ if self.cfg.custom_embeddings_aggr_method == "concat":
+ local_features = torch.cat([local_features, custom_embeddings], dim=1)
+ elif self.cfg.custom_embeddings_aggr_method == "add":
+ local_features = local_features + custom_embeddings
+ else:
+ raise NotImplementedError
+
+ local_features = local_features.permute(0, 2, 1)
+ local_features = rearrange(
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
+ )
+ if packed:
+ local_features = local_features.squeeze(1)
+
+ return local_features
diff --git a/3D_Stage/lrm/models/tokenizers/triplane.py b/3D_Stage/lrm/models/tokenizers/triplane.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfbbba59683fb30998c204eaf89536e1cbe06e01
--- /dev/null
+++ b/3D_Stage/lrm/models/tokenizers/triplane.py
@@ -0,0 +1,49 @@
+from dataclasses import dataclass
+import math
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+from ...utils.base import BaseModule
+from ...utils.typing import *
+
+
+class TriplaneLearnablePositionalEmbedding(BaseModule):
+ @dataclass
+ class Config(BaseModule.Config):
+ plane_size: int = 32
+ num_channels: int = 1024
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+ self.embeddings = nn.Parameter(
+ torch.randn(
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
+ dtype=torch.float32,
+ )
+ * 1
+ / math.sqrt(self.cfg.num_channels)
+ )
+
+ def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
+ return rearrange(
+ repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
+ )
+
+ def detokenize(
+ self, tokens: Float[Tensor, "B Ct Nt"]
+ ) -> Float[Tensor, "B 3 Ct Hp Wp"]:
+ batch_size, Ct, Nt = tokens.shape
+ assert Nt == self.cfg.plane_size**2 * 3
+ assert Ct == self.cfg.num_channels
+ return rearrange(
+ tokens,
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
+ Np=3,
+ Hp=self.cfg.plane_size,
+ Wp=self.cfg.plane_size,
+ )
diff --git a/3D_Stage/lrm/models/transformers/__init__.py b/3D_Stage/lrm/models/transformers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3D_Stage/lrm/models/transformers/attention.py b/3D_Stage/lrm/models/transformers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b5891fcf694d94aa11385e057300841f2f4f326
--- /dev/null
+++ b/3D_Stage/lrm/models/transformers/attention.py
@@ -0,0 +1,669 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import Attention
+from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
+
+from ...utils.typing import *
+
+
+class MemoryEfficientAttentionMixin:
+ def enable_xformers_memory_efficient_attention(
+ self, attention_op: Optional[Callable] = None
+ ):
+ r"""
+ Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this
+ option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed
+ up during training is not guaranteed.
+
+
+
+ ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
+ precedent.
+
+
+
+ Parameters:
+ attention_op (`Callable`, *optional*):
+ Override the default `None` operator for use as `op` argument to the
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
+ function of xFormers.
+
+ Examples:
+
+ ```py
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
+
+ >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+ >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
+ >>> # Workaround for not accepting attention shape using VAE for Flash Attention
+ >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
+ ```
+ """
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
+
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
+ """
+ self.set_use_memory_efficient_attention_xformers(False)
+
+ def set_use_memory_efficient_attention_xformers(
+ self, valid: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ # Recursively walk through all the children.
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
+ # gets the message
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
+
+ for child in module.children():
+ fn_recursive_set_mem_eff(child)
+
+ for module in self.children():
+ if isinstance(module, torch.nn.Module):
+ fn_recursive_set_mem_eff(module)
+
+
+@maybe_allow_in_graph
+class GatedSelfAttentionDense(nn.Module):
+ r"""
+ A gated self-attention dense layer that combines visual features and object features.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ context_dim (`int`): The number of channels in the context.
+ n_heads (`int`): The number of heads to use for attention.
+ d_head (`int`): The number of channels in each head.
+ """
+
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
+
+ self.enabled = True
+
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
+ if not self.enabled:
+ return x
+
+ n_visual = x.shape[1]
+ objs = self.linear(objs)
+
+ x = (
+ x
+ + self.alpha_attn.tanh()
+ * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
+ )
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
+
+ return x
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock(nn.Module, MemoryEfficientAttentionMixin):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ cond_dim_ada_norm_continuous: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm",
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (
+ num_embeds_ada_norm is not None
+ ) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (
+ num_embeds_ada_norm is not None
+ ) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_continuous = (
+ cond_dim_ada_norm_continuous is not None
+ ) and norm_type == "ada_norm_continuous"
+
+ assert (
+ int(self.use_ada_layer_norm)
+ + int(self.use_ada_layer_norm_continuous)
+ + int(self.use_ada_layer_norm_zero)
+ <= 1
+ )
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_continuous:
+ self.norm1 = AdaLayerNormContinuous(dim, cond_dim_ada_norm_continuous)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ if self.use_ada_layer_norm:
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_continuous:
+ self.norm2 = AdaLayerNormContinuous(dim, cond_dim_ada_norm_continuous)
+ else:
+ self.norm2 = nn.LayerNorm(
+ dim, elementwise_affine=norm_elementwise_affine
+ )
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim
+ if not double_self_attention
+ else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ if self.use_ada_layer_norm_continuous:
+ self.norm3 = AdaLayerNormContinuous(dim, cond_dim_ada_norm_continuous)
+ else:
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ )
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
+ )
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ modulation_cond: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_continuous:
+ norm_hidden_states = self.norm1(hidden_states, modulation_cond)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Retrieve lora scale.
+ lora_scale = (
+ cross_attention_kwargs.get("scale", 1.0)
+ if cross_attention_kwargs is not None
+ else 1.0
+ )
+
+ # 2. Prepare GLIGEN inputs
+ cross_attention_kwargs = (
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ )
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states
+ if self.only_cross_attention
+ else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ # 2.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+ # 2.5 ends
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.use_ada_layer_norm_continuous:
+ norm_hidden_states = self.norm2(hidden_states, modulation_cond)
+ else:
+ norm_hidden_states = self.norm2(hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ if self.use_ada_layer_norm_continuous:
+ norm_hidden_states = self.norm3(hidden_states, modulation_cond)
+ else:
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = (
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ )
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [
+ self.ff(hid_slice, scale=lora_scale)
+ for hid_slice in norm_hidden_states.chunk(
+ num_chunks, dim=self._chunk_dim
+ )
+ ],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+ linear_cls = nn.Linear
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(linear_cls(inner_dim, dim_out))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+class GELU(nn.Module):
+ r"""
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+ self.approximate = approximate
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate, approximate=self.approximate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
+ dtype=gate.dtype
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+class GEGLU(nn.Module):
+ r"""
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ linear_cls = nn.Linear
+
+ self.proj = linear_cls(dim_in, dim_out * 2)
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states, scale: float = 1.0):
+ args = ()
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+ r"""
+ The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
+ https://arxiv.org/abs/1606.08415.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
+
+
+class AdaLayerNorm(nn.Module):
+ r"""
+ Norm layer modified to incorporate timestep embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the dictionary of embeddings.
+ """
+
+ def __init__(self, embedding_dim: int, num_embeddings: int):
+ super().__init__()
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
+
+ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
+ emb = self.linear(self.silu(self.emb(timestep)))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+ return x
+
+
+class AdaLayerNormContinuous(nn.Module):
+ r"""
+ Norm layer modified to incorporate arbitrary continuous embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ """
+
+ def __init__(self, embedding_dim: int, condition_dim: int):
+ super().__init__()
+ self.silu = nn.SiLU()
+ self.linear1 = nn.Linear(condition_dim, condition_dim)
+ self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
+
+ def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
+ emb = self.linear2(self.silu(self.linear1(condition)))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+ return x
+
+
+class Modulation(nn.Module):
+ def __init__(self, embedding_dim: int, condition_dim: int, zero_init: bool = False, single_layer: bool = False):
+ super().__init__()
+ self.silu = nn.SiLU()
+ if single_layer:
+ self.linear1 = nn.Identity()
+ else:
+ self.linear1 = nn.Linear(condition_dim, condition_dim)
+
+ self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
+
+ # Only zero init the last linear layer
+ if zero_init:
+ nn.init.zeros_(self.linear2.weight)
+ nn.init.zeros_(self.linear2.bias)
+
+ def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
+ emb = self.linear2(self.silu(self.linear1(condition)))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+ return x
+
+
+class AdaLayerNormZero(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (adaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the dictionary of embeddings.
+ """
+
+ def __init__(self, embedding_dim: int, num_embeddings: int):
+ super().__init__()
+
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timestep: torch.Tensor,
+ class_labels: torch.LongTensor,
+ hidden_dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb = self.linear(
+ self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))
+ )
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(
+ 6, dim=1
+ )
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+class AdaGroupNorm(nn.Module):
+ r"""
+ GroupNorm layer modified to incorporate timestep embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the dictionary of embeddings.
+ num_groups (`int`): The number of groups to separate the channels into.
+ act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
+ eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ out_dim: int,
+ num_groups: int,
+ act_fn: Optional[str] = None,
+ eps: float = 1e-5,
+ ):
+ super().__init__()
+ self.num_groups = num_groups
+ self.eps = eps
+
+ if act_fn is None:
+ self.act = None
+ else:
+ self.act = get_activation(act_fn)
+
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ if self.act:
+ emb = self.act(emb)
+ emb = self.linear(emb)
+ emb = emb[:, :, None, None]
+ scale, shift = emb.chunk(2, dim=1)
+
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
+ x = x * (1 + scale) + shift
+ return x
diff --git a/3D_Stage/lrm/models/transformers/transformer_1d.py b/3D_Stage/lrm/models/transformers/transformer_1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d03a488b6ffb6e5a5c87c3dd9c0130baf743472
--- /dev/null
+++ b/3D_Stage/lrm/models/transformers/transformer_1d.py
@@ -0,0 +1,252 @@
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed
+
+import lrm
+from ...utils.base import BaseModule
+from .attention import (
+ BasicTransformerBlock,
+ MemoryEfficientAttentionMixin,
+)
+from ...utils.typing import *
+
+
+class Transformer1D(BaseModule, MemoryEfficientAttentionMixin):
+ """
+ A 1D Transformer model for sequence data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+
+ @dataclass
+ class Config(BaseModule.Config):
+ num_attention_heads: int = 16
+ attention_head_dim: int = 88
+ in_channels: Optional[int] = None
+ out_channels: Optional[int] = None
+ num_layers: int = 1
+ dropout: float = 0.0
+ norm_num_groups: int = 32
+ cross_attention_dim: Optional[int] = None
+ attention_bias: bool = False
+ activation_fn: str = "geglu"
+ num_embeds_ada_norm: Optional[int] = None
+ cond_dim_ada_norm_continuous: Optional[int] = None
+ only_cross_attention: bool = False
+ double_self_attention: bool = False
+ upcast_attention: bool = False
+ norm_type: str = "layer_norm"
+ norm_elementwise_affine: bool = True
+ attention_type: str = "default"
+ enable_memory_efficient_attention: bool = False
+ gradient_checkpointing: bool = False
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+
+ self.num_attention_heads = self.cfg.num_attention_heads
+ self.attention_head_dim = self.cfg.attention_head_dim
+ inner_dim = self.num_attention_heads * self.attention_head_dim
+
+ linear_cls = nn.Linear
+
+ if self.cfg.norm_type == "layer_norm" and (
+ self.cfg.num_embeds_ada_norm is not None
+ or self.cfg.cond_dim_ada_norm_continuous is not None
+ ):
+ raise ValueError("Incorrect norm_type.")
+
+ # 2. Define input layers
+ self.in_channels = self.cfg.in_channels
+
+ self.norm = torch.nn.GroupNorm(
+ num_groups=self.cfg.norm_num_groups,
+ num_channels=self.cfg.in_channels,
+ eps=1e-6,
+ affine=True,
+ )
+ self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ self.num_attention_heads,
+ self.attention_head_dim,
+ dropout=self.cfg.dropout,
+ cross_attention_dim=self.cfg.cross_attention_dim,
+ activation_fn=self.cfg.activation_fn,
+ num_embeds_ada_norm=self.cfg.num_embeds_ada_norm,
+ cond_dim_ada_norm_continuous=self.cfg.cond_dim_ada_norm_continuous,
+ attention_bias=self.cfg.attention_bias,
+ only_cross_attention=self.cfg.only_cross_attention,
+ double_self_attention=self.cfg.double_self_attention,
+ upcast_attention=self.cfg.upcast_attention,
+ norm_type=self.cfg.norm_type,
+ norm_elementwise_affine=self.cfg.norm_elementwise_affine,
+ attention_type=self.cfg.attention_type,
+ )
+ for d in range(self.cfg.num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ self.out_channels = (
+ self.cfg.in_channels
+ if self.cfg.out_channels is None
+ else self.cfg.out_channels
+ )
+
+ self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
+
+ self.gradient_checkpointing = self.cfg.gradient_checkpointing
+
+ self.set_use_memory_efficient_attention_xformers(
+ self.cfg.enable_memory_efficient_attention
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ modulation_cond: Optional[torch.FloatTensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ """
+ The [`Transformer1DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ attention_mask ( `torch.Tensor`, *optional*):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
+ ) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ batch, _, seq_len = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 1).reshape(
+ batch, seq_len, inner_dim
+ )
+ hidden_states = self.proj_in(hidden_states)
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ if self.training and self.gradient_checkpointing:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ block,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ modulation_cond,
+ cross_attention_kwargs,
+ class_labels,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ modulation_cond=modulation_cond,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+
+ # 3. Output
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, seq_len, inner_dim)
+ .permute(0, 2, 1)
+ .contiguous()
+ )
+
+ output = hidden_states + residual
+
+ return output
diff --git a/3D_Stage/lrm/models/transformers/transformer_2d.py b/3D_Stage/lrm/models/transformers/transformer_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c31de6efebf0f3d1565d96e076870eb6979396a
--- /dev/null
+++ b/3D_Stage/lrm/models/transformers/transformer_2d.py
@@ -0,0 +1,414 @@
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed
+
+import lrm
+from ...utils.base import BaseModule
+from .attention import (
+ BasicTransformerBlock,
+ MemoryEfficientAttentionMixin,
+)
+from ...utils.typing import *
+
+
+class Transformer2D(BaseModule, MemoryEfficientAttentionMixin):
+ """
+ A 2D Transformer model for image-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ num_vector_embeds (`int`, *optional*):
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+
+ @dataclass
+ class Config(BaseModule.Config):
+ num_attention_heads: int = 16
+ attention_head_dim: int = 88
+ in_channels: Optional[int] = None
+ out_channels: Optional[int] = None
+ num_layers: int = 1
+ dropout: float = 0.0
+ norm_num_groups: int = 32
+ cross_attention_dim: Optional[int] = None
+ attention_bias: bool = False
+ sample_size: Optional[int] = None
+ num_vector_embeds: Optional[int] = None
+ patch_size: Optional[int] = None
+ activation_fn: str = "geglu"
+ num_embeds_ada_norm: Optional[int] = None
+ use_linear_projection: bool = False
+ only_cross_attention: bool = False
+ double_self_attention: bool = False
+ upcast_attention: bool = False
+ norm_type: str = "layer_norm"
+ norm_elementwise_affine: bool = True
+ attention_type: str = "default"
+ enable_memory_efficient_attention: bool = False
+ gradient_checkpointing: bool = False
+
+ cfg: Config
+
+ def configure(self) -> None:
+ super().configure()
+
+ self.use_linear_projection = self.cfg.use_linear_projection
+ self.num_attention_heads = self.cfg.num_attention_heads
+ self.attention_head_dim = self.cfg.attention_head_dim
+ inner_dim = self.num_attention_heads * self.attention_head_dim
+
+ conv_cls = nn.Conv2d
+ linear_cls = nn.Linear
+
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = (self.cfg.in_channels is not None) and (
+ self.cfg.patch_size is None
+ )
+ self.is_input_vectorized = self.cfg.num_vector_embeds is not None
+ self.is_input_patches = (
+ self.cfg.in_channels is not None and self.cfg.patch_size is not None
+ )
+
+ if (
+ self.cfg.norm_type == "layer_norm"
+ and self.cfg.num_embeds_ada_norm is not None
+ ):
+ deprecation_message = (
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
+ )
+ lrm.warn(deprecation_message)
+ self.cfg.norm_type = "ada_norm"
+
+ if self.is_input_continuous and self.is_input_vectorized:
+ raise ValueError(
+ f"Cannot define both `in_channels`: {self.cfg.in_channels} and `num_vector_embeds`: {self.cfg.num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is None."
+ )
+ elif self.is_input_vectorized and self.is_input_patches:
+ raise ValueError(
+ f"Cannot define both `num_vector_embeds`: {self.cfg.num_vector_embeds} and `patch_size`: {self.cfg.patch_size}. Make"
+ " sure that either `num_vector_embeds` or `num_patches` is None."
+ )
+ elif (
+ not self.is_input_continuous
+ and not self.is_input_vectorized
+ and not self.is_input_patches
+ ):
+ raise ValueError(
+ f"Has to define `in_channels`: {self.cfg.in_channels}, `num_vector_embeds`: {self.cfg.num_vector_embeds}, or patch_size:"
+ f" {self.cfg.patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
+ )
+
+ # 2. Define input layers
+ if self.is_input_continuous:
+ self.in_channels = self.cfg.in_channels
+
+ self.norm = torch.nn.GroupNorm(
+ num_groups=self.cfg.norm_num_groups,
+ num_channels=self.cfg.in_channels,
+ eps=1e-6,
+ affine=True,
+ )
+ if self.cfg.use_linear_projection:
+ self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
+ else:
+ self.proj_in = conv_cls(
+ self.cfg.in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+ elif self.is_input_vectorized:
+ assert (
+ self.cfg.sample_size is not None
+ ), "Transformer2DModel over discrete input must provide sample_size"
+ assert (
+ self.cfg.num_vector_embeds is not None
+ ), "Transformer2DModel over discrete input must provide num_embed"
+
+ self.height = self.cfg.sample_size
+ self.width = self.cfg.sample_size
+ self.num_vector_embeds = self.cfg.num_vector_embeds
+ self.num_latent_pixels = self.height * self.width
+
+ self.latent_image_embedding = ImagePositionalEmbeddings(
+ num_embed=self.cfg.num_vector_embeds,
+ embed_dim=inner_dim,
+ height=self.height,
+ width=self.width,
+ )
+ elif self.is_input_patches:
+ assert (
+ self.cfg.sample_size is not None
+ ), "Transformer2DModel over patched input must provide sample_size"
+
+ self.height = self.cfg.sample_size
+ self.width = self.cfg.sample_size
+
+ self.patch_size = self.cfg.patch_size
+ self.pos_embed = PatchEmbed(
+ height=self.cfg.sample_size,
+ width=self.cfg.sample_size,
+ patch_size=self.cfg.patch_size,
+ in_channels=self.cfg.in_channels,
+ embed_dim=inner_dim,
+ )
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ self.num_attention_heads,
+ self.attention_head_dim,
+ dropout=self.cfg.dropout,
+ cross_attention_dim=self.cfg.cross_attention_dim,
+ activation_fn=self.cfg.activation_fn,
+ num_embeds_ada_norm=self.cfg.num_embeds_ada_norm,
+ attention_bias=self.cfg.attention_bias,
+ only_cross_attention=self.cfg.only_cross_attention,
+ double_self_attention=self.cfg.double_self_attention,
+ upcast_attention=self.cfg.upcast_attention,
+ norm_type=self.cfg.norm_type,
+ norm_elementwise_affine=self.cfg.norm_elementwise_affine,
+ attention_type=self.cfg.attention_type,
+ )
+ for d in range(self.cfg.num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ self.out_channels = (
+ self.cfg.in_channels
+ if self.cfg.out_channels is None
+ else self.cfg.out_channels
+ )
+ if self.is_input_continuous:
+ # TODO: should use out_channels for continuous projections
+ if self.cfg.use_linear_projection:
+ self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
+ else:
+ self.proj_out = conv_cls(
+ inner_dim, self.cfg.in_channels, kernel_size=1, stride=1, padding=0
+ )
+ elif self.is_input_vectorized:
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+ elif self.is_input_patches:
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
+ self.proj_out_2 = nn.Linear(
+ inner_dim, self.cfg.patch_size * self.cfg.patch_size * self.out_channels
+ )
+
+ self.gradient_checkpointing = self.cfg.gradient_checkpointing
+
+ self.set_use_memory_efficient_attention_xformers(
+ self.cfg.enable_memory_efficient_attention
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ """
+ The [`Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ attention_mask ( `torch.Tensor`, *optional*):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
+ ) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ if self.is_input_continuous:
+ batch, _, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
+ batch, height * width, inner_dim
+ )
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
+ batch, height * width, inner_dim
+ )
+ hidden_states = self.proj_in(hidden_states)
+
+ elif self.is_input_vectorized:
+ hidden_states = self.latent_image_embedding(hidden_states)
+ elif self.is_input_patches:
+ hidden_states = self.pos_embed(hidden_states)
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ if self.training and self.gradient_checkpointing:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ block,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ cross_attention_kwargs,
+ class_labels,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+
+ # 3. Output
+ if self.is_input_continuous:
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, width, inner_dim)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, width, inner_dim)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ )
+
+ output = hidden_states + residual
+ elif self.is_input_vectorized:
+ hidden_states = self.norm_out(hidden_states)
+ logits = self.out(hidden_states)
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
+ logits = logits.permute(0, 2, 1)
+
+ # log(p(x_0))
+ output = F.log_softmax(logits.double(), dim=1).float()
+ elif self.is_input_patches:
+ # TODO: cleanup!
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = (
+ self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ )
+ hidden_states = self.proj_out_2(hidden_states)
+
+ # unpatchify
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(
+ -1,
+ height,
+ width,
+ self.patch_size,
+ self.patch_size,
+ self.out_channels,
+ )
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(
+ -1,
+ self.out_channels,
+ height * self.patch_size,
+ width * self.patch_size,
+ )
+ )
+
+ return output
diff --git a/3D_Stage/lrm/systems/__init__.py b/3D_Stage/lrm/systems/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3D_Stage/lrm/systems/base.py b/3D_Stage/lrm/systems/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..885bef8328064bd3151f1883d7246bcd330ac6c0
--- /dev/null
+++ b/3D_Stage/lrm/systems/base.py
@@ -0,0 +1,283 @@
+import os
+from dataclasses import dataclass, field
+
+import pytorch_lightning as pl
+import torch.nn.functional as F
+
+import lrm
+from .utils import parse_optimizer, parse_scheduler
+from ..utils.base import (
+ Updateable,
+ update_end_if_possible,
+ update_if_possible,
+)
+from ..models.exporters.base import Exporter, ExporterOutput
+from ..utils.config import parse_structured
+from ..utils.misc import C, cleanup, get_device, load_module_weights
+from ..utils.saving import SaverMixin
+from ..utils.typing import *
+
+
+@dataclass
+class BaseLossConfig:
+ pass
+
+
+class BaseSystem(pl.LightningModule, Updateable, SaverMixin):
+ @dataclass
+ class Config:
+ loss: BaseLossConfig = BaseLossConfig()
+ optimizer: dict = field(default_factory=dict)
+ scheduler: Optional[dict] = None
+ weights: Optional[str] = None
+ weights_ignore_modules: Optional[List[str]] = None
+ weights_mapping: Optional[List[Dict[str, str]]] = None
+ check_train_every_n_steps: int = 0
+ check_val_limit_rank: int = 8
+ cleanup_after_validation_step: bool = False
+ cleanup_after_test_step: bool = False
+
+ exporter_cls: str = "lrm.models.exporters.mesh_exporter.MeshExporter"
+ exporter: dict = field(default_factory=lambda: {"fmt": "obj", "save_uv": False})
+
+ cfg: Config
+
+ def __init__(self, cfg, resumed=False) -> None:
+ super().__init__()
+ self.cfg = parse_structured(self.Config, cfg)
+ self._save_dir: Optional[str] = None
+ self._resumed: bool = resumed
+ self._resumed_eval: bool = False
+ self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0}
+
+ self.configure()
+ print(self.cfg.weights)
+ if self.cfg.weights is not None:
+ self.load_weights(
+ self.cfg.weights,
+ self.cfg.weights_ignore_modules,
+ self.cfg.weights_mapping,
+ )
+ print("finish loading!!")
+ self.post_configure()
+
+ def load_weights(
+ self,
+ weights: str,
+ ignore_modules: Optional[List[str]] = None,
+ mapping: Optional[List[Dict[str, str]]] = None,
+ ):
+ state_dict, epoch, global_step = load_module_weights(
+ weights,
+ ignore_modules=ignore_modules,
+ mapping=mapping,
+ map_location="cpu",
+ )
+ self.load_state_dict(state_dict, strict=False)
+ # restore step-dependent states
+ self.do_update_step(epoch, global_step, on_load_weights=True)
+
+ def set_resume_status(self, current_epoch: int, global_step: int):
+ # restore correct epoch and global step in eval
+ self._resumed_eval = True
+ self._resumed_eval_status["current_epoch"] = current_epoch
+ self._resumed_eval_status["global_step"] = global_step
+
+ @property
+ def resumed(self):
+ # whether from resumed checkpoint
+ return self._resumed
+
+ @property
+ def true_global_step(self):
+ if self._resumed_eval:
+ return self._resumed_eval_status["global_step"]
+ else:
+ return self.global_step
+
+ @property
+ def true_current_epoch(self):
+ if self._resumed_eval:
+ return self._resumed_eval_status["current_epoch"]
+ else:
+ return self.current_epoch
+
+ def configure(self) -> None:
+ pass
+
+ def post_configure(self) -> None:
+ """
+ executed after weights are loaded
+ """
+ pass
+
+ def C(self, value: Any) -> float:
+ return C(value, self.true_current_epoch, self.true_global_step)
+
+ def configure_optimizers(self):
+ optim = parse_optimizer(self.cfg.optimizer, self)
+ ret = {
+ "optimizer": optim,
+ }
+ if self.cfg.scheduler is not None:
+ ret.update(
+ {
+ "lr_scheduler": parse_scheduler(self.cfg.scheduler, optim),
+ }
+ )
+ return ret
+
+ def on_fit_start(self) -> None:
+ if self._save_dir is not None:
+ lrm.info(f"Validation results will be saved to {self._save_dir}")
+ else:
+ lrm.warn(
+ f"Saving directory not set for the system, visualization results will not be saved"
+ )
+
+ def training_step(self, batch, batch_idx):
+ raise NotImplementedError
+
+ def check_train(self, batch, outputs, **kwargs):
+ if (
+ self.global_rank == 0
+ and self.cfg.check_train_every_n_steps > 0
+ and self.true_global_step % self.cfg.check_train_every_n_steps == 0
+ ):
+ self.on_check_train(batch, outputs, **kwargs)
+
+ def on_check_train(self, batch, outputs, **kwargs):
+ pass
+
+ def validation_step(self, batch, batch_idx):
+ raise NotImplementedError
+
+ def on_validation_epoch_end(self):
+ pass
+
+ def test_step(self, batch, batch_idx):
+ raise NotImplementedError
+
+ def on_test_epoch_end(self):
+ pass
+
+ def on_test_end(self) -> None:
+ if self._save_dir is not None:
+ lrm.info(f"Test results saved to {self._save_dir}")
+
+ def on_predict_start(self) -> None:
+ pass
+
+ def predict_step(self, batch, batch_idx):
+ batch_size = batch["index"].shape[0]
+ scene_codes = self(batch)
+ for b in range(batch_size):
+ if batch["view_index"][b, 0] == 0:
+ exporter_output: List[ExporterOutput] = self.exporter(
+ batch["index"][b][None], scene_codes[b][None]
+ )
+ for out in exporter_output:
+ save_func_name = f"save_{out.save_type}"
+ if not hasattr(self, save_func_name):
+ raise ValueError(
+ f"{save_func_name} not supported by the SaverMixin"
+ )
+ save_func = getattr(self, save_func_name)
+ save_func(
+ f"it{self.true_global_step}-export/{out.save_name}",
+ **out.params,
+ )
+ if self.exporter.cfg.save_video:
+ self.test_step(batch, batch_idx)
+
+ def on_predict_epoch_end(self) -> None:
+ if self.exporter.cfg.save_video:
+ self.on_test_epoch_end()
+
+ def on_predict_end(self) -> None:
+ if self._save_dir is not None:
+ lrm.info(f"Export assets saved to {self._save_dir}")
+
+ def preprocess_data(self, batch, stage):
+ pass
+
+ """
+ Implementing on_after_batch_transfer of DataModule does the same.
+ But on_after_batch_transfer does not support DP.
+ """
+
+ def on_train_batch_start(self, batch, batch_idx, unused=0):
+ self.preprocess_data(batch, "train")
+ self.dataset = self.trainer.train_dataloader.dataset
+ update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step)
+ self.do_update_step(self.true_current_epoch, self.true_global_step)
+
+ def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0):
+ self.preprocess_data(batch, "validation")
+ self.dataset = self.trainer.val_dataloaders.dataset
+ update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step)
+ self.do_update_step(self.true_current_epoch, self.true_global_step)
+
+ def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0):
+ self.preprocess_data(batch, "test")
+ self.dataset = self.trainer.test_dataloaders.dataset
+ update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step)
+ self.do_update_step(self.true_current_epoch, self.true_global_step)
+
+ def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0):
+ self.preprocess_data(batch, "predict")
+ self.dataset = self.trainer.predict_dataloaders.dataset
+ update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step)
+ self.do_update_step(self.true_current_epoch, self.true_global_step)
+
+ def on_train_batch_end(self, outputs, batch, batch_idx):
+ self.dataset = self.trainer.train_dataloader.dataset
+ update_end_if_possible(
+ self.dataset, self.true_current_epoch, self.true_global_step
+ )
+ self.do_update_step_end(self.true_current_epoch, self.true_global_step)
+
+ def on_validation_batch_end(self, outputs, batch, batch_idx):
+ self.dataset = self.trainer.val_dataloaders.dataset
+ update_end_if_possible(
+ self.dataset, self.true_current_epoch, self.true_global_step
+ )
+ self.do_update_step_end(self.true_current_epoch, self.true_global_step)
+ if self.cfg.cleanup_after_validation_step:
+ # cleanup to save vram
+ cleanup()
+
+ def on_test_batch_end(self, outputs, batch, batch_idx):
+ self.dataset = self.trainer.test_dataloaders.dataset
+ update_end_if_possible(
+ self.dataset, self.true_current_epoch, self.true_global_step
+ )
+ self.do_update_step_end(self.true_current_epoch, self.true_global_step)
+ if self.cfg.cleanup_after_test_step:
+ # cleanup to save vram
+ cleanup()
+
+ def on_predict_batch_end(self, outputs, batch, batch_idx):
+ self.dataset = self.trainer.predict_dataloaders.dataset
+ update_end_if_possible(
+ self.dataset, self.true_current_epoch, self.true_global_step
+ )
+ self.do_update_step_end(self.true_current_epoch, self.true_global_step)
+ if self.cfg.cleanup_after_test_step:
+ # cleanup to save vram
+ cleanup()
+
+ def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
+ pass
+
+ def on_before_optimizer_step(self, optimizer):
+ """
+ # some gradient-related debugging goes here, example:
+ from lightning.pytorch.utilities import grad_norm
+ norms = grad_norm(self.geometry, norm_type=2)
+ print(norms)
+ for name, p in self.named_parameters():
+ if p.grad is None:
+ lrm.info(f"{name} does not receive gradients!")
+ """
+ pass
diff --git a/3D_Stage/lrm/systems/multiview_lrm.py b/3D_Stage/lrm/systems/multiview_lrm.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bec99b677f9402bc453e866cf7d517132f2d80e
--- /dev/null
+++ b/3D_Stage/lrm/systems/multiview_lrm.py
@@ -0,0 +1,335 @@
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+
+import lrm
+from lrm.systems.base import BaseLossConfig, BaseSystem
+from lrm.utils.ops import binary_cross_entropy, get_plucker_rays
+from lrm.utils.typing import *
+from lrm.models.lpips import LPIPS
+from lrm.utils.misc import time_recorder as tr
+
+
+@dataclass
+class MultiviewLRMLossConfig(BaseLossConfig):
+ lambda_mse: Any = 0.0
+ lambda_mse_coarse: Any = 0.0
+ lambda_smooth_l1: Any = 0.0
+ lambda_smooth_l1_coarse: Any = 0.0
+ lambda_lpips: Any = 0.0
+ lambda_lpips_coarse: Any = 0.0
+ lambda_mask: Any = 0.0
+ lambda_mask_coarse: Any = 0.0
+
+
+class MultiviewLRM(BaseSystem):
+ @dataclass
+ class Config(BaseSystem.Config):
+ loss: MultiviewLRMLossConfig = MultiviewLRMLossConfig()
+
+ camera_embedder_cls: str = ""
+ camera_embedder: dict = field(default_factory=dict)
+
+ image_tokenizer_cls: str = ""
+ image_tokenizer: dict = field(default_factory=dict)
+
+ tokenizer_cls: str = ""
+ tokenizer: dict = field(default_factory=dict)
+
+ backbone_cls: str = ""
+ backbone: dict = field(default_factory=dict)
+
+ post_processor_cls: str = ""
+ post_processor: dict = field(default_factory=dict)
+
+ decoder_cls: str = ""
+ decoder: dict = field(default_factory=dict)
+
+ material_cls: str = ""
+ material: dict = field(default_factory=dict)
+
+ background_cls: str = ""
+ background: dict = field(default_factory=dict)
+
+ renderer_cls: str = ""
+ renderer: dict = field(default_factory=dict)
+
+ resume_ckpt_path: str = ""
+
+ cfg: Config
+
+ def configure(self):
+ super().configure()
+ self.image_tokenizer = lrm.find(self.cfg.image_tokenizer_cls)(
+ self.cfg.image_tokenizer
+ )
+ if self.cfg.image_tokenizer.modulation:
+ self.camera_embedder = lrm.find(self.cfg.camera_embedder_cls)(
+ self.cfg.camera_embedder
+ )
+ self.tokenizer = lrm.find(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
+ self.backbone = lrm.find(self.cfg.backbone_cls)(self.cfg.backbone)
+ self.post_processor = lrm.find(self.cfg.post_processor_cls)(
+ self.cfg.post_processor
+ )
+ self.decoder = lrm.find(self.cfg.decoder_cls)(self.cfg.decoder)
+ self.material = lrm.find(self.cfg.material_cls)(self.cfg.material)
+ self.background = lrm.find(self.cfg.background_cls)(self.cfg.background)
+ self.renderer = lrm.find(self.cfg.renderer_cls)(
+ self.cfg.renderer, self.decoder, self.material, self.background
+ )
+
+ self.exporter = lrm.find(self.cfg.exporter_cls)(
+ self.cfg.exporter, self.renderer
+ )
+
+ def on_fit_start(self):
+ super().on_fit_start()
+ self.lpips_loss_fn = LPIPS()
+
+ def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
+ # batch["rgb_cond"]: B, N_cond, H, W, 3
+ # batch["rgb"]: B, N_render, H, W, 3
+ # batch["c2w_cond"]: B, N_cond, 4, 4
+ # for single image input (like LRM), N_cond = 1
+
+ batch_size, n_input_views = batch["rgb_cond"].shape[:2]
+
+ # Camera modulation
+ camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
+ if self.cfg.image_tokenizer.modulation:
+ camera_embeds = self.camera_embedder(**batch)
+ else:
+ camera_embeds = None
+
+ tr.start("image tokenizer")
+ input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer(
+ rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
+ modulation_cond=camera_embeds,
+ plucker_rays=rearrange(
+ get_plucker_rays(batch["rays_o_cond"], batch["rays_d_cond"]),
+ "B Nv H W C -> B Nv C H W",
+ )
+ if "rays_o_cond" in batch
+ else None,
+ )
+ tr.end("image tokenizer")
+
+ input_image_tokens = rearrange(
+ input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
+ )
+
+ tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
+
+ tr.start("backbone")
+ tokens = self.backbone(
+ tokens,
+ encoder_hidden_states=input_image_tokens,
+ modulation_cond=None,
+ )
+ tr.end("backbone")
+
+ scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
+ return scene_codes
+
+ def forward_renderer_nerf(
+ self, batch: Dict[str, Any], scene_codes
+ ) -> Dict[str, Any]:
+ tr.start("render")
+ render_out = self.renderer(scene_codes, **batch)
+ tr.end("render")
+ return render_out
+
+ def training_step(self, batch, batch_idx):
+ scene_codes = self(batch)
+ out = self.forward_renderer_nerf(batch, scene_codes)
+
+ loss = 0.0
+
+ for suffix in ["", "_coarse"]:
+ if not f"comp_rgb{suffix}" in out:
+ continue
+
+ comp_rgb: Float[Tensor, "B Nv H W 3"] = out["comp_rgb{}".format(suffix)]
+ gt_rgb: Float[Tensor, "B Nv H W 3"] = batch["rgb"]
+
+ self.log(f"train/comp_rgb_min{suffix}", comp_rgb.min())
+
+ loss_mse = F.mse_loss(comp_rgb, gt_rgb, reduction="mean")
+ self.log(f"train/loss_mse{suffix}", loss_mse)
+ loss += loss_mse * self.C(self.cfg.loss[f"lambda_mse{suffix}"])
+
+ loss_smooth_l1 = F.smooth_l1_loss(
+ comp_rgb, gt_rgb, beta=0.1, reduction="mean"
+ )
+ self.log(f"train/loss_smooth_l1{suffix}", loss_smooth_l1)
+ loss += loss_smooth_l1 * self.C(self.cfg.loss[f"lambda_smooth_l1{suffix}"])
+
+ if self.C(self.cfg.loss[f"lambda_lpips{suffix}"]) > 0:
+ loss_lpips = self.lpips_loss_fn(
+ rearrange(comp_rgb, "B Nv H W C -> (B Nv) C H W"),
+ rearrange(gt_rgb, "B Nv H W C -> (B Nv) C H W"),
+ input_range=(0, 1),
+ ).mean()
+ self.log(f"train/loss_lpips{suffix}", loss_lpips)
+ loss += loss_lpips * self.C(self.cfg.loss[f"lambda_lpips{suffix}"])
+
+ loss_mask = binary_cross_entropy(
+ out[f"opacity{suffix}"].clamp(1e-5, 1 - 1e-5), batch["mask"]
+ )
+ self.log(f"train/loss_mask{suffix}", loss_mask)
+ loss += loss_mask * self.C(self.cfg.loss[f"lambda_mask{suffix}"])
+
+ for name, value in self.cfg.loss.items():
+ self.log(f"train_params/{name}", self.C(value))
+
+ # will execute self.on_check_train every self.cfg.check_train_every_n_steps steps
+ self.check_train(
+ batch,
+ out,
+ extra=f"m{loss_mse:.2f}_l{loss_smooth_l1:.2f}_p{loss_lpips:.2f}_ma{loss_mask:.2f}",
+ )
+
+ return {"loss": loss}
+
+ def get_input_visualizations(self, batch):
+ return [
+ {
+ "type": "rgb",
+ "img": rearrange(batch["rgb_cond"], "B N H W C -> (B H) (N W) C"),
+ "kwargs": {"data_format": "HWC"},
+ }
+ ]
+
+ def get_output_visualizations(self, batch, outputs):
+ out = outputs
+ images = []
+ if "rgb" in batch:
+ images += [
+ {
+ "type": "rgb",
+ "img": rearrange(batch["rgb"], "B N H W C -> (B H) (N W) C"),
+ "kwargs": {"data_format": "HWC"},
+ },
+ {
+ "type": "grayscale",
+ "img": rearrange(batch["mask"], "B N H W C -> (B H) (N W) C")[
+ ..., 0
+ ],
+ "kwargs": {"cmap": None, "data_range": None},
+ },
+ ]
+ for suffix in ["", "_coarse"]:
+ if not f"comp_rgb{suffix}" in out:
+ continue
+ images += [
+ {
+ "type": "rgb",
+ "img": rearrange(
+ out[f"comp_rgb{suffix}"], "B N H W C -> (B H) (N W) C"
+ ),
+ "kwargs": {"data_format": "HWC"},
+ },
+ {
+ "type": "grayscale",
+ "img": rearrange(
+ out[f"opacity{suffix}"], "B N H W C -> (B H) (N W) C"
+ )[..., 0],
+ "kwargs": {"cmap": None, "data_range": None},
+ },
+ {
+ "type": "grayscale",
+ "img": rearrange(
+ out[f"depth{suffix}"], "B N H W C -> (B H) (N W) C"
+ )[..., 0],
+ "kwargs": {"cmap": None, "data_range": None},
+ },
+ ]
+ return images
+
+ # def check_train(self, batch, outputs, **kwargs):
+ # self.on_check_train(batch, outputs, **kwargs)
+
+ def on_check_train(self, batch, outputs, extra=""):
+ self.save_image_grid(
+ f"it{self.true_global_step}-train.jpg",
+ self.get_output_visualizations(batch, outputs),
+ name="train_step_output",
+ step=self.true_global_step,
+ )
+ # self.save_image_grid(
+ # f"debug/it{self.true_global_step}-{self.global_rank}-{extra}.jpg",
+ # self.get_output_visualizations(batch, outputs),
+ # name="train_step_output",
+ # step=self.true_global_step,
+ # )
+ # self.save_json(
+ # f"debug_list/it{self.true_global_step}-{self.global_rank}-ids.json",
+ # batch["scene_id"],
+ # )
+
+ def validation_step(self, batch, batch_idx):
+ scene_codes = self(batch)
+ out = self.forward_renderer_nerf(batch, scene_codes)
+ if (
+ self.cfg.check_val_limit_rank > 0
+ and self.global_rank < self.cfg.check_val_limit_rank
+ ):
+ self.save_image_grid(
+ f"it{self.true_global_step}-validation-{self.global_rank}_{batch_idx}-input.jpg",
+ self.get_input_visualizations(batch),
+ name=f"validation_step_input_{self.global_rank}_{batch_idx}",
+ step=self.true_global_step,
+ )
+ self.save_image_grid(
+ f"it{self.true_global_step}-validation-{self.global_rank}_{batch_idx}.jpg",
+ self.get_output_visualizations(batch, out),
+ name=f"validation_step_output_{self.global_rank}_{batch_idx}",
+ step=self.true_global_step,
+ )
+
+ def test_step(self, batch, batch_idx):
+ # not saved to wandb
+ scene_codes = self(batch)
+ out = self.forward_renderer_nerf(batch, scene_codes)
+ batch_size = batch["index"].shape[0]
+ for b in range(batch_size):
+ if batch["view_index"][b, 0] == 0:
+ self.save_image_grid(
+ f"it{self.true_global_step}-test/{batch['index'][b]}-input.jpg",
+ [
+ {
+ "type": "rgb",
+ "img": rearrange(
+ batch["rgb_cond"][b], "N H W C -> H (N W) C"
+ ),
+ "kwargs": {"data_format": "HWC"},
+ },
+ ],
+ )
+ self.save_image_grid(
+ f"it{self.true_global_step}-test/{batch['index'][b]}/{batch['view_index'][b,0]}.png",
+ [
+ {
+ "type": "rgb",
+ "img": out["comp_rgb"][b][0],
+ "kwargs": {"data_format": "HWC"},
+ },
+ {
+ "type": "grayscale",
+ "img": out["depth"][b][0, ..., 0],
+ "kwargs": {"cmap": None, "data_range": None},
+ },
+ ],
+ )
+
+ def on_test_end(self):
+ if self.global_rank == 0:
+ self.save_img_sequences(
+ f"it{self.true_global_step}-test",
+ "(\d+)\.png",
+ save_format="mp4",
+ fps=30,
+ )
diff --git a/3D_Stage/lrm/systems/utils.py b/3D_Stage/lrm/systems/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..91e534849bcb6e627df383c84696264808382c66
--- /dev/null
+++ b/3D_Stage/lrm/systems/utils.py
@@ -0,0 +1,100 @@
+import torch
+import torch.nn as nn
+from torch.optim import lr_scheduler
+
+import lrm
+
+
+def get_scheduler(name):
+ if hasattr(lr_scheduler, name):
+ return getattr(lr_scheduler, name)
+ else:
+ raise NotImplementedError
+
+
+def getattr_recursive(m, attr):
+ for name in attr.split("."):
+ m = getattr(m, name)
+ return m
+
+
+def get_parameters(model, name):
+ module = getattr_recursive(model, name)
+ if isinstance(module, nn.Module):
+ return module.parameters()
+ elif isinstance(module, nn.Parameter):
+ return module
+ return []
+
+
+def parse_optimizer(config, model):
+ if hasattr(config, "params"):
+ params = [
+ {"params": get_parameters(model, name), "name": name, **args}
+ for name, args in config.params.items()
+ ]
+ lrm.debug(f"Specify optimizer params: {config.params}")
+ else:
+ params = model.parameters()
+ if config.name in ["FusedAdam"]:
+ import apex
+
+ optim = getattr(apex.optimizers, config.name)(params, **config.args)
+ elif config.name in ["Adam8bit", "AdamW8bit"]:
+ import bitsandbytes as bnb
+
+ optim = bnb.optim.Adam8bit(params, **config.args)
+ else:
+ optim = getattr(torch.optim, config.name)(params, **config.args)
+ return optim
+
+
+def parse_scheduler_to_instance(config, optimizer):
+ if config.name == "ChainedScheduler":
+ schedulers = [
+ parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers
+ ]
+ scheduler = lr_scheduler.ChainedScheduler(schedulers)
+ elif config.name == "Sequential":
+ schedulers = [
+ parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers
+ ]
+ scheduler = lr_scheduler.SequentialLR(
+ optimizer, schedulers, milestones=config.milestones
+ )
+ else:
+ scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args)
+ return scheduler
+
+
+def parse_scheduler(config, optimizer):
+ interval = config.get("interval", "epoch")
+ assert interval in ["epoch", "step"]
+ if config.name == "SequentialLR":
+ scheduler = {
+ "scheduler": lr_scheduler.SequentialLR(
+ optimizer,
+ [
+ parse_scheduler(conf, optimizer)["scheduler"]
+ for conf in config.schedulers
+ ],
+ milestones=config.milestones,
+ ),
+ "interval": interval,
+ }
+ elif config.name == "ChainedScheduler":
+ scheduler = {
+ "scheduler": lr_scheduler.ChainedScheduler(
+ [
+ parse_scheduler(conf, optimizer)["scheduler"]
+ for conf in config.schedulers
+ ]
+ ),
+ "interval": interval,
+ }
+ else:
+ scheduler = {
+ "scheduler": get_scheduler(config.name)(optimizer, **config.args),
+ "interval": interval,
+ }
+ return scheduler
diff --git a/3D_Stage/lrm/utils/__init__.py b/3D_Stage/lrm/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3D_Stage/lrm/utils/base.py b/3D_Stage/lrm/utils/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f78adda249e57664ee2ee8b38b41caf5e80a780f
--- /dev/null
+++ b/3D_Stage/lrm/utils/base.py
@@ -0,0 +1,123 @@
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+
+from .config import parse_structured
+from .misc import get_device, load_module_weights
+from .typing import *
+
+
+class Configurable:
+ @dataclass
+ class Config:
+ pass
+
+ def __init__(self, cfg: Optional[dict] = None) -> None:
+ super().__init__()
+ self.cfg = parse_structured(self.Config, cfg)
+
+
+class Updateable:
+ def do_update_step(
+ self, epoch: int, global_step: int, on_load_weights: bool = False
+ ):
+ for attr in self.__dir__():
+ if attr.startswith("_"):
+ continue
+ try:
+ module = getattr(self, attr)
+ except:
+ continue # ignore attributes like property, which can't be retrived using getattr?
+ if isinstance(module, Updateable):
+ module.do_update_step(
+ epoch, global_step, on_load_weights=on_load_weights
+ )
+ self.update_step(epoch, global_step, on_load_weights=on_load_weights)
+
+ def do_update_step_end(self, epoch: int, global_step: int):
+ for attr in self.__dir__():
+ if attr.startswith("_"):
+ continue
+ try:
+ module = getattr(self, attr)
+ except:
+ continue # ignore attributes like property, which can't be retrived using getattr?
+ if isinstance(module, Updateable):
+ module.do_update_step_end(epoch, global_step)
+ self.update_step_end(epoch, global_step)
+
+ def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
+ # override this method to implement custom update logic
+ # if on_load_weights is True, you should be careful doing things related to model evaluations,
+ # as the models and tensors are not guarenteed to be on the same device
+ pass
+
+ def update_step_end(self, epoch: int, global_step: int):
+ pass
+
+
+def update_if_possible(module: Any, epoch: int, global_step: int) -> None:
+ if isinstance(module, Updateable):
+ module.do_update_step(epoch, global_step)
+
+
+def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None:
+ if isinstance(module, Updateable):
+ module.do_update_step_end(epoch, global_step)
+
+
+class BaseObject(Updateable):
+ @dataclass
+ class Config:
+ pass
+
+ cfg: Config # add this to every subclass of BaseObject to enable static type checking
+
+ def __init__(
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
+ ) -> None:
+ super().__init__()
+ self.cfg = parse_structured(self.Config, cfg)
+ self.device = get_device()
+ self.configure(*args, **kwargs)
+
+ def configure(self, *args, **kwargs) -> None:
+ pass
+
+
+class BaseModule(nn.Module, Updateable):
+ @dataclass
+ class Config:
+ weights: Optional[str] = None
+
+ cfg: Config # add this to every subclass of BaseModule to enable static type checking
+
+ def __init__(
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
+ ) -> None:
+ super().__init__()
+ self.cfg = parse_structured(self.Config, cfg)
+ self.device = get_device()
+ self._non_modules = {}
+ self.configure(*args, **kwargs)
+ if self.cfg.weights is not None:
+ # format: path/to/weights:module_name
+ weights_path, module_name = self.cfg.weights.split(":")
+ state_dict, epoch, global_step = load_module_weights(
+ weights_path, module_name=module_name, map_location="cpu"
+ )
+ self.load_state_dict(state_dict)
+ self.do_update_step(
+ epoch, global_step, on_load_weights=True
+ ) # restore states
+
+ def configure(self, *args, **kwargs) -> None:
+ pass
+
+ def register_non_module(self, name: str, module: nn.Module) -> None:
+ # non-modules won't be treated as model parameters
+ self._non_modules[name] = module
+
+ def non_module(self, name: str):
+ return self._non_modules.get(name, None)
diff --git a/3D_Stage/lrm/utils/callbacks.py b/3D_Stage/lrm/utils/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f4cbd68e513e2ce6c905b1bb95842cdad8a7551
--- /dev/null
+++ b/3D_Stage/lrm/utils/callbacks.py
@@ -0,0 +1,156 @@
+import os
+import shutil
+import subprocess
+
+import pytorch_lightning
+
+from .config import dump_config
+from .misc import parse_version
+
+if parse_version(pytorch_lightning.__version__) > parse_version("1.8"):
+ from pytorch_lightning.callbacks import Callback
+else:
+ from pytorch_lightning.callbacks.base import Callback
+
+from pytorch_lightning.callbacks.progress import TQDMProgressBar
+from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
+
+
+class VersionedCallback(Callback):
+ def __init__(self, save_root, version=None, use_version=True):
+ self.save_root = save_root
+ self._version = version
+ self.use_version = use_version
+
+ @property
+ def version(self) -> int:
+ """Get the experiment version.
+
+ Returns:
+ The experiment version if specified else the next version.
+ """
+ if self._version is None:
+ self._version = self._get_next_version()
+ return self._version
+
+ def _get_next_version(self):
+ existing_versions = []
+ if os.path.isdir(self.save_root):
+ for f in os.listdir(self.save_root):
+ bn = os.path.basename(f)
+ if bn.startswith("version_"):
+ dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "")
+ existing_versions.append(int(dir_ver))
+ if len(existing_versions) == 0:
+ return 0
+ return max(existing_versions) + 1
+
+ @property
+ def savedir(self):
+ if not self.use_version:
+ return self.save_root
+ return os.path.join(
+ self.save_root,
+ self.version
+ if isinstance(self.version, str)
+ else f"version_{self.version}",
+ )
+
+
+class CodeSnapshotCallback(VersionedCallback):
+ def __init__(self, save_root, version=None, use_version=True):
+ super().__init__(save_root, version, use_version)
+
+ def get_file_list(self):
+ return [
+ b.decode()
+ for b in set(
+ subprocess.check_output(
+ 'git ls-files -- ":!:load/*"', shell=True
+ ).splitlines()
+ )
+ | set( # hard code, TODO: use config to exclude folders or files
+ subprocess.check_output(
+ "git ls-files --others --exclude-standard", shell=True
+ ).splitlines()
+ )
+ ]
+
+ @rank_zero_only
+ def save_code_snapshot(self):
+ os.makedirs(self.savedir, exist_ok=True)
+ for f in self.get_file_list():
+ if not os.path.exists(f) or os.path.isdir(f):
+ continue
+ os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True)
+ shutil.copyfile(f, os.path.join(self.savedir, f))
+
+ def on_fit_start(self, trainer, pl_module):
+ try:
+ self.save_code_snapshot()
+ except:
+ rank_zero_warn(
+ "Code snapshot is not saved. Please make sure you have git installed and are in a git repository."
+ )
+
+
+class ConfigSnapshotCallback(VersionedCallback):
+ def __init__(self, config_path, config, save_root, version=None, use_version=True):
+ super().__init__(save_root, version, use_version)
+ self.config_path = config_path
+ self.config = config
+
+ @rank_zero_only
+ def save_config_snapshot(self):
+ os.makedirs(self.savedir, exist_ok=True)
+ dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config)
+ shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml"))
+
+ def on_fit_start(self, trainer, pl_module):
+ self.save_config_snapshot()
+
+
+class CustomProgressBar(TQDMProgressBar):
+ def get_metrics(self, *args, **kwargs):
+ # don't show the version number
+ items = super().get_metrics(*args, **kwargs)
+ items.pop("v_num", None)
+ return items
+
+
+class ProgressCallback(Callback):
+ def __init__(self, save_path):
+ super().__init__()
+ self.save_path = save_path
+ self._file_handle = None
+
+ @property
+ def file_handle(self):
+ if self._file_handle is None:
+ self._file_handle = open(self.save_path, "w")
+ return self._file_handle
+
+ @rank_zero_only
+ def write(self, msg: str) -> None:
+ self.file_handle.seek(0)
+ self.file_handle.truncate()
+ self.file_handle.write(msg)
+ self.file_handle.flush()
+
+ @rank_zero_only
+ def on_train_batch_end(self, trainer, pl_module, *args, **kwargs):
+ self.write(
+ f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%"
+ )
+
+ @rank_zero_only
+ def on_validation_start(self, trainer, pl_module):
+ self.write(f"Rendering validation image ...")
+
+ @rank_zero_only
+ def on_test_start(self, trainer, pl_module):
+ self.write(f"Rendering video ...")
+
+ @rank_zero_only
+ def on_predict_start(self, trainer, pl_module):
+ self.write(f"Exporting mesh assets ...")
diff --git a/3D_Stage/lrm/utils/config.py b/3D_Stage/lrm/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c490d30efe9fdc0a104d8e0854eba3474518e4e
--- /dev/null
+++ b/3D_Stage/lrm/utils/config.py
@@ -0,0 +1,140 @@
+import os
+from dataclasses import dataclass, field
+from datetime import datetime
+
+from omegaconf import OmegaConf
+
+import lrm
+from .typing import *
+
+# ============ Register OmegaConf Resolvers ============= #
+OmegaConf.register_new_resolver(
+ "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n)
+)
+OmegaConf.register_new_resolver("add", lambda a, b: a + b)
+OmegaConf.register_new_resolver("sub", lambda a, b: a - b)
+OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
+OmegaConf.register_new_resolver("div", lambda a, b: a / b)
+OmegaConf.register_new_resolver("idiv", lambda a, b: a // b)
+OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p))
+OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub))
+OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)])
+OmegaConf.register_new_resolver("gt0", lambda s: s > 0)
+OmegaConf.register_new_resolver("not", lambda s: not s)
+
+
+def calc_num_train_steps(num_data, batch_size, max_epochs, num_nodes, num_cards=8):
+ return int(num_data / (num_nodes * num_cards * batch_size)) * max_epochs
+
+
+OmegaConf.register_new_resolver("calc_num_train_steps", calc_num_train_steps)
+
+# ======================================================= #
+
+
+# ============== Automatic Name Resolvers =============== #
+def get_naming_convention(cfg):
+ # TODO
+ name = f"lrm_{cfg.system.backbone.num_layers}"
+ return name
+
+
+# ======================================================= #
+
+
+@dataclass
+class ExperimentConfig:
+ name: str = "default"
+ description: str = ""
+ tag: str = ""
+ seed: int = 0
+ use_timestamp: bool = True
+ timestamp: Optional[str] = None
+ exp_root_dir: str = "outputs"
+
+ ### these shouldn't be set manually
+ exp_dir: str = "outputs/default"
+ trial_name: str = "exp"
+ trial_dir: str = "outputs/default/exp"
+ n_gpus: int = 1
+ ###
+
+ resume: Optional[str] = None
+
+ data_cls: str = ""
+ data: dict = field(default_factory=dict)
+
+ system_cls: str = ""
+ system: dict = field(default_factory=dict)
+
+ # accept pytorch-lightning trainer parameters
+ # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api
+ trainer: dict = field(default_factory=dict)
+
+ # accept pytorch-lightning checkpoint callback parameters
+ # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint
+ checkpoint: dict = field(default_factory=dict)
+
+
+def load_config(
+ *yamls: str, cli_args: list = [], from_string=False, makedirs=True, **kwargs
+) -> Any:
+ if from_string:
+ parse_func = OmegaConf.create
+ else:
+ parse_func = OmegaConf.load
+ yaml_confs = []
+ for y in yamls:
+ conf = parse_func(y)
+ extends = conf.pop("extends", None)
+ if extends:
+ assert os.path.exists(extends), f"File {extends} does not exist."
+ yaml_confs.append(OmegaConf.load(extends))
+ yaml_confs.append(conf)
+ cli_conf = OmegaConf.from_cli(cli_args)
+ cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs)
+ OmegaConf.resolve(cfg)
+ assert isinstance(cfg, DictConfig)
+ scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg)
+
+ # post processing
+ # auto naming
+ if scfg.name == "auto":
+ scfg.name = get_naming_convention(scfg)
+ # add timestamp
+ if not scfg.tag and not scfg.use_timestamp:
+ raise ValueError("Either tag is specified or use_timestamp is True.")
+ scfg.trial_name = scfg.tag
+ # if resume from an existing config, scfg.timestamp should not be None
+ if scfg.timestamp is None:
+ scfg.timestamp = ""
+ if scfg.use_timestamp:
+ if scfg.n_gpus > 1:
+ lrm.warn(
+ "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag."
+ )
+ else:
+ scfg.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S")
+ # make directories
+ scfg.trial_name += scfg.timestamp
+ scfg.exp_dir = os.path.join(scfg.exp_root_dir, scfg.name)
+ scfg.trial_dir = os.path.join(scfg.exp_dir, scfg.trial_name)
+
+ if makedirs:
+ os.makedirs(scfg.trial_dir, exist_ok=True)
+
+ return scfg
+
+
+def config_to_primitive(config, resolve: bool = True) -> Any:
+ return OmegaConf.to_container(config, resolve=resolve)
+
+
+def dump_config(path: str, config) -> None:
+ with open(path, "w") as fp:
+ OmegaConf.save(config=config, f=fp)
+
+
+def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
+ scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
+ return scfg
diff --git a/3D_Stage/lrm/utils/misc.py b/3D_Stage/lrm/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..80bfa1e9ba4c1a46e3e268d691be6e8a3562cccc
--- /dev/null
+++ b/3D_Stage/lrm/utils/misc.py
@@ -0,0 +1,211 @@
+import gc
+import os
+import re
+import time
+from collections import defaultdict
+from contextlib import contextmanager
+
+import torch
+from packaging import version
+
+import lrm
+from .config import config_to_primitive
+from .typing import *
+
+
+def parse_version(ver: str):
+ return version.parse(ver)
+
+
+def get_rank():
+ # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
+ # therefore LOCAL_RANK needs to be checked first
+ rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
+ for key in rank_keys:
+ rank = os.environ.get(key)
+ if rank is not None:
+ return int(rank)
+ return 0
+
+
+def get_device():
+ return torch.device(f"cuda:{get_rank()}")
+
+
+def load_module_weights(
+ path, module_name=None, ignore_modules=None, mapping=None, map_location=None
+) -> Tuple[dict, int, int]:
+ if module_name is not None and ignore_modules is not None:
+ raise ValueError("module_name and ignore_modules cannot be both set")
+ if map_location is None:
+ map_location = get_device()
+
+ ckpt = torch.load(path, map_location=map_location)
+ state_dict = ckpt["state_dict"]
+
+ if mapping is not None:
+ state_dict_to_load = {}
+ for k, v in state_dict.items():
+ if any([k.startswith(m["to"]) for m in mapping]):
+ pass
+ else:
+ state_dict_to_load[k] = v
+ for k, v in state_dict.items():
+ for m in mapping:
+ if k.startswith(m["from"]):
+ k_dest = k.replace(m["from"], m["to"])
+ lrm.info(f"Mapping {k} => {k_dest}")
+ state_dict_to_load[k_dest] = v.clone()
+ state_dict = state_dict_to_load
+
+ state_dict_to_load = state_dict
+
+ if ignore_modules is not None:
+ state_dict_to_load = {}
+ for k, v in state_dict.items():
+ ignore = any(
+ [k.startswith(ignore_module + ".") for ignore_module in ignore_modules]
+ )
+ if ignore:
+ continue
+ state_dict_to_load[k] = v
+
+ if module_name is not None:
+ state_dict_to_load = {}
+ for k, v in state_dict.items():
+ m = re.match(rf"^{module_name}\.(.*)$", k)
+ if m is None:
+ continue
+ state_dict_to_load[m.group(1)] = v
+
+ return state_dict_to_load, ckpt["epoch"], ckpt["global_step"]
+
+
+def C(value: Any, epoch: int, global_step: int) -> float:
+ if isinstance(value, int) or isinstance(value, float):
+ pass
+ else:
+ value = config_to_primitive(value)
+ if not isinstance(value, list):
+ raise TypeError("Scalar specification only supports list, got", type(value))
+ if len(value) == 3:
+ value = [0] + value
+ assert len(value) == 4
+ start_step, start_value, end_value, end_step = value
+ if isinstance(end_step, int):
+ current_step = global_step
+ value = start_value + (end_value - start_value) * max(
+ min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
+ )
+ elif isinstance(end_step, float):
+ current_step = epoch
+ value = start_value + (end_value - start_value) * max(
+ min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
+ )
+ return value
+
+
+def cleanup():
+ gc.collect()
+ torch.cuda.empty_cache()
+ try:
+ import tinycudann as tcnn
+
+ tcnn.free_temporary_memory()
+ except:
+ pass
+
+
+def finish_with_cleanup(func: Callable):
+ def wrapper(*args, **kwargs):
+ out = func(*args, **kwargs)
+ cleanup()
+ return out
+
+ return wrapper
+
+
+def _distributed_available():
+ return torch.distributed.is_available() and torch.distributed.is_initialized()
+
+
+def barrier():
+ if not _distributed_available():
+ return
+ else:
+ torch.distributed.barrier()
+
+
+def broadcast(tensor, src=0):
+ if not _distributed_available():
+ return tensor
+ else:
+ torch.distributed.broadcast(tensor, src=src)
+ return tensor
+
+
+def enable_gradient(model, enabled: bool = True) -> None:
+ for param in model.parameters():
+ param.requires_grad_(enabled)
+
+
+class TimeRecorder:
+ _instance = None
+
+ def __init__(self):
+ self.items = {}
+ self.accumulations = defaultdict(list)
+ self.time_scale = 1000.0 # ms
+ self.time_unit = "ms"
+ self.enabled = False
+
+ def __new__(cls):
+ # singleton
+ if cls._instance is None:
+ cls._instance = super(TimeRecorder, cls).__new__(cls)
+ return cls._instance
+
+ def enable(self, enabled: bool) -> None:
+ self.enabled = enabled
+
+ def start(self, name: str) -> None:
+ if not self.enabled:
+ return
+ torch.cuda.synchronize()
+ self.items[name] = time.time()
+
+ def end(self, name: str, accumulate: bool = False) -> float:
+ if not self.enabled or name not in self.items:
+ return
+ torch.cuda.synchronize()
+ start_time = self.items.pop(name)
+ delta = time.time() - start_time
+ if accumulate:
+ self.accumulations[name].append(delta)
+ t = delta * self.time_scale
+ lrm.info(f"{name}: {t:.2f}{self.time_unit}")
+
+ def get_accumulation(self, name: str, average: bool = False) -> float:
+ if not self.enabled or name not in self.accumulations:
+ return
+ acc = self.accumulations.pop(name)
+ total = sum(acc)
+ if average:
+ t = total / len(acc) * self.time_scale
+ else:
+ t = total * self.time_scale
+ lrm.info(f"{name} for {len(acc)} times: {t:.2f}{self.time_unit}")
+
+
+### global time recorder
+time_recorder = TimeRecorder()
+
+
+@contextmanager
+def time_recorder_enabled():
+ enabled = time_recorder.enabled
+ time_recorder.enable(enabled=True)
+ try:
+ yield
+ finally:
+ time_recorder.enable(enabled=enabled)
diff --git a/3D_Stage/lrm/utils/ops.py b/3D_Stage/lrm/utils/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b1e06d3d7136575ed96dad1764b72374d5b5529
--- /dev/null
+++ b/3D_Stage/lrm/utils/ops.py
@@ -0,0 +1,435 @@
+from collections import defaultdict
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+import lrm
+from .typing import *
+
+
+def dot(x, y):
+ return torch.sum(x * y, -1, keepdim=True)
+
+
+def reflect(x, n):
+ return 2 * dot(x, n) * n - x
+
+
+ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
+
+
+def scale_tensor(
+ dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
+):
+ if inp_scale is None:
+ inp_scale = (0, 1)
+ if tgt_scale is None:
+ tgt_scale = (0, 1)
+ if isinstance(tgt_scale, Tensor):
+ assert dat.shape[-1] == tgt_scale.shape[-1]
+ dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
+ dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
+ return dat
+
+
+class _TruncExp(Function): # pylint: disable=abstract-method
+ # Implementation from torch-ngp:
+ # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, x): # pylint: disable=arguments-differ
+ ctx.save_for_backward(x)
+ return torch.exp(x)
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, g): # pylint: disable=arguments-differ
+ x = ctx.saved_tensors[0]
+ return g * torch.exp(torch.clamp(x, max=15))
+
+
+trunc_exp = _TruncExp.apply
+
+
+def get_activation(name) -> Callable:
+ if name is None:
+ return lambda x: x
+ name = name.lower()
+ if name == "none":
+ return lambda x: x
+ elif name == "lin2srgb":
+ return lambda x: torch.where(
+ x > 0.0031308,
+ torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
+ 12.92 * x,
+ ).clamp(0.0, 1.0)
+ elif name == "exp":
+ return lambda x: torch.exp(x)
+ elif name == "shifted_exp":
+ return lambda x: torch.exp(x - 1.0)
+ elif name == "trunc_exp":
+ return trunc_exp
+ elif name == "shifted_trunc_exp":
+ return lambda x: trunc_exp(x - 1.0)
+ elif name == "sigmoid":
+ return lambda x: torch.sigmoid(x)
+ elif name == "tanh":
+ return lambda x: torch.tanh(x)
+ elif name == "shifted_softplus":
+ return lambda x: F.softplus(x - 1.0)
+ elif name == "scale_-11_01":
+ return lambda x: x * 0.5 + 0.5
+ elif name == "negative":
+ return lambda x: -x
+ else:
+ try:
+ return getattr(F, name)
+ except AttributeError:
+ raise ValueError(f"Unknown activation function: {name}")
+
+
+def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
+ if chunk_size <= 0:
+ return func(*args, **kwargs)
+ B = None
+ for arg in list(args) + list(kwargs.values()):
+ if isinstance(arg, torch.Tensor):
+ B = arg.shape[0]
+ break
+ assert (
+ B is not None
+ ), "No tensor found in args or kwargs, cannot determine batch size."
+ out = defaultdict(list)
+ out_type = None
+ # max(1, B) to support B == 0
+ for i in range(0, max(1, B), chunk_size):
+ out_chunk = func(
+ *[
+ arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
+ for arg in args
+ ],
+ **{
+ k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
+ for k, arg in kwargs.items()
+ },
+ )
+ if out_chunk is None:
+ continue
+ out_type = type(out_chunk)
+ if isinstance(out_chunk, torch.Tensor):
+ out_chunk = {0: out_chunk}
+ elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
+ chunk_length = len(out_chunk)
+ out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
+ elif isinstance(out_chunk, dict):
+ pass
+ else:
+ print(
+ f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
+ )
+ exit(1)
+ for k, v in out_chunk.items():
+ v = v if torch.is_grad_enabled() else v.detach()
+ out[k].append(v)
+
+ if out_type is None:
+ return None
+
+ out_merged: Dict[Any, Optional[torch.Tensor]] = {}
+ for k, v in out.items():
+ if all([vv is None for vv in v]):
+ # allow None in return value
+ out_merged[k] = None
+ elif all([isinstance(vv, torch.Tensor) for vv in v]):
+ out_merged[k] = torch.cat(v, dim=0)
+ else:
+ raise TypeError(
+ f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
+ )
+
+ if out_type is torch.Tensor:
+ return out_merged[0]
+ elif out_type in [tuple, list]:
+ return out_type([out_merged[i] for i in range(chunk_length)])
+ elif out_type is dict:
+ return out_merged
+
+
+def get_ray_directions(
+ H: int,
+ W: int,
+ focal: Union[float, Tuple[float, float]],
+ principal: Optional[Tuple[float, float]] = None,
+ use_pixel_centers: bool = True,
+ normalize: bool = True,
+) -> Float[Tensor, "H W 3"]:
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+
+ Inputs:
+ H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ pixel_center = 0.5 if use_pixel_centers else 0
+
+ if isinstance(focal, float):
+ fx, fy = focal, focal
+ cx, cy = W / 2, H / 2
+ else:
+ fx, fy = focal
+ assert principal is not None
+ cx, cy = principal
+
+ i, j = torch.meshgrid(
+ torch.arange(W, dtype=torch.float32) + pixel_center,
+ torch.arange(H, dtype=torch.float32) + pixel_center,
+ indexing="xy",
+ )
+
+ directions: Float[Tensor, "H W 3"] = torch.stack(
+ [(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1
+ )
+
+ if normalize:
+ directions = F.normalize(directions, dim=-1)
+
+ return directions
+
+
+def get_rays(
+ directions: Float[Tensor, "... 3"],
+ c2w: Float[Tensor, "... 4 4"],
+ keepdim=False,
+ noise_scale=0.0,
+ normalize=False,
+) -> Tuple[Float[Tensor, "... 3"], Float[Tensor, "... 3"]]:
+ # Rotate ray directions from camera coordinate to the world coordinate
+ assert directions.shape[-1] == 3
+
+ if directions.ndim == 2: # (N_rays, 3)
+ if c2w.ndim == 2: # (4, 4)
+ c2w = c2w[None, :, :]
+ assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4)
+ rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3)
+ rays_o = c2w[:, :3, 3].expand(rays_d.shape)
+ elif directions.ndim == 3: # (H, W, 3)
+ assert c2w.ndim in [2, 3]
+ if c2w.ndim == 2: # (4, 4)
+ rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
+ -1
+ ) # (H, W, 3)
+ rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
+ elif c2w.ndim == 3: # (B, 4, 4)
+ rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
+ -1
+ ) # (B, H, W, 3)
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
+ elif directions.ndim == 4: # (B, H, W, 3)
+ assert c2w.ndim == 3 # (B, 4, 4)
+ rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
+ -1
+ ) # (B, H, W, 3)
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
+
+ # add camera noise to avoid grid-like artifect
+ # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373
+ if noise_scale > 0:
+ rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale
+ rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale
+
+ if normalize:
+ rays_d = F.normalize(rays_d, dim=-1)
+ if not keepdim:
+ rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
+
+ return rays_o, rays_d
+
+
+def get_projection_matrix(
+ fovy: Union[float, Float[Tensor, "B"]], aspect_wh: float, near: float, far: float
+) -> Float[Tensor, "*B 4 4"]:
+ if isinstance(fovy, float):
+ proj_mtx = torch.zeros(4, 4, dtype=torch.float32)
+ proj_mtx[0, 0] = 1.0 / (math.tan(fovy / 2.0) * aspect_wh)
+ proj_mtx[1, 1] = -1.0 / math.tan(
+ fovy / 2.0
+ ) # add a negative sign here as the y axis is flipped in nvdiffrast output
+ proj_mtx[2, 2] = -(far + near) / (far - near)
+ proj_mtx[2, 3] = -2.0 * far * near / (far - near)
+ proj_mtx[3, 2] = -1.0
+ else:
+ batch_size = fovy.shape[0]
+ proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32)
+ proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh)
+ proj_mtx[:, 1, 1] = -1.0 / torch.tan(
+ fovy / 2.0
+ ) # add a negative sign here as the y axis is flipped in nvdiffrast output
+ proj_mtx[:, 2, 2] = -(far + near) / (far - near)
+ proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near)
+ proj_mtx[:, 3, 2] = -1.0
+ return proj_mtx
+
+
+def get_mvp_matrix(
+ c2w: Float[Tensor, "*B 4 4"], proj_mtx: Float[Tensor, "*B 4 4"]
+) -> Float[Tensor, "*B 4 4"]:
+ # calculate w2c from c2w: R' = Rt, t' = -Rt * t
+ # mathematically equivalent to (c2w)^-1
+ if c2w.ndim == 2:
+ assert proj_mtx.ndim == 2
+ w2c: Float[Tensor, "4 4"] = torch.zeros(4, 4).to(c2w)
+ w2c[:3, :3] = c2w[:3, :3].permute(1, 0)
+ w2c[:3, 3:] = -c2w[:3, :3].permute(1, 0) @ c2w[:3, 3:]
+ w2c[3, 3] = 1.0
+ else:
+ w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w)
+ w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1)
+ w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:]
+ w2c[:, 3, 3] = 1.0
+ # calculate mvp matrix by proj_mtx @ w2c (mv_mtx)
+ mvp_mtx = proj_mtx @ w2c
+ return mvp_mtx
+
+
+def get_intrinsic_from_fov(fov, H, W, bs=-1):
+ focal_length = 0.5 * H / np.tan(0.5 * fov)
+ intrinsic = np.identity(3, dtype=np.float32)
+ intrinsic[0, 0] = focal_length
+ intrinsic[1, 1] = focal_length
+ intrinsic[0, 2] = W / 2.0
+ intrinsic[1, 2] = H / 2.0
+
+ if bs > 0:
+ intrinsic = intrinsic[None].repeat(bs, axis=0)
+
+ return torch.from_numpy(intrinsic)
+
+
+def binary_cross_entropy(input, target):
+ """
+ F.binary_cross_entropy is not numerically stable in mixed-precision training.
+ """
+ return -(target * torch.log(input) + (1 - target) * torch.log(1 - input)).mean()
+
+
+def tet_sdf_diff(
+ vert_sdf: Float[Tensor, "Nv 1"], tet_edges: Integer[Tensor, "Ne 2"]
+) -> Float[Tensor, ""]:
+ sdf_f1x6x2 = vert_sdf[:, 0][tet_edges.reshape(-1)].reshape(-1, 2)
+ mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
+ sdf_f1x6x2 = sdf_f1x6x2[mask]
+ sdf_diff = F.binary_cross_entropy_with_logits(
+ sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()
+ ) + F.binary_cross_entropy_with_logits(
+ sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()
+ )
+ return sdf_diff
+
+
+def validate_empty_rays(ray_indices, t_start, t_end):
+ if ray_indices.nelement() == 0:
+ lrm.warn("Empty rays_indices!")
+ ray_indices = torch.LongTensor([0]).to(ray_indices)
+ t_start = torch.Tensor([0]).to(ray_indices)
+ t_end = torch.Tensor([0]).to(ray_indices)
+ return ray_indices, t_start, t_end
+
+
+def rays_intersect_bbox(
+ rays_o: Float[Tensor, "N 3"],
+ rays_d: Float[Tensor, "N 3"],
+ radius: Float,
+ near: Float = 0.0,
+ valid_thresh: Float = 0.01,
+ background: bool = False,
+):
+ input_shape = rays_o.shape[:-1]
+ rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
+ rays_d_valid = torch.where(
+ rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d
+ )
+ if type(radius) in [int, float]:
+ radius = torch.FloatTensor(
+ [[-radius, radius], [-radius, radius], [-radius, radius]]
+ ).to(rays_o.device)
+ radius = (
+ 1.0 - 1.0e-3
+ ) * radius # tighten the radius to make sure the intersection point lies in the bounding box
+ interx0 = (radius[..., 1] - rays_o) / rays_d_valid
+ interx1 = (radius[..., 0] - rays_o) / rays_d_valid
+ t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near)
+ t_far = torch.maximum(interx0, interx1).amin(dim=-1)
+
+ # check wheter a ray intersects the bbox or not
+ rays_valid = t_far - t_near > valid_thresh
+
+ t_near_valid, t_far_valid = t_near[rays_valid], t_far[rays_valid]
+ global_near = t_near_valid.min().item()
+ global_far = t_far_valid.max().item()
+
+ t_near[torch.where(~rays_valid)] = 0.0
+ t_far[torch.where(~rays_valid)] = 0.0
+
+ t_near = t_near.view(*input_shape, 1)
+ t_far = t_far.view(*input_shape, 1)
+ rays_valid = rays_valid.view(*input_shape)
+
+ return t_near, t_far, rays_valid
+
+
+def get_plucker_rays(rays_o: Float[Tensor, "*N 3"], rays_d: Float[Tensor, "*N 3"]) -> Float[Tensor, "*N 6"]:
+ rays_o = F.normalize(rays_o, dim=-1)
+ rays_d = F.normalize(rays_d, dim=-1)
+ return torch.cat([
+ rays_o.cross(rays_d),
+ rays_d
+ ], dim=-1)
+
+
+def c2w_to_polar(c2w: Float[Tensor, "4 4"]) -> Tuple[float, float, float]:
+ cam_pos = c2w[:3, 3]
+ x, y, z = cam_pos.tolist()
+ distance = cam_pos.norm().item()
+ elevation = math.asin(z / distance)
+ if abs(x) < 1.0e-5 and abs(y) < 1.0e-5:
+ azimuth = 0
+ else:
+ azimuth = math.atan2(y, x)
+ if azimuth < 0:
+ azimuth += 2 * math.pi
+
+ return elevation, azimuth, distance
+
+
+def polar_to_c2w(elevation: float, azimuth: float, distance: float) -> Float[Tensor, "4 4"]:
+ """
+ Compute L = p - C.
+ Normalize L.
+ Compute s = L x u. (cross product)
+ Normalize s.
+ Compute u' = s x L.
+ rotation = [s, u, -l]
+ """
+ z = distance * math.sin(elevation)
+ x = distance * math.cos(elevation) * math.cos(azimuth)
+ y = distance * math.cos(elevation) * math.sin(azimuth)
+ l = -torch.as_tensor([x, y, z]).float()
+ l = F.normalize(l, dim=0)
+ u = torch.as_tensor([0.0, 0.0, 1.0]).float()
+ s = l.cross(u)
+ s = F.normalize(s, dim=0)
+ u = s.cross(l)
+ rot = torch.stack([s, u, -l], dim=0).T
+ c2w = torch.zeros((4, 4), dtype=torch.float32)
+ c2w[:3, :3] = rot
+ c2w[:3, 3] = torch.as_tensor([x, y, z])
+ c2w[3, 3] = 1
+ return c2w
diff --git a/3D_Stage/lrm/utils/rasterize.py b/3D_Stage/lrm/utils/rasterize.py
new file mode 100644
index 0000000000000000000000000000000000000000..448f37ea4432b9d351dee39f77e82665a4a78cbb
--- /dev/null
+++ b/3D_Stage/lrm/utils/rasterize.py
@@ -0,0 +1,81 @@
+import nvdiffrast.torch as dr
+import torch
+
+from .typing import *
+
+
+class NVDiffRasterizerContext:
+ def __init__(self, context_type: str, device: torch.device) -> None:
+ self.device = device
+ self.ctx = self.initialize_context(context_type, device)
+
+ def initialize_context(
+ self, context_type: str, device: torch.device
+ ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
+ context_type = "cuda"
+ if context_type == "gl":
+ return dr.RasterizeGLContext(device=device)
+ elif context_type == "cuda":
+ return dr.RasterizeCudaContext(device=device)
+ else:
+ raise ValueError(f"Unknown rasterizer context type: {context_type}")
+
+ def vertex_transform(
+ self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"]
+ ) -> Float[Tensor, "B Nv 4"]:
+ with torch.cuda.amp.autocast(enabled=False):
+ verts_homo = torch.cat(
+ [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1
+ )
+ verts_clip = torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1))
+ return verts_clip
+
+ def rasterize(
+ self,
+ pos: Float[Tensor, "B Nv 4"],
+ tri: Integer[Tensor, "Nf 3"],
+ resolution: Union[int, Tuple[int, int]],
+ ):
+ # rasterize in instance mode (single topology)
+ return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True)
+
+ def rasterize_one(
+ self,
+ pos: Float[Tensor, "Nv 4"],
+ tri: Integer[Tensor, "Nf 3"],
+ resolution: Union[int, Tuple[int, int]],
+ ):
+ # rasterize one single mesh under a single viewpoint
+ rast, rast_db = self.rasterize(pos[None, ...], tri, resolution)
+ return rast[0], rast_db[0]
+
+ def antialias(
+ self,
+ color: Float[Tensor, "B H W C"],
+ rast: Float[Tensor, "B H W 4"],
+ pos: Float[Tensor, "B Nv 4"],
+ tri: Integer[Tensor, "Nf 3"],
+ ) -> Float[Tensor, "B H W C"]:
+ return dr.antialias(color.float(), rast, pos.float(), tri.int())
+
+ def interpolate(
+ self,
+ attr: Float[Tensor, "B Nv C"],
+ rast: Float[Tensor, "B H W 4"],
+ tri: Integer[Tensor, "Nf 3"],
+ rast_db=None,
+ diff_attrs=None,
+ ) -> Float[Tensor, "B H W C"]:
+ return dr.interpolate(
+ attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs
+ )
+
+ def interpolate_one(
+ self,
+ attr: Float[Tensor, "Nv C"],
+ rast: Float[Tensor, "B H W 4"],
+ tri: Integer[Tensor, "Nf 3"],
+ rast_db=None,
+ diff_attrs=None,
+ ) -> Float[Tensor, "B H W C"]:
+ return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs)
diff --git a/3D_Stage/lrm/utils/saving.py b/3D_Stage/lrm/utils/saving.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f8fa010ad1dc816f9d46d7c131df5e2ca7e856c
--- /dev/null
+++ b/3D_Stage/lrm/utils/saving.py
@@ -0,0 +1,725 @@
+import json
+import os
+import re
+import shutil
+
+import cv2
+import imageio
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import wandb
+from matplotlib import cm
+from matplotlib.colors import LinearSegmentedColormap
+from PIL import Image, ImageDraw
+from pytorch_lightning.loggers import WandbLogger
+
+import lrm
+from ..models.mesh import Mesh
+from ..utils.typing import *
+
+
+class SaverMixin:
+ _save_dir: Optional[str] = None
+ _wandb_logger: Optional[WandbLogger] = None
+
+ def set_save_dir(self, save_dir: str):
+ self._save_dir = save_dir
+
+ def get_save_dir(self):
+ if self._save_dir is None:
+ raise ValueError("Save dir is not set")
+ return self._save_dir
+
+ def convert_data(self, data):
+ if data is None:
+ return None
+ elif isinstance(data, np.ndarray):
+ return data
+ elif isinstance(data, torch.Tensor):
+ if data.dtype in [torch.float16, torch.bfloat16]:
+ data = data.float()
+ return data.detach().cpu().numpy()
+ elif isinstance(data, list):
+ return [self.convert_data(d) for d in data]
+ elif isinstance(data, dict):
+ return {k: self.convert_data(v) for k, v in data.items()}
+ else:
+ raise TypeError(
+ "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
+ type(data),
+ )
+
+ def get_save_path(self, filename):
+ save_path = os.path.join(self.get_save_dir(), filename)
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ return save_path
+
+ DEFAULT_RGB_KWARGS = {"data_format": "HWC", "data_range": (0, 1)}
+ DEFAULT_UV_KWARGS = {
+ "data_format": "HWC",
+ "data_range": (0, 1),
+ "cmap": "checkerboard",
+ }
+ DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"}
+ DEFAULT_GRID_KWARGS = {"align": "max"}
+
+ def get_rgb_image_(self, img, data_format, data_range, rgba=False):
+ img = self.convert_data(img)
+ assert data_format in ["CHW", "HWC"]
+ if data_format == "CHW":
+ img = img.transpose(1, 2, 0)
+ if img.dtype != np.uint8:
+ img = img.clip(min=data_range[0], max=data_range[1])
+ img = (
+ (img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0
+ ).astype(np.uint8)
+ nc = 4 if rgba else 3
+ imgs = [img[..., start : start + nc] for start in range(0, img.shape[-1], nc)]
+ imgs = [
+ img_
+ if img_.shape[-1] == nc
+ else np.concatenate(
+ [
+ img_,
+ np.zeros(
+ (img_.shape[0], img_.shape[1], nc - img_.shape[2]),
+ dtype=img_.dtype,
+ ),
+ ],
+ axis=-1,
+ )
+ for img_ in imgs
+ ]
+ img = np.concatenate(imgs, axis=1)
+ if rgba:
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ return img
+
+ def _save_rgb_image(
+ self,
+ filename,
+ img,
+ data_format,
+ data_range,
+ name: Optional[str] = None,
+ step: Optional[int] = None,
+ ):
+ img = self.get_rgb_image_(img, data_format, data_range)
+ cv2.imwrite(filename, img)
+ if name and self._wandb_logger:
+ self._wandb_logger.log_image(
+ key=name, images=[self.get_save_path(filename)], step=step
+ )
+
+ def save_rgb_image(
+ self,
+ filename,
+ img,
+ data_format=DEFAULT_RGB_KWARGS["data_format"],
+ data_range=DEFAULT_RGB_KWARGS["data_range"],
+ name: Optional[str] = None,
+ step: Optional[int] = None,
+ ) -> str:
+ save_path = self.get_save_path(filename)
+ self._save_rgb_image(save_path, img, data_format, data_range, name, step)
+ return save_path
+
+ def get_uv_image_(self, img, data_format, data_range, cmap):
+ img = self.convert_data(img)
+ assert data_format in ["CHW", "HWC"]
+ if data_format == "CHW":
+ img = img.transpose(1, 2, 0)
+ img = img.clip(min=data_range[0], max=data_range[1])
+ img = (img - data_range[0]) / (data_range[1] - data_range[0])
+ assert cmap in ["checkerboard", "color"]
+ if cmap == "checkerboard":
+ n_grid = 64
+ mask = (img * n_grid).astype(int)
+ mask = (mask[..., 0] + mask[..., 1]) % 2 == 0
+ img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255
+ img[mask] = np.array([255, 0, 255], dtype=np.uint8)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ elif cmap == "color":
+ img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8)
+ img_[..., 0] = (img[..., 0] * 255).astype(np.uint8)
+ img_[..., 1] = (img[..., 1] * 255).astype(np.uint8)
+ img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR)
+ img = img_
+ return img
+
+ def save_uv_image(
+ self,
+ filename,
+ img,
+ data_format=DEFAULT_UV_KWARGS["data_format"],
+ data_range=DEFAULT_UV_KWARGS["data_range"],
+ cmap=DEFAULT_UV_KWARGS["cmap"],
+ ) -> str:
+ save_path = self.get_save_path(filename)
+ img = self.get_uv_image_(img, data_format, data_range, cmap)
+ cv2.imwrite(save_path, img)
+ return save_path
+
+ def get_grayscale_image_(self, img, data_range, cmap):
+ img = self.convert_data(img)
+ img = np.nan_to_num(img)
+ if data_range is None:
+ img = (img - img.min()) / (img.max() - img.min())
+ else:
+ img = img.clip(data_range[0], data_range[1])
+ img = (img - data_range[0]) / (data_range[1] - data_range[0])
+ assert cmap in [None, "jet", "magma", "spectral"]
+ if cmap == None:
+ img = (img * 255.0).astype(np.uint8)
+ img = np.repeat(img[..., None], 3, axis=2)
+ elif cmap == "jet":
+ img = (img * 255.0).astype(np.uint8)
+ img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
+ elif cmap == "magma":
+ img = 1.0 - img
+ base = cm.get_cmap("magma")
+ num_bins = 256
+ colormap = LinearSegmentedColormap.from_list(
+ f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins
+ )(np.linspace(0, 1, num_bins))[:, :3]
+ a = np.floor(img * 255.0)
+ b = (a + 1).clip(max=255.0)
+ f = img * 255.0 - a
+ a = a.astype(np.uint16).clip(0, 255)
+ b = b.astype(np.uint16).clip(0, 255)
+ img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None]
+ img = (img * 255.0).astype(np.uint8)
+ elif cmap == "spectral":
+ colormap = plt.get_cmap("Spectral")
+
+ def blend_rgba(image):
+ image = image[..., :3] * image[..., -1:] + (
+ 1.0 - image[..., -1:]
+ ) # blend A to RGB
+ return image
+
+ img = colormap(img)
+ img = blend_rgba(img)
+ img = (img * 255).astype(np.uint8)
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ return img
+
+ def _save_grayscale_image(
+ self,
+ filename,
+ img,
+ data_range,
+ cmap,
+ name: Optional[str] = None,
+ step: Optional[int] = None,
+ ):
+ img = self.get_grayscale_image_(img, data_range, cmap)
+ cv2.imwrite(filename, img)
+ if name and self._wandb_logger:
+ self._wandb_logger.log_image(
+ key=name, images=[self.get_save_path(filename)], step=step
+ )
+
+ def save_grayscale_image(
+ self,
+ filename,
+ img,
+ data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"],
+ cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"],
+ name: Optional[str] = None,
+ step: Optional[int] = None,
+ ) -> str:
+ save_path = self.get_save_path(filename)
+ self._save_grayscale_image(save_path, img, data_range, cmap, name, step)
+ return save_path
+
+ def get_image_grid_(self, imgs, align):
+ if isinstance(imgs[0], list):
+ return np.concatenate(
+ [self.get_image_grid_(row, align) for row in imgs], axis=0
+ )
+ cols = []
+ for col in imgs:
+ assert col["type"] in ["rgb", "uv", "grayscale"]
+ if col["type"] == "rgb":
+ rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy()
+ rgb_kwargs.update(col["kwargs"])
+ cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs))
+ elif col["type"] == "uv":
+ uv_kwargs = self.DEFAULT_UV_KWARGS.copy()
+ uv_kwargs.update(col["kwargs"])
+ cols.append(self.get_uv_image_(col["img"], **uv_kwargs))
+ elif col["type"] == "grayscale":
+ grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy()
+ grayscale_kwargs.update(col["kwargs"])
+ cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs))
+
+ if align == "max":
+ h = max([col.shape[0] for col in cols])
+ elif align == "min":
+ h = min([col.shape[0] for col in cols])
+ elif isinstance(align, int):
+ h = align
+ else:
+ raise ValueError(
+ f"Unsupported image grid align: {align}, should be min, max, or int"
+ )
+
+ for i in range(len(cols)):
+ if cols[i].shape[0] != h:
+ w = int(cols[i].shape[1] * h / cols[i].shape[0])
+ cols[i] = cv2.resize(cols[i], (w, h), interpolation=cv2.INTER_CUBIC)
+ return np.concatenate(cols, axis=1)
+
+ def save_image_grid(
+ self,
+ filename,
+ imgs,
+ align=DEFAULT_GRID_KWARGS["align"],
+ name: Optional[str] = None,
+ step: Optional[int] = None,
+ texts: Optional[List[float]] = None,
+ ):
+ save_path = self.get_save_path(filename)
+ img = self.get_image_grid_(imgs, align=align)
+
+ if texts is not None:
+ img = Image.fromarray(img)
+ draw = ImageDraw.Draw(img)
+ black, white = (0, 0, 0), (255, 255, 255)
+ for i, text in enumerate(texts):
+ draw.text((2, (img.size[1] // len(texts)) * i + 1), f"{text}", white)
+ draw.text((0, (img.size[1] // len(texts)) * i + 1), f"{text}", white)
+ draw.text((2, (img.size[1] // len(texts)) * i - 1), f"{text}", white)
+ draw.text((0, (img.size[1] // len(texts)) * i - 1), f"{text}", white)
+ draw.text((1, (img.size[1] // len(texts)) * i), f"{text}", black)
+ img = np.asarray(img)
+
+ cv2.imwrite(save_path, img)
+ if name and self._wandb_logger:
+ self._wandb_logger.log_image(key=name, images=[save_path], step=step)
+ return save_path
+
+ def save_image(self, filename, img) -> str:
+ save_path = self.get_save_path(filename)
+ img = self.convert_data(img)
+ assert img.dtype == np.uint8 or img.dtype == np.uint16
+ if img.ndim == 3 and img.shape[-1] == 3:
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ elif img.ndim == 3 and img.shape[-1] == 4:
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
+ cv2.imwrite(save_path, img)
+ return save_path
+
+ def save_cubemap(self, filename, img, data_range=(0, 1), rgba=False) -> str:
+ save_path = self.get_save_path(filename)
+ img = self.convert_data(img)
+ assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2]
+
+ imgs_full = []
+ for start in range(0, img.shape[-1], 3):
+ img_ = img[..., start : start + 3]
+ img_ = np.stack(
+ [
+ self.get_rgb_image_(img_[i], "HWC", data_range, rgba=rgba)
+ for i in range(img_.shape[0])
+ ],
+ axis=0,
+ )
+ size = img_.shape[1]
+ placeholder = np.zeros((size, size, 3), dtype=np.float32)
+ img_full = np.concatenate(
+ [
+ np.concatenate(
+ [placeholder, img_[2], placeholder, placeholder], axis=1
+ ),
+ np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1),
+ np.concatenate(
+ [placeholder, img_[3], placeholder, placeholder], axis=1
+ ),
+ ],
+ axis=0,
+ )
+ imgs_full.append(img_full)
+
+ imgs_full = np.concatenate(imgs_full, axis=1)
+ cv2.imwrite(save_path, imgs_full)
+ return save_path
+
+ def save_data(self, filename, data) -> str:
+ data = self.convert_data(data)
+ if isinstance(data, dict):
+ if not filename.endswith(".npz"):
+ filename += ".npz"
+ save_path = self.get_save_path(filename)
+ np.savez(save_path, **data)
+ else:
+ if not filename.endswith(".npy"):
+ filename += ".npy"
+ save_path = self.get_save_path(filename)
+ np.save(save_path, data)
+ return save_path
+
+ def save_state_dict(self, filename, data) -> str:
+ save_path = self.get_save_path(filename)
+ torch.save(data, save_path)
+ return save_path
+
+ def save_img_sequence(
+ self,
+ filename,
+ img_dir,
+ matcher,
+ save_format="mp4",
+ fps=30,
+ name: Optional[str] = None,
+ step: Optional[int] = None,
+ ) -> str:
+ assert save_format in ["gif", "mp4"]
+ if not filename.endswith(save_format):
+ filename += f".{save_format}"
+ save_path = self.get_save_path(filename)
+ matcher = re.compile(matcher)
+ img_dir = os.path.join(self.get_save_dir(), img_dir)
+ imgs = []
+ for f in os.listdir(img_dir):
+ if matcher.search(f):
+ imgs.append(f)
+ imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0]))
+ imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs]
+
+ if save_format == "gif":
+ imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs]
+ imageio.mimsave(save_path, imgs, fps=fps, palettesize=256)
+ elif save_format == "mp4":
+ imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs]
+ imageio.mimsave(save_path, imgs, fps=fps)
+ if name and self._wandb_logger:
+ lrm.warn("Wandb logger does not support video logging yet!")
+ return save_path
+
+ def save_img_sequences(
+ self,
+ seq_dir,
+ matcher,
+ save_format="mp4",
+ fps=30,
+ delete=True,
+ name: Optional[str] = None,
+ step: Optional[int] = None,
+ ):
+ seq_dir_ = os.path.join(self.get_save_dir(), seq_dir)
+ for f in os.listdir(seq_dir_):
+ img_dir_ = os.path.join(seq_dir_, f)
+ if not os.path.isdir(img_dir_):
+ continue
+ try:
+ self.save_img_sequence(
+ os.path.join(seq_dir, f),
+ os.path.join(seq_dir, f),
+ matcher,
+ save_format=save_format,
+ fps=fps,
+ name=f"{name}_{f}",
+ step=step,
+ )
+ if delete:
+ shutil.rmtree(img_dir_)
+ except:
+ lrm.warn(f"Video saving for directory {seq_dir_} failed!")
+
+ def save_mesh(self, filename, v_pos, t_pos_idx, v_tex=None, t_tex_idx=None) -> str:
+ import trimesh
+
+ save_path = self.get_save_path(filename)
+ v_pos = self.convert_data(v_pos)
+ t_pos_idx = self.convert_data(t_pos_idx)
+ mesh = trimesh.Trimesh(vertices=v_pos, faces=t_pos_idx)
+ mesh.export(save_path)
+ return save_path
+
+ def save_obj(
+ self,
+ filename: str,
+ mesh: Mesh,
+ save_mat: bool = False,
+ save_normal: bool = False,
+ save_uv: bool = False,
+ save_vertex_color: bool = True,
+ map_Kd: Optional[Float[Tensor, "H W 3"]] = None,
+ map_Ks: Optional[Float[Tensor, "H W 3"]] = None,
+ map_Bump: Optional[Float[Tensor, "H W 3"]] = None,
+ map_Pm: Optional[Float[Tensor, "H W 1"]] = None,
+ map_Pr: Optional[Float[Tensor, "H W 1"]] = None,
+ map_format: str = "jpg",
+ ) -> List[str]:
+ save_paths: List[str] = []
+ if not filename.endswith(".obj"):
+ filename += ".obj"
+ v_pos, t_pos_idx = self.convert_data(mesh.v_pos), self.convert_data(
+ mesh.t_pos_idx
+ )
+ v_nrm, v_tex, t_tex_idx, v_rgb = None, None, None, None
+ if save_normal:
+ v_nrm = self.convert_data(mesh.v_nrm)
+ if save_uv:
+ v_tex, t_tex_idx = self.convert_data(mesh.v_tex), self.convert_data(
+ mesh.t_tex_idx
+ )
+ if save_vertex_color:
+ v_rgb = self.convert_data(mesh.v_rgb)
+ matname, mtllib = None, None
+ if save_mat:
+ matname = "default"
+ mtl_filename = filename.replace(".obj", ".mtl")
+ mtllib = os.path.basename(mtl_filename)
+ mtl_save_paths = self._save_mtl(
+ mtl_filename,
+ matname,
+ map_Kd=self.convert_data(map_Kd),
+ map_Ks=self.convert_data(map_Ks),
+ map_Bump=self.convert_data(map_Bump),
+ map_Pm=self.convert_data(map_Pm),
+ map_Pr=self.convert_data(map_Pr),
+ map_format=map_format,
+ )
+ save_paths += mtl_save_paths
+ obj_save_path = self._save_obj(
+ filename,
+ v_pos,
+ t_pos_idx,
+ v_nrm=v_nrm,
+ v_tex=v_tex,
+ t_tex_idx=t_tex_idx,
+ v_rgb=v_rgb,
+ matname=matname,
+ mtllib=mtllib,
+ )
+ save_paths.append(obj_save_path)
+ return save_paths
+
+ def _save_obj(
+ self,
+ filename,
+ v_pos,
+ t_pos_idx,
+ v_nrm=None,
+ v_tex=None,
+ t_tex_idx=None,
+ v_rgb=None,
+ matname=None,
+ mtllib=None,
+ ) -> str:
+ obj_str = ""
+ if matname is not None:
+ obj_str += f"mtllib {mtllib}\n"
+ obj_str += f"g object\n"
+ obj_str += f"usemtl {matname}\n"
+ for i in range(len(v_pos)):
+ obj_str += f"v {v_pos[i][0]} {v_pos[i][1]} {v_pos[i][2]}"
+ if v_rgb is not None:
+ obj_str += f" {v_rgb[i][0]} {v_rgb[i][1]} {v_rgb[i][2]}"
+ obj_str += "\n"
+ if v_nrm is not None:
+ for v in v_nrm:
+ obj_str += f"vn {v[0]} {v[1]} {v[2]}\n"
+ if v_tex is not None:
+ for v in v_tex:
+ obj_str += f"vt {v[0]} {1.0 - v[1]}\n"
+
+ for i in range(len(t_pos_idx)):
+ obj_str += "f"
+ for j in range(3):
+ obj_str += f" {t_pos_idx[i][j] + 1}/"
+ if v_tex is not None:
+ obj_str += f"{t_tex_idx[i][j] + 1}"
+ obj_str += "/"
+ if v_nrm is not None:
+ obj_str += f"{t_pos_idx[i][j] + 1}"
+ obj_str += "\n"
+
+ save_path = self.get_save_path(filename)
+ with open(save_path, "w") as f:
+ f.write(obj_str)
+ return save_path
+
+ def _save_mtl(
+ self,
+ filename,
+ matname,
+ Ka=(0.0, 0.0, 0.0),
+ Kd=(1.0, 1.0, 1.0),
+ Ks=(0.0, 0.0, 0.0),
+ map_Kd=None,
+ map_Ks=None,
+ map_Bump=None,
+ map_Pm=None,
+ map_Pr=None,
+ map_format="jpg",
+ step: Optional[int] = None,
+ ) -> List[str]:
+ mtl_save_path = self.get_save_path(filename)
+ save_paths = [mtl_save_path]
+ mtl_str = f"newmtl {matname}\n"
+ mtl_str += f"Ka {Ka[0]} {Ka[1]} {Ka[2]}\n"
+ if map_Kd is not None:
+ map_Kd_save_path = os.path.join(
+ os.path.dirname(mtl_save_path), f"texture_kd.{map_format}"
+ )
+ mtl_str += f"map_Kd texture_kd.{map_format}\n"
+ self._save_rgb_image(
+ map_Kd_save_path,
+ map_Kd,
+ data_format="HWC",
+ data_range=(0, 1),
+ name=f"{matname}_Kd",
+ step=step,
+ )
+ save_paths.append(map_Kd_save_path)
+ else:
+ mtl_str += f"Kd {Kd[0]} {Kd[1]} {Kd[2]}\n"
+ if map_Ks is not None:
+ map_Ks_save_path = os.path.join(
+ os.path.dirname(mtl_save_path), f"texture_ks.{map_format}"
+ )
+ mtl_str += f"map_Ks texture_ks.{map_format}\n"
+ self._save_rgb_image(
+ map_Ks_save_path,
+ map_Ks,
+ data_format="HWC",
+ data_range=(0, 1),
+ name=f"{matname}_Ks",
+ step=step,
+ )
+ save_paths.append(map_Ks_save_path)
+ else:
+ mtl_str += f"Ks {Ks[0]} {Ks[1]} {Ks[2]}\n"
+ if map_Bump is not None:
+ map_Bump_save_path = os.path.join(
+ os.path.dirname(mtl_save_path), f"texture_nrm.{map_format}"
+ )
+ mtl_str += f"map_Bump texture_nrm.{map_format}\n"
+ self._save_rgb_image(
+ map_Bump_save_path,
+ map_Bump,
+ data_format="HWC",
+ data_range=(0, 1),
+ name=f"{matname}_Bump",
+ step=step,
+ )
+ save_paths.append(map_Bump_save_path)
+ if map_Pm is not None:
+ map_Pm_save_path = os.path.join(
+ os.path.dirname(mtl_save_path), f"texture_metallic.{map_format}"
+ )
+ mtl_str += f"map_Pm texture_metallic.{map_format}\n"
+ self._save_grayscale_image(
+ map_Pm_save_path,
+ map_Pm,
+ data_range=(0, 1),
+ cmap=None,
+ name=f"{matname}_refl",
+ step=step,
+ )
+ save_paths.append(map_Pm_save_path)
+ if map_Pr is not None:
+ map_Pr_save_path = os.path.join(
+ os.path.dirname(mtl_save_path), f"texture_roughness.{map_format}"
+ )
+ mtl_str += f"map_Pr texture_roughness.{map_format}\n"
+ self._save_grayscale_image(
+ map_Pr_save_path,
+ map_Pr,
+ data_range=(0, 1),
+ cmap=None,
+ name=f"{matname}_Ns",
+ step=step,
+ )
+ save_paths.append(map_Pr_save_path)
+ with open(self.get_save_path(filename), "w") as f:
+ f.write(mtl_str)
+ return save_paths
+
+ def save_glb(
+ self,
+ filename: str,
+ mesh: Mesh,
+ save_mat: bool = False,
+ save_normal: bool = True,
+ save_uv: bool = True,
+ save_vertex_color: bool = True,
+ map_Kd: Optional[Float[Tensor, "H W 3"]] = None,
+ map_Ks: Optional[Float[Tensor, "H W 3"]] = None,
+ map_Bump: Optional[Float[Tensor, "H W 3"]] = None,
+ map_Pm: Optional[Float[Tensor, "H W 1"]] = None,
+ map_Pr: Optional[Float[Tensor, "H W 1"]] = None,
+ map_format: str = "jpg",
+ ) -> List[str]:
+ save_paths: List[str] = []
+ if not filename.endswith(".glb"):
+ filename += ".glb"
+ v_pos, t_pos_idx = self.convert_data(mesh.v_pos), self.convert_data(
+ mesh.t_pos_idx
+ )
+ v_nrm, v_tex, t_tex_idx, v_rgb = None, None, None, None
+ if save_normal:
+ v_nrm = self.convert_data(mesh.v_nrm)
+ if save_uv:
+ v_tex, t_tex_idx = self.convert_data(mesh.v_tex), self.convert_data(
+ mesh.t_tex_idx
+ )
+ if save_vertex_color:
+ v_rgb = self.convert_data(mesh.v_rgb)
+
+ obj_save_path = self._save_glb(
+ filename,
+ v_pos,
+ t_pos_idx,
+ v_nrm=v_nrm,
+ v_tex=v_tex,
+ t_tex_idx=t_tex_idx,
+ v_rgb=v_rgb,
+ )
+ save_paths.append(obj_save_path)
+ return save_paths
+
+ def _save_glb(
+ self,
+ filename,
+ v_pos,
+ t_pos_idx,
+ v_nrm=None,
+ v_tex=None,
+ t_tex_idx=None,
+ v_rgb=None,
+ matname=None,
+ mtllib=None,
+ ) -> str:
+ import trimesh
+
+ mesh = trimesh.Trimesh(
+ vertices=v_pos, faces=t_pos_idx, vertex_normals=v_nrm, vertex_colors=v_rgb
+ )
+ # not tested
+ if v_tex is not None:
+ mesh.visual = trimesh.visual.TextureVisuals(uv=v_tex)
+
+ save_path = self.get_save_path(filename)
+ mesh.export(save_path)
+ return save_path
+
+ def save_file(self, filename, src_path, delete=False) -> str:
+ save_path = self.get_save_path(filename)
+ shutil.copyfile(src_path, save_path)
+ if delete:
+ os.remove(src_path)
+ return save_path
+
+ def save_json(self, filename, payload) -> str:
+ save_path = self.get_save_path(filename)
+ with open(save_path, "w") as f:
+ f.write(json.dumps(payload))
+ return save_path
diff --git a/3D_Stage/lrm/utils/sdf.py b/3D_Stage/lrm/utils/sdf.py
new file mode 100644
index 0000000000000000000000000000000000000000..b463a0c9f7482eb73b42205b92fbaadbc8ffde5f
--- /dev/null
+++ b/3D_Stage/lrm/utils/sdf.py
@@ -0,0 +1,15 @@
+import pySDF as SDF
+import cv2
+import numpy as np
+import torch
+from lrm.models.isosurface import MarchingTetrahedraHelper
+
+def get_tetra_for_mesh(mesh_path, resolution=128):
+ isosurface_helper = MarchingTetrahedraHelper(resolution, f"load/{resolution}_tets.npz")
+ isosurface_helper.points_range = (-1, 1)
+ mesh = trimesh.load(mesh_path)
+ dmtet = np.load(f"")
+ sdf = SDF(mesh.vertices, mesh.faces)
+ sdf_gt = sdf(isosurface_helper.grid_vertices.numpy())
+ return sdf_gt
+
diff --git a/3D_Stage/lrm/utils/typing.py b/3D_Stage/lrm/utils/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..dee9f967c21f94db1ad939d7dead156d86748752
--- /dev/null
+++ b/3D_Stage/lrm/utils/typing.py
@@ -0,0 +1,40 @@
+"""
+This module contains type annotations for the project, using
+1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
+2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
+
+Two types of typing checking can be used:
+1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
+2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
+"""
+
+# Basic types
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Literal,
+ NamedTuple,
+ NewType,
+ Optional,
+ Sized,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+# Tensor dtype
+# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
+from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
+
+# Config type
+from omegaconf import DictConfig
+
+# PyTorch Tensor type
+from torch import Tensor
+
+# Runtime type checking decorator
+from typeguard import typechecked as typechecker
diff --git a/3D_Stage/material/examples/1/1.png b/3D_Stage/material/examples/1/1.png
new file mode 100644
index 0000000000000000000000000000000000000000..25719db8520ec47127497a5e1f5b826b5caed95c
--- /dev/null
+++ b/3D_Stage/material/examples/1/1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:24eab85189be0a9f4977b561807e5da1458c70f7f72b156e6af77d0ed9b25334
+size 186833
diff --git a/3D_Stage/material/examples/1/2.png b/3D_Stage/material/examples/1/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..75ddc041ba44a7ea3e6c50a5741575a3b7160a44
--- /dev/null
+++ b/3D_Stage/material/examples/1/2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dc956a3a62dbd23b3eb4a8efcec71e30ea3d612ff1fa8e1cd01c3b59e97d1bfc
+size 203314
diff --git a/3D_Stage/material/examples/1/3.png b/3D_Stage/material/examples/1/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..3a2aeb4899988e9bb53c503e608e7254a10f1e51
--- /dev/null
+++ b/3D_Stage/material/examples/1/3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:636e479fd74580b1cfaeb9c1f40b097afca4c4678b084dcda5cd648e2e415d9c
+size 133351
diff --git a/3D_Stage/material/examples/1/4.png b/3D_Stage/material/examples/1/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..9ab5a320857692af16b09c502795cf32ff2c847f
--- /dev/null
+++ b/3D_Stage/material/examples/1/4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5f9fbaf91b5d15b95d6b14e0f2331b81b6c66388ab9e38a1ecc5b07094e69272
+size 131478
diff --git a/3D_Stage/material/meta.json b/3D_Stage/material/meta.json
new file mode 100644
index 0000000000000000000000000000000000000000..9ca2e428f7cea46525623179a35168bbe4ff15c9
--- /dev/null
+++ b/3D_Stage/material/meta.json
@@ -0,0 +1,36 @@
+{
+ "locations": [
+ {
+ "transform_matrix": [
+ [ 1.0, 0.0, 0.0, 0.0 ],
+ [ 0.0, 0.0, 1.0, 1.5 ],
+ [ 0.0, 1.0, 0.0, 0.0 ],
+ [ 0.0, 0.0, 0.0, 1.0 ]
+ ]
+ },
+ {
+ "transform_matrix": [
+ [ -1.0, 0.0, 0.0, 0.0 ],
+ [ 0.0, 0.0, -1.0, -1.5 ],
+ [ 0.0, 1.0, 0.0, 0.0 ],
+ [ 0.0, 0.0, 0.0, 1.0 ]
+ ]
+ },
+ {
+ "transform_matrix": [
+ [ 0.0, 0.0, 1.0, 1.5 ],
+ [ -1.0, 0.0, 0.0, 0.0 ],
+ [ 0.0, 1.0, 0.0, 0.0 ],
+ [ 0.0, 0.0, 0.0, 1.0 ]
+ ]
+ },
+ {
+ "transform_matrix": [
+ [ 0.0, 0.0, -1.0, -1.5 ],
+ [ 1.0, 0.0, 0.0, 0.0 ],
+ [ 0.0, 1.0, 0.0, 0.0 ],
+ [ 0.0, 0.0, 0.0, 1.0 ]
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/3D_Stage/refine.py b/3D_Stage/refine.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d87c44f2c9ba8052aeb534b28df67ae26733898
--- /dev/null
+++ b/3D_Stage/refine.py
@@ -0,0 +1,185 @@
+import os
+import numpy as np
+import torch
+import nvdiffrast.torch as dr
+import json
+import torch.nn.functional as F
+from PIL import Image
+import pymeshlab
+import cv2
+
+def back_to_texture(glctx, look_at, pos, tri, tex, uv, uv_idx, idx, vn):
+ rast_out, rast_out_db = dr.rasterize(glctx, pos, tri, resolution=[tex.shape[0],tex.shape[1]])
+ gb_normal, _ = dr.interpolate(vn[None], rast_out, tri)
+ gb_normal = F.normalize(gb_normal, dim=-1)
+ if idx == 2 or idx == 0:
+ filter_camera = [torch.tensor([[[[1,0.,0.]]]]).cuda(), torch.tensor([[[[-1,0.,0.]]]]).cuda()]
+ else:
+ filter_camera = [torch.tensor([[[[0,-1.,0.]]]]).cuda(), torch.tensor([[[[0,1.,0.]]]]).cuda()]
+ nmasks = []
+ for fc in filter_camera:
+ nmasks.append(((gb_normal * fc) > 0.75).int().sum(keepdim=True, dim=-1))
+ gb_normal_mask = 1 - (nmasks[0] | nmasks[1])
+ #Image.fromarray(np.clip(gb_normal_mask[0,...,0].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)).save(f"mask_normal_{idx}.png")
+ gb_mask = rast_out[...,3:4] > 0
+ tri_list = torch.unique(rast_out[...,3:4].reshape(-1))
+ tri_list = (tri_list[1:] - 1).to(torch.int32)
+ pos = pos[0]
+
+ depth_map = rast_out[...,3:4].clone()
+ depth_map[depth_map > 0] = 1
+ depth_map = depth_map.to(torch.float32)
+ dmax = (rast_out[...,2:3] * gb_mask).max()
+ uv = torch.cat([uv * 2 - 1, torch.zeros(uv.shape[0], 1).cuda(), torch.ones(uv.shape[0], 1).cuda()], dim=1).unsqueeze(0)
+ uv_idx = uv_idx[tri_list.to(torch.long)]
+ rast_uv, rast_uv_db = dr.rasterize(glctx, uv, uv_idx, resolution=(1024, 1024))
+ pos_clip = torch.cat([pos[...,:2], pos[...,3:]], -1)
+ pos_2d, _ = dr.interpolate(pos_clip, rast_uv, tri[tri_list.to(torch.long)]) # pos (x, y, z, w)
+ pos_coord = (pos_2d[...,:2] / (pos_2d[...,2:3] + 1e-6) + 1) / 2.
+ texture_mask = (rast_uv[...,3:4] > 0).int()
+ color = dr.texture(tex[None, ...] * gb_normal_mask, pos_coord, filter_mode='linear')
+ color_mask = dr.texture(gb_normal_mask.to(torch.float32), pos_coord, filter_mode='linear')
+ color_mask[color_mask > 0.82] = 1
+ color_mask[color_mask <= 0.82] = 0
+ color_mask = color_mask.to(torch.int32)
+ #Image.fromarray(np.clip(color_mask[0].repeat(1,1,3).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)).save(f"depth_{idx}.png")
+ texture_mask = texture_mask * color_mask
+ #Image.fromarray(np.clip(color[0].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)).save(f"{idx}.png")
+ #Image.fromarray(np.clip(texture_mask[0].repeat(1,1,3).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)).convert("RGB").save(f"mask-{idx}.png")
+ return color, texture_mask, rast_uv
+
+def perspective(fovy=0.6913, aspect=1.0, n=0.1, f=1000.0, device=None):
+ y = np.tan(fovy / 2)
+ return torch.tensor([[1/(y*aspect), 0, 0, 0],
+ [ 0, 1/-y, 0, 0],
+ [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
+ [ 0, 0, -1, 0]]).to(torch.float32).cuda()
+
+def rec_mvp(trans, h, w):
+ mv = trans
+ fov = 40. / 180. * np.pi
+ proj = perspective(fov, h / w, n=0.1, f=1000)
+ mvp = proj @ mv
+ return mvp
+
+def aggregate_texture(kd_map, textures, texture_masks, rast_uvs):
+ texture = torch.zeros_like(textures[0])
+ texture_mask = torch.zeros_like(texture_masks[0])
+ ctex = []
+ for idx in range(len(textures)):
+ ctex.append(textures[idx] * texture_masks[idx] + 10 * (1 - texture_masks[idx]))
+ cat_textures = torch.stack(ctex, dim=-2)
+ dis_measure = (cat_textures - kd_map.unsqueeze(-2)).abs().sum(-1)
+ _, choose_idx = dis_measure.min(-1)
+
+ choose_idx = choose_idx.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 1, 3)
+ final_texture_map = torch.gather(cat_textures, 3, choose_idx).squeeze(-2)
+ #cv2.imwrite("final_texture_map.png", cv2.cvtColor((final_texture_map[0].cpu().numpy() * 255).astype(np.uint8), cv2.COLOR_BGR2RGB))
+ #cv2.imwrite("final_texture_mask.png", (texture_mask[0].cpu().numpy() * 255).astype(np.uint8))
+ zero_mask = (final_texture_map.max(dim=-1, keepdim=True)[0] > 0.1)
+ close_mask = ((final_texture_map[0] - kd_map).abs().sum(dim=-1, keepdim=True) < 1.0).int()
+ for idx in range(len(textures)):
+ texture += textures[idx] * texture_masks[idx]
+ texture_mask |= texture_masks[idx]
+ texture_mask = texture_mask * zero_mask * close_mask[None]
+ optimize_mask = (texture_mask == 0).int()
+
+ #import pdb; pdb.set_trace()
+ #mask = (texture_mask[0].cpu().numpy() * 255).astype(np.uint8)
+ #cv2.imwrite("mask.png", mask)
+ #kernel = np.ones((5,5), np.uint8)
+ #dilated = cv2.dilate(mask, kernel, iterations=1)
+ #cv2.imwrite("di_mask.png", dilated)
+ #texture_mask[0] = torch.from_numpy(dilated).unsqueeze(-1).to(torch.float32) / 255.
+
+ final_texture_map = final_texture_map[0] * texture_mask[0]
+ Image.fromarray(np.rint(final_texture_map.cpu().numpy() * 255).astype(np.uint8)).save(f"final_texture.png")
+
+ #cv2.imwrite("kd_map.png", cv2.cvtColor((kd_map.cpu().numpy() * 255).astype(np.uint8), cv2.COLOR_BGR2RGB))
+ #cv2.imwrite("texture_map.png", cv2.cvtColor((final_texture_map.cpu().numpy() * 255).astype(np.uint8), cv2.COLOR_BGR2RGB))
+ #result = cv2.seamlessClone((final_texture_map.cpu().numpy() * 255).astype(np.uint8), (kd_map.cpu().numpy() * 255).astype(np.uint8), mask, (mask.shape[1]//2, mask.shape[0]//2), cv2.NORMAL_CLONE)
+ #cv2.imwrite("result.png", cv2.cvtColor(result * 255, cv2.COLOR_BGR2RGB))
+
+ kd_map = kd_map * (1 - texture_mask[0]) + final_texture_map
+ return kd_map, optimize_mask
+
+def refine(save_path, front_image, back_image, left_image, right_image):
+ ms = pymeshlab.MeshSet()
+ mesh_path = f"{save_path}/model-00.obj"
+ ms.load_new_mesh(mesh_path)
+ ms.apply_coord_laplacian_smoothing(stepsmoothnum=10)
+ tl = open(mesh_path, "r").readlines()
+ tex_uv = []
+ uv_idx = []
+ for line in tl:
+ if line.startswith("vt"):
+ uvs = line.split(" ")[1:3]
+ tex_uv += [float(uvs[0]), 1.0-float(uvs[1])]
+ tex_uv = torch.from_numpy(np.array(tex_uv)).to(torch.float32).cuda().reshape(-1, 2)
+ m = ms.current_mesh()
+ v_matrix = m.vertex_matrix()
+ f_matrix = m.face_matrix()
+ vn = m.vertex_normal_matrix()
+ uv_idx = torch.arange(f_matrix.shape[0] * 3).reshape(-1, 3).to(torch.int32).cuda()
+ vn = torch.tensor(vn).contiguous().cuda().to(torch.float32)
+
+ frames = []
+ front_camera = torch.tensor([[
+ 1,0,0,0,
+ 0,0,1,0,
+ 0,-1,0,-1.5,
+ 0,0,0,1,
+ ]]).to(torch.float32).reshape(4,4).cuda()
+ back_camera = torch.tensor([[
+ 1,0,0,0,
+ 0,0,1,0,
+ 0,1,0,-1.5,
+ 0,0,0,1,
+ ]]).to(torch.float32).reshape(4,4).cuda()
+ right_camera = torch.tensor([[
+ 0,-1,0,0,
+ 0,0,1,0,
+ 1,0,0,-1.5,
+ 0,0,0,1,
+ ]]).to(torch.float32).reshape(4,4).cuda()
+ left_camera = torch.tensor([[
+ 0,1,0,0,
+ 0,0,1,0,
+ -1,0,0,-1.5,
+ 0,0,0,1,
+ ]]).to(torch.float32).reshape(4,4).cuda()
+ frames = [front_camera, left_camera, back_camera, right_camera]
+
+ target_images = []
+ for target_image in [front_image, left_image, back_image, right_image]:
+ target_images.append(torch.from_numpy(np.asarray(target_image.convert("RGB"))).to(torch.float32).cuda() / 255.)
+
+ pos = torch.tensor(v_matrix, dtype=torch.float32).contiguous().cuda()
+ tri = torch.tensor(f_matrix, dtype=torch.int32).contiguous().cuda()
+
+ kd_map = (torch.tensor(np.asarray(Image.open(f"{save_path}/texture_kd.jpg"))) / 255.).cuda()
+ translate_tensor = torch.zeros((1,1,3)).cuda()
+ pos = torch.cat([pos, torch.ones([pos.shape[0], 1]).cuda()],-1).unsqueeze(0)
+ glctx = dr.RasterizeCudaContext()
+ target_texture = []
+ target_mask = []
+ rast_uvs = []
+ with torch.no_grad():
+ for idx, trans in enumerate(frames):
+ target_image = target_images[idx]
+ look_at = -torch.linalg.inv(trans)[:3,2]
+ mvp = rec_mvp(trans, h=target_images[0].shape[0], w=target_images[0].shape[1])
+ trans_pos = pos.clone()
+ trans_pos[...,:3] += translate_tensor
+ view_pos = torch.matmul(mvp, trans_pos.unsqueeze(-1)).squeeze(-1)
+ texture, mask, rast_uv = back_to_texture(glctx, look_at, view_pos, tri, target_image, tex_uv, uv_idx, idx, vn)
+ target_texture.append(texture)
+ target_mask.append(mask)
+ rast_uvs.append(rast_uv)
+ kd_map, opt_mask = aggregate_texture(kd_map, target_texture, target_mask, rast_uvs)
+ opt_mask = opt_mask[0]
+ Image.fromarray((np.clip(kd_map.detach().cpu().numpy() * 255, 0, 255)).astype(np.uint8)).save(f"{save_path}/refined_texture_kd.jpg")
+
+ #ms.save_current_mesh(f"{save_path}/model-00.obj")
+ with open(f"{save_path}/model-00.mtl", "w") as f:
+ f.write(f"newmtl default\nKa 0.0 0.0 0.0\nmap_Kd refined_texture_kd.jpg\nKs 0.0 0.0 0.0")
\ No newline at end of file
diff --git a/3D_Stage/webui.py b/3D_Stage/webui.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebfaee7db54c1f1e769cec2779c789be9c75bc34
--- /dev/null
+++ b/3D_Stage/webui.py
@@ -0,0 +1,162 @@
+import os
+import json
+import tqdm
+import cv2
+import numpy as np
+import torch, lrm
+import torch.nn.functional as F
+from lrm.utils.config import load_config
+from datetime import datetime
+import gradio as gr
+from pygltflib import GLTF2
+from PIL import Image
+from huggingface_hub import hf_hub_download
+
+from refine import refine
+
+device = "cuda"
+
+import trimesh
+import pymeshlab
+import numpy as np
+
+from huggingface_hub import hf_hub_download, list_repo_files
+
+repo_id = "zjpshadow/CharacterGen"
+all_files = list_repo_files(repo_id, revision="main")
+
+for file in all_files:
+ if os.path.exists("../" + file):
+ continue
+ if file.startswith("3D_Stage"):
+ hf_hub_download(repo_id, file, local_dir="../")
+
+def traverse(path, back_proj):
+ mesh = trimesh.load(f"{path}/model-00.obj")
+ mesh.apply_transform(trimesh.transformations.rotation_matrix(np.radians(90.0), [-1, 0, 0]))
+ mesh.apply_transform(trimesh.transformations.rotation_matrix(np.radians(180.0), [0, 1, 0]))
+
+ cmesh = pymeshlab.Mesh(mesh.vertices, mesh.faces)
+ ms = pymeshlab.MeshSet()
+ ms.add_mesh(cmesh)
+ ms.apply_coord_laplacian_smoothing(stepsmoothnum=4)
+ mesh.vertices = ms.current_mesh().vertex_matrix()
+
+ mesh.export(f'{path}/output.glb', file_type='glb')
+
+ image = Image.open(f"{path}/{'refined_texture_kd.jpg' if back_proj else 'texture_kd.jpg'}")
+ texture = np.array(image)
+ vertex_colors = np.zeros((mesh.vertices.shape[0], 4), dtype=np.uint8)
+
+ for vertex_index in range(len(mesh.visual.uv)):
+ uv = mesh.visual.uv[vertex_index]
+ x = int(uv[0] * (texture.shape[1] - 1))
+ y = int((1 - uv[1]) * (texture.shape[0] - 1))
+
+ color = texture[y, x, :3]
+ vertex_colors[vertex_index] = [color[0], color[1], color[2], 255]
+ return trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, vertex_colors=vertex_colors)
+
+class Inference_API:
+
+ def __init__(self):
+ # Load config
+ self.cfg = load_config("configs/infer.yaml", makedirs=False)
+ # Load system
+ print("Loading system")
+ self.system = lrm.find(self.cfg.system_cls)(self.cfg.system).to(device)
+ self.system.eval()
+
+ def process_images(self, img_input0, img_input1, img_input2, img_input3, back_proj):
+ meta = json.load(open("material/meta.json"))
+ c2w_cond = [np.array(loc["transform_matrix"]) for loc in meta["locations"]]
+ c2w_cond = torch.from_numpy(np.stack(c2w_cond, axis=0)).float()[None].to(device)
+
+ # Prepare input data
+ rgb_cond = []
+ files = [img_input0, img_input1, img_input2, img_input3]
+ new_image = []
+ for file in files:
+ image = np.array(file)
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
+ new_image.append(Image.fromarray(image.astype(np.uint8)).convert("RGB"))
+ rgb = cv2.resize(image, (self.cfg.data.cond_width,
+ self.cfg.data.cond_height)).astype(np.float32) / 255.0
+ rgb_cond.append(rgb)
+ assert len(rgb_cond) == 4, "Please provide 4 images"
+
+ rgb_cond = torch.from_numpy(np.stack(rgb_cond, axis=0)).float()[None].to(device)
+
+ # Run inference
+ with torch.no_grad():
+ scene_codes = self.system({"rgb_cond": rgb_cond, "c2w_cond": c2w_cond})
+ exporter_output = self.system.exporter([f"{i:02d}" for i in range(rgb_cond.shape[0])], scene_codes)
+
+ # Save output
+ save_dir = os.path.join("./outputs", datetime.now().strftime("@%Y%m%d-%H%M%S"))
+ os.makedirs(save_dir, exist_ok=True)
+ self.system.set_save_dir(save_dir)
+
+ for out in exporter_output:
+ save_func_name = f"save_{out.save_type}"
+ save_func = getattr(self.system, save_func_name)
+ save_func(f"{out.save_name}", **out.params)
+
+ if back_proj:
+ refine(save_dir, new_image[1], new_image[0], new_image[3], new_image[2])
+
+ new_obj = traverse(save_dir, back_proj)
+ new_obj.export(f'{save_dir}/output.obj', file_type='obj')
+
+ gltf = GLTF2().load(f'{save_dir}/output.glb')
+ for material in gltf.materials:
+ if material.pbrMetallicRoughness:
+ material.pbrMetallicRoughness.baseColorFactor = [1.0, 1.0, 1.0, 100.0]
+ material.pbrMetallicRoughness.metallicFactor = 0.0
+ material.pbrMetallicRoughness.roughnessFactor = 1.0
+ gltf.save(f'{save_dir}/output.glb')
+
+ return save_dir, f"{save_dir}/output.obj", f"{save_dir}/output.glb"
+
+inferapi = Inference_API()
+
+# Define the interface
+with gr.Blocks() as demo:
+ gr.Markdown("# [SIGGRAPH'24] CharacterGen: Efficient 3D Character Generation from Single Images with Multi-View Pose Calibration")
+ gr.Markdown("# 3D Stage: Four View Images to 3D Mesh")
+ with gr.Row(variant="panel"):
+ with gr.Column():
+ with gr.Row():
+ img_input0 = gr.Image(type="pil", label="Back Image", image_mode="RGBA", width=256, height=384)
+ img_input1 = gr.Image(type="pil", label="Front Image", image_mode="RGBA", width=256, height=384)
+ with gr.Row():
+ img_input2 = gr.Image(type="pil", label="Right Image", image_mode="RGBA", width=256, height=384)
+ img_input3 = gr.Image(type="pil", label="Left Image", image_mode="RGBA", width=256, height=384)
+ with gr.Row():
+ gr.Examples(
+ examples=
+ [["material/examples/1/1.png",
+ "material/examples/1/2.png",
+ "material/examples/1/3.png",
+ "material/examples/1/4.png"]],
+ label="Example Images",
+ inputs=[img_input0, img_input1, img_input2, img_input3]
+ )
+ with gr.Column():
+ with gr.Row():
+ back_proj = gr.Checkbox(label="Back Projection")
+ submit_button = gr.Button("Process")
+ output_dir = gr.Textbox(label="Output Directory")
+ with gr.Column():
+ with gr.Tab("GLB"):
+ output_model_glb = gr.Model3D( label="Output Model (GLB Format)", height = 768)
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
+ with gr.Tab("OBJ"):
+ output_model_obj = gr.Model3D( label="Output Model (OBJ Format)", height = 768)
+ gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
+
+ submit_button.click(inferapi.process_images, inputs=[img_input0, img_input1, img_input2, img_input3, back_proj],
+ outputs=[output_dir, output_model_obj, output_model_glb])
+
+# Run the interface
+demo.launch()
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..0ad25db4bd1d86c452db3f9602ccdbe172438f52
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,661 @@
+ GNU AFFERO GENERAL PUBLIC LICENSE
+ Version 3, 19 November 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+our General Public Licenses are intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+ A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+ The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+ An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing under
+this license.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU Affero General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Remote Network Interaction; Use with the GNU General Public License.
+
+ Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your version
+supports such interaction) an opportunity to receive the Corresponding
+Source of your version by providing access to the Corresponding Source
+from a network server at no charge, through some standard or customary
+means of facilitating copying of software. This Corresponding Source
+shall include the Corresponding Source for any work covered by version 3
+of the GNU General Public License that is incorporated pursuant to the
+following paragraph.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU Affero General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU Affero General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU Affero General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU Affero General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published
+ by the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+.
diff --git a/README.md b/README.md
index 9241e43cf99951683b8885bf28be80f805c1c0ee..c6440c88ba3ba1c915f8dcf18fef946927ed3f7c 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,137 @@
---
+license: apache-2.0
title: CharacterGen
-emoji: 👁
-colorFrom: red
-colorTo: blue
sdk: gradio
-sdk_version: 6.6.0
-python_version: '3.12'
-app_file: app.py
+sdk_version: 5.14.0
+python_version: "3.10"
+emoji: 🏃
+colorFrom: gray
+colorTo: red
pinned: false
-license: apache-2.0
+short_description: Gradio demo of CharacterGen (SIGGRAPH 2024)
+hardware: zero-gpu
---
+# CharacterGen: Efficient 3D Character Generation from Single Images with Multi-View Pose Calibration
+
+This is the official codebase of SIGGRAPH'24 (TOG) [CharacterGen](https://charactergen.github.io/).
+
+
+
+- [x] Rendering Script of VRM model, including blender and three-js.
+- [x] Inference code for 2D generation stage.
+- [x] Inference code for 3D generation stage.
+
+## Quick Start
+
+### 1. Prepare environment
+
+`pip install -r requirements.txt`
+
+### 2. Download the weight
+
+Install `huggingface-cli` first.
+
+```bash
+huggingface-cli download --resume-download zjpshadow/CharacterGen --include 2D_Stage/* --local-dir .
+huggingface-cli download --resume-download zjpshadow/CharacterGen --include 3D_Stage/* --local-dir .
+```
+
+If you find mistakes on download, you can download all the reporitory and move to the right folder.
+
+### 3. Run the script
+
+#### Run the whole pipeline
+```bash
+python webui.py
+```
+
+#### Only Run 2D Stage
+
+```bash
+cd 2D_Stage
+python webui.py
+```
+
+#### Only Run 3D Stage
+
+```bash
+cd 3D_Stage
+python webui.py
+```
+
+## Get the Anime3D Dataset
+
+Due to the policy, we cannot redistribute the raw data of VRM format 3D character.
+You can download the vroid dataset follow [PAniC-3D](https://github.com/ShuhongChen/panic3d-anime-reconstruction) instruction.
+And the you can render the script with blender or three-js with our released rendering script.
+
+### Blender
+
+First, you should install [Blender](https://www.blender.org/) and [the VRM addon for Blender](https://github.com/saturday06/VRM-Addon-for-Blender).
+
+The you can render the VRM and export the obj of VRM under some fbx animation.
+
+```bash
+blender -b --python render_script/blender/render.py importVrmPath importFbxPath outputFolder [is_apose]
+```
+
+The last input argument represents whether you use apose; if used, output apose; otherwise, output the action of any frame in the fbx.
+
+### [three-vrm](https://github.com/pixiv/three-vrm)
+
+**Much quicker than blender VRM add-on.**
+
+Install [Node.js](https://nodejs.org/) first to use the npm environment.
+
+```bash
+cd render_script/three-js
+npm install three @pixiv/three-vrm
+```
+
+If you want to render depth-map images of VRM, you should replace three-vrm with [my version](/home/zjp/CharacterGen/render_script/three-js/src/three-vrm.js).
+
+Fisrt, run the backend to catch the data from the frontend (default port is `17070`), remember to change the folder path.
+
+```bash
+pip install fastapi uvicorn aiofiles pillow numpy
+python up_backend.py
+```
+
+Second, run the frontend to render the images.
+
+```bash
+npm run dev
+```
+
+The open the website http://localhost:5173/, it use 2 threads to render the image, which costs about 1 day.
+
+## Our Result
+
+| Single Input Image | 2D Multi-View Images | 3D Character |
+|-------|-------|-------|
+|  |  |
|
+|  |  |
|
+|  |  |
|
+
+# Acknowledgements
+
+This project is built upon **[Tune-A-Video](https://github.com/showlab/Tune-A-Video)** and **[TripoSR](https://github.com/VAST-AI-Research/TripoSR)**.
+And the rendering scripts is build upon **[three-vrm](https://github.com/pixiv/three-vrm)** and **[VRM-Addon-for-Blender](https://github.com/saturday06/VRM-Addon-for-Blender)**.
+Thanks very much to many friends for their unselfish help with our work. We're extremely grateful to **[Yuanchen](https://github.com/bennyguo)**, **[Yangguang](https://scholar.google.com/citations?user=a7AMvgkAAAAJ)**, and **Yuan Liang** for their guidance on code details and ideas.
+We thank all the authors for their great repos and help.
+
+# Citation
+
+If you find our code or paper helps, please consider citing:
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+```bibtex
+@article{peng2024charactergen,
+ title ={CharacterGen: Efficient 3D Character Generation from Single Images with Multi-View Pose Canonicalization},
+ author ={Hao-Yang Peng and Jia-Peng Zhang and Meng-Hao Guo and Yan-Pei Cao and Shi-Min Hu},
+ journal ={ACM Transactions on Graphics (TOG)},
+ year ={2024},
+ volume ={43},
+ number ={4},
+ doi ={10.1145/3658217}
+}
+```
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..5681c451056531b66ebcf070a569d5ae03711874
--- /dev/null
+++ b/app.py
@@ -0,0 +1,483 @@
+import gradio as gr
+from PIL import Image
+import glob
+
+import spaces
+
+import io
+import argparse
+import os
+import random
+from typing import Dict, Optional, Tuple
+from omegaconf import OmegaConf
+import numpy as np
+
+import torch
+import torch.utils.checkpoint
+
+from accelerate.logging import get_logger
+from accelerate.utils import set_seed
+from diffusers import AutoencoderKL, DDIMScheduler
+from diffusers.utils import check_min_version
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection
+from torchvision import transforms
+
+import sys
+
+sys.path.append("2D_Stage")
+sys.path.append("3D_Stage")
+from tuneavideo.models.unet_mv2d_condition import UNetMV2DConditionModel
+from tuneavideo.models.unet_mv2d_ref import UNetMV2DRefModel
+from tuneavideo.models.PoseGuider import PoseGuider
+from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
+from tuneavideo.util import shifted_noise
+from einops import rearrange
+import PIL
+from PIL import Image
+from torchvision.utils import save_image
+import json
+import cv2
+
+import lrm
+import trimesh
+from lrm.utils.config import load_config
+from refine import refine
+from datetime import datetime
+import gradio as gr
+from pygltflib import GLTF2
+
+import onnxruntime as rt
+from huggingface_hub import hf_hub_download, list_repo_files
+from rm_anime_bg.cli import get_mask, SCALE
+import pymeshlab
+
+
+def download_model_files():
+ """モデルファイルをHuggingFace Hubからダウンロードする(初期化時に実行)"""
+ repo_id = "zjpshadow/CharacterGen"
+ all_files = list_repo_files(repo_id, revision="main")
+ for file in all_files:
+ if os.path.exists(file):
+ continue
+ if file.startswith("2D_Stage") or file.startswith("3D_Stage"):
+ hf_hub_download(repo_id, file, local_dir=".")
+
+
+check_min_version("0.24.0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+def get_bg_color(bg_color):
+ if bg_color == 'white':
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
+ elif bg_color == 'black':
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
+ elif bg_color == 'gray':
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
+ elif bg_color == 'random':
+ bg_color = np.random.rand(3)
+ elif isinstance(bg_color, float):
+ bg_color = np.array([bg_color] * 3, dtype=np.float32)
+ else:
+ raise NotImplementedError
+ return bg_color
+
+def process_image(image, totensor):
+ if not image.mode == "RGBA":
+ image = image.convert("RGBA")
+
+ # Find non-transparent pixels
+ non_transparent = np.nonzero(np.array(image)[..., 3])
+ min_x, max_x = non_transparent[1].min(), non_transparent[1].max()
+ min_y, max_y = non_transparent[0].min(), non_transparent[0].max()
+ image = image.crop((min_x, min_y, max_x, max_y))
+
+ # paste to center
+ max_dim = max(image.width, image.height)
+ max_height = max_dim
+ max_width = int(max_dim / 3 * 2)
+ new_image = Image.new("RGBA", (max_width, max_height))
+ left = (max_width - image.width) // 2
+ top = (max_height - image.height) // 2
+ new_image.paste(image, (left, top))
+
+ image = new_image.resize((512, 768), resample=PIL.Image.BICUBIC)
+ image = np.array(image)
+ image = image.astype(np.float32) / 255.
+ assert image.shape[-1] == 4 # RGBA
+ alpha = image[..., 3:4]
+ bg_color = get_bg_color("gray")
+ image = image[..., :3] * alpha + bg_color * (1 - alpha)
+ # save image
+ new_image = Image.fromarray((image * 255).astype(np.uint8))
+ new_image.save("input.png")
+ return totensor(image)
+
+class rm_bg_api:
+
+ def __init__(self, force_cpu: Optional[bool] = True):
+ session_infer_path = hf_hub_download(
+ repo_id="skytnt/anime-seg", filename="isnetis.onnx",
+ )
+ providers: list[str] = ["CPUExecutionProvider"]
+ if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers():
+ providers = ["CUDAExecutionProvider"]
+
+ self.session_infer = rt.InferenceSession(
+ session_infer_path, providers=providers,
+ )
+
+ def _remove_background_impl(
+ self,
+ imgs: list[np.ndarray],
+ alpha_min: float,
+ alpha_max: float,
+ ) -> list:
+ process_imgs = []
+ for img in imgs:
+ img = np.array(img)
+ # CHANGE to RGB
+ if img.shape[-1] == 4:
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
+ mask = get_mask(self.session_infer, img)
+
+ mask[mask < alpha_min] = 0.0 # type: ignore
+ mask[mask > alpha_max] = 1.0 # type: ignore
+
+ img_after = (mask * img).astype(np.uint8) # type: ignore
+ mask = (mask * SCALE).astype(np.uint8) # type: ignore
+ img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8)
+ mask = mask.repeat(3, axis=2)
+ process_imgs.append(Image.fromarray(img_after))
+ return process_imgs
+
+
+class Inference2D_API:
+
+ def __init__(self,
+ pretrained_model_path: str,
+ image_encoder_path: str,
+ ckpt_dir: str,
+ validation: Dict,
+ local_crossattn: bool = True,
+ unet_from_pretrained_kwargs=None,
+ unet_condition_type=None,
+ use_pose_guider=False,
+ use_shifted_noise=False,
+ use_noise=True,
+ device="cuda"
+ ):
+ self.validation = validation
+ self.use_noise = use_noise
+ self.use_shifted_noise = use_shifted_noise
+ self.unet_condition_type = unet_condition_type
+ image_encoder_path = image_encoder_path.replace("./", "./2D_Stage/")
+ ckpt_dir = ckpt_dir.replace("./", "./2D_Stage/")
+
+ self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path)
+ feature_extractor = CLIPImageProcessor()
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
+ unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
+ ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
+ if use_pose_guider:
+ pose_guider = PoseGuider(noise_latent_channels=4).to("cuda")
+ else:
+ pose_guider = None
+
+ unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model.bin"), map_location="cpu")
+ if use_pose_guider:
+ pose_guider_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_1.bin"), map_location="cpu")
+ ref_unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_2.bin"), map_location="cpu")
+ pose_guider.load_state_dict(pose_guider_params)
+ else:
+ ref_unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_1.bin"), map_location="cpu")
+ unet.load_state_dict(unet_params)
+ ref_unet.load_state_dict(ref_unet_params)
+
+ weight_dtype = torch.float16
+
+ text_encoder.to(device, dtype=weight_dtype)
+ image_encoder.to(device, dtype=weight_dtype)
+ vae.to(device, dtype=weight_dtype)
+ ref_unet.to(device, dtype=weight_dtype)
+ unet.to(device, dtype=weight_dtype)
+
+ vae.requires_grad_(False)
+ unet.requires_grad_(False)
+ ref_unet.requires_grad_(False)
+
+ noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
+ self.validation_pipeline = TuneAVideoPipeline(
+ vae=vae, text_encoder=text_encoder, tokenizer=self.tokenizer, unet=unet, ref_unet=ref_unet,feature_extractor=feature_extractor,image_encoder=image_encoder,
+ scheduler=noise_scheduler
+ )
+ self.validation_pipeline.enable_vae_slicing()
+ self.validation_pipeline.set_progress_bar_config(disable=True)
+ self.generator = torch.Generator(device=device)
+
+ @torch.no_grad()
+ def _inference_impl(self, input_image, val_width, val_height,
+ use_shifted_noise=False, crop=False, seed=100, timestep=20):
+ set_seed(seed)
+ totensor = transforms.ToTensor()
+
+ metas = json.load(open("./2D_Stage/material/pose.json", "r"))
+ cameras = []
+ pose_images = []
+ input_path = "./2D_Stage/material"
+ for lm in metas:
+ cameras.append(torch.tensor(np.array(lm[0]).reshape(4, 4).transpose(1,0)[:3, :4]).reshape(-1))
+ if not crop:
+ pose_images.append(totensor(np.asarray(Image.open(os.path.join(input_path, lm[1])).resize(
+ (val_height, val_width), resample=PIL.Image.BICUBIC)).astype(np.float32) / 255.))
+ else:
+ pose_image = Image.open(os.path.join(input_path, lm[1]))
+ crop_area = (128, 0, 640, 768)
+ pose_images.append(totensor(np.array(pose_image.crop(crop_area)).astype(np.float32)) / 255.)
+ camera_matrixs = torch.stack(cameras).unsqueeze(0).to("cuda")
+ pose_imgs_in = torch.stack(pose_images).to("cuda")
+ prompts = "high quality, best quality"
+ prompt_ids = self.tokenizer(
+ prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ ).input_ids[0]
+
+ # (B*Nv, 3, H, W)
+ B = 1
+ weight_dtype = torch.bfloat16
+ imgs_in = process_image(input_image, totensor)
+ imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")
+
+ with torch.autocast("cuda", dtype=weight_dtype):
+ imgs_in = imgs_in.to("cuda")
+ # B*Nv images
+ out = self.validation_pipeline(prompt=prompts, image=imgs_in.to(weight_dtype), generator=self.generator,
+ num_inference_steps=timestep,
+ camera_matrixs=camera_matrixs.to(weight_dtype), prompt_ids=prompt_ids,
+ height=val_height, width=val_width, unet_condition_type=self.unet_condition_type,
+ pose_guider=None, pose_image=pose_imgs_in, use_noise=self.use_noise,
+ use_shifted_noise=use_shifted_noise, **self.validation).videos
+ out = rearrange(out, "B C f H W -> (B f) C H W", f=self.validation.video_length)
+
+ image_outputs = []
+ for bs in range(4):
+ img_buf = io.BytesIO()
+ save_image(out[bs], img_buf, format='PNG')
+ img_buf.seek(0)
+ img = Image.open(img_buf)
+ image_outputs.append(img)
+ torch.cuda.empty_cache()
+ return image_outputs
+
+
+def traverse(path, back_proj, smooth_iter):
+ ms = pymeshlab.MeshSet()
+ ms.load_new_mesh(f"{path}/model-00.obj")
+ image = Image.open(f"{path}/{'refined_texture_kd.jpg' if back_proj else 'texture_kd.jpg'}")
+ out_image_path = f"{path}/{'refined_texture_kd.png' if back_proj else 'texture_kd.png'}"
+ image.save(out_image_path, 'PNG')
+ ms.set_texture_per_mesh(textname=f"{path}/{'refined_texture_kd.png' if back_proj else 'texture_kd.png'}")
+ ms.meshing_merge_close_vertices()
+ ms.apply_coord_laplacian_smoothing(stepsmoothnum=smooth_iter)
+ ms.save_current_mesh(f"{path}/temp-00.obj", save_vertex_normal=False, save_wedge_normal=False, save_vertex_color=False)
+
+ mesh = trimesh.load(f"{path}/temp-00.obj", process=False)
+ mesh.apply_transform(trimesh.transformations.rotation_matrix(np.radians(90.0), [-1, 0, 0]))
+ mesh.apply_transform(trimesh.transformations.rotation_matrix(np.radians(180.0), [0, 1, 0]))
+
+ mesh.export(f'{path}/output.glb', file_type='glb')
+
+ image = Image.open(f"{path}/{'refined_texture_kd.png' if back_proj else 'texture_kd.png'}")
+ texture = np.array(image)
+ vertex_colors = np.zeros((mesh.vertices.shape[0], 4), dtype=np.uint8)
+
+ for vertex_index in range(len(mesh.visual.uv)):
+ uv = mesh.visual.uv[vertex_index]
+ x = int(uv[0] * (texture.shape[1] - 1))
+ y = int((1 - uv[1]) * (texture.shape[0] - 1))
+
+ color = texture[y, x, :3]
+ vertex_colors[vertex_index] = [color[0], color[1], color[2], 255]
+ return trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, vertex_colors=vertex_colors, process=False)
+
+class Inference3D_API:
+
+ def __init__(self, device="cuda"):
+ self.cfg = load_config("3D_Stage/configs/infer.yaml", makedirs=False)
+ print("Loading system")
+ self.device = device
+ self.cfg.system.weights = self.cfg.system.weights.replace("./", "./3D_Stage/")
+ self.cfg.system.image_tokenizer.pretrained_model_name_or_path = \
+ self.cfg.system.image_tokenizer.pretrained_model_name_or_path.replace("./", "./3D_Stage/")
+ self.cfg.system.renderer.tet_dir = self.cfg.system.renderer.tet_dir.replace("./", "./3D_Stage/")
+ self.cfg.system.exporter.output_path = self.cfg.system.exporter.output_path.replace("./", "./3D_Stage/")
+ self.system = lrm.find(self.cfg.system_cls)(self.cfg.system).to(self.device)
+ self.system.eval()
+
+ def _process_images_impl(self, img_input0, img_input1, img_input2, img_input3, back_proj, smooth_iter):
+ meta = json.load(open("./3D_Stage/material/meta.json"))
+ c2w_cond = [np.array(loc["transform_matrix"]) for loc in meta["locations"]]
+ c2w_cond = torch.from_numpy(np.stack(c2w_cond, axis=0)).float()[None].to(self.device)
+
+ rgb_cond = []
+ files = [img_input0, img_input1, img_input2, img_input3]
+ new_images = []
+ for file in files:
+ image = np.array(file)
+ image = Image.fromarray(image)
+ if image.width != image.height:
+ max_dim = max(image.width, image.height)
+ new_image = Image.new("RGBA", (max_dim, max_dim))
+ left = (max_dim - image.width) // 2
+ top = (max_dim - image.height) // 2
+ new_image.paste(image, (left, top))
+ image = new_image
+ image.save("input_3D.png")
+
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGBA2RGB)
+ rgb = cv2.resize(image, (self.cfg.data.cond_width,
+ self.cfg.data.cond_height)).astype(np.float32) / 255.0
+ new_images.append(Image.fromarray(image.astype(np.uint8)).convert("RGB"))
+ rgb_cond.append(rgb)
+ assert len(rgb_cond) == 4, "Please provide 4 images"
+
+ rgb_cond = torch.from_numpy(np.stack(rgb_cond, axis=0)).float()[None].to(self.device)
+
+ with torch.no_grad():
+ scene_codes = self.system({"rgb_cond": rgb_cond, "c2w_cond": c2w_cond})
+ exporter_output = self.system.exporter([f"{i:02d}" for i in range(rgb_cond.shape[0])], scene_codes)
+
+ save_dir = os.path.join("./3D_Stage/outputs", datetime.now().strftime("@%Y%m%d-%H%M%S"))
+ os.makedirs(save_dir, exist_ok=True)
+ self.system.set_save_dir(save_dir)
+
+ for out in exporter_output:
+ save_func_name = f"save_{out.save_type}"
+ save_func = getattr(self.system, save_func_name)
+ save_func(f"{out.save_name}", **out.params)
+ if back_proj:
+ refine(save_dir, new_images[1], new_images[0], new_images[3], new_images[2])
+
+ new_obj = traverse(save_dir, back_proj, smooth_iter)
+ new_obj.export(f'{save_dir}/output.obj', file_type='obj')
+
+ gltf = GLTF2().load(f'{save_dir}/output.glb')
+ for material in gltf.materials:
+ if material.pbrMetallicRoughness:
+ material.pbrMetallicRoughness.baseColorFactor = [1.0, 1.0, 1.0, 100.0]
+ material.pbrMetallicRoughness.metallicFactor = 0.0
+ material.pbrMetallicRoughness.roughnessFactor = 1.0
+ gltf.save(f'{save_dir}/output.glb')
+
+ return save_dir, f"{save_dir}/output.obj", f"{save_dir}/output.glb"
+
+
+# モジュールレベルのシングルトンインスタンス(遅延初期化)
+_remove_api = None
+_infer2dapi = None
+_infer3dapi = None
+
+
+@spaces.GPU
+def run_remove_background(imgs, alpha_min, alpha_max):
+ global _remove_api
+ if _remove_api is None:
+ _remove_api = rm_bg_api()
+ return _remove_api._remove_background_impl(imgs, alpha_min, alpha_max)
+
+
+@spaces.GPU
+def run_inference2d(input_image, val_width, val_height,
+ use_shifted_noise=False, crop=False, seed=100, timestep=20):
+ global _infer2dapi
+ if _infer2dapi is None:
+ download_model_files()
+ _infer2dapi = Inference2D_API(**OmegaConf.load("./2D_Stage/configs/infer.yaml"))
+ return _infer2dapi._inference_impl(input_image, val_width, val_height,
+ use_shifted_noise, crop, seed, timestep)
+
+
+@spaces.GPU
+def run_inference3d(img_input0, img_input1, img_input2, img_input3, back_proj, smooth_iter):
+ global _infer3dapi
+ if _infer3dapi is None:
+ download_model_files()
+ _infer3dapi = Inference3D_API()
+ return _infer3dapi._process_images_impl(img_input0, img_input1, img_input2, img_input3,
+ back_proj, smooth_iter)
+
+
+@torch.no_grad()
+def main():
+ def gen4views(image, width, height, seed, timestep, remove_bg):
+ if remove_bg:
+ image = run_remove_background(
+ imgs=[np.array(image)],
+ alpha_min=0.1,
+ alpha_max=0.9,
+ )[0]
+ return run_remove_background(
+ imgs=run_inference2d(
+ image, width, height, crop=True, seed=seed, timestep=timestep
+ ), alpha_min=0.2, alpha_max=0.9)
+
+ with gr.Blocks() as demo:
+ gr.Markdown("# [SIGGRAPH'24] CharacterGen: Efficient 3D Character Generation from Single Images with Multi-View Pose Calibration")
+ with gr.Row():
+ with gr.Column(variant="panel"):
+ img_input = gr.Image(type="pil", label="Upload Image(without background)", image_mode="RGBA", width=768, height=512)
+ gr.Examples(
+ label="Example Images",
+ examples=glob.glob("./2D_Stage/material/examples/*.png"),
+ inputs=[img_input]
+ )
+ with gr.Row():
+ width_input = gr.Number(label="Width", value=512)
+ height_input = gr.Number(label="Height", value=768)
+ seed_input = gr.Number(label="Seed", value=2333)
+ remove_bg = gr.Checkbox(label="Remove Background (with algorithm)", value=True)
+ with gr.Column(variant="panel"):
+ timestep = gr.Slider(minimum=10, maximum=70, step=1, value=40, label="Timesteps")
+ button1 = gr.Button(value="Generate 4 Views")
+ with gr.Row():
+ img_input0 = gr.Image(type="pil", label="Back Image", image_mode="RGBA", width=256, height=384)
+ img_input1 = gr.Image(type="pil", label="Front Image", image_mode="RGBA", width=256, height=384)
+ with gr.Row():
+ img_input2 = gr.Image(type="pil", label="Right Image", image_mode="RGBA", width=256, height=384)
+ img_input3 = gr.Image(type="pil", label="Left Image", image_mode="RGBA", width=256, height=384)
+ with gr.Column(variant="panel"):
+ smooth_iter = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Laplacian Smoothing Iterations")
+ with gr.Row():
+ back_proj = gr.Checkbox(label="Back Projection")
+ button2 = gr.Button(value="Generate 3D Mesh")
+ output_dir = gr.Textbox(label="Output Directory")
+ with gr.Row():
+ with gr.Tab("GLB"):
+ output_model_glb = gr.Model3D( label="Output Model (GLB Format)", height=512)
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
+ with gr.Tab("OBJ"):
+ output_model_obj = gr.Model3D( label="Output Model (OBJ Format)")
+ gr.Markdown("Note: The model shown here's texture is mapped to vertex. Download to get correct results.")
+ button1.click(
+ fn=gen4views,
+ inputs=[img_input, width_input, height_input, seed_input, timestep, remove_bg],
+ outputs=[img_input2, img_input0, img_input3, img_input1]
+ )
+ button2.click(
+ run_inference3d,
+ inputs=[img_input0, img_input1, img_input2, img_input3, back_proj, smooth_iter],
+ outputs=[output_dir, output_model_obj, output_model_glb]
+ )
+ demo.queue()
+ demo.launch()
+
+if __name__ == "__main__":
+ main()
diff --git a/final_texture.png b/final_texture.png
new file mode 100644
index 0000000000000000000000000000000000000000..4b3c4225e2d1fd0cc7d34f8950f05b6a8578e56a
--- /dev/null
+++ b/final_texture.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd3ef257ea3be83344ab7ddd0100c5ec4568a28b4aff05952a7b4fbc394fb240
+size 835726
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5a3de257693a4ff57e1f1a655f9c45ecff2bf4b6
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,31 @@
+gradio==5.14.0
+rm_anime_bg
+# for 2D_stage
+accelerate
+transformers==4.32.1
+diffusers==0.24.0
+huggingface_hub==0.25.2
+ipdb
+einops
+imageio
+
+--extra-index-url https://download.pytorch.org/whl/cu121
+torch==2.4.0
+torchvision==0.19.0
+xformers==0.0.27.post2
+
+onnxruntime
+omegaconf
+# for 3D_stage
+pytorch_lightning
+jaxtyping
+wandb
+lpips
+https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true
+ninja
+open3d
+trimesh
+pymeshlab
+pygltflib
+omegaconf
+typeguard==4.1.5
\ No newline at end of file