| from .generation import GenerationMixin |
| from mamba_ssm.modules.mamba2 import Mamba2 |
| from mamba_ssm.modules.mha import MHA |
| from mamba_ssm.modules.mlp import GatedMLP |
| from mamba_ssm.modules.block import Block |
| from mamba_ssm.models.mixer_seq_simple import _init_weights |
| from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf |
| from .bamba_modules import BertEmbeddings, BertPooler, BertPreTrainingHeads, BlockCrossAttention |
| from .bamba_config import BambaConfig, BambaEncoderDecoderConfig |
|
|
| try: |
| from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn |
| except ImportError: |
| RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None |
|
|
| from typing import List, Optional, Tuple, Union |
| from collections import namedtuple |
| import torch.backends.cudnn as cudnn |
| import math |
| import random |
| from functools import partial |
| import json |
| import os |
| import copy |
| import torch |
| import torch.nn as nn |
| import pandas as pd |
| import numpy as np |
| import gc |
| from tqdm import tqdm |
|
|
|
|
| def create_block( |
| d_model, |
| d_intermediate, |
| block_class, |
| ssm_cfg=None, |
| attn_layer_idx=None, |
| attn_cfg=None, |
| norm_epsilon=1e-5, |
| rms_norm=False, |
| residual_in_fp32=False, |
| fused_add_norm=False, |
| layer_idx=None, |
| device=None, |
| dtype=None, |
| ): |
| if ssm_cfg is None: |
| ssm_cfg = {} |
| if attn_layer_idx is None: |
| attn_layer_idx = [] |
| if attn_cfg is None: |
| attn_cfg = {} |
| factory_kwargs = {"device": device, "dtype": dtype} |
| if layer_idx not in attn_layer_idx: |
| |
| ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {} |
| ssm_layer = ssm_cfg.pop("layer", "Mamba1") |
| if ssm_layer not in ["Mamba1", "Mamba2"]: |
| raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2") |
| mixer_cls = partial( |
| Mamba2 if ssm_layer == "Mamba2" else Mamba, |
| layer_idx=layer_idx, |
| **ssm_cfg, |
| **factory_kwargs |
| ) |
| else: |
| mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs) |
| norm_cls = partial( |
| nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs |
| ) |
| if d_intermediate == 0: |
| mlp_cls = nn.Identity |
| else: |
| mlp_cls = partial( |
| GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs |
| ) |
| block = block_class( |
| d_model, |
| mixer_cls, |
| mlp_cls, |
| norm_cls=norm_cls, |
| fused_add_norm=fused_add_norm, |
| residual_in_fp32=residual_in_fp32, |
| ) |
| if isinstance(block, BlockCrossAttention) and factory_kwargs["dtype"] is not None: |
| block.encoder_attn.type(factory_kwargs["dtype"]).to(factory_kwargs["device"]) |
| block.layer_idx = layer_idx |
| return block |
|
|
|
|
| class BambaMixerModel(nn.Module): |
| def __init__( |
| self, |
| d_model: int, |
| n_layer: int, |
| d_intermediate: int, |
| vocab_size: int, |
| max_position_embeddings: int, |
| is_decoder: bool = False, |
| ssm_cfg=None, |
| attn_layer_idx=None, |
| attn_cfg=None, |
| norm_epsilon: float = 1e-5, |
| rms_norm: bool = False, |
| initializer_cfg=None, |
| fused_add_norm=False, |
| residual_in_fp32=False, |
| device=None, |
| dtype=None, |
| ) -> None: |
| factory_kwargs = {"device": device, "dtype": dtype} |
| super().__init__() |
| self.residual_in_fp32 = residual_in_fp32 |
|
|
| self.is_decoder = is_decoder |
|
|
| if is_decoder: |
| self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) |
| else: |
| self.embedding = BertEmbeddings(vocab_size, d_model, max_position_embeddings, **factory_kwargs) |
|
|
| |
| |
| |
| |
| |
| self.fused_add_norm = fused_add_norm |
| if self.fused_add_norm: |
| if layer_norm_fn is None or rms_norm_fn is None: |
| raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") |
|
|
| if is_decoder: |
| block_class = BlockCrossAttention |
| else: |
| block_class = Block |
|
|
| self.layers = nn.ModuleList( |
| [ |
| create_block( |
| d_model, |
| d_intermediate=d_intermediate, |
| block_class=block_class, |
| ssm_cfg=ssm_cfg, |
| attn_layer_idx=attn_layer_idx, |
| attn_cfg=attn_cfg, |
| norm_epsilon=norm_epsilon, |
| rms_norm=rms_norm, |
| residual_in_fp32=residual_in_fp32, |
| fused_add_norm=fused_add_norm, |
| layer_idx=i, |
| **factory_kwargs, |
| ) |
| for i in range(n_layer) |
| ] |
| ) |
|
|
| self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( |
| d_model, eps=norm_epsilon, **factory_kwargs |
| ) |
|
|
| if not is_decoder: |
| self.pooler = BertPooler(d_model, **factory_kwargs) |
|
|
| self.apply( |
| partial( |
| _init_weights, |
| n_layer=n_layer, |
| **(initializer_cfg if initializer_cfg is not None else {}), |
| n_residuals_per_layer=1 if d_intermediate == 0 else 2, |
| ) |
| ) |
|
|
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| return { |
| i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) |
| for i, layer in enumerate(self.layers) |
| } |
|
|
| def forward(self, input_ids, token_type_ids=None, inference_params=None, encoder_hidden_states=None, attention_mask=None, **mixer_kwargs): |
| if self.is_decoder: |
| hidden_states = self.embedding(input_ids) |
| else: |
| hidden_states = self.embedding(input_ids, token_type_ids) |
| residual = None |
| for layer in self.layers: |
| if self.is_decoder: |
| hidden_states, residual = layer( |
| hidden_states, residual, inference_params=inference_params, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **mixer_kwargs |
| ) |
| else: |
| hidden_states, residual = layer( |
| hidden_states, residual, inference_params=inference_params, **mixer_kwargs |
| ) |
| if not self.fused_add_norm: |
| residual = (hidden_states + residual) if residual is not None else hidden_states |
| hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) |
| else: |
| |
| hidden_states = layer_norm_fn( |
| hidden_states, |
| self.norm_f.weight, |
| self.norm_f.bias, |
| eps=self.norm_f.eps, |
| residual=residual, |
| prenorm=False, |
| residual_in_fp32=self.residual_in_fp32, |
| is_rms_norm=isinstance(self.norm_f, RMSNorm) |
| ) |
| if not self.is_decoder: |
| pooled_output = self.pooler(hidden_states) |
| return hidden_states, pooled_output |
| return hidden_states |
|
|
|
|
| class BambaEncoder(nn.Module): |
|
|
| def __init__( |
| self, |
| config: BambaConfig, |
| initializer_cfg=None, |
| device=None, |
| dtype=None, |
| ) -> None: |
| self.config = config |
| d_model = config.d_model |
| n_layer = config.n_layer |
| d_intermediate = config.d_intermediate |
| vocab_size = config.vocab_size |
| max_position_embeddings = config.max_position_embeddings |
| ssm_cfg = config.ssm_cfg |
| attn_layer_idx = config.attn_layer_idx |
| attn_cfg = config.attn_cfg |
| rms_norm = config.rms_norm |
| residual_in_fp32 = config.residual_in_fp32 |
| fused_add_norm = config.fused_add_norm |
| pad_vocab_size_multiple = config.pad_vocab_size_multiple |
| factory_kwargs = {"device": device, "dtype": dtype} |
|
|
| super().__init__() |
| if vocab_size % pad_vocab_size_multiple != 0: |
| vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) |
| self.backbone = BambaMixerModel( |
| d_model=d_model, |
| n_layer=n_layer, |
| d_intermediate=d_intermediate, |
| vocab_size=vocab_size, |
| max_position_embeddings=max_position_embeddings, |
| is_decoder=False, |
| ssm_cfg=ssm_cfg, |
| attn_layer_idx=attn_layer_idx, |
| attn_cfg=attn_cfg, |
| rms_norm=rms_norm, |
| initializer_cfg=initializer_cfg, |
| fused_add_norm=fused_add_norm, |
| residual_in_fp32=residual_in_fp32, |
| **factory_kwargs, |
| ) |
| self.cls = BertPreTrainingHeads(vocab_size, d_model, **factory_kwargs) |
|
|
| |
| self.apply( |
| partial( |
| _init_weights, |
| n_layer=n_layer, |
| **(initializer_cfg if initializer_cfg is not None else {}), |
| ) |
| ) |
| self.tie_weights() |
|
|
| def tie_weights(self): |
| if self.config.tie_embeddings: |
| self.lm_head.weight = self.backbone.embedding.weight |
|
|
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) |
|
|
| def forward(self, input_ids, token_type_ids=None, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs): |
| """ |
| "position_ids" is just to be compatible with Transformer generation. We don't use it. |
| num_last_tokens: if > 0, only return the logits for the last n tokens |
| """ |
| hidden_states, pooled_output = self.backbone(input_ids, token_type_ids, inference_params=inference_params, **mixer_kwargs) |
| if num_last_tokens > 0: |
| hidden_states = hidden_states[:, -num_last_tokens:] |
| lm_logits, seq_relationship_score = self.cls(hidden_states, pooled_output) |
| CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "seq_relationship_logits", "hidden_states"]) |
| return CausalLMOutput(logits=lm_logits, seq_relationship_logits=seq_relationship_score, hidden_states=hidden_states) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): |
| config_data = load_config_hf(pretrained_model_name) |
| config = MambaConfig(**config_data) |
| model = cls(config, device=device, dtype=dtype, **kwargs) |
| model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) |
| return model |
|
|
| def save_pretrained(self, save_directory): |
| """ |
| Minimal implementation of save_pretrained for MambaLMHeadModel. |
| Save the model and its configuration file to a directory. |
| """ |
| |
| os.makedirs(save_directory, exist_ok=True) |
|
|
| |
| model_path = os.path.join(save_directory, 'pytorch_model.bin') |
| torch.save(self.state_dict(), model_path) |
|
|
| |
| config_path = os.path.join(save_directory, 'config.json') |
| with open(config_path, 'w') as f: |
| json.dump(self.config.__dict__, f, indent=4) |
|
|
|
|
| class BambaDecoder(nn.Module, GenerationMixin): |
|
|
| def __init__( |
| self, |
| config: BambaConfig, |
| initializer_cfg=None, |
| device=None, |
| dtype=None, |
| ) -> None: |
| self.config = config |
| d_model = config.d_model |
| n_layer = config.n_layer |
| d_intermediate = config.d_intermediate |
| vocab_size = config.vocab_size |
| max_position_embeddings = config.max_position_embeddings |
| ssm_cfg = config.ssm_cfg |
| attn_layer_idx = config.attn_layer_idx |
| attn_cfg = config.attn_cfg |
| rms_norm = config.rms_norm |
| residual_in_fp32 = config.residual_in_fp32 |
| fused_add_norm = config.fused_add_norm |
| pad_vocab_size_multiple = config.pad_vocab_size_multiple |
| factory_kwargs = {"device": device, "dtype": dtype} |
|
|
| super().__init__() |
| if vocab_size % pad_vocab_size_multiple != 0: |
| vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) |
| self.backbone = BambaMixerModel( |
| d_model=d_model, |
| n_layer=n_layer, |
| d_intermediate=d_intermediate, |
| vocab_size=vocab_size, |
| max_position_embeddings=max_position_embeddings, |
| is_decoder=True, |
| ssm_cfg=ssm_cfg, |
| attn_layer_idx=attn_layer_idx, |
| attn_cfg=attn_cfg, |
| rms_norm=rms_norm, |
| initializer_cfg=initializer_cfg, |
| fused_add_norm=fused_add_norm, |
| residual_in_fp32=residual_in_fp32, |
| **factory_kwargs, |
| ) |
| self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) |
|
|
| |
| self.apply( |
| partial( |
| _init_weights, |
| n_layer=n_layer, |
| **(initializer_cfg if initializer_cfg is not None else {}), |
| ) |
| ) |
| self.tie_weights() |
|
|
| def tie_weights(self): |
| if self.config.tie_embeddings: |
| self.lm_head.weight = self.backbone.embedding.weight |
|
|
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) |
|
|
| def forward(self, input_ids, token_type_ids=None, position_ids=None, inference_params=None, num_last_tokens=0, encoder_hidden_states=None, attention_mask=None, **mixer_kwargs): |
| """ |
| "position_ids" is just to be compatible with Transformer generation. We don't use it. |
| num_last_tokens: if > 0, only return the logits for the last n tokens |
| """ |
| hidden_states = self.backbone( |
| input_ids, token_type_ids, inference_params=inference_params, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **mixer_kwargs |
| ) |
| if num_last_tokens > 0: |
| hidden_states = hidden_states[:, -num_last_tokens:] |
| lm_logits = self.lm_head(hidden_states) |
| CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) |
| return CausalLMOutput(logits=lm_logits) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): |
| config_data = load_config_hf(pretrained_model_name) |
| config = MambaConfig(**config_data) |
| model = cls(config, device=device, dtype=dtype, **kwargs) |
| model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) |
| return model |
|
|
| def save_pretrained(self, save_directory): |
| """ |
| Minimal implementation of save_pretrained for MambaLMHeadModel. |
| Save the model and its configuration file to a directory. |
| """ |
| |
| os.makedirs(save_directory, exist_ok=True) |
|
|
| |
| model_path = os.path.join(save_directory, 'pytorch_model.bin') |
| torch.save(self.state_dict(), model_path) |
|
|
| |
| config_path = os.path.join(save_directory, 'config.json') |
| with open(config_path, 'w') as f: |
| json.dump(self.config.__dict__, f, indent=4) |
|
|
|
|
| class BambaEncoderDecoder(nn.Module, GenerationMixin): |
|
|
| def __init__( |
| self, |
| config: BambaEncoderDecoderConfig, |
| tokenizer=None, |
| initializer_cfg=None, |
| device=None, |
| dtype=None, |
| ) -> None: |
| self.config = config |
| self.encoder_config = config.encoder_config |
| self.decoder_config = config.decoder_config |
| factory_kwargs = {"device": device, "dtype": dtype} |
| self.tokenizer = tokenizer |
|
|
| super().__init__() |
| self.encoder = BambaEncoder(self.encoder_config, **factory_kwargs) |
| self.decoder = BambaDecoder(self.decoder_config, **factory_kwargs) |
|
|
| self.device = device |
|
|
| self.tie_weights() |
| self._set_seed(config.seed) |
|
|
| def tie_weights(self): |
| if self.config.tie_word_embeddings: |
| self.decoder.backbone.embedding.weight = self.encoder.backbone.embedding.word_embeddings.weight |
|
|
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
| return self.decoder.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) |
|
|
| def forward(self, encoder_input_ids, decoder_input_ids, token_type_ids=None, attention_mask=None, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs): |
| """ |
| "position_ids" is just to be compatible with Transformer generation. We don't use it. |
| num_last_tokens: if > 0, only return the logits for the last n tokens |
| """ |
| encoder_hidden_states = self.encoder(encoder_input_ids, inference_params=inference_params, **mixer_kwargs).hidden_states |
| lm_logits = self.decoder(decoder_input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, inference_params=inference_params, **mixer_kwargs).logits |
| if num_last_tokens > 0: |
| hidden_states = hidden_states[:, -num_last_tokens:] |
| CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) |
| return CausalLMOutput(logits=lm_logits) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): |
| config_data = load_config_hf(pretrained_model_name) |
| config = MambaConfig(**config_data) |
| model = cls(config, device=device, dtype=dtype, **kwargs) |
| model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) |
| return model |
|
|
| def save_pretrained(self, save_directory): |
| """ |
| Minimal implementation of save_pretrained for MambaLMHeadModel. |
| Save the model and its configuration file to a directory. |
| """ |
| |
| os.makedirs(save_directory, exist_ok=True) |
|
|
| |
| model_path = os.path.join(save_directory, 'pytorch_model.bin') |
| torch.save(self.state_dict(), model_path) |
|
|
| |
| config_path = os.path.join(save_directory, 'config.json') |
| with open(config_path, 'w') as f: |
| json.dump(self.config.__dict__, f, indent=4) |
|
|
| def _set_seed(self, value): |
| print('Random Seed:', value) |
| random.seed(value) |
| torch.manual_seed(value) |
| torch.cuda.manual_seed(value) |
| torch.cuda.manual_seed_all(value) |
| np.random.seed(value) |
| cudnn.deterministic = True |
| cudnn.benchmark = False |
|
|
| def extract_embeddings(self, smiles): |
| tokens = self.tokenizer(smiles, padding=True, truncation=True, return_tensors='pt') |
| |
| idx = tokens['input_ids'].to(self.device) |
| mask = tokens['attention_mask'].to(self.device) |
| outputs = self.encoder(input_ids=idx) |
| hidden_states = outputs.hidden_states |
|
|
| token_embeddings = hidden_states |
| input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
| embeddings = sum_embeddings / sum_mask |
|
|
| return embeddings |
|
|
| def encode(self, smiles, useCuda=False, batch_size=100, return_torch=False): |
| """Extract efficiently SMILES embeddings per batches.""" |
| |
|
|
| |
| smiles = pd.Series(smiles) if isinstance(smiles, str) else pd.Series(list(smiles)) |
|
|
| |
| n_split = smiles.shape[0] // batch_size if smiles.shape[0] >= batch_size else smiles.shape[0] |
| embeddings = [ |
| self.extract_embeddings(list(batch)).cpu().detach().numpy() |
| for batch in tqdm(np.array_split(smiles, n_split)) |
| ] |
| flat_list = [item for sublist in embeddings for item in sublist] |
|
|
| |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| if return_torch: |
| return torch.tensor(flat_list) |
| return pd.DataFrame(flat_list) |
|
|