Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023-2024, Zexin He | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import torch.nn as nn | |
| from accelerate.logging import get_logger | |
| from .embedder import CameraEmbedder | |
| from .transformer import TransformerDecoder | |
| from .rendering.synthesizer import TriplaneSynthesizer | |
| from .utils import zero_module | |
| import loratorch as lora | |
| from .swin_transformer import CrossAttentionLayer | |
| logger = get_logger(__name__) | |
| class ModelLRM(nn.Module): | |
| """ | |
| Full model of the basic single-view large reconstruction model. | |
| """ | |
| def __init__(self, camera_embed_dim: int, rendering_samples_per_ray: int, | |
| transformer_dim: int, transformer_layers: int, transformer_heads: int, | |
| triplane_low_res: int, triplane_high_res: int, triplane_dim: int, | |
| encoder_freeze: bool = True, encoder_type: str = 'dino', | |
| encoder_model_name: str = 'facebook/dino-vitb16', encoder_feat_dim: int = 768, | |
| model_lora_rank: int = 0, conv_fuse=False, | |
| swin_ca_fuse=False, ca_dim=32, ca_depth=2, ca_num_heads=8, ca_window_size=2): | |
| super().__init__() | |
| # attributes | |
| self.encoder_feat_dim = encoder_feat_dim | |
| self.camera_embed_dim = camera_embed_dim | |
| self.triplane_low_res = triplane_low_res | |
| self.triplane_high_res = triplane_high_res | |
| self.triplane_dim = triplane_dim | |
| self.conv_fuse = conv_fuse | |
| self.swin_ca_fuse = swin_ca_fuse | |
| # modules | |
| self.encoder = self._encoder_fn(encoder_type)( | |
| model_name=encoder_model_name, | |
| freeze=encoder_freeze, | |
| ) | |
| self.camera_embedder = CameraEmbedder( | |
| raw_dim=12+4, embed_dim=camera_embed_dim, | |
| ) | |
| # initialize pos_embed with 1/sqrt(dim) * N(0, 1) | |
| self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, transformer_dim) * (1. / transformer_dim) ** 0.5) | |
| if model_lora_rank > 0: | |
| self.transformer = TransformerDecoder( | |
| block_type='cond_mod', | |
| num_layers=transformer_layers, num_heads=transformer_heads, | |
| inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=camera_embed_dim, | |
| lora_rank=model_lora_rank | |
| ) | |
| lora.mark_only_lora_as_trainable(self.transformer) | |
| else: | |
| self.transformer = TransformerDecoder( | |
| block_type='cond_mod', | |
| num_layers=transformer_layers, num_heads=transformer_heads, | |
| inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=camera_embed_dim, | |
| ) | |
| self.upsampler = nn.ConvTranspose2d(transformer_dim, triplane_dim, kernel_size=2, stride=2, padding=0) | |
| self.synthesizer = TriplaneSynthesizer( | |
| triplane_dim=triplane_dim, samples_per_ray=rendering_samples_per_ray, | |
| ) | |
| if model_lora_rank > 0: | |
| if self.conv_fuse: | |
| # self.front_back_conv = nn.Conv2d(in_channels=triplane_dim*2, out_channels=triplane_dim, kernel_size=(3, 3), stride=(1, 1), padding=1) | |
| # zero_module(self.front_back_conv) | |
| self.front_back_conv = nn.ModuleList([ | |
| nn.Conv2d(in_channels=triplane_dim*2, out_channels=triplane_dim*4, kernel_size=(3, 3), stride=(1, 1), padding=1), | |
| nn.LayerNorm([triplane_dim*4, triplane_high_res, triplane_high_res]), # Using Layer Normalization | |
| nn.GELU(), # Using GELU activation | |
| nn.Conv2d(in_channels=triplane_dim*4, out_channels=triplane_dim*4, kernel_size=(3, 3), stride=(1, 1), padding=1), | |
| nn.LayerNorm([triplane_dim*4, triplane_high_res, triplane_high_res]), # Using Layer Normalization | |
| nn.GELU(), # Using GELU activation | |
| nn.Conv2d(in_channels=triplane_dim*4, out_channels=triplane_dim, kernel_size=(3, 3), stride=(1, 1), padding=1) | |
| ]) | |
| self.freeze_modules(encoder=True, camera_embedder=True, | |
| pos_embed=False, transformer=False, upsampler=False, | |
| synthesizer=False) | |
| elif self.swin_ca_fuse: | |
| self.swin_cross_attention = CrossAttentionLayer(dim=ca_dim, depth=ca_depth, num_heads=ca_num_heads, window_size=ca_window_size) | |
| self.freeze_modules(encoder=True, camera_embedder=True, | |
| pos_embed=False, transformer=False, upsampler=False, | |
| synthesizer=False) | |
| else: | |
| raise ValueError("You need to specify a method for fusing the front and the back.") | |
| def freeze_modules(self, encoder=False, camera_embedder=False, | |
| pos_embed=False, transformer=False, upsampler=False, | |
| synthesizer=False): | |
| """ | |
| Freeze specified modules | |
| """ | |
| if encoder: | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = False | |
| if camera_embedder: | |
| for param in self.camera_embedder.parameters(): | |
| param.requires_grad = False | |
| if pos_embed: | |
| for param in self.pos_embed.parameters(): | |
| param.requires_grad = False | |
| if transformer: | |
| for param in self.transformer.parameters(): | |
| param.requires_grad = False | |
| if upsampler: | |
| for param in self.upsampler.parameters(): | |
| param.requires_grad = False | |
| if synthesizer: | |
| for param in self.synthesizer.parameters(): | |
| param.requires_grad = False | |
| def _encoder_fn(encoder_type: str): | |
| encoder_type = encoder_type.lower() | |
| assert encoder_type in ['dino', 'dinov2'], "Unsupported encoder type" | |
| if encoder_type == 'dino': | |
| from .encoders.dino_wrapper import DinoWrapper | |
| logger.info("Using DINO as the encoder") | |
| return DinoWrapper | |
| elif encoder_type == 'dinov2': | |
| from .encoders.dinov2_wrapper import Dinov2Wrapper | |
| logger.info("Using DINOv2 as the encoder") | |
| return Dinov2Wrapper | |
| def forward_transformer(self, image_feats, camera_embeddings): | |
| assert image_feats.shape[0] == camera_embeddings.shape[0], \ | |
| "Batch size mismatch for image_feats and camera_embeddings!" | |
| N = image_feats.shape[0] | |
| x = self.pos_embed.repeat(N, 1, 1) # [N, L, D] | |
| x = self.transformer( | |
| x, | |
| cond=image_feats, | |
| mod=camera_embeddings, | |
| ) | |
| return x | |
| def reshape_upsample(self, tokens): | |
| N = tokens.shape[0] | |
| H = W = self.triplane_low_res | |
| x = tokens.view(N, 3, H, W, -1) | |
| x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W] | |
| x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W] | |
| x = self.upsampler(x) # [3*N, D', H', W'] | |
| x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W'] | |
| x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] | |
| x = x.contiguous() | |
| return x | |
| def forward_planes(self, image, camera): | |
| # image: [N, C_img, H_img, W_img] | |
| # camera: [N, D_cam_raw] | |
| N = image.shape[0] | |
| # encode image | |
| image_feats = self.encoder(image) | |
| assert image_feats.shape[-1] == self.encoder_feat_dim, \ | |
| f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}" | |
| # embed camera | |
| camera_embeddings = self.camera_embedder(camera) | |
| assert camera_embeddings.shape[-1] == self.camera_embed_dim, \ | |
| f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}" | |
| # transformer generating planes | |
| tokens = self.forward_transformer(image_feats, camera_embeddings) | |
| planes = self.reshape_upsample(tokens) | |
| assert planes.shape[0] == N, "Batch size mismatch for planes" | |
| assert planes.shape[1] == 3, "Planes should have 3 channels" | |
| return planes | |
| def forward(self, image, source_camera, render_cameras, render_anchors, render_resolutions, render_bg_colors, render_region_size: int, | |
| image_back=None,): | |
| # image: [N, C_img, H_img, W_img] | |
| # source_camera: [N, D_cam_raw] | |
| # render_cameras: [N, M, D_cam_render] | |
| # render_anchors: [N, M, 2] | |
| # render_resolutions: [N, M, 1] | |
| # render_bg_colors: [N, M, 1] | |
| # render_region_size: int | |
| assert image.shape[0] == source_camera.shape[0], "Batch size mismatch for image and source_camera" | |
| assert image.shape[0] == render_cameras.shape[0], "Batch size mismatch for image and render_cameras" | |
| assert image.shape[0] == render_anchors.shape[0], "Batch size mismatch for image and render_anchors" | |
| assert image.shape[0] == render_bg_colors.shape[0], "Batch size mismatch for image and render_bg_colors" | |
| N, M = render_cameras.shape[:2] | |
| if image_back is not None: | |
| front_planes = self.forward_planes(image, source_camera) | |
| back_planes = self.forward_planes(image_back, source_camera) | |
| # XY Plane | |
| back_planes[:, 0, :, :, :] = torch.flip(back_planes[:, 0, :, :, :], dims=[-2, -1]) | |
| # XZ Plane | |
| back_planes[:, 1, :, :, :] = torch.flip(back_planes[:, 1, :, :, :], dims=[-1]) | |
| # YZ Plane | |
| back_planes[:, 2, :, :, :] = torch.flip(back_planes[:, 2, :, :, :], dims=[-2]) | |
| # To fuse the front planes and the back planes | |
| bs, num_planes, channels, height, width = front_planes.shape | |
| if self.conv_fuse: | |
| planes = torch.cat((front_planes, back_planes), dim=2) | |
| planes = planes.reshape(-1, channels*2, height, width) | |
| # Apply multiple convolutional layers | |
| for layer in self.front_back_conv: | |
| planes = layer(planes) | |
| planes = planes.view(bs, num_planes, -1, height, width) | |
| # planes = self.front_back_conv(planes).view(bs, num_planes, -1, height, width) # only one layer. | |
| elif self.swin_ca_fuse: | |
| front_planes = front_planes.reshape(bs*num_planes, channels, height*width).permute(0, 2, 1).contiguous() # [8, 3, 32, 64, 64] -> [24, 32, 4096] -> [24, 4096, 32] | |
| back_planes = back_planes.reshape(bs*num_planes, channels, height*width).permute(0, 2, 1).contiguous() | |
| planes = self.swin_cross_attention(front_planes, back_planes, height, width)[0].permute(0, 2, 1).reshape(bs, num_planes, channels, height, width) | |
| else: | |
| planes = self.forward_planes(image, source_camera) | |
| # render target views | |
| render_results = self.synthesizer(planes, render_cameras, render_anchors, render_resolutions, render_bg_colors, render_region_size) | |
| assert render_results['images_rgb'].shape[0] == N, "Batch size mismatch for render_results" | |
| assert render_results['images_rgb'].shape[1] == M, "Number of rendered views should be consistent with render_cameras" | |
| return { | |
| 'planes': planes, | |
| **render_results, | |
| } | |