File size: 7,816 Bytes
b8c7219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gradio as gr
import torch
from pathlib import Path
from helpers import load_model, parse_fasta_string
from io import StringIO
import csv
import tempfile
import transformers
# mute esm warning for weights
transformers.logging.set_verbosity_error()

# Constants
MODEL_NAME = "esm2_t33_650M_UR50D"
CURRENT_DIR = Path(__file__).parent
PATH_MODEL = CURRENT_DIR / "model"
DEVICE = "cpu"
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # arount 2 mins for fireprot
VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY")

model, tokenizer = load_model(MODEL_NAME, PATH_MODEL, DEVICE)

def predict_tm(seq_text, seq_file, threshold):
    if seq_file is not None:
        with open(seq_file.name, "r", encoding="utf-8") as f:
            fasta_str = f.read()
    elif seq_text.strip():
        fasta_str = seq_text
    else:
        return "Please provide a sequence via text or file."

    try:
        records = parse_fasta_string(fasta_str)
    except Exception as e:
        return f"FASTA parsing failed: {str(e)}"

    if not records:
        return "No valid sequences found."
    
    results = []
    for i, record in enumerate(records, 1):
        seq = record["sequence"].upper()
        if len(seq) < 20:
            return f"Sequence '{record['id']}' is too short (<20 amino acids)."
        if len(seq) > 2000:
            return f"Sequence '{record['id']}' is too long (>2000 amino acids)."

        if not set(seq).issubset(VALID_AMINO_ACIDS):
            invalid = "".join(set(seq) - VALID_AMINO_ACIDS)
            return f"Invalid characters in sequence: {invalid}"
        inputs = tokenizer(seq, return_tensors="pt", max_length=512, truncation=True, padding=True)
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = model(**inputs)
            prediction = outputs.logits.squeeze().item()
        results.append({"id": record["id"], "tm": round(prediction, 2)})
    results_sorted = sorted(results, key=lambda x: x["tm"], reverse=True)
    table = [
        [i + 1, r["id"], r["tm"], "Yes" if r["tm"] > float(threshold) else "No"]
        for i, r in enumerate(results_sorted)
    ]
    csv_buffer = StringIO()
    writer = csv.writer(csv_buffer)
    writer.writerow(["Rank", "ID", "Predicted Tm [°C]", f"Thermostable"])
    writer.writerows(table)
    csv_str = csv_buffer.getvalue()
    with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", encoding="utf-8") as tmp:
        tmp.write(csv_str)
        tmp_path = tmp.name
    return table, tmp_path

