Spaces:
Running
on
Zero
Running
on
Zero
| import sys | |
| import spaces | |
| import os, shutil | |
| import gradio as gr | |
| from data.scripts.data_utils import parse_PDB | |
| from utils.utils import ClassConfig, DataCollatorForTokenRegression, process_in_batches_and_combine, get_dot_separated_name | |
| from models.T5_encoder_per_token import PT5_classification_model | |
| from data.scripts.get_enm_fluctuations_for_dataset import get_fluctuation_for_json_dict | |
| import argparse | |
| import yaml | |
| import torch | |
| from pathlib import Path | |
| from Bio import SeqIO | |
| import json | |
| import os | |
| import warnings | |
| from datetime import datetime | |
| from pathlib import Path | |
| BASE_DIR = Path(__file__).resolve().parent | |
| LOCAL_COMPONENT_PATH = BASE_DIR / "gradio_molecule3d" / "backend" | |
| sys.path.insert(0, str(LOCAL_COMPONENT_PATH)) | |
| from gradio_molecule3d.molecule3d import Molecule3D | |
| from Bio.PDB import PDBParser, PDBIO | |
| from biotite.structure import annotate_sse | |
| import biotite.structure.io as strucio | |
| import biotite.structure.residues as residues | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download, utils | |
| import biotite.structure.io.pdb as pdb | |
| import biotite.structure as struc | |
| import biotite.sequence as seq | |
| from data.scripts.data_utils import modify_bfactor_biotite | |
| def get_first_chain_id(pdb_file): | |
| try: | |
| # Load the PDB file | |
| f = pdb.PDBFile.read(pdb_file) | |
| # Get structure (model 1) | |
| atom_array = f.get_structure(model=1) | |
| # Filter for amino acids (Protein only) | |
| # This handles standard ATOMs and also common HETATMs like MSE automatically if defined | |
| protein_mask = struc.filter_amino_acids(atom_array) | |
| protein_atoms = atom_array[protein_mask] | |
| if len(protein_atoms) == 0: | |
| # Fallback: if filter_amino_acids is too strict for this file, | |
| # just grab the chain of the first ATOM record found. | |
| if len(atom_array) > 0: | |
| return atom_array.chain_id[0] | |
| return "" | |
| # Return the chain ID of the first protein atom found | |
| return protein_atoms.chain_id[0] | |
| except Exception as e: | |
| print(f"Warning: Biotite failed to detect chain for {pdb_file}: {e}") | |
| return "" | |
| def get_weights_path(repo_id, filename): | |
| """ | |
| Tries to get the local path immediately. If not found, downloads it. | |
| """ | |
| print(f"Looking for {filename} in {repo_id}...") | |
| try: | |
| # 1. FASTEST: Try loading entirely from local cache (no internet check) | |
| return hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| local_files_only=True | |
| ) | |
| except (utils.EntryNotFoundError, utils.LocalEntryNotFoundError, FileNotFoundError): | |
| # 2. FALLBACK: If not found locally, download it (cached for next time) | |
| print(f"Weights not found locally. Downloading from HF Hub...") | |
| return hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| local_files_only=False | |
| ) | |
| def process_pdb_file(pdb_file, backbones, sequences, names): | |
| _name = pdb_file[:-4] | |
| _chain = get_first_chain_id(pdb_file) | |
| parsed_pdb = parse_PDB(pdb_file, name=_name, input_chain_list=[_chain])[0] | |
| backbone, sequence = parsed_pdb['coords_chain_{}'.format(_chain)], parsed_pdb['seq_chain_{}'.format(_chain)] | |
| if len(sequence) > 1023: | |
| print("Sequence length is greater than 1023, skipping {}".format(_name)) | |
| else: | |
| backbones.append(backbone) | |
| sequences.append(sequence) | |
| names.append(_name) | |
| return backbones, sequences, names | |
| def flex_seq(input_seq, input_file): | |
| if not input_seq: | |
| input_seq = "" | |
| if not input_seq.strip() and not input_file: | |
| return None, "Provide a file/s or a input sequence/s" | |
| if input_file: | |
| if len(input_file) == 1: | |
| input_file = input_file[0] | |
| filename, suffix = os.path.splitext(input_file) | |
| else: | |
| suffix = ".pdb_list" | |
| default_name = '{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S')) | |
| output_name = default_name | |
| sequences = [] | |
| names = [] | |
| backbones = [] | |
| flucts_list = [] | |
| pdb_files = [] | |
| datapoint_for_eval = 'all' | |
| if input_seq: | |
| suffix = "" | |
| proteins = input_seq.strip().split('\n') | |
| if len(proteins) % 2 != 0: | |
| raise ValueError("You must adhere to the .fasta format") | |
| for record in range(0, len(proteins), 2): | |
| if ">" in proteins[record]: | |
| name = proteins[record][1:] | |
| sequence = proteins[record+1] | |
| else: | |
| raise ValueError("You must adhere to the .fasta format") | |
| if datapoint_for_eval == 'all': | |
| names.append(name) | |
| sequences.append(sequence) | |
| backbones.append(None) | |
| elif suffix == ".fasta": | |
| for record in SeqIO.parse(input_file, "fasta"): | |
| name = record.name | |
| if datapoint_for_eval == 'all': | |
| names.append(name) | |
| sequences.append(str(record.seq)) | |
| backbones.append(None) | |
| elif suffix == ".pdb": | |
| backbones, sequences, names = process_pdb_file(input_file, backbones, sequences, names) | |
| pdb_files.append(input_file) | |
| elif suffix == ".pdb_list": | |
| for i in input_file: | |
| backbones, sequences, names = process_pdb_file(i, backbones, sequences, names) | |
| pdb_files.append(i) | |
| env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader) | |
| # Set folder for huggingface cache | |
| os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME'] | |
| # Set gpu device | |
| os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device'] | |
| config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader) | |
| class_config=ClassConfig(config) | |
| class_config.adaptor_architecture = 'no-adaptor' | |
| config['inference_args']['device'] = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu' | |
| model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config) | |
| model.to(config['inference_args']['device']) | |
| repo_id = "Honzus24/Flexpert_weights" | |
| file_weights = config['inference_args']['seq_model_path'] | |
| # Get path (instant if cached) | |
| weights_path = get_weights_path(repo_id, file_weights) | |
| # Load weights | |
| state_dict = torch.load(weights_path, map_location=config['inference_args']['device']) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| data_to_collate = [] | |
| for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)): | |
| #Ensure that the missing residues in the sequence are not represented as '-' but as 'X' | |
| sequence = sequence.replace('-', 'X') #due to the tokenizer vocabulary | |
| tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt') | |
| tokenized_seq, attention_mask = tokenizer_out['input_ids'].to(config['inference_args']['device']), tokenizer_out['attention_mask'].to(config['inference_args']['device']) | |
| data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:]}) | |
| data_collator = DataCollatorForTokenRegression(tokenizer) | |
| batch = data_collator(data_to_collate) # Wrap in list since collator expects batch | |
| batch.to(model.device) | |
| # Predict | |
| with torch.no_grad(): | |
| output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size']) | |
| predictions = output_logits[:,:,0] #includes the prediction for the added token | |
| # subselect the predictions using the attention mask | |
| output_filename = Path(config['inference_args']['prediction_output_dir'].format(output_name, "seq")) | |
| output_filename.parent.mkdir(parents=True, exist_ok=True) | |
| output_files = [] | |
| output_message = "Success" | |
| for prediction, mask, name, sequence in zip(predictions, batch['attention_mask'], names, sequences): | |
| output_filename_new = output_filename.with_stem("{}_".format(name.split("/")[-1]) + output_filename.stem) | |
| with open(output_filename_new.with_suffix('.txt'), 'w') as f: | |
| f.write("Residue Number\tResidue ID\tFlexibility\n") | |
| prediction = prediction[mask.bool()] | |
| if len(prediction) != len(sequence)+1: | |
| print("Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)) | |
| assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1) | |
| p = prediction.tolist()[:-1] | |
| for i in range(len(p)): | |
| f.write(f"{i:<10}\t{sequence[i]:<20}\t{round(p[i], 4):<10}\n") | |
| output_files.append(str(output_filename_new.with_suffix('.txt'))) | |
| if suffix == ".pdb" or suffix == ".pdb_list": | |
| for name, pdb_file, prediction in zip(names, pdb_files, predictions): | |
| _prediction = prediction[:-1].reshape(1,-1) | |
| _outname = output_filename.with_name('{}_'.format(name.split("/")[-1]) + output_filename.stem + '.pdb') | |
| print("Saving prediction to {}.".format(_outname)) | |
| modify_bfactor_biotite(pdb_file, None, _outname, _prediction) #writing the prediction without the last token | |
| output_files.append(str(_outname)) | |
| _outname = output_filename.with_name(name.split("/")[-1] + output_filename.stem + '.fasta') | |
| with open(_outname, 'w') as f: | |
| print("Saving fasta to {}.".format(_outname)) | |
| for name, sequence in zip(names, sequences): | |
| f.write('>' + name + '\n') | |
| f.write(sequence + '\n') | |
| output_files.append(str(_outname)) | |
| return output_files, output_message | |
| def flex_3d(input_file): | |
| if not input_file: | |
| return None, "Provide a file or a input sequence" | |
| if len(input_file) == 1: | |
| input_file = input_file[0] | |
| filename, suffix = os.path.splitext(input_file) | |
| else: | |
| suffix = ".pdb_list" | |
| default_name = '{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S')) | |
| output_name = default_name | |
| sequences = [] | |
| names = [] | |
| backbones = [] | |
| pdb_files = [] | |
| flucts_list = [] | |
| datapoint_for_eval = 'all' | |
| if suffix == ".pdb": | |
| backbones, sequences, names = process_pdb_file(input_file, backbones, sequences, names) | |
| pdb_files.append(input_file) | |
| elif suffix == ".jsonl": | |
| for line in open(input_file, 'r'): | |
| _dict = json.loads(line) | |
| if 'fluctuations' in _dict.keys(): | |
| print("fluctuations are precomputed, using them") | |
| dot_separated_name = get_dot_separated_name(key='pdb_name', _dict=_dict) | |
| if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval: | |
| names.append(_dict['pdb_name']) | |
| backbones.append(None) | |
| sequences.append(_dict['sequence']) | |
| flucts_list.append(_dict['fluctuations']+[0.0]) #padding for end cls token | |
| continue | |
| dot_separated_name = get_dot_separated_name(key='name', _dict=_dict) | |
| if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval: | |
| backbones.append(_dict['coords']) | |
| sequences.append(_dict['seq']) | |
| names.append(dot_separated_name) | |
| elif suffix == ".pdb_list": | |
| for i in input_file: | |
| backbones, sequences, names = process_pdb_file(i, backbones, sequences, names) | |
| pdb_files.append(i) | |
| env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader) | |
| # Set folder for huggingface cache | |
| os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME'] | |
| # Set gpu device | |
| os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device'] | |
| config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader) | |
| class_config=ClassConfig(config) | |
| class_config.adaptor_architecture = 'conv' | |
| config['inference_args']['device'] = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu' | |
| model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config) | |
| model.to(config['inference_args']['device']) | |
| repo_id = "Honzus24/Flexpert_weights" | |
| print("Loading 3D model from {}".format(config['inference_args']['3d_model_path'])) | |
| file_weights = config['inference_args']['3d_model_path'] | |
| # Get path (instant if cached) | |
| weights_path = get_weights_path(repo_id, file_weights) | |
| # Load weights | |
| state_dict = torch.load(weights_path, map_location=config['inference_args']['device']) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| data_to_collate = [] | |
| for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)): | |
| if backbone is not None: | |
| _dict = {'coords': backbone, 'seq': sequence} | |
| flucts, _ = get_fluctuation_for_json_dict(_dict, enm_type = config['inference_args']['enm_type']) | |
| flucts = flucts.tolist() | |
| flucts.append(0.0) #To match the special token for the sequence | |
| flucts = torch.tensor(flucts).to(config['inference_args']['device']) | |
| else: | |
| flucts = flucts_list[idx] | |
| #Ensure that the missing residues in the sequence are not represented as '-' but as 'X' | |
| sequence = sequence.replace('-', 'X') #due to the tokenizer vocabulary | |
| tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt') | |
| tokenized_seq, attention_mask = tokenizer_out['input_ids'].to(config['inference_args']['device']), tokenizer_out['attention_mask'].to(config['inference_args']['device']) | |
| data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:], 'enm_vals': flucts}) | |
| # Use the data collator to process the input | |
| data_collator = DataCollatorForTokenRegression(tokenizer) | |
| batch = data_collator(data_to_collate) # Wrap in list since collator expects batch | |
| batch.to(model.device) | |
| # Predict | |
| with torch.no_grad(): | |
| output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size']) | |
| predictions = output_logits[:,:,0] #includes the prediction for the added token | |
| # subselect the predictions using the attention mask | |
| output_filename = Path(config['inference_args']['prediction_output_dir'].format(output_name, "3D")) | |
| output_filename.parent.mkdir(parents=True, exist_ok=True) | |
| output_files = [] | |
| output_message = "Success" | |
| for prediction, mask, name, sequence in zip(predictions, batch['attention_mask'], names, sequences): | |
| output_filename_new = output_filename.with_stem("{}_".format(name.split("/")[-1]) + output_filename.stem) | |
| with open(output_filename_new.with_suffix('.txt'), 'w') as f: | |
| f.write("Residue Number\tResidue ID\tFlexibility\n") | |
| prediction = prediction[mask.bool()] | |
| if len(prediction) != len(sequence)+1: | |
| print("Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)) | |
| assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1) | |
| p = prediction.tolist()[:-1] | |
| for i in range(len(p)): | |
| f.write(f"{i:<10}\t{sequence[i]:<20}\t{round(p[i], 4):<10}\n") | |
| output_files.append(str(output_filename_new.with_suffix('.txt'))) | |
| output_files_enm = [] | |
| for enm_prediction, name in zip(batch['enm_vals'], names): | |
| _outname_new = output_filename.with_name("{}".format(name.split("/")[-1]) + '_enm_' + output_filename.stem + '.txt') | |
| with open(_outname_new, 'w') as f: | |
| print("Saving ENM predictions to {}.".format(_outname_new)) | |
| for enm_prediction, name in zip(batch['enm_vals'], names): | |
| f.write('>' + name + '\n') | |
| f.write(', '.join([str(p) for p in enm_prediction.tolist()[:-1]]) + '\n') | |
| output_files_enm.append(str(_outname_new)) | |
| if suffix == ".pdb" or suffix == ".pdb_list": | |
| for name, pdb_file, prediction in zip(names, pdb_files, predictions): | |
| _prediction = prediction[:-1].reshape(1,-1) | |
| _outname = output_filename.with_name('{}_'.format(name.split("/")[-1]) + output_filename.stem + '.pdb') | |
| print("Saving prediction to {}.".format(_outname)) | |
| modify_bfactor_biotite(pdb_file, None, _outname, _prediction) #writing the prediction without the last token | |
| output_files.append(str(_outname)) | |
| _outname = output_filename.with_name(name.split("/")[-1] + output_filename.stem + '.fasta') | |
| with open(_outname, 'w') as f: | |
| print("Saving fasta to {}.".format(_outname)) | |
| for name, sequence in zip(names, sequences): | |
| f.write('>' + name + '\n') | |
| f.write(sequence + '\n') | |
| output_files.append(str(_outname)) | |
| if suffix == ".pdb" or suffix == ".pdb_list": | |
| for name, pdb_file, enm_vals_single in zip(names, pdb_files, batch['enm_vals']): | |
| _outname = output_filename.with_name('{}_enm_'.format(name.split("/")[-1]) + output_filename.stem + '.pdb') | |
| print("Saving ENM prediction to {}.".format(_outname)) | |
| _enm_vals = enm_vals_single[:-1].reshape(1,-1) | |
| eps = 1e-6 | |
| _enm_vals = torch.clip(_enm_vals, -100+eps, 1000-eps) | |
| _enm_vals = torch.nan_to_num(_enm_vals, nan=0.0) | |
| _enm_vals = torch.round(_enm_vals, decimals=2) | |
| modify_bfactor_biotite(pdb_file, None, _outname, _enm_vals) #writing the prediction without the last token | |
| output_files_enm.append(str(_outname)) | |
| return output_files, output_message, output_files_enm | |
| def rescale_bfactors(pdb_file): | |
| base, ext = os.path.splitext(pdb_file) | |
| # Create the new filename | |
| out_file = base + "-scaled" + ext | |
| atom_array = strucio.load_structure(pdb_file) | |
| sse = annotate_sse(atom_array) | |
| start = 0 | |
| for i, item in enumerate(sse): | |
| if item == "a" or item == "b": | |
| start = i | |
| break | |
| sse = sse[::-1] | |
| end = 0 | |
| for i, item in enumerate(sse): | |
| if item == "a" or item == "b": | |
| end = i | |
| break | |
| end = len(sse) - end - 1 | |
| parser = PDBParser(QUIET=True) | |
| structure = parser.get_structure("prot", pdb_file) | |
| # Collect all bfactors | |
| bfactors = [atom.bfactor for atom in structure.get_atoms()] | |
| res_starts = residues.get_residue_starts(atom_array) | |
| start = res_starts[start] | |
| end = res_starts[end] | |
| bfactors_start = bfactors[:start] | |
| bfactors_end = bfactors[end:] | |
| bfactors_struct = bfactors[start:end] | |
| min_b = min(bfactors_struct) | |
| max_b = max(bfactors_struct) | |
| bfactors_start = np.clip(a = bfactors_start, min = min_b, max = max_b) | |
| bfactors_end = np.clip(a = bfactors_end, min = min_b, max = max_b) | |
| bfactors = np.concatenate((bfactors_start, bfactors_struct, bfactors_end)) | |
| def scale(b): | |
| if max_b == min_b: | |
| return 0.5 # arbitrary mid value | |
| return ((b - min_b) / (max_b - min_b)) | |
| # Rescale all atoms | |
| for i, atom in enumerate(structure.get_atoms()): | |
| atom.set_bfactor(scale(bfactors[i])) | |
| # Save to the *new* file path | |
| io = PDBIO() | |
| io.set_structure(structure) | |
| io.save(out_file) | |
| return out_file | |
| def clear_files(): | |
| folder = 'prediction_results/' | |
| if not os.path.isdir(folder): | |
| os.makedirs(folder) | |
| for filename in os.listdir(folder): | |
| file_path = os.path.join(folder, filename) | |
| os.remove(file_path) | |
| def handle_seq_prediction(input_seq, input_file): | |
| clear_files() | |
| main_files, message = flex_seq(input_seq, input_file) | |
| fasta_index = next( | |
| (i for i, filename in enumerate(main_files) if filename.endswith(".fasta")) | |
| ) | |
| txt_index = next( | |
| (i for i in range(len(main_files) - 1, -1, -1) if main_files[i].endswith(".txt")) | |
| ) | |
| pdb_files_for_viz = [str(f) for f in main_files[txt_index+1:fasta_index] if f.endswith(('.pdb'))] | |
| pdb_files_for_viz_scaled = [str(rescale_bfactors(f)) for f in main_files[txt_index+1:fasta_index] if f.endswith(('.pdb'))] | |
| main_files.extend(pdb_files_for_viz_scaled) | |
| pdb_files_for_viz.extend(pdb_files_for_viz_scaled) | |
| return main_files, message, pdb_files_for_viz | |
| def handle_3d_prediction(input_file): | |
| clear_files() | |
| main_files, message, enm_files = flex_3d(input_file) | |
| fasta_index = next( | |
| (i for i, filename in enumerate(main_files) if filename.endswith(".fasta")) | |
| ) | |
| txt_index = next( | |
| (i for i in range(len(main_files) - 1, -1, -1) if main_files[i].endswith(".txt")) | |
| ) | |
| pdb_files_for_viz = [f for f in main_files if f.endswith(('.pdb'))] | |
| pdb_files_for_viz_scaled = [rescale_bfactors(f) for f in main_files[1:fasta_index] if f.endswith(('.pdb'))] | |
| pdb_files_for_viz.extend(pdb_files_for_viz_scaled) | |
| main_files.extend(pdb_files_for_viz_scaled) | |
| main_files.extend(enm_files) | |
| return main_files, message, pdb_files_for_viz | |
| def clear_outputs(): | |
| return "", None, None | |
| PRIMARY = "primary" | |
| SECONDARY = "secondary" | |
| def switch_component_view(button_label): | |
| updates = { | |
| "text_visible": gr.update(visible=False), | |
| "file_visible": gr.update(visible=False), | |
| "text_clear": "", | |
| "file_clear": [] | |
| } | |
| # Updates for button colors | |
| button_updates = { | |
| "text_variant": gr.update(variant=SECONDARY), | |
| "file_variant": gr.update(variant=SECONDARY) | |
| } | |
| if button_label == "Text Input": | |
| updates["text_visible"] = gr.update(visible=True) | |
| button_updates["text_variant"] = gr.update(variant=PRIMARY) | |
| elif button_label == "File Input": | |
| updates["file_visible"] = gr.update(visible=True) | |
| button_updates["file_variant"] = gr.update(variant=PRIMARY) | |
| return [ | |
| updates["text_visible"], | |
| updates["file_visible"], | |
| updates["text_clear"], | |
| updates["file_clear"], | |
| button_updates["text_variant"], | |
| button_updates["file_variant"] | |
| ] | |
| theme = gr.themes.Base( | |
| neutral_hue="gray", | |
| primary_hue="slate") | |
| gr.set_static_paths(["prediction_results"]) | |
| with gr.Blocks(theme=theme) as demo: | |
| gr.Image("Flexpert_logo.png", show_label=False, interactive=False) | |
| gr.Markdown(value=""" | |
| ## About Flexpert | |
| On the web-version of Flexpert you can calculate the per-residue flexibility of a protein by either inputting the protein as a string or through .pdb/.fasta files. | |
| ### Inputs: | |
| #### Flexpert-Seq: | |
| * **Text** - Enter one or more proteins according to the specified format. | |
| * **File** - Select either .fasta file containing one or more proteins, or one or more .pdb files with a single-chain protein in the file. | |
| * **Note:** You can only select either **Text** or **File** input options per a single prediction. | |
| #### Flexpert-3D: | |
| * **File** - Select one or more .pdb files with a single-chain protein in the file. | |
| ### Outputs: | |
| #### Files: | |
| * Depending on your input, different output files appear: | |
| * A **.txt file** with the per-residue flexibility for all proteins **always appears**. | |
| * A **.fasta file** appears with all the proteins. | |
| * If you input a **.pdb file**, two .pdb files per protein appear, one with **'true'** per-residue flexibilities and **'scaled'** per-residue flexibilities. | |
| * For Flexpert-3D, another **.pdb file** per protein also appears containing per-residue ENM values. | |
| #### Visualisations: | |
| * You will notice that there is a possibility of seeing a visualisation of the per-residue flexibility of the provided proteins. These visualisations can only appear if you predict the flexibility via a **.pdb file**. | |
| * We provide both the **'real'** (flexibilities predicted by Flexpert) and the **'scaled'** (flexibilities normalised according to the maximum flexibility) visualisations. | |
| * To toggle between visualisations, click the lower-most button on the side-panel (the brush) and then choose between files. | |
| """) | |
| with gr.Tabs() as tabs: | |
| with gr.Tab("Flexpert-Seq", id="tab_seq"): | |
| with gr.Row(): | |
| text_button = gr.Button("Text Input", variant=PRIMARY) | |
| file_button = gr.Button("File Input", variant=SECONDARY) | |
| with gr.Column(visible=True) as col_text_input: | |
| input_seq = gr.Textbox( | |
| label="Paste Protein Sequences (FASTA format)", | |
| placeholder=">ProteinName1\nAGFASRGT...\n>ProteinName2\nQWERTY...", | |
| lines=10, | |
| scale=2 | |
| ) | |
| # Column for File Input (Default: Hidden) | |
| with gr.Column(visible=False) as col_file_input: | |
| input_file = gr.File(label="Select one or more .pdb files OR a .fasta file containing one or more proteins", file_count="multiple", file_types = ['.fasta', '.pdb']) | |
| predict_seq = gr.Button("Predict") | |
| all_outputs = [ | |
| col_text_input, | |
| col_file_input, | |
| input_seq, | |
| input_file, | |
| text_button, | |
| file_button | |
| ] | |
| text_button.click( | |
| fn=switch_component_view, | |
| inputs=[text_button], | |
| outputs= all_outputs | |
| ) | |
| file_button.click( | |
| fn=switch_component_view, | |
| inputs=[file_button], | |
| outputs= all_outputs | |
| ) | |
| with gr.Tab("Flexpert-3D"): | |
| input_file_3d = gr.File(label="Select one or more .pdb files", file_count = "multiple", file_types = ['.pdb']) | |
| predict_3d = gr.Button("Predict") | |
| output_text = gr.Textbox(label = "Output message", placeholder="The output message statement will be displayed here") | |
| reps = [ | |
| { | |
| "model": 0, | |
| "chain": "", | |
| "resname": "", | |
| "style": "cartoon", # or "stick", "sphere", "surface" | |
| "color": "alphafold", # This is the key - use alphafold color scheme | |
| "around": 0, | |
| "byres": False, | |
| "opacity": 1.0, | |
| } | |
| ] | |
| molecule_output = Molecule3D(label="Protein Structure", height=500, file_count = "multiple", reps = reps, confidenceLabel="Flexibility") | |
| output_files = gr.File(file_count="multiple", type = "filepath") | |
| clear_button = gr.ClearButton([input_seq, input_file, input_file_3d, output_text, molecule_output, output_files]) | |
| with gr.Row(): | |
| logos = gr.Image("logos.png", show_label=False, interactive=False) | |
| tabs.select( | |
| fn=clear_outputs, | |
| inputs=None, | |
| outputs=[output_text, molecule_output, output_files] | |
| ) | |
| text_button.click( | |
| fn=clear_outputs, | |
| inputs=None, | |
| outputs=[output_text, molecule_output, output_files] | |
| ) | |
| file_button.click( | |
| fn=clear_outputs, | |
| inputs=None, | |
| outputs=[output_text, molecule_output, output_files] | |
| ) | |
| # Connect the buttons to their respective functions. | |
| predict_seq.click(handle_seq_prediction, inputs=[input_seq, input_file], outputs=[output_files, output_text, molecule_output]) | |
| predict_3d.click(handle_3d_prediction, inputs=[input_file_3d], outputs=[output_files, output_text, molecule_output]) | |
| # Launch the interface | |
| demo.launch(show_error=True) |