File size: 7,816 Bytes
b8c7219 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | import gradio as gr
import torch
from pathlib import Path
from helpers import load_model, parse_fasta_string
from io import StringIO
import csv
import tempfile
import transformers
# mute esm warning for weights
transformers.logging.set_verbosity_error()
# Constants
MODEL_NAME = "esm2_t33_650M_UR50D"
CURRENT_DIR = Path(__file__).parent
PATH_MODEL = CURRENT_DIR / "model"
DEVICE = "cpu"
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # arount 2 mins for fireprot
VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY")
model, tokenizer = load_model(MODEL_NAME, PATH_MODEL, DEVICE)
def predict_tm(seq_text, seq_file, threshold):
if seq_file is not None:
with open(seq_file.name, "r", encoding="utf-8") as f:
fasta_str = f.read()
elif seq_text.strip():
fasta_str = seq_text
else:
return "Please provide a sequence via text or file."
try:
records = parse_fasta_string(fasta_str)
except Exception as e:
return f"FASTA parsing failed: {str(e)}"
if not records:
return "No valid sequences found."
results = []
for i, record in enumerate(records, 1):
seq = record["sequence"].upper()
if len(seq) < 20:
return f"Sequence '{record['id']}' is too short (<20 amino acids)."
if len(seq) > 2000:
return f"Sequence '{record['id']}' is too long (>2000 amino acids)."
if not set(seq).issubset(VALID_AMINO_ACIDS):
invalid = "".join(set(seq) - VALID_AMINO_ACIDS)
return f"Invalid characters in sequence: {invalid}"
inputs = tokenizer(seq, return_tensors="pt", max_length=512, truncation=True, padding=True)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
prediction = outputs.logits.squeeze().item()
results.append({"id": record["id"], "tm": round(prediction, 2)})
results_sorted = sorted(results, key=lambda x: x["tm"], reverse=True)
table = [
[i + 1, r["id"], r["tm"], "Yes" if r["tm"] > float(threshold) else "No"]
for i, r in enumerate(results_sorted)
]
csv_buffer = StringIO()
writer = csv.writer(csv_buffer)
writer.writerow(["Rank", "ID", "Predicted Tm [°C]", f"Thermostable"])
writer.writerows(table)
csv_str = csv_buffer.getvalue()
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", encoding="utf-8") as tmp:
tmp.write(csv_str)
tmp_path = tmp.name
return table, tmp_path
demo = gr.Blocks(theme=gr.themes.Origin())
with demo:
with gr.Row():
with gr.Column(scale=1):
gr.Image("assets/TmProt_logo.png", width=100, height=100, show_label=False, show_download_button=False, container=False, show_share_button=False, show_fullscreen_button=False, interactive=False)
with gr.Column(scale=7):
gr.Markdown("""
# TmProt
## Protein Thermostability Predictor
""")
gr.Markdown(value="""
### TmProt is a machine-learning-based protein thermostability predictor that leverages a fine-tuned ESM-2 protein language model to estimate melting temperatures (Tm) of protein sequences. It enables users to upload protein sequences in FASTA format (either pasted as text or uploaded as a file), and outputs predicted Tm values ranked by a user-defined thermostability threshold.
**Paper:** [https://doi.org/10.64898/2026.05.07.723192](https://doi.org/10.64898/2026.05.07.723192)
**GitHub:** [https://github.com/loschmidt/TmProt](https://github.com/loschmidt/TmProt)
"""
)
with gr.Column(scale=1):
gr.Image("assets/logo.png", width=100, height=100, show_label=False, show_download_button=False, container=False, show_share_button=False, show_fullscreen_button=False, interactive=False)
with gr.Row():
with gr.Column(scale=4):
seq_text = gr.Textbox(
label="FASTA sequences",
lines=6,
placeholder=">seq\nMKTIIALSYIFCLVFA",
value="",
)
seq_file = gr.File(label="Or upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
btn = gr.Button("Predict")
cutoff_bins = [str(x) for x in range(20, 101, 10)]
cutoff_bar = gr.Radio(
choices=cutoff_bins,
label="Select thermostability threshold (°C)",
info="Default is 60°C",
value="60"
)
with gr.Column(scale=4):
output = gr.Dataframe(headers=["Rank", "ID", "Predicted Tm [°C]", "Thermostable"], label="Results")
download_btn = gr.DownloadButton(label="Download CSV")
btn.click(
predict_tm,
inputs=[seq_text, seq_file, cutoff_bar],
outputs=[output, download_btn]
)
with gr.Row():
gr.Examples(
examples = [
[""">I1W5V5
MSIENLSSNKSFGGWHKQYSHVSNTLNCAMRFAIYLPPQASTGAKVPVLYWLSGLTCSDENFMQKAGAQRLAAELGIAIVAPDTSPRGEGVADDEGYDLGQGAGFYVNATQAPWNRHYQMYDYVVNELPELIESMFPVSDKRAIAGHSMGGHGALTIALRNPERYQSVSAFSPINNPVNCPWGQKAFTAYLGKDTDTWREYDASLLMRAAKQYVPALVDQGEADNFLAEQLKPEVLEAAASSNNYPLELRSHEGYDHSYYFIASFIEDHLRFHSNYLNA
""",
None, # seq_file is None (we use text)
"60" # threshold
],
[
""">R4YJ85
MINLEKALAGRRILIVDDLVEARSSLKKMATILGGDNIDVATDGIEAMSLIHEHEYDIVLSDYNLGRTKDGQQILEEARFTQRLRATSLFIVITGENAIDMVMGALEYDPDGYITKPYTLNMLKERLIRIITIKEELRKVNKAIDLQKYDLAIKYCLEVLDSNPRLRLPASRILGQLLMRQKRFQQALKIYSQLLNERSVSWAKLGQAICIFKLGDPNSALALLNRALVDHPLYVQCYDWIAKILLTLDKPLEAQAALEKAIVISPKAVLRQMELGRIAYENGDMVTAEPAFKYSVRLGRFSCHKSAKNYLQFVRSAQALLINPKERQTQNKANEAFRALTELKQDFSDDKDSLFEASIVESKTHLKMENLDEAKRSANDAEDMLAKLECPKIDYKLQMTETFIETDQSVKAQKMIDELKSAELSDKQIIMLNRLDNDLNGEALKRHSTSLNDQGVSHYEKGELEEAIIAFDQATHYEQAGISVLLNSIQAKISLMERDSPDKKILKNVRSLLIRIGEIAKDDERFARYSRLRKTYDRLCRAAAK
""",
None,
"50"
],
],
inputs=[seq_text, seq_file, cutoff_bar],
label="Click an example to try TmProt instantly",
examples_per_page=2,
)
with gr.Row():
with gr.Column(scale=1):
pass
with gr.Column(scale=7):
gr.Markdown(value="""
## Features
- Predict protein melting temperature (Tm) from amino acid sequences
- Accepts input via FASTA text or FASTA file upload
- Supports sequences from 20 to 2000 amino acids in length
- Outputs a ranked table with predicted Tm and thermostability status based on user-chosen threshold
- CSV download option for easy export and downstream analysis
## Model Overview
- Base Model: facebook/esm2_t33_650M_UR50D (650M parameters)
- Fine-tuning method: LoRA (Low-Rank Adaptation) using PEFT framework
- Task: Regression prediction of protein melting temperature (Tm)
- Training Data: ProMelt dataset (merged Meltome Atlas + ProTherm) with ~45,000 protein sequences and experimental Tm values
- Output: Single linear regression output neuron predicting Tm in °C
"""
)
with gr.Column(scale=1):
pass
if __name__ == "__main__":
demo.launch(share=True, allowed_paths=['./assets'])
|