| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from step1x3d_geometry.utils.typing import * |
| from step1x3d_geometry.utils.checkpoint import checkpoint |
|
|
| from .utils import init_linear |
| from .attention import ResidualAttentionBlock |
|
|
|
|
| class Perceiver(nn.Module): |
| def __init__( |
| self, |
| *, |
| n_ctx: int, |
| width: int, |
| layers: int, |
| heads: int, |
| init_scale: float = 0.25, |
| qkv_bias: bool = True, |
| qk_norm: bool = True, |
| use_flash: bool = False, |
| use_checkpoint: bool = False |
| ): |
| super().__init__() |
| self.n_ctx = n_ctx |
| self.width = width |
| self.layers = layers |
| self.resblocks = nn.ModuleList( |
| [ |
| ResidualAttentionBlock( |
| n_ctx=n_ctx, |
| width=width, |
| heads=heads, |
| init_scale=init_scale, |
| qkv_bias=qkv_bias, |
| qk_norm=qk_norm, |
| use_flash=use_flash, |
| use_checkpoint=use_checkpoint, |
| ) |
| for _ in range(layers) |
| ] |
| ) |
|
|
| def forward(self, x: torch.Tensor): |
| for block in self.resblocks: |
| x = block(x) |
| return x |
|
|