| import torch |
| import torch.nn as nn |
| from typing import Optional |
| from transformers import AutoModel, AutoTokenizer, AutoModelForMaskedLM |
|
|
|
|
| """ |
| Custom models are currently supposed to load completely from AutoModel.from_pretrained(path, trust_remote_code=True) |
| """ |
|
|
|
|
| class CustomModelForEmbedding(nn.Module): |
| def __init__(self, model_path: str, dtype: torch.dtype = None): |
| super().__init__() |
| self.model = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True) |
| if hasattr(self.model, 'tokenizer'): |
| self.tokenizer = self.model.tokenizer |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = False, |
| **kwargs, |
| ) -> torch.Tensor: |
| if output_attentions: |
| out = self.model(input_ids, attention_mask=attention_mask, output_attentions=output_attentions) |
| return out.last_hidden_state, out.attentions |
| else: |
| return self.model(input_ids, attention_mask=attention_mask).last_hidden_state |
|
|
|
|
| def build_custom_model(model_path: str, masked_lm: bool = False, dtype: torch.dtype = None, **kwargs): |
| if masked_lm: |
| model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=dtype, trust_remote_code=True).eval() |
| else: |
| model = CustomModelForEmbedding(model_path, dtype=dtype).eval() |
| try: |
| tokenizer = model.tokenizer |
| except: |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| return model, tokenizer |
|
|
|
|
| def build_custom_tokenizer(model_path: str, **kwargs): |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| return tokenizer |
|
|
|
|
| if __name__ == "__main__": |
| |
| model, tokenizer = build_custom_model('answerdotai/ModernBERT-base') |
| print(model) |
| print(tokenizer) |
| seq = 'MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL' |
| encoded = tokenizer.encode(seq) |
| decoded = tokenizer.decode(encoded) |
| print(encoded) |
| print(decoded) |
|
|