demo = gr.Blocks(theme=gr.themes.Origin())
with demo:
    with gr.Row():
        with gr.Column(scale=1):
            gr.Image("assets/TmProt_logo.png", width=100, height=100, show_label=False, show_download_button=False, container=False, show_share_button=False, show_fullscreen_button=False, interactive=False)
        with gr.Column(scale=7):
            gr.Markdown("""
                        # TmProt
                        ## Protein Thermostability Predictor
                        """)
            gr.Markdown(value="""
                        ### TmProt is a machine-learning-based protein thermostability predictor that leverages a fine-tuned ESM-2 protein language model to estimate melting temperatures (Tm) of protein sequences. It enables users to upload protein sequences in FASTA format (either pasted as text or uploaded as a file), and outputs predicted Tm values ranked by a user-defined thermostability threshold.
                        
                        **Paper:** [https://doi.org/10.64898/2026.05.07.723192](https://doi.org/10.64898/2026.05.07.723192)  
                        **GitHub:** [https://github.com/loschmidt/TmProt](https://github.com/loschmidt/TmProt)
                        """
                        )
            
        with gr.Column(scale=1):
            gr.Image("assets/logo.png", width=100, height=100, show_label=False, show_download_button=False, container=False, show_share_button=False, show_fullscreen_button=False, interactive=False)
    with gr.Row():
        with gr.Column(scale=4):
            seq_text = gr.Textbox(
                label="FASTA sequences",
                lines=6,
                placeholder=">seq\nMKTIIALSYIFCLVFA",
                value="",
            )
            seq_file = gr.File(label="Or upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
            btn = gr.Button("Predict")
            cutoff_bins = [str(x) for x in range(20, 101, 10)]
            cutoff_bar = gr.Radio(
                choices=cutoff_bins,
                label="Select thermostability threshold (°C)",
                info="Default is 60°C",
                value="60"
            )
        with gr.Column(scale=4):
            output = gr.Dataframe(headers=["Rank", "ID", "Predicted Tm [°C]", "Thermostable"], label="Results")
            download_btn = gr.DownloadButton(label="Download CSV")
    btn.click(
        predict_tm,
        inputs=[seq_text, seq_file, cutoff_bar],
        outputs=[output, download_btn]
    )
    with gr.Row():
        gr.Examples(
            examples = [
            [""">I1W5V5
MSIENLSSNKSFGGWHKQYSHVSNTLNCAMRFAIYLPPQASTGAKVPVLYWLSGLTCSDENFMQKAGAQRLAAELGIAIVAPDTSPRGEGVADDEGYDLGQGAGFYVNATQAPWNRHYQMYDYVVNELPELIESMFPVSDKRAIAGHSMGGHGALTIALRNPERYQSVSAFSPINNPVNCPWGQKAFTAYLGKDTDTWREYDASLLMRAAKQYVPALVDQGEADNFLAEQLKPEVLEAAASSNNYPLELRSHEGYDHSYYFIASFIEDHLRFHSNYLNA
             """,
                        None,           # seq_file is None (we use text)
                        "60"            # threshold
                    ],
                    [  
             """>R4YJ85
MINLEKALAGRRILIVDDLVEARSSLKKMATILGGDNIDVATDGIEAMSLIHEHEYDIVLSDYNLGRTKDGQQILEEARFTQRLRATSLFIVITGENAIDMVMGALEYDPDGYITKPYTLNMLKERLIRIITIKEELRKVNKAIDLQKYDLAIKYCLEVLDSNPRLRLPASRILGQLLMRQKRFQQALKIYSQLLNERSVSWAKLGQAICIFKLGDPNSALALLNRALVDHPLYVQCYDWIAKILLTLDKPLEAQAALEKAIVISPKAVLRQMELGRIAYENGDMVTAEPAFKYSVRLGRFSCHKSAKNYLQFVRSAQALLINPKERQTQNKANEAFRALTELKQDFSDDKDSLFEASIVESKTHLKMENLDEAKRSANDAEDMLAKLECPKIDYKLQMTETFIETDQSVKAQKMIDELKSAELSDKQIIMLNRLDNDLNGEALKRHSTSLNDQGVSHYEKGELEEAIIAFDQATHYEQAGISVLLNSIQAKISLMERDSPDKKILKNVRSLLIRIGEIAKDDERFARYSRLRKTYDRLCRAAAK
             """,
                        None,
                        "50"
                    ],
            ],
            inputs=[seq_text, seq_file, cutoff_bar],
            label="Click an example to try TmProt instantly",
            examples_per_page=2,
        )
    with gr.Row():
        with gr.Column(scale=1):
                pass
        with gr.Column(scale=7):
            gr.Markdown(value="""
                ## Features

                    - Predict protein melting temperature (Tm) from amino acid sequences

                    - Accepts input via FASTA text or FASTA file upload

                    - Supports sequences from 20 to 2000 amino acids in length

                    - Outputs a ranked table with predicted Tm and thermostability status based on user-chosen threshold

                    - CSV download option for easy export and downstream analysis

                ## Model Overview

                    - Base Model: facebook/esm2_t33_650M_UR50D (650M parameters)

                    - Fine-tuning method: LoRA (Low-Rank Adaptation) using PEFT framework

                    - Task: Regression prediction of protein melting temperature (Tm)

                    - Training Data: ProMelt dataset (merged Meltome Atlas + ProTherm) with ~45,000 protein sequences and experimental Tm values

                    - Output: Single linear regression output neuron predicting Tm in °C
            """
            )
        with gr.Column(scale=1):
            pass            

if __name__ == "__main__":
    demo.launch(share=True, allowed_paths=['./assets'])