Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from .layer import QK_Norm_TransformerBlock, init_weights | |
| class LVSMTransformer(nn.Module): | |
| def __init__(self, | |
| dim=768, | |
| d_head=64, | |
| n_layer=24, | |
| special_init=True, | |
| depth_init=True, | |
| use_qk_norm=True, | |
| ): | |
| super().__init__() | |
| # Create transformer blocks | |
| self.transformer_blocks = [ | |
| QK_Norm_TransformerBlock( | |
| dim, d_head, use_qk_norm=use_qk_norm | |
| ) for _ in range(n_layer) | |
| ] | |
| # Apply special initialization if configured | |
| if special_init: | |
| for idx, block in enumerate(self.transformer_blocks): | |
| if depth_init: | |
| weight_init_std = 0.02 / (2 * (idx + 1)) ** 0.5 | |
| else: | |
| weight_init_std = 0.02 / (2 * n_layer) ** 0.5 | |
| block.apply(lambda module: init_weights(module, weight_init_std)) | |
| else: | |
| for block in self.transformer_blocks: | |
| block.apply(init_weights) | |
| self.transformer_blocks = nn.ModuleList(self.transformer_blocks) | |
| def forward(self, x): | |
| for blk in self.transformer_blocks: | |
| x = blk(x) | |
| return x | |
| if __name__ == '__main__': | |
| device = torch.device('cuda') | |
| model = LVSMTransformer().to(device) | |
| x = torch.randn(2, 64, 768).to(device) | |
| with torch.autocast('cuda', dtype=torch.bfloat16): | |
| y = model(x) | |
| print(y.shape) | |