File size: 6,094 Bytes
9c396a5 a7d4876 9c396a5 a7d4876 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | ---
tags:
- generation
- protein-sequence
- rna-sequence
- pytorch
---
# Protein to RNA CDS Sequence Generation Model
This model is a custom PyTorch model designed to generate RNA CDS sequences from protein sequences. It utilizes a custom transformer-based architecture incorporating an ESM-2 encoder and a Mixture-of-Experts (MoE) layer.
## Model Architecture
The model `ActorModel_encoder_esm2` is defined in `utils.py`.
The key parameters used for instantiation are:
- `d_model`: Dimension of the model's internal representation (768).
- `nhead`: Number of attention heads (8).
- `num_encoder_layers`: Number of transformer encoder layers (8).
- `dim_feedforward`: Dimension of the feedforward network (`d_model * 2`).
- `esm2_dim`: Dimension of the ESM-2 embeddings (1280 for esm2_t33_650M_UR50D).
- `dropout`: Dropout rate (0.3).
- `num_experts`: Number of experts in the MoE layer (6).
- `top_k_experts`: Number of top experts to use (2).
- `device`: The device to run the model on.
## Files in this Repository
- `homo_mrna.pt`: The PyTorch state_dict of the trained model for Homo sapiens mRNA.
- `homo_circ.pt`: The PyTorch state_dict of the trained model for Homo sapiens circlar RNA.
- `Arabidopsis.pt`: The PyTorch state_dict of the trained model for Arabidopsis thaliana mRNA.
- `CR.pt`: The PyTorch state_dict of the trained model for Chlamydomonas reinhardtii mRNA.
- `EscherichiaColi.pt`: The PyTorch state_dict of the trained model for Escherichia coli mRNA.
- `PC.pt`: The PyTorch state_dict of the trained model for Penicillium chrysogenum mRNA.
- `TK.pt`: The PyTorch state_dict of the trained model for Thermococcus kodakarensis KOD1 mRNA.
- `utils.py`: Contains the definition of the `ActorModel_encoder_esm2` class and the `Tokenizer` class.
- `transformer_encoder_MoE.py`: Contains the definition of the `Encoder` class
- `README.md`: This file.
## How to Load the Model
Since this is a custom model, you need to download the `utils.py`,`transformer_encoder_MoE.py`, and the `.pt` file and then instantiate the model class and load the state dictionary.
1. **Download Files:**
You can download the files using the `huggingface_hub` library:
```python
from huggingface_hub import hf_hub_download
import os
repo_id = "sglin/RNARL"
local_dir = "./my_RNARL"
# Download model weights and utils.py
hf_hub_download(repo_id=repo_id, filename="homo_mrna.pt", local_dir=local_dir)
hf_hub_download(repo_id=repo_id, filename="homo_circ.pt", local_dir=local_dir)
hf_hub_download(repo_id=repo_id, filename="Arabidopsis.pt", local_dir=local_dir)
hf_hub_download(repo_id=repo_id, filename="CR.pt", local_dir=local_dir)
hf_hub_download(repo_id=repo_id, filename="EscherichiaColi.pt", local_dir=local_dir)
hf_hub_download(repo_id=repo_id, filename="PC.pt", local_dir=local_dir)
hf_hub_download(repo_id=repo_id, filename="TK.pt", local_dir=local_dir)
hf_hub_download(repo_id=repo_id, filename="utils.py", local_dir=local_dir)
hf_hub_download(repo_id=repo_id, filename="transformer_encoder_MoE.py", local_dir=local_dir)
# Now utils.py,transformer_encoder_MoE.py and model weights are in ./my_RNARL
```
2. **Import Model Class:**
```python
# Assuming you are in or have added ./my_RNARL to your path
# Example: If in local_dir
# import sys
# sys.path.append("./my_RNARL")
# from utils import Tokenizer, ActorModel_encoder_esm2
# Or if you copied utils.py to your current working directory:
from utils import Tokenizer, ActorModel_encoder_esm2
```
3. **Load ESM-2 (Dependency):**
The model requires the ESM-2 encoder. You'll need to load it separately, typically from Hugging Face Hub.
```python
from transformers import AutoTokenizer, EsmModel
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
esm2_tokenizer = AutoTokenizer.from_pretrained("esm2_t33_650M_UR50D")
esm2_model = EsmModel.from_pretrained("esm2_t33_650M_UR50D").to(device)
esm2_model.eval()
esm2_dim = esm2_model.config.hidden_size # Get the actual dimension
```
*Note:* Your original script used a local path (`./esm2_model_t33_650M_UR50D`). Users loading from the Hub will likely prefer loading directly from the official Hugging Face repo unless you explicitly provide the ESM-2 files in your repo (which is usually not necessary as they are already on the Hub).
4. **Instantiate Custom Model and Load Weights:**
Instantiate your `ActorModel_encoder_esm2` using the parameters from your training script and load the state dictionary.
```python
# Define the parameters used during training
d_model = 768
nhead = 8
num_encoder_layers = 8
dim_feedforward = d_model * 2 # or the exact value you used
dropout = 0.3
num_experts = 6
top_k_experts = 2
# vocab_size needs to match your Tokenizer
tokenizer = Tokenizer() # Instantiate your custom tokenizer
vocab_size = len(tokenizer.tokens) # Get vocab size from your tokenizer
# Instantiate the model
model = ActorModel_encoder_esm2(
vocab_size=vocab_size,
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
dim_feedforward=dim_feedforward,
esm2_dim=esm2_dim, # Use the esm2_model's dimension
dropout=dropout,
num_experts=num_experts,
top_k_experts=top_k_experts,
device=device
)
# Load the state dictionary
model_weights_path = os.path.join(local_dir, "homo_mrna.pt")
model.load_state_dict(torch.load(model_weights_path, map_location=device))
model.to(device)
model.eval()
print("Model loaded successfully!")
# Now you can use the 'model' object for inference
# Remember you also need your Tokenizer and the ESM-2 tokenizer/model
```
## Dependencies
- `torch`
- `transformers`
- `huggingface_hub`
- `pandas`
- `numpy`
- The specific ESM-2 model used (`esm2_t33_650M_UR50D` or the one you used).
## License
[ MIT, Apache 2.0]
## Contact
[linsg4521@sjtu.edu.cn]
|