import gradio as gr import torch from pathlib import Path from helpers import load_model, parse_fasta_string from io import StringIO import csv import tempfile import transformers # mute esm warning for weights transformers.logging.set_verbosity_error() # Constants MODEL_NAME = "esm2_t33_650M_UR50D" CURRENT_DIR = Path(__file__).parent PATH_MODEL = CURRENT_DIR / "model" DEVICE = "cpu" # DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # arount 2 mins for fireprot VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY") model, tokenizer = load_model(MODEL_NAME, PATH_MODEL, DEVICE) def predict_tm(seq_text, seq_file, threshold): if seq_file is not None: with open(seq_file.name, "r", encoding="utf-8") as f: fasta_str = f.read() elif seq_text.strip(): fasta_str = seq_text else: return "Please provide a sequence via text or file." try: records = parse_fasta_string(fasta_str) except Exception as e: return f"FASTA parsing failed: {str(e)}" if not records: return "No valid sequences found." results = [] for i, record in enumerate(records, 1): seq = record["sequence"].upper() if len(seq) < 20: return f"Sequence '{record['id']}' is too short (<20 amino acids)." if len(seq) > 2000: return f"Sequence '{record['id']}' is too long (>2000 amino acids)." if not set(seq).issubset(VALID_AMINO_ACIDS): invalid = "".join(set(seq) - VALID_AMINO_ACIDS) return f"Invalid characters in sequence: {invalid}" inputs = tokenizer(seq, return_tensors="pt", max_length=512, truncation=True, padding=True) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) prediction = outputs.logits.squeeze().item() results.append({"id": record["id"], "tm": round(prediction, 2)}) results_sorted = sorted(results, key=lambda x: x["tm"], reverse=True) table = [ [i + 1, r["id"], r["tm"], "Yes" if r["tm"] > float(threshold) else "No"] for i, r in enumerate(results_sorted) ] csv_buffer = StringIO() writer = csv.writer(csv_buffer) writer.writerow(["Rank", "ID", "Predicted Tm [°C]", f"Thermostable"]) writer.writerows(table) csv_str = csv_buffer.getvalue() with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", encoding="utf-8") as tmp: tmp.write(csv_str) tmp_path = tmp.name return table, tmp_path demo = gr.Blocks(theme=gr.themes.Origin()) with demo: with gr.Row(): with gr.Column(scale=1): gr.Image("assets/TmProt_logo.png", width=100, height=100, show_label=False, show_download_button=False, container=False, show_share_button=False, show_fullscreen_button=False, interactive=False) with gr.Column(scale=7): gr.Markdown(""" # TmProt ## Protein Thermostability Predictor """) gr.Markdown(value=""" ### TmProt is a machine-learning-based protein thermostability predictor that leverages a fine-tuned ESM-2 protein language model to estimate melting temperatures (Tm) of protein sequences. It enables users to upload protein sequences in FASTA format (either pasted as text or uploaded as a file), and outputs predicted Tm values ranked by a user-defined thermostability threshold. **Paper:** [https://doi.org/10.64898/2026.05.07.723192](https://doi.org/10.64898/2026.05.07.723192) **GitHub:** [https://github.com/loschmidt/TmProt](https://github.com/loschmidt/TmProt) """ ) with gr.Column(scale=1): gr.Image("assets/logo.png", width=100, height=100, show_label=False, show_download_button=False, container=False, show_share_button=False, show_fullscreen_button=False, interactive=False) with gr.Row(): with gr.Column(scale=4): seq_text = gr.Textbox( label="FASTA sequences", lines=6, placeholder=">seq\nMKTIIALSYIFCLVFA", value="", ) seq_file = gr.File(label="Or upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath") btn = gr.Button("Predict") cutoff_bins = [str(x) for x in range(20, 101, 10)] cutoff_bar = gr.Radio( choices=cutoff_bins, label="Select thermostability threshold (°C)", info="Default is 60°C", value="60" ) with gr.Column(scale=4): output = gr.Dataframe(headers=["Rank", "ID", "Predicted Tm [°C]", "Thermostable"], label="Results") download_btn = gr.DownloadButton(label="Download CSV") btn.click( predict_tm, inputs=[seq_text, seq_file, cutoff_bar], outputs=[output, download_btn] ) with gr.Row(): gr.Examples( examples = [ [""">I1W5V5 MSIENLSSNKSFGGWHKQYSHVSNTLNCAMRFAIYLPPQASTGAKVPVLYWLSGLTCSDENFMQKAGAQRLAAELGIAIVAPDTSPRGEGVADDEGYDLGQGAGFYVNATQAPWNRHYQMYDYVVNELPELIESMFPVSDKRAIAGHSMGGHGALTIALRNPERYQSVSAFSPINNPVNCPWGQKAFTAYLGKDTDTWREYDASLLMRAAKQYVPALVDQGEADNFLAEQLKPEVLEAAASSNNYPLELRSHEGYDHSYYFIASFIEDHLRFHSNYLNA """, None, # seq_file is None (we use text) "60" # threshold ], [ """>R4YJ85 MINLEKALAGRRILIVDDLVEARSSLKKMATILGGDNIDVATDGIEAMSLIHEHEYDIVLSDYNLGRTKDGQQILEEARFTQRLRATSLFIVITGENAIDMVMGALEYDPDGYITKPYTLNMLKERLIRIITIKEELRKVNKAIDLQKYDLAIKYCLEVLDSNPRLRLPASRILGQLLMRQKRFQQALKIYSQLLNERSVSWAKLGQAICIFKLGDPNSALALLNRALVDHPLYVQCYDWIAKILLTLDKPLEAQAALEKAIVISPKAVLRQMELGRIAYENGDMVTAEPAFKYSVRLGRFSCHKSAKNYLQFVRSAQALLINPKERQTQNKANEAFRALTELKQDFSDDKDSLFEASIVESKTHLKMENLDEAKRSANDAEDMLAKLECPKIDYKLQMTETFIETDQSVKAQKMIDELKSAELSDKQIIMLNRLDNDLNGEALKRHSTSLNDQGVSHYEKGELEEAIIAFDQATHYEQAGISVLLNSIQAKISLMERDSPDKKILKNVRSLLIRIGEIAKDDERFARYSRLRKTYDRLCRAAAK """, None, "50" ], ], inputs=[seq_text, seq_file, cutoff_bar], label="Click an example to try TmProt instantly", examples_per_page=2, ) with gr.Row(): with gr.Column(scale=1): pass with gr.Column(scale=7): gr.Markdown(value=""" ## Features - Predict protein melting temperature (Tm) from amino acid sequences - Accepts input via FASTA text or FASTA file upload - Supports sequences from 20 to 2000 amino acids in length - Outputs a ranked table with predicted Tm and thermostability status based on user-chosen threshold - CSV download option for easy export and downstream analysis ## Model Overview - Base Model: facebook/esm2_t33_650M_UR50D (650M parameters) - Fine-tuning method: LoRA (Low-Rank Adaptation) using PEFT framework - Task: Regression prediction of protein melting temperature (Tm) - Training Data: ProMelt dataset (merged Meltome Atlas + ProTherm) with ~45,000 protein sequences and experimental Tm values - Output: Single linear regression output neuron predicting Tm in °C """ ) with gr.Column(scale=1): pass if __name__ == "__main__": demo.launch(share=True, allowed_paths=['./assets'])