| """ |
| We use the FastPLM implementation of E1. |
| """ |
| import sys |
| import os |
| import torch |
| import torch.nn as nn |
| from typing import Optional, Union, List, Dict, Tuple |
|
|
| _FASTPLMS = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'FastPLMs') |
| if _FASTPLMS not in sys.path: |
| sys.path.insert(0, _FASTPLMS) |
|
|
| from e1_fastplms.modeling_e1 import ( |
| E1Model, |
| E1ForMaskedLM, |
| E1ForSequenceClassification, |
| E1ForTokenClassification, |
| ) |
| from .base_tokenizer import BaseSequenceTokenizer |
| from .e1_utils import E1BatchPreparer |
|
|
|
|
| presets = { |
| 'E1-150': 'Synthyra/Profluent-E1-150M', |
| 'E1-300': 'Synthyra/Profluent-E1-300M', |
| 'E1-600': 'Synthyra/Profluent-E1-600M', |
| } |
|
|
|
|
| class E1TokenizerWrapper(BaseSequenceTokenizer): |
| def __init__(self, tokenizer: E1BatchPreparer): |
| super().__init__(tokenizer) |
|
|
| def __call__(self, sequences: Union[str, List[str]], **kwargs) -> Dict[str, torch.Tensor]: |
| if isinstance(sequences, str): |
| sequences = [sequences] |
| tokenized = self.tokenizer.get_batch_kwargs(sequences) |
| return tokenized |
|
|
|
|
| class E1ForEmbedding(nn.Module): |
| def __init__(self, model_path: str, dtype: torch.dtype = None): |
| super().__init__() |
| self.e1 = E1Model.from_pretrained(model_path, dtype=dtype) |
|
|
| def forward( |
| self, |
| output_attentions: Optional[bool] = False, |
| output_hidden_states: Optional[bool] = False, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]: |
| if output_attentions: |
| out = self.e1(**kwargs, output_attentions=output_attentions) |
| return out.last_hidden_state, out.attentions |
| else: |
| return self.e1(**kwargs, output_hidden_states=False, output_attentions=False).last_hidden_state |
|
|
|
|
| def get_e1_tokenizer(preset: str, model_path: str = None): |
| tokenizer = E1BatchPreparer() |
| return E1TokenizerWrapper(tokenizer) |
|
|
|
|
| def build_e1_model(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs): |
| model_path = model_path or presets[preset] |
| if masked_lm: |
| model = E1ForMaskedLM.from_pretrained(model_path, dtype=dtype).eval() |
| else: |
| model = E1ForEmbedding(model_path, dtype=dtype).eval() |
| tokenizer = get_e1_tokenizer(preset) |
| return model, tokenizer |
|
|
|
|
| def get_e1_for_training(preset: str, tokenwise: bool = False, num_labels: int = None, hybrid: bool = False, dtype: torch.dtype = None, model_path: str = None): |
| model_path = model_path or presets[preset] |
| if hybrid: |
| model = E1Model.from_pretrained(model_path, dtype=dtype).eval() |
| else: |
| if tokenwise: |
| model = E1ForTokenClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval() |
| else: |
| model = E1ForSequenceClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval() |
| tokenizer = get_e1_tokenizer(preset) |
| return model, tokenizer |
|
|
|
|
| if __name__ == '__main__': |
| |
| model, tokenizer = build_e1_model('E1-150') |
| print(model) |
| print(tokenizer) |
| print(tokenizer(['MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL', 'MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL'])) |
|
|