| import gradio as gr |
| import torch |
| from pathlib import Path |
| from helpers import load_model, parse_fasta_string |
| from io import StringIO |
| import csv |
| import tempfile |
| import transformers |
| |
| transformers.logging.set_verbosity_error() |
|
|
| |
| MODEL_NAME = "esm2_t33_650M_UR50D" |
| CURRENT_DIR = Path(__file__).parent |
| PATH_MODEL = CURRENT_DIR / "model" |
| DEVICE = "cpu" |
| |
| VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY") |
|
|
| model, tokenizer = load_model(MODEL_NAME, PATH_MODEL, DEVICE) |
|
|
| def predict_tm(seq_text, seq_file, threshold): |
| if seq_file is not None: |
| with open(seq_file.name, "r", encoding="utf-8") as f: |
| fasta_str = f.read() |
| elif seq_text.strip(): |
| fasta_str = seq_text |
| else: |
| return "Please provide a sequence via text or file." |
|
|
| try: |
| records = parse_fasta_string(fasta_str) |
| except Exception as e: |
| return f"FASTA parsing failed: {str(e)}" |
|
|
| if not records: |
| return "No valid sequences found." |
| |
| results = [] |
| for i, record in enumerate(records, 1): |
| seq = record["sequence"].upper() |
| if len(seq) < 20: |
| return f"Sequence '{record['id']}' is too short (<20 amino acids)." |
| if len(seq) > 2000: |
| return f"Sequence '{record['id']}' is too long (>2000 amino acids)." |
|
|
| if not set(seq).issubset(VALID_AMINO_ACIDS): |
| invalid = "".join(set(seq) - VALID_AMINO_ACIDS) |
| return f"Invalid characters in sequence: {invalid}" |
| inputs = tokenizer(seq, return_tensors="pt", max_length=512, truncation=True, padding=True) |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| prediction = outputs.logits.squeeze().item() |
| results.append({"id": record["id"], "tm": round(prediction, 2)}) |
| results_sorted = sorted(results, key=lambda x: x["tm"], reverse=True) |
| table = [ |
| [i + 1, r["id"], r["tm"], "Yes" if r["tm"] > float(threshold) else "No"] |
| for i, r in enumerate(results_sorted) |
| ] |
| csv_buffer = StringIO() |
| writer = csv.writer(csv_buffer) |
| writer.writerow(["Rank", "ID", "Predicted Tm [°C]", f"Thermostable"]) |
| writer.writerows(table) |
| csv_str = csv_buffer.getvalue() |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", encoding="utf-8") as tmp: |
| tmp.write(csv_str) |
| tmp_path = tmp.name |
| return table, tmp_path |
|
|
| demo = gr.Blocks(theme=gr.themes.Origin()) |
| with demo: |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Image("assets/TmProt_logo.png", width=100, height=100, show_label=False, show_download_button=False, container=False, show_share_button=False, show_fullscreen_button=False, interactive=False) |
| with gr.Column(scale=7): |
| gr.Markdown(""" |
| # TmProt |
| ## Protein Thermostability Predictor |
| """) |
| gr.Markdown(value=""" |
| ### TmProt is a machine-learning-based protein thermostability predictor that leverages a fine-tuned ESM-2 protein language model to estimate melting temperatures (Tm) of protein sequences. It enables users to upload protein sequences in FASTA format (either pasted as text or uploaded as a file), and outputs predicted Tm values ranked by a user-defined thermostability threshold. |
| |
| **Paper:** [https://doi.org/10.64898/2026.05.07.723192](https://doi.org/10.64898/2026.05.07.723192) |
| **GitHub:** [https://github.com/loschmidt/TmProt](https://github.com/loschmidt/TmProt) |
| """ |
| ) |
| |
| with gr.Column(scale=1): |
| gr.Image("assets/logo.png", width=100, height=100, show_label=False, show_download_button=False, container=False, show_share_button=False, show_fullscreen_button=False, interactive=False) |
| with gr.Row(): |
| with gr.Column(scale=4): |
| seq_text = gr.Textbox( |
| label="FASTA sequences", |
| lines=6, |
| placeholder=">seq\nMKTIIALSYIFCLVFA", |
| value="", |
| ) |
| seq_file = gr.File(label="Or upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") |
| btn = gr.Button("Predict") |
| cutoff_bins = [str(x) for x in range(20, 101, 10)] |
| cutoff_bar = gr.Radio( |
| choices=cutoff_bins, |
| label="Select thermostability threshold (°C)", |
| info="Default is 60°C", |
| value="60" |
| ) |
| with gr.Column(scale=4): |
| output = gr.Dataframe(headers=["Rank", "ID", "Predicted Tm [°C]", "Thermostable"], label="Results") |
| download_btn = gr.DownloadButton(label="Download CSV") |
| btn.click( |
| predict_tm, |
| inputs=[seq_text, seq_file, cutoff_bar], |
| outputs=[output, download_btn] |
| ) |
| with gr.Row(): |
| gr.Examples( |
| examples = [ |
| [""">I1W5V5 |
| MSIENLSSNKSFGGWHKQYSHVSNTLNCAMRFAIYLPPQASTGAKVPVLYWLSGLTCSDENFMQKAGAQRLAAELGIAIVAPDTSPRGEGVADDEGYDLGQGAGFYVNATQAPWNRHYQMYDYVVNELPELIESMFPVSDKRAIAGHSMGGHGALTIALRNPERYQSVSAFSPINNPVNCPWGQKAFTAYLGKDTDTWREYDASLLMRAAKQYVPALVDQGEADNFLAEQLKPEVLEAAASSNNYPLELRSHEGYDHSYYFIASFIEDHLRFHSNYLNA |
| """, |
| None, |
| "60" |
| ], |
| [ |
| """>R4YJ85 |
| MINLEKALAGRRILIVDDLVEARSSLKKMATILGGDNIDVATDGIEAMSLIHEHEYDIVLSDYNLGRTKDGQQILEEARFTQRLRATSLFIVITGENAIDMVMGALEYDPDGYITKPYTLNMLKERLIRIITIKEELRKVNKAIDLQKYDLAIKYCLEVLDSNPRLRLPASRILGQLLMRQKRFQQALKIYSQLLNERSVSWAKLGQAICIFKLGDPNSALALLNRALVDHPLYVQCYDWIAKILLTLDKPLEAQAALEKAIVISPKAVLRQMELGRIAYENGDMVTAEPAFKYSVRLGRFSCHKSAKNYLQFVRSAQALLINPKERQTQNKANEAFRALTELKQDFSDDKDSLFEASIVESKTHLKMENLDEAKRSANDAEDMLAKLECPKIDYKLQMTETFIETDQSVKAQKMIDELKSAELSDKQIIMLNRLDNDLNGEALKRHSTSLNDQGVSHYEKGELEEAIIAFDQATHYEQAGISVLLNSIQAKISLMERDSPDKKILKNVRSLLIRIGEIAKDDERFARYSRLRKTYDRLCRAAAK |
| """, |
| None, |
| "50" |
| ], |
| ], |
| inputs=[seq_text, seq_file, cutoff_bar], |
| label="Click an example to try TmProt instantly", |
| examples_per_page=2, |
| ) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| pass |
| with gr.Column(scale=7): |
| gr.Markdown(value=""" |
| ## Features |
| |
| - Predict protein melting temperature (Tm) from amino acid sequences |
| |
| - Accepts input via FASTA text or FASTA file upload |
| |
| - Supports sequences from 20 to 2000 amino acids in length |
| |
| - Outputs a ranked table with predicted Tm and thermostability status based on user-chosen threshold |
| |
| - CSV download option for easy export and downstream analysis |
| |
| ## Model Overview |
| |
| - Base Model: facebook/esm2_t33_650M_UR50D (650M parameters) |
| |
| - Fine-tuning method: LoRA (Low-Rank Adaptation) using PEFT framework |
| |
| - Task: Regression prediction of protein melting temperature (Tm) |
| |
| - Training Data: ProMelt dataset (merged Meltome Atlas + ProTherm) with ~45,000 protein sequences and experimental Tm values |
| |
| - Output: Single linear regression output neuron predicting Tm in °C |
| """ |
| ) |
| with gr.Column(scale=1): |
| pass |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True, allowed_paths=['./assets']) |
|
|