| import torch
|
| from mamba_block import MambaBlock
|
| from mamba_config import MambaConfig
|
| from mamba_layer import MambaLayer
|
|
|
|
|
| config = MambaConfig(
|
| hidden_size=512,
|
| num_layers=6,
|
| num_heads=8,
|
| intermediate_size=2048,
|
| max_position_embeddings=1024,
|
| rms_norm=False,
|
| residual_in_fp32=False,
|
| fused_add_norm=False,
|
| )
|
|
|
|
|
| class MambaModel(torch.nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
| self.config = config
|
| self.layers = torch.nn.ModuleList([MambaBlock(config, MambaLayer) for _ in range(config.num_layers)])
|
| self.norm = torch.nn.LayerNorm(config.hidden_size)
|
|
|
| def forward(self, hidden_states: torch.Tensor):
|
| residual = None
|
| for layer in self.layers:
|
| hidden_states, residual = layer(hidden_states, residual)
|
| hidden_states = self.norm(hidden_states + residual if residual is not None else hidden_states)
|
| return hidden_states
|
|
|
|
|
| mamba_model = MambaModel(config)
|
| mamba_model.eval()
|
|
|
|
|
| def generate_text(prompt, model, max_length=50):
|
|
|
| hidden_states = torch.randn(1, len(prompt), config.hidden_size)
|
|
|
| with torch.no_grad():
|
| output = model(hidden_states)
|
|
|
|
|
|
|
| generated_text = "這裡是生成的文本"
|
|
|
| return generated_text
|
|
|
|
|
| def generate_uncensored_text(prompt, max_length=50):
|
| mamba_text = generate_text(prompt, mamba_model, max_length)
|
| return mamba_text
|
|
|
|
|
| prompt = "I want to generate some uncensored text."
|
| uncensored_text = generate_uncensored_text(prompt)
|
| print(uncensored_text)
|
|
|