| --- |
| tags: |
| - generation |
| - protein-sequence |
| - rna-sequence |
| - pytorch |
| --- |
| |
| # Protein to RNA CDS Sequence Generation Model |
|
|
| This model is a custom PyTorch model designed to generate RNA CDS sequences from protein sequences. It utilizes a custom transformer-based architecture incorporating an ESM-2 encoder and a Mixture-of-Experts (MoE) layer. |
|
|
| ## Model Architecture |
|
|
| The model `ActorModel_encoder_esm2` is defined in `utils.py`. |
|
|
| The key parameters used for instantiation are: |
|
|
| - `d_model`: Dimension of the model's internal representation (768). |
| - `nhead`: Number of attention heads (8). |
| - `num_encoder_layers`: Number of transformer encoder layers (8). |
| - `dim_feedforward`: Dimension of the feedforward network (`d_model * 2`). |
| - `esm2_dim`: Dimension of the ESM-2 embeddings (1280 for esm2_t33_650M_UR50D). |
| - `dropout`: Dropout rate (0.3). |
| - `num_experts`: Number of experts in the MoE layer (6). |
| - `top_k_experts`: Number of top experts to use (2). |
| - `device`: The device to run the model on. |
|
|
| ## Files in this Repository |
|
|
| - `homo_mrna.pt`: The PyTorch state_dict of the trained model for Homo sapiens mRNA. |
| - `homo_circ.pt`: The PyTorch state_dict of the trained model for Homo sapiens circlar RNA. |
| - `Arabidopsis.pt`: The PyTorch state_dict of the trained model for Arabidopsis thaliana mRNA. |
| - `CR.pt`: The PyTorch state_dict of the trained model for Chlamydomonas reinhardtii mRNA. |
| - `EscherichiaColi.pt`: The PyTorch state_dict of the trained model for Escherichia coli mRNA. |
| - `PC.pt`: The PyTorch state_dict of the trained model for Penicillium chrysogenum mRNA. |
| - `TK.pt`: The PyTorch state_dict of the trained model for Thermococcus kodakarensis KOD1 mRNA. |
| - `utils.py`: Contains the definition of the `ActorModel_encoder_esm2` class and the `Tokenizer` class. |
| - `transformer_encoder_MoE.py`: Contains the definition of the `Encoder` class |
| - `README.md`: This file. |
|
|
| ## How to Load the Model |
|
|
| Since this is a custom model, you need to download the `utils.py`,`transformer_encoder_MoE.py`, and the `.pt` file and then instantiate the model class and load the state dictionary. |
|
|
| 1. **Download Files:** |
| You can download the files using the `huggingface_hub` library: |
| ```python |
| from huggingface_hub import hf_hub_download |
| import os |
| |
| repo_id = "sglin/RNARL" |
| local_dir = "./my_RNARL" |
| |
| # Download model weights and utils.py |
| hf_hub_download(repo_id=repo_id, filename="homo_mrna.pt", local_dir=local_dir) |
| hf_hub_download(repo_id=repo_id, filename="homo_circ.pt", local_dir=local_dir) |
| hf_hub_download(repo_id=repo_id, filename="Arabidopsis.pt", local_dir=local_dir) |
| hf_hub_download(repo_id=repo_id, filename="CR.pt", local_dir=local_dir) |
| hf_hub_download(repo_id=repo_id, filename="EscherichiaColi.pt", local_dir=local_dir) |
| hf_hub_download(repo_id=repo_id, filename="PC.pt", local_dir=local_dir) |
| hf_hub_download(repo_id=repo_id, filename="TK.pt", local_dir=local_dir) |
| hf_hub_download(repo_id=repo_id, filename="utils.py", local_dir=local_dir) |
| hf_hub_download(repo_id=repo_id, filename="transformer_encoder_MoE.py", local_dir=local_dir) |
| |
| # Now utils.py,transformer_encoder_MoE.py and model weights are in ./my_RNARL |
| ``` |
| |
| 2. **Import Model Class:** |
| |
| ```python |
| # Assuming you are in or have added ./my_RNARL to your path |
| # Example: If in local_dir |
| # import sys |
| # sys.path.append("./my_RNARL") |
| # from utils import Tokenizer, ActorModel_encoder_esm2 |
| |
| # Or if you copied utils.py to your current working directory: |
| from utils import Tokenizer, ActorModel_encoder_esm2 |
| ``` |
| |
| 3. **Load ESM-2 (Dependency):** |
| The model requires the ESM-2 encoder. You'll need to load it separately, typically from Hugging Face Hub. |
| ```python |
| from transformers import AutoTokenizer, EsmModel |
| import torch |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| esm2_tokenizer = AutoTokenizer.from_pretrained("esm2_t33_650M_UR50D") |
| esm2_model = EsmModel.from_pretrained("esm2_t33_650M_UR50D").to(device) |
| esm2_model.eval() |
| esm2_dim = esm2_model.config.hidden_size # Get the actual dimension |
| ``` |
| *Note:* Your original script used a local path (`./esm2_model_t33_650M_UR50D`). Users loading from the Hub will likely prefer loading directly from the official Hugging Face repo unless you explicitly provide the ESM-2 files in your repo (which is usually not necessary as they are already on the Hub). |
| |
| 4. **Instantiate Custom Model and Load Weights:** |
| Instantiate your `ActorModel_encoder_esm2` using the parameters from your training script and load the state dictionary. |
| ```python |
| # Define the parameters used during training |
| d_model = 768 |
| nhead = 8 |
| num_encoder_layers = 8 |
| dim_feedforward = d_model * 2 # or the exact value you used |
| dropout = 0.3 |
| num_experts = 6 |
| top_k_experts = 2 |
| # vocab_size needs to match your Tokenizer |
| tokenizer = Tokenizer() # Instantiate your custom tokenizer |
| vocab_size = len(tokenizer.tokens) # Get vocab size from your tokenizer |
| |
| # Instantiate the model |
| model = ActorModel_encoder_esm2( |
| vocab_size=vocab_size, |
| d_model=d_model, |
| nhead=nhead, |
| num_encoder_layers=num_encoder_layers, |
| dim_feedforward=dim_feedforward, |
| esm2_dim=esm2_dim, # Use the esm2_model's dimension |
| dropout=dropout, |
| num_experts=num_experts, |
| top_k_experts=top_k_experts, |
| device=device |
| ) |
| |
| # Load the state dictionary |
| model_weights_path = os.path.join(local_dir, "homo_mrna.pt") |
| model.load_state_dict(torch.load(model_weights_path, map_location=device)) |
| model.to(device) |
| model.eval() |
| |
| print("Model loaded successfully!") |
| |
| # Now you can use the 'model' object for inference |
| # Remember you also need your Tokenizer and the ESM-2 tokenizer/model |
| ``` |
| |
| ## Dependencies |
|
|
| - `torch` |
| - `transformers` |
| - `huggingface_hub` |
| - `pandas` |
| - `numpy` |
| - The specific ESM-2 model used (`esm2_t33_650M_UR50D` or the one you used). |
|
|
| ## License |
|
|
| [ MIT, Apache 2.0] |
|
|
| ## Contact |
|
|
| [linsg4521@sjtu.edu.cn] |
|
|