Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import MambaConfig, MambaModel, Mamba2Config, Mamba2Model | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"CUDA device: {torch.cuda.get_device_name()}") | |
| print(f"CUDA version: {torch.version.cuda}") | |
| batch, channel, height, width = 256, 16, 8, 8 | |
| x = torch.randn(batch, channel, height, width).to("cuda") | |
| print(f'x: {x.shape}') | |
| B, C, H, W = x.shape | |
| x = x.permute(0, 2, 3, 1) # [B, H, W, C] | |
| print(f'Permuted x: {x.shape}') | |
| x = x.reshape(B, H * W, C) # [B, L, C], L = H * W | |
| print(f'Reshaped x: {x.shape}') | |
| # Initializing a Mamba configuration | |
| configuration = MambaConfig(vocab_size=0, hidden_size=channel, num_hidden_layers=2) | |
| # configuration = Mamba2Config(hidden_size=channel) | |
| # Initializing a model (with random weights) from the configuration | |
| model = MambaModel(configuration).to("cuda") | |
| # model = Mamba2Model(configuration).to("cuda") | |
| print(f'Model: {model}') | |
| # Accessing the model configuration | |
| configuration = model.config | |
| print(f'Configuration: {configuration}') | |
| # y = model(inputs_embeds=x).last_hidden_state | |
| y = model(inputs_embeds=x, return_dict=True)[0] | |
| print(f'y: {y.shape}') | |
| y = y.reshape(B, H, W, -1) | |
| print(f'Reshaped y: {y.shape}') | |
| y = y.permute(0, 3, 1, 2) # [B, C, H, W] | |
| print(f'Permuted y: {y.shape}') | |