| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .deeplabv3 import DeepLabV3 |
| from .simple_encoders.high_resolution_encoder import HighResoEncoder |
| from .segformer import LowResolutionViT, TriplanePredictorViT |
| from modules.commons.loralib.layers import MergedLoRALinear, LoRALinear, LoRAConv2d |
| import copy |
| from utils.commons.hparams import hparams |
|
|
|
|
| class Img2PlaneModel(nn.Module): |
| def __init__(self, out_channels=96, hp=None, lora_args=None): |
| super().__init__() |
| global hparams |
| self.hparams = hp if hp is not None else copy.deepcopy(hparams) |
| hparams = self.hparams |
| |
| self.input_mode = hparams.get("img2plane_input_mode", "rgb") |
| if self.input_mode == 'rgb': |
| in_channels = 3 |
| elif self.input_mode == 'rgb_alpha': |
| in_channels = 4 |
| elif self.input_mode == 'rgb_camera': |
| self.camera_to_channel = nn.Linear(25, 3) |
| in_channels = 3 + 3 |
| elif self.input_mode == 'rgb_alpha_camera': |
| self.camera_to_channel = nn.Linear(25, 3) |
| in_channels = 4 + 3 |
|
|
| in_channels += 2 |
| self.low_reso_encoder = DeepLabV3(in_channels=in_channels) |
| self.high_reso_encoder = HighResoEncoder(in_dim=in_channels) |
| self.low_reso_vit = LowResolutionViT() |
| self.triplane_predictor_vit = TriplanePredictorViT(out_channels=out_channels, img2plane_backbone_scale=hparams['img2plane_backbone_scale'], lora_args=lora_args) |
|
|
| def forward(self, x, cond=None, **synthesis_kwargs): |
| """ |
| x: original image, [B, 3, H=512, W=512] |
| return: predicted triplane, [B, 32*3, H=256, W=256] |
| optional: |
| ref_alphas: 0/1 mask, if img2plane, all ones; if secc2plane, only ones for head, [B, 1, H, W] |
| ref_camera: camera pose of the input img, [B, 25] |
| """ |
| bs, _, H, W = x.shape |
|
|
| if self.input_mode in ['rgb_alpha', 'rgb_alpha_camera']: |
| if cond is None or cond.get("ref_alphas") is None: |
| ref_alphas = (x.mean(dim=1, keepdim=True) >= -0.999).float() |
| else: |
| ref_alphas = cond["ref_alphas"] |
| x = torch.cat([x, ref_alphas], dim=1) |
| if self.input_mode in ['rgb_camera', 'rgb_alpha_camera']: |
| ref_cameras = cond["ref_cameras"] |
| camera_feat = self.camera_to_channel(ref_cameras).reshape(bs, 3, 1, 1).repeat([1, 1, H, W]) |
| x = torch.cat([x, camera_feat], dim=1) |
|
|
| |
| grid_x, grid_y = torch.meshgrid(torch.arange(H, device=x.device), torch.arange(W, device=x.device)) |
| grid_x = grid_x / H |
| grid_y = grid_y / H |
| expanded_x = grid_x[None, None, :, :].repeat(bs, 1, 1, 1) |
| expanded_y = grid_y[None, None, :, :].repeat(bs, 1, 1, 1) |
| x = torch.cat([x, expanded_x, expanded_y], dim=1) |
|
|
| feat_low = self.low_reso_encoder(x) |
| feat_low_after_vit = self.low_reso_vit(feat_low) |
| feat_high = self.high_reso_encoder(x) |
| |
| planes = self.triplane_predictor_vit(feat_low_after_vit, feat_high) |
|
|
| planes = planes.view(len(planes), 3, -1, planes.shape[-2], planes.shape[-1]) |
|
|
| |
| planes_xy = planes[:,0] |
| planes_xy = torch.flip(planes_xy, [2]) |
| planes_xz = planes[:,1] |
| planes_xz = torch.flip(planes_xz, [2]) |
| planes_zy = planes[:,2] |
| planes_zy = torch.flip(planes_zy, [2, 3]) |
| planes = torch.stack([planes_xy, planes_xz, planes_zy], dim=1) |
| return planes |