Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import math | |
| from einops import rearrange | |
| from vit.vision_transformer import MemEffAttention, Attention | |
| # from xformers.triton import FusedLayerNorm as LayerNorm | |
| from torch.nn import LayerNorm | |
| from xformers.components.feedforward import fused_mlp | |
| # from xformers.components.feedforward import mlp | |
| from xformers.components.activations import build_activation, Activation | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, num_octaves=8, start_octave=0): | |
| super().__init__() | |
| self.num_octaves = num_octaves | |
| self.start_octave = start_octave | |
| def forward(self, coords, rays=None): | |
| embed_fns = [] | |
| batch_size, num_points, dim = coords.shape | |
| octaves = torch.arange(self.start_octave, | |
| self.start_octave + self.num_octaves) | |
| octaves = octaves.float().to(coords) | |
| multipliers = 2**octaves * math.pi | |
| coords = coords.unsqueeze(-1) | |
| while len(multipliers.shape) < len(coords.shape): | |
| multipliers = multipliers.unsqueeze(0) | |
| scaled_coords = coords * multipliers | |
| sines = torch.sin(scaled_coords).reshape(batch_size, num_points, | |
| dim * self.num_octaves) | |
| cosines = torch.cos(scaled_coords).reshape(batch_size, num_points, | |
| dim * self.num_octaves) | |
| result = torch.cat((sines, cosines), -1) | |
| return result | |
| class RayEncoder(nn.Module): | |
| def __init__(self, | |
| pos_octaves=8, | |
| pos_start_octave=0, | |
| ray_octaves=4, | |
| ray_start_octave=0): | |
| super().__init__() | |
| self.pos_encoding = PositionalEncoding(num_octaves=pos_octaves, | |
| start_octave=pos_start_octave) | |
| self.ray_encoding = PositionalEncoding(num_octaves=ray_octaves, | |
| start_octave=ray_start_octave) | |
| def forward(self, pos, rays): | |
| if len(rays.shape) == 4: | |
| batchsize, height, width, dims = rays.shape | |
| pos_enc = self.pos_encoding(pos.unsqueeze(1)) | |
| pos_enc = pos_enc.view(batchsize, pos_enc.shape[-1], 1, 1) | |
| pos_enc = pos_enc.repeat(1, 1, height, width) | |
| rays = rays.flatten(1, 2) | |
| ray_enc = self.ray_encoding(rays) | |
| ray_enc = ray_enc.view(batchsize, height, width, ray_enc.shape[-1]) | |
| ray_enc = ray_enc.permute((0, 3, 1, 2)) | |
| x = torch.cat((pos_enc, ray_enc), 1) | |
| else: | |
| pos_enc = self.pos_encoding(pos) | |
| ray_enc = self.ray_encoding(rays) | |
| x = torch.cat((pos_enc, ray_enc), -1) | |
| return x | |
| # Transformer implementation based on ViT | |
| # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.norm = LayerNorm(dim) | |
| self.fn = fn | |
| def forward(self, x, **kwargs): | |
| return self.fn(self.norm(x), **kwargs) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout=0.): | |
| super().__init__() | |
| self.net = nn.Sequential(nn.Linear(dim, hidden_dim), nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, | |
| dim), nn.Dropout(dropout)) | |
| def forward(self, x): | |
| return self.net(x) | |
| # class Attention(nn.Module): | |
| # def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None): | |
| # super().__init__() | |
| # inner_dim = dim_head * heads | |
| # project_out = not (heads == 1 and dim_head == dim) | |
| # self.heads = heads | |
| # self.scale = dim_head ** -0.5 | |
| # self.attend = nn.Softmax(dim=-1) | |
| # if selfatt: | |
| # self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) | |
| # else: | |
| # self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
| # self.to_kv = nn.Linear(kv_dim, inner_dim * 2, bias=False) | |
| # self.to_out = nn.Sequential( | |
| # nn.Linear(inner_dim, dim), | |
| # nn.Dropout(dropout) | |
| # ) if project_out else nn.Identity() | |
| # def forward(self, x, z=None): | |
| # if z is None: | |
| # qkv = self.to_qkv(x).chunk(3, dim=-1) | |
| # else: | |
| # q = self.to_q(x) | |
| # k, v = self.to_kv(z).chunk(2, dim=-1) | |
| # qkv = (q, k, v) | |
| # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) | |
| # dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale | |
| # attn = self.attend(dots) | |
| # out = torch.matmul(attn, v) | |
| # out = rearrange(out, 'b h n d -> b n (h d)') | |
| # return self.to_out(out) | |
| class Transformer(nn.Module): | |
| def __init__(self, | |
| dim, | |
| depth, | |
| heads, | |
| mlp_dim, | |
| dropout=0., | |
| selfatt=True, | |
| kv_dim=None, | |
| no_flash_op=False,): | |
| super().__init__() | |
| # if no_flash_op: | |
| # attn_cls = Attention # raw torch attention | |
| # else: | |
| attn_cls = MemEffAttention | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(depth): | |
| self.layers.append( | |
| nn.ModuleList([ | |
| PreNorm(dim, | |
| attn_cls( | |
| dim, | |
| num_heads=heads, | |
| qkv_bias=True, | |
| qk_norm=True, # as in vit-22B | |
| no_flash_op=no_flash_op, | |
| )), | |
| PreNorm( | |
| dim, | |
| fused_mlp.FusedMLP(dim, | |
| # mlp.MLP(dim, | |
| hidden_layer_multiplier=mlp_dim // | |
| dim, | |
| dropout=dropout, | |
| activation=Activation.GeLU)) | |
| ])) | |
| def forward(self, x): | |
| for attn, ff in self.layers: # type: ignore | |
| x = attn(x) + x | |
| x = ff(x) + x | |
| return x | |