Spaces:
Configuration error
Configuration error
| import torch.nn as nn | |
| import torch | |
| from lib.config import cfg | |
| from .embedder import get_embedder | |
| import torch.nn.functional as F | |
| class Nerf(nn.Module): | |
| def __init__(self, | |
| D=8, | |
| W=256, | |
| input_ch=3, | |
| input_ch_views=3, | |
| skips=[4], | |
| use_viewdirs=False): | |
| """ | |
| """ | |
| super(Nerf, self).__init__() | |
| self.D = D | |
| self.W = W | |
| self.input_ch = input_ch | |
| self.input_ch_views = input_ch_views | |
| self.skips = skips | |
| self.use_viewdirs = use_viewdirs | |
| self.pts_linears = nn.ModuleList([nn.Linear(input_ch, W)] + [ | |
| nn.Linear(W, W) if i not in | |
| self.skips else nn.Linear(W + input_ch, W) for i in range(D - 1) | |
| ]) | |
| ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) | |
| self.views_linears = nn.ModuleList( | |
| [nn.Linear(input_ch_views + W, W // 2)]) | |
| ### Implementation according to the paper | |
| # self.views_linears = nn.ModuleList( | |
| # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) | |
| if self.use_viewdirs: | |
| self.feature_linear = nn.Linear(W, W) | |
| self.alpha_linear = nn.Linear(W, 1) | |
| self.rgb_linear = nn.Linear(W // 2, 3) | |
| def forward(self, x): | |
| input_pts = x | |
| h = input_pts | |
| for i, l in enumerate(self.pts_linears): | |
| h = self.pts_linears[i](h) | |
| h = F.relu(h) | |
| if i in self.skips: | |
| h = torch.cat([input_pts, h], -1) | |
| alpha = self.alpha_linear(h) | |
| return alpha | |
| def load_weights_from_keras(self, weights): | |
| assert self.use_viewdirs, "Not implemented if use_viewdirs=False" | |
| # Load pts_linears | |
| for i in range(self.D): | |
| idx_pts_linears = 2 * i | |
| self.pts_linears[i].weight.data = torch.from_numpy( | |
| np.transpose(weights[idx_pts_linears])) | |
| self.pts_linears[i].bias.data = torch.from_numpy( | |
| np.transpose(weights[idx_pts_linears + 1])) | |
| # Load feature_linear | |
| idx_feature_linear = 2 * self.D | |
| self.feature_linear.weight.data = torch.from_numpy( | |
| np.transpose(weights[idx_feature_linear])) | |
| self.feature_linear.bias.data = torch.from_numpy( | |
| np.transpose(weights[idx_feature_linear + 1])) | |
| # Load views_linears | |
| idx_views_linears = 2 * self.D + 2 | |
| self.views_linears[0].weight.data = torch.from_numpy( | |
| np.transpose(weights[idx_views_linears])) | |
| self.views_linears[0].bias.data = torch.from_numpy( | |
| np.transpose(weights[idx_views_linears + 1])) | |
| # Load rgb_linear | |
| idx_rbg_linear = 2 * self.D + 4 | |
| self.rgb_linear.weight.data = torch.from_numpy( | |
| np.transpose(weights[idx_rbg_linear])) | |
| self.rgb_linear.bias.data = torch.from_numpy( | |
| np.transpose(weights[idx_rbg_linear + 1])) | |
| # Load alpha_linear | |
| idx_alpha_linear = 2 * self.D + 6 | |
| self.alpha_linear.weight.data = torch.from_numpy( | |
| np.transpose(weights[idx_alpha_linear])) | |
| self.alpha_linear.bias.data = torch.from_numpy( | |
| np.transpose(weights[idx_alpha_linear + 1])) | |
| class Network(nn.Module): | |
| def __init__(self): | |
| super(Network, self).__init__() | |
| self.embed_fn, input_ch = get_embedder(cfg.xyz_res) | |
| self.embeddirs_fn, input_ch_views = get_embedder(cfg.view_res) | |
| skips = [4] | |
| self.model = Nerf(D=cfg.netdepth, | |
| W=cfg.netwidth, | |
| input_ch=input_ch, | |
| skips=skips, | |
| input_ch_views=input_ch_views, | |
| use_viewdirs=cfg.use_viewdirs) | |
| # self.model_fine = Nerf(D=cfg.netdepth_fine, | |
| # W=cfg.netwidth_fine, | |
| # input_ch=input_ch, | |
| # skips=skips, | |
| # input_ch_views=input_ch_views, | |
| # use_viewdirs=cfg.use_viewdirs) | |
| def batchify(self, fn, chunk): | |
| """Constructs a version of 'fn' that applies to smaller batches. | |
| """ | |
| def ret(inputs): | |
| return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) | |
| return ret | |
| def forward(self, inputs, model=''): | |
| """Prepares inputs and applies network 'fn'. | |
| """ | |
| if model == 'fine': | |
| fn = self.model_fine | |
| else: | |
| fn = self.model | |
| inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) | |
| embedded = self.embed_fn(inputs_flat) | |
| outputs_flat = self.batchify(fn, cfg.netchunk)(embedded) | |
| outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) | |
| return outputs | |