tmprot / app.py
GitLab CI
Latest changes
b8c7219
import gradio as gr
import torch
from pathlib import Path
from helpers import load_model, parse_fasta_string
from io import StringIO
import csv
import tempfile
import transformers
# mute esm warning for weights
transformers.logging.set_verbosity_error()
# Constants
MODEL_NAME = "esm2_t33_650M_UR50D"
CURRENT_DIR = Path(__file__).parent
PATH_MODEL = CURRENT_DIR / "model"
DEVICE = "cpu"
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # arount 2 mins for fireprot
VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY")
model, tokenizer = load_model(MODEL_NAME, PATH_MODEL, DEVICE)
def predict_tm(seq_text, seq_file, threshold):
if seq_file is not None:
with open(seq_file.name, "r", encoding="utf-8") as f:
fasta_str = f.read()
elif seq_text.strip():
fasta_str = seq_text
else:
return "Please provide a sequence via text or file."
try:
records = parse_fasta_string(fasta_str)
except Exception as e:
return f"FASTA parsing failed: {str(e)}"
if not records:
return "No valid sequences found."
results = []
for i, record in enumerate(records, 1):
seq = record["sequence"].upper()
if len(seq) < 20:
return f"Sequence '{record['id']}' is too short (<20 amino acids)."
if len(seq) > 2000:
return f"Sequence '{record['id']}' is too long (>2000 amino acids)."
if not set(seq).issubset(VALID_AMINO_ACIDS):
invalid = "".join(set(seq) - VALID_AMINO_ACIDS)
return f"Invalid characters in sequence: {invalid}"
inputs = tokenizer(seq, return_tensors="pt", max_length=512, truncation=True, padding=True)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
prediction = outputs.logits.squeeze().item()
results.append({"id": record["id"], "tm": round(prediction, 2)})
results_sorted = sorted(results, key=lambda x: x["tm"], reverse=True)
table = [
[i + 1, r["id"], r["tm"], "Yes" if r["tm"] > float(threshold) else "No"]
for i, r in enumerate(results_sorted)
]
csv_buffer = StringIO()
writer = csv.writer(csv_buffer)
writer.writerow(["Rank", "ID", "Predicted Tm [°C]", f"Thermostable"])
writer.writerows(table)
csv_str = csv_buffer.getvalue()
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", encoding="utf-8") as tmp:
tmp.write(csv_str)
tmp_path = tmp.name
return table, tmp_path
demo = gr.Blocks(theme=gr.themes.Origin())
with demo:
with gr.Row():
with gr.Column(scale=1):
gr.Image("assets/TmProt_logo.png", width=100, height=100, show_label=False, show_download_button=False, container=False, show_share_button=False, show_fullscreen_button=False, interactive=False)
with gr.Column(scale=7):
gr.Markdown("""
# TmProt
## Protein Thermostability Predictor
""")
gr.Markdown(value="""
### TmProt is a machine-learning-based protein thermostability predictor that leverages a fine-tuned ESM-2 protein language model to estimate melting temperatures (Tm) of protein sequences. It enables users to upload protein sequences in FASTA format (either pasted as text or uploaded as a file), and outputs predicted Tm values ranked by a user-defined thermostability threshold.
**Paper:** [https://doi.org/10.64898/2026.05.07.723192](https://doi.org/10.64898/2026.05.07.723192)
**GitHub:** [https://github.com/loschmidt/TmProt](https://github.com/loschmidt/TmProt)
"""
)
with gr.Column(scale=1):
gr.Image("assets/logo.png", width=100, height=100, show_label=False, show_download_button=False, container=False, show_share_button=False, show_fullscreen_button=False, interactive=False)
with gr.Row():
with gr.Column(scale=4):
seq_text = gr.Textbox(
label="FASTA sequences",
lines=6,
placeholder=">seq\nMKTIIALSYIFCLVFA",
value="",
)
seq_file = gr.File(label="Or upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
btn = gr.Button("Predict")
cutoff_bins = [str(x) for x in range(20, 101, 10)]
cutoff_bar = gr.Radio(
choices=cutoff_bins,
label="Select thermostability threshold (°C)",
info="Default is 60°C",
value="60"
)
with gr.Column(scale=4):
output = gr.Dataframe(headers=["Rank", "ID", "Predicted Tm [°C]", "Thermostable"], label="Results")
download_btn = gr.DownloadButton(label="Download CSV")
btn.click(
predict_tm,
inputs=[seq_text, seq_file, cutoff_bar],
outputs=[output, download_btn]
)
with gr.Row():
gr.Examples(
examples = [
[""">I1W5V5
MSIENLSSNKSFGGWHKQYSHVSNTLNCAMRFAIYLPPQASTGAKVPVLYWLSGLTCSDENFMQKAGAQRLAAELGIAIVAPDTSPRGEGVADDEGYDLGQGAGFYVNATQAPWNRHYQMYDYVVNELPELIESMFPVSDKRAIAGHSMGGHGALTIALRNPERYQSVSAFSPINNPVNCPWGQKAFTAYLGKDTDTWREYDASLLMRAAKQYVPALVDQGEADNFLAEQLKPEVLEAAASSNNYPLELRSHEGYDHSYYFIASFIEDHLRFHSNYLNA
""",
None, # seq_file is None (we use text)
"60" # threshold
],
[
""">R4YJ85
MINLEKALAGRRILIVDDLVEARSSLKKMATILGGDNIDVATDGIEAMSLIHEHEYDIVLSDYNLGRTKDGQQILEEARFTQRLRATSLFIVITGENAIDMVMGALEYDPDGYITKPYTLNMLKERLIRIITIKEELRKVNKAIDLQKYDLAIKYCLEVLDSNPRLRLPASRILGQLLMRQKRFQQALKIYSQLLNERSVSWAKLGQAICIFKLGDPNSALALLNRALVDHPLYVQCYDWIAKILLTLDKPLEAQAALEKAIVISPKAVLRQMELGRIAYENGDMVTAEPAFKYSVRLGRFSCHKSAKNYLQFVRSAQALLINPKERQTQNKANEAFRALTELKQDFSDDKDSLFEASIVESKTHLKMENLDEAKRSANDAEDMLAKLECPKIDYKLQMTETFIETDQSVKAQKMIDELKSAELSDKQIIMLNRLDNDLNGEALKRHSTSLNDQGVSHYEKGELEEAIIAFDQATHYEQAGISVLLNSIQAKISLMERDSPDKKILKNVRSLLIRIGEIAKDDERFARYSRLRKTYDRLCRAAAK
""",
None,
"50"
],
],
inputs=[seq_text, seq_file, cutoff_bar],
label="Click an example to try TmProt instantly",
examples_per_page=2,
)
with gr.Row():
with gr.Column(scale=1):
pass
with gr.Column(scale=7):
gr.Markdown(value="""
## Features
- Predict protein melting temperature (Tm) from amino acid sequences
- Accepts input via FASTA text or FASTA file upload
- Supports sequences from 20 to 2000 amino acids in length
- Outputs a ranked table with predicted Tm and thermostability status based on user-chosen threshold
- CSV download option for easy export and downstream analysis
## Model Overview
- Base Model: facebook/esm2_t33_650M_UR50D (650M parameters)
- Fine-tuning method: LoRA (Low-Rank Adaptation) using PEFT framework
- Task: Regression prediction of protein melting temperature (Tm)
- Training Data: ProMelt dataset (merged Meltome Atlas + ProTherm) with ~45,000 protein sequences and experimental Tm values
- Output: Single linear regression output neuron predicting Tm in °C
"""
)
with gr.Column(scale=1):
pass
if __name__ == "__main__":
demo.launch(share=True, allowed_paths=['./assets'])