| import torch |
| import torch.nn as nn |
| from typing import Optional, Tuple, Union, List |
| from transformers import ( |
| AutoTokenizer, |
| AutoModel, |
| AutoModelForSequenceClassification, |
| AutoModelForTokenClassification |
| ) |
| from .base_tokenizer import BaseSequenceTokenizer |
|
|
|
|
| presets = { |
| "ProtCLM-1b": "biomap-research/proteinglm-1b-clm", |
| |
| |
| } |
|
|
|
|
| class ProtCLMTokenizerWrapper(BaseSequenceTokenizer): |
| def __init__(self, tokenizer: AutoTokenizer): |
| super().__init__(tokenizer) |
| def __call__(self, sequences: Union[str, List[str]], **kwargs): |
| if isinstance(sequences, str): |
| sequences = [sequences] |
| kwargs.setdefault("return_tensors", "pt") |
| kwargs.setdefault("padding", "longest") |
| kwargs.setdefault("add_special_tokens", True) |
| return self.tokenizer(sequences, **kwargs) |
|
|
| class ProtCLMForEmbedding(nn.Module): |
| def __init__(self, model_path: str, dtype: torch.dtype = None): |
| super().__init__() |
| self.plm = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| assert not output_attentions or not output_hidden_states, ( |
| "output_attentions=True and output_hidden_states=True are not supported by ProtCLMForEmbedding." |
| ) |
|
|
| out = self.plm( |
| input_ids=input_ids, |
| attention_mask=attention_mask |
| ) |
| return out.last_hidden_state |
|
|
|
|
| def get_protCLM_tokenizer(preset: str, model_path: str = None) -> BaseSequenceTokenizer: |
| return ProtCLMTokenizerWrapper( |
| AutoTokenizer.from_pretrained(model_path or presets[preset], trust_remote_code=True) |
| ) |
|
|
|
|
| def build_protCLM(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs) -> Tuple[AutoModel, BaseSequenceTokenizer]: |
| if masked_lm: |
| raise ValueError(f"Model {preset} does not support masked language modeling") |
| model_path = model_path or presets[preset] |
| model = ProtCLMForEmbedding(model_path, dtype=dtype).eval() |
| tokenizer = get_protCLM_tokenizer(preset) |
| return model, tokenizer |
|
|
|
|
| def get_protCLM_for_training( |
| preset: str, |
| tokenwise: bool = False, |
| num_labels: int = None, |
| hybrid: bool = False, |
| dtype: torch.dtype = None, |
| model_path: str = None, |
| ): |
| model_path = model_path or presets[preset] |
| if hybrid: |
| model = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True).eval() |
| else: |
| if tokenwise: |
| model = AutoModelForTokenClassification.from_pretrained( |
| model_path, num_labels=num_labels, dtype=dtype, trust_remote_code=True |
| ).eval() |
| else: |
| model = AutoModelForSequenceClassification.from_pretrained( |
| model_path, num_labels=num_labels, dtype=dtype, trust_remote_code=True |
| ).eval() |
| tokenizer = get_protCLM_tokenizer(preset) |
| return model, tokenizer |
|
|
|
|
| if __name__ == "__main__": |
| |
| model, tokenizer = build_protCLM("ProtCLM-1b") |
| print(model) |
| print(tokenizer) |
| print(tokenizer("MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL")) |
|
|