| import torch |
| import torch.nn as nn |
| import numpy as np |
| from einops import repeat |
|
|
| class SKYMLP(nn.Module): |
| r"""MLP converting ray directions to sky features.""" |
| def __init__(self, in_channels, style_dim, L=None,out_channels_c=3, |
| hidden_channels=256,is_pos_embedding = True): |
| super(SKYMLP, self).__init__() |
| self.is_pos_embedding = is_pos_embedding |
| self.L = L |
| self.fc_z_a = nn.Linear(style_dim, hidden_channels, bias=False) |
| input_channel = in_channels+ 2*self.L*in_channels if is_pos_embedding else in_channels |
| self.fc1 = nn.Linear(input_channel, hidden_channels) |
| self.fc2 = nn.Linear(hidden_channels*2, hidden_channels) |
| |
| |
| |
| self.fc_out_c = nn.Linear(hidden_channels, out_channels_c) |
| self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
| def positional_encoding(self,input): |
| shape = input.shape |
| freq = 2**torch.arange(self.L,dtype=torch.float32).to(input.device)*np.pi |
| spectrum = input[...,None]*freq |
| sin,cos = spectrum.sin(),spectrum.cos() |
| input_enc = torch.stack([sin,cos],dim=-2) |
| input_enc = input_enc.view(*shape[:-1],-1) |
| return input_enc |
| def forward(self, x, z): |
| r"""Forward network |
| Args: |
| x (... x in_channels tensor): Ray direction embeddings. |
| z (... x style_dim tensor): Style codes. |
| """ |
| x = torch.cat([x,self.positional_encoding(x)],dim=-1) if self.is_pos_embedding else x |
| z = self.fc_z_a(z) |
| assert len(x.shape) == 4 |
| (H,W )= x.shape[1:3] |
| z = repeat(z,'b c -> b h w c',h=H,w=W) |
| |
| |
| |
| y = self.act(torch.cat([self.fc1(x),z],dim=-1)) |
| y = self.act(self.fc2(y)) |
| c = self.fc_out_c(y) |
| c = torch.sigmoid(c) |
| return c |
| |
| class MLPNetwork2(nn.Module): |
| """Defines fully-connected layer head in EG3D.""" |
|
|
| def __init__(self, input_dim, hidden_dim, output_dim,style_dim=270): |
| super().__init__() |
|
|
| self.net0 = nn.Linear(input_dim, hidden_dim) |
| self.net0_act = nn.Softplus() |
| self.net1_feature = nn.Linear(hidden_dim, hidden_dim//2) |
| self.net1_density = nn.Linear(hidden_dim, 1) |
| self.style_dim = style_dim |
| self.style_squ = nn.Linear(self.style_dim,hidden_dim//2) |
| self.grd_color_convert = nn.Linear(hidden_dim, output_dim) |
|
|
| def forward(self, point_features, style=None, only_density=False): |
|
|
| N, M, C = point_features.shape |
| point_features = point_features.view(N * M, C) |
| y = self.net0(point_features) |
| y = self.net0_act(y) |
| density = self.net1_density(y).view(N, M, -1) |
| result = {} |
| result['density'] = density |
| if only_density: |
| return result |
|
|
| color = self.net1_feature(y).view(N, M, -1) |
| if style is None: |
| style = repeat(torch.zeros([self.style_dim]), 'd -> n m d', n=N,m=M).float().to(point_features.device) |
| style = self.style_squ(style) |
|
|
| if len(style.shape) == 2: |
| style = repeat(style, 'n d -> n m d', m=M).float().to(point_features.device) |
| combine_color_style = torch.cat([color, style], dim=-1) |
| color = self.grd_color_convert(combine_color_style) |
|
|
| color = torch.sigmoid(color) |
| result['color'] = color |
| return result |
|
|