| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Literal |
| |
|
| | from transformers.models.modernbert.configuration_modernbert import ModernBertConfig |
| |
|
| |
|
| | class ModChemBertConfig(ModernBertConfig): |
| | """ |
| | Configuration class for ModChemBert models. |
| | |
| | This configuration class extends ModernBertConfig with additional parameters specific to |
| | chemical molecule modeling and custom pooling strategies for classification/regression tasks. |
| | It accepts all arguments and keyword arguments from ModernBertConfig. |
| | |
| | Args: |
| | classifier_pooling (str, optional): Pooling strategy for sequence classification. |
| | Available options: |
| | - "cls": Use CLS token representation |
| | - "mean": Attention-weighted average pooling |
| | - "sum_mean": Sum all hidden states across layers, then mean pool over sequence (ChemLM approach) |
| | - "sum_sum": Sum all hidden states across layers, then sum pool over sequence |
| | - "mean_mean": Mean all hidden states across layers, then mean pool over sequence |
| | - "mean_sum": Mean all hidden states across layers, then sum pool over sequence |
| | - "max_cls": Element-wise max pooling over last k hidden states, then take CLS token |
| | - "cls_mha": Multi-head attention with CLS token as query and full sequence as keys/values |
| | - "max_seq_mha": Max pooling over last k states + multi-head attention with CLS as query |
| | - "mean_seq_mha": Mean pooling over last k states + multi-head attention with CLS as query |
| | - "max_seq_mean": Max pooling over last k hidden states, then mean pooling over sequence |
| | Defaults to "sum_mean". |
| | classifier_pooling_num_attention_heads (int, optional): Number of attention heads for multi-head attention |
| | pooling strategies (cls_mha, max_seq_mha, mean_seq_mha). Defaults to 4. |
| | classifier_pooling_attention_dropout (float, optional): Dropout probability for multi-head attention |
| | pooling strategies (cls_mha, max_seq_mha, mean_seq_mha). Defaults to 0.0. |
| | classifier_pooling_last_k (int, optional): Number of last hidden layers to use for max/mean pooling |
| | strategies (max_cls, max_seq_mha, mean_seq_mha, max_seq_mean). Defaults to 8. |
| | *args: Variable length argument list passed to ModernBertConfig. |
| | **kwargs: Arbitrary keyword arguments passed to ModernBertConfig. |
| | |
| | Note: |
| | This class inherits all configuration parameters from ModernBertConfig including |
| | hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, etc. |
| | """ |
| |
|
| | model_type = "modchembert" |
| |
|
| | def __init__( |
| | self, |
| | *args, |
| | classifier_pooling: Literal[ |
| | "cls", |
| | "mean", |
| | "sum_mean", |
| | "sum_sum", |
| | "mean_mean", |
| | "mean_sum", |
| | "max_cls", |
| | "cls_mha", |
| | "max_seq_mha", |
| | "mean_seq_mha", |
| | "max_seq_mean", |
| | ] = "max_seq_mha", |
| | classifier_pooling_num_attention_heads: int = 4, |
| | classifier_pooling_attention_dropout: float = 0.0, |
| | classifier_pooling_last_k: int = 8, |
| | **kwargs, |
| | ): |
| | valid_classifier_pooling_options = [ |
| | "cls", |
| | "mean", |
| | "sum_mean", |
| | "sum_sum", |
| | "mean_mean", |
| | "mean_sum", |
| | "max_cls", |
| | "cls_mha", |
| | "max_seq_mha", |
| | "mean_seq_mha", |
| | "max_seq_mean", |
| | ] |
| | if classifier_pooling not in valid_classifier_pooling_options: |
| | raise ValueError( |
| | f"Invalid value for `classifier_pooling`, should be one of {valid_classifier_pooling_options}, " |
| | f"but is {classifier_pooling}." |
| | ) |
| |
|
| | |
| | super().__init__(*args, classifier_pooling="cls", **kwargs) |
| | |
| | self.classifier_pooling = classifier_pooling |
| | self.classifier_pooling_num_attention_heads = classifier_pooling_num_attention_heads |
| | self.classifier_pooling_attention_dropout = classifier_pooling_attention_dropout |
| | self.classifier_pooling_last_k = classifier_pooling_last_k |
| | self.auto_map = { |
| | "AutoConfig": "configuration_modchembert.ModChemBertConfig", |
| | "AutoModel": "modeling_modchembert.ModChemBertModel", |
| | "AutoModelForMaskedLM": "modeling_modchembert.ModChemBertForMaskedLM", |
| | "AutoModelForSequenceClassification": "modeling_modchembert.ModChemBertForSequenceClassification", |
| | } |
| |
|