| | import config |
| | import torch |
| | import torch.nn as nn |
| | from pretrained_models import load_esm2_model |
| | from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel |
| |
|
| | class MembraneTokenizer: |
| | def __init__(self, esm_model_path=config.ESM_MODEL_PATH): |
| | self.tokenizer = AutoTokenizer.from_pretrained(esm_model_path) |
| | |
| | def __getattr__(self, name): |
| | return getattr(self.tokenizer, name) |
| |
|
| | def __call__(self, *args, **kwargs): |
| | return self.tokenizer(*args, **kwargs) |
| | |
| | def save_tokenizer(self, save_dir): |
| | self.tokenizer.save_pretrained(save_dir) |
| | |
| | def load_tokenizer(self, load_dir): |
| | self.tokenizer.save_pretrained(load_dir) |
| |
|
| | class MembraneMLM: |
| | def __init__(self, esm_model_path=config.ESM_MODEL_PATH): |
| | self.model = AutoModelForMaskedLM.from_pretrained(esm_model_path) |
| | self.tokenizer = AutoTokenizer.from_pretrained(esm_model_path) |
| | |
| | def __getattr__(self, name): |
| | return getattr(self.model, name) |
| |
|
| | def __call__(self, *args, **kwargs): |
| | return self.model(*args, **kwargs) |
| |
|
| | def freeze_model(self): |
| | |
| | for param in self.model.parameters(): |
| | param.requires_grad = False |
| |
|
| | def unfreeze_n_layers(self): |
| | |
| | model_layers = len(self.model.esm.encoder.layer) |
| |
|
| | |
| | for i, layer in enumerate(self.model.esm.encoder.layer): |
| | if i >= model_layers-config.ESM_LAYERS: |
| | for module in layer.attention.self.key.modules(): |
| | for param in module.parameters(): |
| | param.requires_grad = True |
| | for module in layer.attention.self.query.modules(): |
| | for param in module.parameters(): |
| | param.requires_grad = True |
| | for module in layer.attention.self.value.modules(): |
| | for param in module.parameters(): |
| | param.requires_grad = True |
| | |
| | def forward(self, **inputs): |
| | return self.model(**inputs) |
| |
|
| | def save_model(self, save_dir): |
| | self.model.save_pretrained(save_dir) |
| | self.tokenizer.save_pretrained(save_dir) |
| |
|
| | def load_model(self, load_dir): |
| | self.model = AutoModel.from_pretrained(load_dir) |
| | self.tokenizer = AutoTokenizer.from_pretrained(load_dir) |