| import gradio as gr |
| import sys |
| import random |
| import os |
| import pandas as pd |
| import torch |
| import itertools |
| from torch.utils.data import DataLoader |
| from transformers import AutoTokenizer |
|
|
| sys.path.append("scripts/") |
| from foldseek_util import get_struc_seq |
| from utils import seed_everything |
| from models import PLTNUM_PreTrainedModel |
| from datasets_ import PLTNUMDataset |
|
|
|
|
| class Config: |
| def __init__(self): |
| self.batch_size = 2 |
| self.use_amp = False |
| self.num_workers = 1 |
| self.max_length = 512 |
| self.used_sequence = "left" |
| self.padding_side = "right" |
| self.task = "classification" |
| self.sequence_col = "sequence" |
| self.seed = 42 |
|
|
|
|
| def predict_stability_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()): |
| results = { |
| "file_name": [], |
| "raw prediction value": [], |
| "binary prediction value": [], |
| } |
| file_names = [] |
| input_sequences = [] |
|
|
| os.system("chmod 777 bin/foldseek") |
| for pdb_file in pdb_files: |
| pdb_path = pdb_file.name |
| sequences = get_foldseek_seq(pdb_path) |
|
|
| file_name = os.path.basename(pdb_path) |
| if not sequences: |
| results["file_name"].append(file_name) |
| results["raw prediction value"].append(None) |
| results["binary prediction value"].append(None) |
| continue |
|
|
| sequence = sequences[2] if model_choice == "SaProt" else sequences[0] |
| file_names.append(file_name) |
| input_sequences.append(sequence) |
|
|
| raw_pred, binary_pred = predict_stability_core( |
| model_choice, organism_choice, input_sequences, cfg |
| ) |
| results["file_name"].extend(file_names) |
| results["raw prediction value"].extend(raw_pred) |
| results["binary prediction value"].extend(binary_pred) |
|
|
| df = pd.DataFrame(results) |
| output_csv = "/tmp/predictions.csv" |
| df.to_csv(output_csv, index=False) |
|
|
| return output_csv |
|
|
|
|
| def predict_stability_with_sequence( |
| model_choice, organism_choice, sequence, cfg=Config() |
| ): |
| if not sequence: |
| return "No valid sequence provided." |
| try: |
| raw_pred, binary_pred = predict_stability_core( |
| model_choice, organism_choice, [sequence], cfg |
| ) |
| df = pd.DataFrame( |
| { |
| "sequence": sequence, |
| "raw prediction value": raw_pred, |
| "binary prediction value": binary_pred, |
| } |
| ) |
| output_csv = "/tmp/predictions.csv" |
| df.to_csv(output_csv, index=False) |
|
|
| return output_csv |
| except Exception as e: |
| return f"An error occurred: {str(e)}" |
|
|
|
|
| def predict_stability_core(model_choice, organism_choice, sequences, cfg=Config()): |
| cell_line = "HeLa" if organism_choice == "Human" else "NIH3T3" |
| cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}" |
| cfg.architecture = model_choice |
| cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}" |
|
|
| output = predict(cfg, sequences) |
| return output |
|
|
|
|
| def get_foldseek_seq(pdb_path): |
| parsed_seqs = get_struc_seq( |
| "bin/foldseek", |
| pdb_path, |
| ["A"], |
| process_id=random.randint(0, 10000000), |
| )["A"] |
| return parsed_seqs |
|
|
|
|
| def predict(cfg, sequences): |
| cfg.token_length = 2 if cfg.architecture == "SaProt" else 1 |
| cfg.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| if cfg.used_sequence == "both": |
| cfg.max_length += 1 |
|
|
| seed_everything(cfg.seed) |
| df = pd.DataFrame({cfg.sequence_col: sequences}) |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| cfg.model_path, padding_side=cfg.padding_side |
| ) |
| cfg.tokenizer = tokenizer |
|
|
| dataset = PLTNUMDataset(cfg, df, train=False) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=cfg.batch_size, |
| shuffle=False, |
| num_workers=cfg.num_workers, |
| pin_memory=True, |
| drop_last=False, |
| ) |
|
|
| model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg) |
| model.to(cfg.device) |
|
|
| model.eval() |
| predictions = [] |
|
|
| with torch.no_grad(): |
| for inputs, _ in dataloader: |
| inputs = inputs.to(cfg.device) |
| with torch.amp.autocast(cfg.device, enabled=cfg.use_amp): |
| preds = ( |
| torch.sigmoid(model(inputs)) |
| if cfg.task == "classification" |
| else model(inputs) |
| ) |
| predictions += preds.cpu().tolist() |
|
|
| predictions = list(itertools.chain.from_iterable(predictions)) |
|
|
| return predictions, [1 if x > 0.5 else 0 for x in predictions] |
|
|
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown( |
| """ |
| # PLTNUM: Protein LifeTime Neural Model |
| **Predict the protein half-life from its sequence or PDB file.** |
| """ |
| ) |
|
|
| gr.Image( |
| "https://github.com/sagawatatsuya/PLTNUM/blob/main/model-image.png?raw=true", |
| label="Model Image", |
| ) |
|
|
| |
| with gr.Row(): |
| model_choice = gr.Radio( |
| choices=["SaProt", "ESM2"], |
| label="Select PLTNUM's base model.", |
| value="SaProt", |
| ) |
| organism_choice = gr.Radio( |
| choices=["Mouse", "Human"], |
| label="Select the target organism.", |
| value="Mouse", |
| ) |
|
|
| with gr.Tabs(): |
| with gr.TabItem("Upload PDB File"): |
| gr.Markdown("### Upload your PDB files:") |
| pdb_files = gr.File(label="Upload PDB Files", file_count="multiple") |
| predict_button = gr.Button("Predict Stability") |
| prediction_output = gr.File(label="Download Predictions") |
|
|
| predict_button.click( |
| fn=predict_stability_with_pdb, |
| inputs=[model_choice, organism_choice, pdb_files], |
| outputs=prediction_output, |
| ) |
|
|
| with gr.TabItem("Enter Protein Sequence"): |
| gr.Markdown("### Enter the protein sequence:") |
| sequence = gr.Textbox( |
| label="Protein Sequence", |
| placeholder="Enter your protein sequence here...", |
| lines=8, |
| ) |
| predict_button = gr.Button("Predict Stability") |
| prediction_output = gr.File(label="Download Predictions") |
|
|
| predict_button.click( |
| fn=predict_stability_with_sequence, |
| inputs=[model_choice, organism_choice, sequence], |
| outputs=prediction_output, |
| ) |
|
|
| gr.Markdown( |
| """ |
| ### How to Use: |
| - **Select Model**: Choose between 'SaProt' or 'ESM2' for your prediction. |
| - **Select Organism**: Choose between 'Mouse' or 'Human'. |
| - **Upload PDB File**: Choose the 'Upload PDB File' tab and upload your file. |
| - **Enter Sequence**: Alternatively, switch to the 'Enter Protein Sequence' tab and input your sequence. |
| - **Predict**: Click 'Predict Stability' to receive the prediction. |
| """ |
| ) |
|
|
| gr.Markdown( |
| """ |
| ### About the Tool |
| This tool allows researchers and scientists to predict the stability of proteins using advanced algorithms. It supports both PDB file uploads and direct sequence input. |
| """ |
| ) |
|
|
| demo.launch() |
|
|