flexpert / app.py
Honzus24's picture
Update app.py
3a2c568 verified
import sys
import spaces
import os, shutil
import gradio as gr
from data.scripts.data_utils import parse_PDB
from utils.utils import ClassConfig, DataCollatorForTokenRegression, process_in_batches_and_combine, get_dot_separated_name
from models.T5_encoder_per_token import PT5_classification_model
from data.scripts.get_enm_fluctuations_for_dataset import get_fluctuation_for_json_dict
import argparse
import yaml
import torch
from pathlib import Path
from Bio import SeqIO
import json
import os
import warnings
from datetime import datetime
from pathlib import Path
BASE_DIR = Path(__file__).resolve().parent
LOCAL_COMPONENT_PATH = BASE_DIR / "gradio_molecule3d" / "backend"
sys.path.insert(0, str(LOCAL_COMPONENT_PATH))
from gradio_molecule3d.molecule3d import Molecule3D
from Bio.PDB import PDBParser, PDBIO
from biotite.structure import annotate_sse
import biotite.structure.io as strucio
import biotite.structure.residues as residues
import numpy as np
from huggingface_hub import hf_hub_download, utils
import biotite.structure.io.pdb as pdb
import biotite.structure as struc
import biotite.sequence as seq
from data.scripts.data_utils import modify_bfactor_biotite
def get_first_chain_id(pdb_file):
try:
# Load the PDB file
f = pdb.PDBFile.read(pdb_file)
# Get structure (model 1)
atom_array = f.get_structure(model=1)
# Filter for amino acids (Protein only)
# This handles standard ATOMs and also common HETATMs like MSE automatically if defined
protein_mask = struc.filter_amino_acids(atom_array)
protein_atoms = atom_array[protein_mask]
if len(protein_atoms) == 0:
# Fallback: if filter_amino_acids is too strict for this file,
# just grab the chain of the first ATOM record found.
if len(atom_array) > 0:
return atom_array.chain_id[0]
return ""
# Return the chain ID of the first protein atom found
return protein_atoms.chain_id[0]
except Exception as e:
print(f"Warning: Biotite failed to detect chain for {pdb_file}: {e}")
return ""
def get_weights_path(repo_id, filename):
"""
Tries to get the local path immediately. If not found, downloads it.
"""
print(f"Looking for {filename} in {repo_id}...")
try:
# 1. FASTEST: Try loading entirely from local cache (no internet check)
return hf_hub_download(
repo_id=repo_id,
filename=filename,
local_files_only=True
)
except (utils.EntryNotFoundError, utils.LocalEntryNotFoundError, FileNotFoundError):
# 2. FALLBACK: If not found locally, download it (cached for next time)
print(f"Weights not found locally. Downloading from HF Hub...")
return hf_hub_download(
repo_id=repo_id,
filename=filename,
local_files_only=False
)
def process_pdb_file(pdb_file, backbones, sequences, names):
_name = pdb_file[:-4]
_chain = get_first_chain_id(pdb_file)
parsed_pdb = parse_PDB(pdb_file, name=_name, input_chain_list=[_chain])[0]
backbone, sequence = parsed_pdb['coords_chain_{}'.format(_chain)], parsed_pdb['seq_chain_{}'.format(_chain)]
if len(sequence) > 1023:
print("Sequence length is greater than 1023, skipping {}".format(_name))
else:
backbones.append(backbone)
sequences.append(sequence)
names.append(_name)
return backbones, sequences, names
@spaces.GPU
def flex_seq(input_seq, input_file):
if not input_seq:
input_seq = ""
if not input_seq.strip() and not input_file:
return None, "Provide a file/s or a input sequence/s"
if input_file:
if len(input_file) == 1:
input_file = input_file[0]
filename, suffix = os.path.splitext(input_file)
else:
suffix = ".pdb_list"
default_name = '{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S'))
output_name = default_name
sequences = []
names = []
backbones = []
flucts_list = []
pdb_files = []
datapoint_for_eval = 'all'
if input_seq:
suffix = ""
proteins = input_seq.strip().split('\n')
if len(proteins) % 2 != 0:
raise ValueError("You must adhere to the .fasta format")
for record in range(0, len(proteins), 2):
if ">" in proteins[record]:
name = proteins[record][1:]
sequence = proteins[record+1]
else:
raise ValueError("You must adhere to the .fasta format")
if datapoint_for_eval == 'all':
names.append(name)
sequences.append(sequence)
backbones.append(None)
elif suffix == ".fasta":
for record in SeqIO.parse(input_file, "fasta"):
name = record.name
if datapoint_for_eval == 'all':
names.append(name)
sequences.append(str(record.seq))
backbones.append(None)
elif suffix == ".pdb":
backbones, sequences, names = process_pdb_file(input_file, backbones, sequences, names)
pdb_files.append(input_file)
elif suffix == ".pdb_list":
for i in input_file:
backbones, sequences, names = process_pdb_file(i, backbones, sequences, names)
pdb_files.append(i)
env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader)
# Set folder for huggingface cache
os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME']
# Set gpu device
os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device']
config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
class_config=ClassConfig(config)
class_config.adaptor_architecture = 'no-adaptor'
config['inference_args']['device'] = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
model.to(config['inference_args']['device'])
repo_id = "Honzus24/Flexpert_weights"
file_weights = config['inference_args']['seq_model_path']
# Get path (instant if cached)
weights_path = get_weights_path(repo_id, file_weights)
# Load weights
state_dict = torch.load(weights_path, map_location=config['inference_args']['device'])
model.load_state_dict(state_dict, strict=False)
model.eval()
data_to_collate = []
for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):
#Ensure that the missing residues in the sequence are not represented as '-' but as 'X'
sequence = sequence.replace('-', 'X') #due to the tokenizer vocabulary
tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt')
tokenized_seq, attention_mask = tokenizer_out['input_ids'].to(config['inference_args']['device']), tokenizer_out['attention_mask'].to(config['inference_args']['device'])
data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:]})
data_collator = DataCollatorForTokenRegression(tokenizer)
batch = data_collator(data_to_collate) # Wrap in list since collator expects batch
batch.to(model.device)
# Predict
with torch.no_grad():
output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size'])
predictions = output_logits[:,:,0] #includes the prediction for the added token
# subselect the predictions using the attention mask
output_filename = Path(config['inference_args']['prediction_output_dir'].format(output_name, "seq"))
output_filename.parent.mkdir(parents=True, exist_ok=True)
output_files = []
output_message = "Success"
for prediction, mask, name, sequence in zip(predictions, batch['attention_mask'], names, sequences):
output_filename_new = output_filename.with_stem("{}_".format(name.split("/")[-1]) + output_filename.stem)
with open(output_filename_new.with_suffix('.txt'), 'w') as f:
f.write("Residue Number\tResidue ID\tFlexibility\n")
prediction = prediction[mask.bool()]
if len(prediction) != len(sequence)+1:
print("Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1))
assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)
p = prediction.tolist()[:-1]
for i in range(len(p)):
f.write(f"{i:<10}\t{sequence[i]:<20}\t{round(p[i], 4):<10}\n")
output_files.append(str(output_filename_new.with_suffix('.txt')))
if suffix == ".pdb" or suffix == ".pdb_list":
for name, pdb_file, prediction in zip(names, pdb_files, predictions):
_prediction = prediction[:-1].reshape(1,-1)
_outname = output_filename.with_name('{}_'.format(name.split("/")[-1]) + output_filename.stem + '.pdb')
print("Saving prediction to {}.".format(_outname))
modify_bfactor_biotite(pdb_file, None, _outname, _prediction) #writing the prediction without the last token
output_files.append(str(_outname))
_outname = output_filename.with_name(name.split("/")[-1] + output_filename.stem + '.fasta')
with open(_outname, 'w') as f:
print("Saving fasta to {}.".format(_outname))
for name, sequence in zip(names, sequences):
f.write('>' + name + '\n')
f.write(sequence + '\n')
output_files.append(str(_outname))
return output_files, output_message
@spaces.GPU
def flex_3d(input_file):
if not input_file:
return None, "Provide a file or a input sequence"
if len(input_file) == 1:
input_file = input_file[0]
filename, suffix = os.path.splitext(input_file)
else:
suffix = ".pdb_list"
default_name = '{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S'))
output_name = default_name
sequences = []
names = []
backbones = []
pdb_files = []
flucts_list = []
datapoint_for_eval = 'all'
if suffix == ".pdb":
backbones, sequences, names = process_pdb_file(input_file, backbones, sequences, names)
pdb_files.append(input_file)
elif suffix == ".jsonl":
for line in open(input_file, 'r'):
_dict = json.loads(line)
if 'fluctuations' in _dict.keys():
print("fluctuations are precomputed, using them")
dot_separated_name = get_dot_separated_name(key='pdb_name', _dict=_dict)
if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval:
names.append(_dict['pdb_name'])
backbones.append(None)
sequences.append(_dict['sequence'])
flucts_list.append(_dict['fluctuations']+[0.0]) #padding for end cls token
continue
dot_separated_name = get_dot_separated_name(key='name', _dict=_dict)
if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval:
backbones.append(_dict['coords'])
sequences.append(_dict['seq'])
names.append(dot_separated_name)
elif suffix == ".pdb_list":
for i in input_file:
backbones, sequences, names = process_pdb_file(i, backbones, sequences, names)
pdb_files.append(i)
env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader)
# Set folder for huggingface cache
os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME']
# Set gpu device
os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device']
config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
class_config=ClassConfig(config)
class_config.adaptor_architecture = 'conv'
config['inference_args']['device'] = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
model.to(config['inference_args']['device'])
repo_id = "Honzus24/Flexpert_weights"
print("Loading 3D model from {}".format(config['inference_args']['3d_model_path']))
file_weights = config['inference_args']['3d_model_path']
# Get path (instant if cached)
weights_path = get_weights_path(repo_id, file_weights)
# Load weights
state_dict = torch.load(weights_path, map_location=config['inference_args']['device'])
model.load_state_dict(state_dict, strict=False)
model.eval()
data_to_collate = []
for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):
if backbone is not None:
_dict = {'coords': backbone, 'seq': sequence}
flucts, _ = get_fluctuation_for_json_dict(_dict, enm_type = config['inference_args']['enm_type'])
flucts = flucts.tolist()
flucts.append(0.0) #To match the special token for the sequence
flucts = torch.tensor(flucts).to(config['inference_args']['device'])
else:
flucts = flucts_list[idx]
#Ensure that the missing residues in the sequence are not represented as '-' but as 'X'
sequence = sequence.replace('-', 'X') #due to the tokenizer vocabulary
tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt')
tokenized_seq, attention_mask = tokenizer_out['input_ids'].to(config['inference_args']['device']), tokenizer_out['attention_mask'].to(config['inference_args']['device'])
data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:], 'enm_vals': flucts})
# Use the data collator to process the input
data_collator = DataCollatorForTokenRegression(tokenizer)
batch = data_collator(data_to_collate) # Wrap in list since collator expects batch
batch.to(model.device)
# Predict
with torch.no_grad():
output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size'])
predictions = output_logits[:,:,0] #includes the prediction for the added token
# subselect the predictions using the attention mask
output_filename = Path(config['inference_args']['prediction_output_dir'].format(output_name, "3D"))
output_filename.parent.mkdir(parents=True, exist_ok=True)
output_files = []
output_message = "Success"
for prediction, mask, name, sequence in zip(predictions, batch['attention_mask'], names, sequences):
output_filename_new = output_filename.with_stem("{}_".format(name.split("/")[-1]) + output_filename.stem)
with open(output_filename_new.with_suffix('.txt'), 'w') as f:
f.write("Residue Number\tResidue ID\tFlexibility\n")
prediction = prediction[mask.bool()]
if len(prediction) != len(sequence)+1:
print("Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1))
assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)
p = prediction.tolist()[:-1]
for i in range(len(p)):
f.write(f"{i:<10}\t{sequence[i]:<20}\t{round(p[i], 4):<10}\n")
output_files.append(str(output_filename_new.with_suffix('.txt')))
output_files_enm = []
for enm_prediction, name in zip(batch['enm_vals'], names):
_outname_new = output_filename.with_name("{}".format(name.split("/")[-1]) + '_enm_' + output_filename.stem + '.txt')
with open(_outname_new, 'w') as f:
print("Saving ENM predictions to {}.".format(_outname_new))
for enm_prediction, name in zip(batch['enm_vals'], names):
f.write('>' + name + '\n')
f.write(', '.join([str(p) for p in enm_prediction.tolist()[:-1]]) + '\n')
output_files_enm.append(str(_outname_new))
if suffix == ".pdb" or suffix == ".pdb_list":
for name, pdb_file, prediction in zip(names, pdb_files, predictions):
_prediction = prediction[:-1].reshape(1,-1)
_outname = output_filename.with_name('{}_'.format(name.split("/")[-1]) + output_filename.stem + '.pdb')
print("Saving prediction to {}.".format(_outname))
modify_bfactor_biotite(pdb_file, None, _outname, _prediction) #writing the prediction without the last token
output_files.append(str(_outname))
_outname = output_filename.with_name(name.split("/")[-1] + output_filename.stem + '.fasta')
with open(_outname, 'w') as f:
print("Saving fasta to {}.".format(_outname))
for name, sequence in zip(names, sequences):
f.write('>' + name + '\n')
f.write(sequence + '\n')
output_files.append(str(_outname))
if suffix == ".pdb" or suffix == ".pdb_list":
for name, pdb_file, enm_vals_single in zip(names, pdb_files, batch['enm_vals']):
_outname = output_filename.with_name('{}_enm_'.format(name.split("/")[-1]) + output_filename.stem + '.pdb')
print("Saving ENM prediction to {}.".format(_outname))
_enm_vals = enm_vals_single[:-1].reshape(1,-1)
eps = 1e-6
_enm_vals = torch.clip(_enm_vals, -100+eps, 1000-eps)
_enm_vals = torch.nan_to_num(_enm_vals, nan=0.0)
_enm_vals = torch.round(_enm_vals, decimals=2)
modify_bfactor_biotite(pdb_file, None, _outname, _enm_vals) #writing the prediction without the last token
output_files_enm.append(str(_outname))
return output_files, output_message, output_files_enm
def rescale_bfactors(pdb_file):
base, ext = os.path.splitext(pdb_file)
# Create the new filename
out_file = base + "-scaled" + ext
atom_array = strucio.load_structure(pdb_file)
sse = annotate_sse(atom_array)
start = 0
for i, item in enumerate(sse):
if item == "a" or item == "b":
start = i
break
sse = sse[::-1]
end = 0
for i, item in enumerate(sse):
if item == "a" or item == "b":
end = i
break
end = len(sse) - end - 1
parser = PDBParser(QUIET=True)
structure = parser.get_structure("prot", pdb_file)
# Collect all bfactors
bfactors = [atom.bfactor for atom in structure.get_atoms()]
res_starts = residues.get_residue_starts(atom_array)
start = res_starts[start]
end = res_starts[end]
bfactors_start = bfactors[:start]
bfactors_end = bfactors[end:]
bfactors_struct = bfactors[start:end]
min_b = min(bfactors_struct)
max_b = max(bfactors_struct)
bfactors_start = np.clip(a = bfactors_start, min = min_b, max = max_b)
bfactors_end = np.clip(a = bfactors_end, min = min_b, max = max_b)
bfactors = np.concatenate((bfactors_start, bfactors_struct, bfactors_end))
def scale(b):
if max_b == min_b:
return 0.5 # arbitrary mid value
return ((b - min_b) / (max_b - min_b))
# Rescale all atoms
for i, atom in enumerate(structure.get_atoms()):
atom.set_bfactor(scale(bfactors[i]))
# Save to the *new* file path
io = PDBIO()
io.set_structure(structure)
io.save(out_file)
return out_file
def clear_files():
folder = 'prediction_results/'
if not os.path.isdir(folder):
os.makedirs(folder)
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
os.remove(file_path)
def handle_seq_prediction(input_seq, input_file):
clear_files()
main_files, message = flex_seq(input_seq, input_file)
fasta_index = next(
(i for i, filename in enumerate(main_files) if filename.endswith(".fasta"))
)
txt_index = next(
(i for i in range(len(main_files) - 1, -1, -1) if main_files[i].endswith(".txt"))
)
pdb_files_for_viz = [str(f) for f in main_files[txt_index+1:fasta_index] if f.endswith(('.pdb'))]
pdb_files_for_viz_scaled = [str(rescale_bfactors(f)) for f in main_files[txt_index+1:fasta_index] if f.endswith(('.pdb'))]
main_files.extend(pdb_files_for_viz_scaled)
pdb_files_for_viz.extend(pdb_files_for_viz_scaled)
return main_files, message, pdb_files_for_viz
def handle_3d_prediction(input_file):
clear_files()
main_files, message, enm_files = flex_3d(input_file)
fasta_index = next(
(i for i, filename in enumerate(main_files) if filename.endswith(".fasta"))
)
txt_index = next(
(i for i in range(len(main_files) - 1, -1, -1) if main_files[i].endswith(".txt"))
)
pdb_files_for_viz = [f for f in main_files if f.endswith(('.pdb'))]
pdb_files_for_viz_scaled = [rescale_bfactors(f) for f in main_files[1:fasta_index] if f.endswith(('.pdb'))]
pdb_files_for_viz.extend(pdb_files_for_viz_scaled)
main_files.extend(pdb_files_for_viz_scaled)
main_files.extend(enm_files)
return main_files, message, pdb_files_for_viz
def clear_outputs():
return "", None, None
PRIMARY = "primary"
SECONDARY = "secondary"
def switch_component_view(button_label):
updates = {
"text_visible": gr.update(visible=False),
"file_visible": gr.update(visible=False),
"text_clear": "",
"file_clear": []
}
# Updates for button colors
button_updates = {
"text_variant": gr.update(variant=SECONDARY),
"file_variant": gr.update(variant=SECONDARY)
}
if button_label == "Text Input":
updates["text_visible"] = gr.update(visible=True)
button_updates["text_variant"] = gr.update(variant=PRIMARY)
elif button_label == "File Input":
updates["file_visible"] = gr.update(visible=True)
button_updates["file_variant"] = gr.update(variant=PRIMARY)
return [
updates["text_visible"],
updates["file_visible"],
updates["text_clear"],
updates["file_clear"],
button_updates["text_variant"],
button_updates["file_variant"]
]
theme = gr.themes.Base(
neutral_hue="gray",
primary_hue="slate")
gr.set_static_paths(["prediction_results"])
with gr.Blocks(theme=theme) as demo:
gr.Image("Flexpert_logo.png", show_label=False, interactive=False)
gr.Markdown(value="""
## About Flexpert
On the web-version of Flexpert you can calculate the per-residue flexibility of a protein by either inputting the protein as a string or through .pdb/.fasta files.
### Inputs:
#### Flexpert-Seq:
* **Text** - Enter one or more proteins according to the specified format.
* **File** - Select either .fasta file containing one or more proteins, or one or more .pdb files with a single-chain protein in the file.
* **Note:** You can only select either **Text** or **File** input options per a single prediction.
#### Flexpert-3D:
* **File** - Select one or more .pdb files with a single-chain protein in the file.
### Outputs:
#### Files:
* Depending on your input, different output files appear:
* A **.txt file** with the per-residue flexibility for all proteins **always appears**.
* A **.fasta file** appears with all the proteins.
* If you input a **.pdb file**, two .pdb files per protein appear, one with **'true'** per-residue flexibilities and **'scaled'** per-residue flexibilities.
* For Flexpert-3D, another **.pdb file** per protein also appears containing per-residue ENM values.
#### Visualisations:
* You will notice that there is a possibility of seeing a visualisation of the per-residue flexibility of the provided proteins. These visualisations can only appear if you predict the flexibility via a **.pdb file**.
* We provide both the **'real'** (flexibilities predicted by Flexpert) and the **'scaled'** (flexibilities normalised according to the maximum flexibility) visualisations.
* To toggle between visualisations, click the lower-most button on the side-panel (the brush) and then choose between files.
""")
with gr.Tabs() as tabs:
with gr.Tab("Flexpert-Seq", id="tab_seq"):
with gr.Row():
text_button = gr.Button("Text Input", variant=PRIMARY)
file_button = gr.Button("File Input", variant=SECONDARY)
with gr.Column(visible=True) as col_text_input:
input_seq = gr.Textbox(
label="Paste Protein Sequences (FASTA format)",
placeholder=">ProteinName1\nAGFASRGT...\n>ProteinName2\nQWERTY...",
lines=10,
scale=2
)
# Column for File Input (Default: Hidden)
with gr.Column(visible=False) as col_file_input:
input_file = gr.File(label="Select one or more .pdb files OR a .fasta file containing one or more proteins", file_count="multiple", file_types = ['.fasta', '.pdb'])
predict_seq = gr.Button("Predict")
all_outputs = [
col_text_input,
col_file_input,
input_seq,
input_file,
text_button,
file_button
]
text_button.click(
fn=switch_component_view,
inputs=[text_button],
outputs= all_outputs
)
file_button.click(
fn=switch_component_view,
inputs=[file_button],
outputs= all_outputs
)
with gr.Tab("Flexpert-3D"):
input_file_3d = gr.File(label="Select one or more .pdb files", file_count = "multiple", file_types = ['.pdb'])
predict_3d = gr.Button("Predict")
output_text = gr.Textbox(label = "Output message", placeholder="The output message statement will be displayed here")
reps = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "cartoon", # or "stick", "sphere", "surface"
"color": "alphafold", # This is the key - use alphafold color scheme
"around": 0,
"byres": False,
"opacity": 1.0,
}
]
molecule_output = Molecule3D(label="Protein Structure", height=500, file_count = "multiple", reps = reps, confidenceLabel="Flexibility")
output_files = gr.File(file_count="multiple", type = "filepath")
clear_button = gr.ClearButton([input_seq, input_file, input_file_3d, output_text, molecule_output, output_files])
with gr.Row():
logos = gr.Image("logos.png", show_label=False, interactive=False)
tabs.select(
fn=clear_outputs,
inputs=None,
outputs=[output_text, molecule_output, output_files]
)
text_button.click(
fn=clear_outputs,
inputs=None,
outputs=[output_text, molecule_output, output_files]
)
file_button.click(
fn=clear_outputs,
inputs=None,
outputs=[output_text, molecule_output, output_files]
)
# Connect the buttons to their respective functions.
predict_seq.click(handle_seq_prediction, inputs=[input_seq, input_file], outputs=[output_files, output_text, molecule_output])
predict_3d.click(handle_3d_prediction, inputs=[input_file_3d], outputs=[output_files, output_text, molecule_output])
# Launch the interface
demo.launch(show_error=True)