File size: 6,094 Bytes
9c396a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7d4876
9c396a5
 
 
a7d4876
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
---
tags:
- generation
- protein-sequence
- rna-sequence
- pytorch
---

# Protein to RNA CDS Sequence Generation Model

This model is a custom PyTorch model designed to generate RNA CDS sequences from protein sequences. It utilizes a custom transformer-based architecture incorporating an ESM-2 encoder and a Mixture-of-Experts (MoE) layer.

## Model Architecture

The model `ActorModel_encoder_esm2` is defined in `utils.py`. 

The key parameters used for instantiation are:

- `d_model`: Dimension of the model's internal representation (768).
- `nhead`: Number of attention heads (8).
- `num_encoder_layers`: Number of transformer encoder layers (8).
- `dim_feedforward`: Dimension of the feedforward network (`d_model * 2`).
- `esm2_dim`: Dimension of the ESM-2 embeddings (1280 for esm2_t33_650M_UR50D).
- `dropout`: Dropout rate (0.3).
- `num_experts`: Number of experts in the MoE layer (6).
- `top_k_experts`: Number of top experts to use (2).
- `device`: The device to run the model on.

## Files in this Repository

- `homo_mrna.pt`: The PyTorch state_dict of the trained model for Homo sapiens mRNA.
- `homo_circ.pt`: The PyTorch state_dict of the trained model for Homo sapiens circlar RNA.
- `Arabidopsis.pt`: The PyTorch state_dict of the trained model for Arabidopsis thaliana mRNA.
- `CR.pt`: The PyTorch state_dict of the trained model for Chlamydomonas reinhardtii mRNA.
- `EscherichiaColi.pt`: The PyTorch state_dict of the trained model for Escherichia coli mRNA.
- `PC.pt`: The PyTorch state_dict of the trained model for Penicillium chrysogenum mRNA.
- `TK.pt`: The PyTorch state_dict of the trained model for Thermococcus kodakarensis KOD1 mRNA.
- `utils.py`: Contains the definition of the `ActorModel_encoder_esm2` class and the `Tokenizer` class.
- `transformer_encoder_MoE.py`: Contains the definition of the `Encoder` class
- `README.md`: This file.

## How to Load the Model

Since this is a custom model, you need to download the `utils.py`,`transformer_encoder_MoE.py`, and the `.pt` file and then instantiate the model class and load the state dictionary.

1.  **Download Files:**
    You can download the files using the `huggingface_hub` library:
    ```python
    from huggingface_hub import hf_hub_download
    import os

    repo_id = "sglin/RNARL"
    local_dir = "./my_RNARL"

    # Download model weights and utils.py
    hf_hub_download(repo_id=repo_id, filename="homo_mrna.pt", local_dir=local_dir)
    hf_hub_download(repo_id=repo_id, filename="homo_circ.pt", local_dir=local_dir)
    hf_hub_download(repo_id=repo_id, filename="Arabidopsis.pt", local_dir=local_dir)
    hf_hub_download(repo_id=repo_id, filename="CR.pt", local_dir=local_dir)
    hf_hub_download(repo_id=repo_id, filename="EscherichiaColi.pt", local_dir=local_dir)
    hf_hub_download(repo_id=repo_id, filename="PC.pt", local_dir=local_dir)
    hf_hub_download(repo_id=repo_id, filename="TK.pt", local_dir=local_dir)
    hf_hub_download(repo_id=repo_id, filename="utils.py", local_dir=local_dir)
    hf_hub_download(repo_id=repo_id, filename="transformer_encoder_MoE.py", local_dir=local_dir)

    # Now utils.py,transformer_encoder_MoE.py and model weights are in ./my_RNARL
    ```

2.  **Import Model Class:**
    
    ```python
    # Assuming you are in or have added ./my_RNARL to your path
    # Example: If in local_dir
    # import sys
    # sys.path.append("./my_RNARL")
    # from utils import Tokenizer, ActorModel_encoder_esm2

    # Or if you copied utils.py to your current working directory:
    from utils import Tokenizer, ActorModel_encoder_esm2
    ```

3.  **Load ESM-2 (Dependency):**
    The model requires the ESM-2 encoder. You'll need to load it separately, typically from Hugging Face Hub.
    ```python
    from transformers import AutoTokenizer, EsmModel
    import torch

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    esm2_tokenizer = AutoTokenizer.from_pretrained("esm2_t33_650M_UR50D")
    esm2_model = EsmModel.from_pretrained("esm2_t33_650M_UR50D").to(device)
    esm2_model.eval()
    esm2_dim = esm2_model.config.hidden_size # Get the actual dimension
    ```
    *Note:* Your original script used a local path (`./esm2_model_t33_650M_UR50D`). Users loading from the Hub will likely prefer loading directly from the official Hugging Face repo unless you explicitly provide the ESM-2 files in your repo (which is usually not necessary as they are already on the Hub).

4.  **Instantiate Custom Model and Load Weights:**
    Instantiate your `ActorModel_encoder_esm2` using the parameters from your training script and load the state dictionary.
    ```python
    # Define the parameters used during training
    d_model = 768
    nhead = 8
    num_encoder_layers = 8
    dim_feedforward = d_model * 2 # or the exact value you used
    dropout = 0.3
    num_experts = 6
    top_k_experts = 2
    # vocab_size needs to match your Tokenizer
    tokenizer = Tokenizer() # Instantiate your custom tokenizer
    vocab_size = len(tokenizer.tokens) # Get vocab size from your tokenizer

    # Instantiate the model
    model = ActorModel_encoder_esm2(
        vocab_size=vocab_size,
        d_model=d_model,
        nhead=nhead,
        num_encoder_layers=num_encoder_layers,
        dim_feedforward=dim_feedforward,
        esm2_dim=esm2_dim, # Use the esm2_model's dimension
        dropout=dropout,
        num_experts=num_experts,
        top_k_experts=top_k_experts,
        device=device
    )

    # Load the state dictionary
    model_weights_path = os.path.join(local_dir, "homo_mrna.pt")
    model.load_state_dict(torch.load(model_weights_path, map_location=device))
    model.to(device)
    model.eval()

    print("Model loaded successfully!")

    # Now you can use the 'model' object for inference
    # Remember you also need your Tokenizer and the ESM-2 tokenizer/model
    ```

## Dependencies

- `torch`
- `transformers`
- `huggingface_hub`
- `pandas`
- `numpy`
- The specific ESM-2 model used (`esm2_t33_650M_UR50D` or the one you used).

## License

[ MIT, Apache 2.0]

## Contact

[linsg4521@sjtu.edu.cn]