File size: 28,190 Bytes
7968cb0
22a9b72
c3d5dd7
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f640a5
7968cb0
 
 
c3d5dd7
 
 
 
26dd42d
3a2c568
 
 
7968cb0
 
 
3a2c568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26dd42d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7968cb0
28e5ab6
3a2c568
28e5ab6
 
 
 
 
 
 
 
 
7968cb0
49eda9c
e0c0158
 
 
 
7968cb0
c3d5dd7
7968cb0
 
 
 
 
 
 
 
e0c0158
 
7968cb0
 
 
 
 
 
 
e0c0158
7968cb0
 
e0c0158
b962172
c3d5dd7
 
 
 
 
 
7968cb0
c3d5dd7
7968cb0
c3d5dd7
7968cb0
 
 
 
 
 
c3d5dd7
 
 
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3d5dd7
7968cb0
 
26dd42d
 
 
 
 
 
 
1a1ed59
 
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3d5dd7
7968cb0
e0c0158
 
7968cb0
e0c0158
c3d5dd7
e0c0158
c3d5dd7
7968cb0
 
 
 
 
 
e0c0158
 
c3d5dd7
e0c0158
7968cb0
 
 
 
c3d5dd7
7968cb0
c3d5dd7
7968cb0
 
c3d5dd7
e0c0158
 
 
 
 
 
7968cb0
 
 
49eda9c
e0c0158
7968cb0
 
 
 
 
 
 
 
 
e0c0158
 
7968cb0
 
 
 
 
 
 
e0c0158
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3d5dd7
7968cb0
 
 
30583fe
7968cb0
30583fe
 
 
 
 
 
 
7968cb0
 
30583fe
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3d5dd7
7968cb0
e0c0158
 
7968cb0
e0c0158
c3d5dd7
e0c0158
c3d5dd7
7968cb0
 
 
 
 
e0c0158
 
 
c3d5dd7
e0c0158
 
 
 
 
c3d5dd7
e0c0158
 
 
 
 
 
7968cb0
 
 
 
c3d5dd7
7968cb0
c3d5dd7
7968cb0
 
c3d5dd7
e0c0158
 
 
 
 
 
7968cb0
e0c0158
 
c3d5dd7
e0c0158
 
3a2c568
 
 
 
c3d5dd7
e0c0158
 
 
7968cb0
e0c0158
 
 
c3d5dd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7968cb0
 
 
 
 
 
c3d5dd7
 
 
 
 
 
 
 
 
7968cb0
c3d5dd7
 
 
 
 
 
 
7968cb0
 
 
c3d5dd7
e0c0158
7968cb0
 
c3d5dd7
 
7968cb0
e0c0158
7968cb0
 
 
e0c0158
7968cb0
 
c3d5dd7
 
28e5ab6
 
 
c3d5dd7
 
 
 
e0c0158
c3d5dd7
 
e0c0158
7968cb0
e0c0158
 
 
 
 
 
 
 
 
 
 
7968cb0
 
 
 
e0c0158
c3d5dd7
 
e0c0158
 
 
 
 
 
 
 
7968cb0
e0c0158
 
 
 
 
7968cb0
 
 
3a2c568
 
 
e0c0158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dc8258
e0c0158
c3d5dd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0c0158
3a2c568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7968cb0
3a2c568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0c0158
3a2c568
 
 
 
 
 
 
 
 
e0c0158
3a2c568
 
 
 
e0c0158
 
3a2c568
 
 
 
 
 
e0c0158
3a2c568
 
 
e0c0158
3a2c568
 
 
 
7968cb0
 
d0dbab5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
import sys
import spaces
import os, shutil
import gradio as gr
from data.scripts.data_utils import parse_PDB
from utils.utils import ClassConfig, DataCollatorForTokenRegression, process_in_batches_and_combine, get_dot_separated_name
from models.T5_encoder_per_token import PT5_classification_model
from data.scripts.get_enm_fluctuations_for_dataset import get_fluctuation_for_json_dict
import argparse
import yaml
import torch
from pathlib import Path
from Bio import SeqIO
import json
import os
import warnings
from datetime import datetime
from pathlib import Path
BASE_DIR = Path(__file__).resolve().parent
LOCAL_COMPONENT_PATH = BASE_DIR / "gradio_molecule3d" / "backend"
sys.path.insert(0, str(LOCAL_COMPONENT_PATH))
from gradio_molecule3d.molecule3d import Molecule3D
from Bio.PDB import PDBParser, PDBIO
from biotite.structure import annotate_sse
import biotite.structure.io as strucio
import biotite.structure.residues as residues
import numpy as np
from huggingface_hub import hf_hub_download, utils
import biotite.structure.io.pdb as pdb
import biotite.structure as struc
import biotite.sequence as seq

from data.scripts.data_utils import modify_bfactor_biotite

def get_first_chain_id(pdb_file):
    try:
        # Load the PDB file
        f = pdb.PDBFile.read(pdb_file)
        # Get structure (model 1)
        atom_array = f.get_structure(model=1)
        
        # Filter for amino acids (Protein only)
        # This handles standard ATOMs and also common HETATMs like MSE automatically if defined
        protein_mask = struc.filter_amino_acids(atom_array)
        protein_atoms = atom_array[protein_mask]
        
        if len(protein_atoms) == 0:
            # Fallback: if filter_amino_acids is too strict for this file,
            # just grab the chain of the first ATOM record found.
            if len(atom_array) > 0:
                return atom_array.chain_id[0]
            return ""

        # Return the chain ID of the first protein atom found
        return protein_atoms.chain_id[0]

    except Exception as e:
        print(f"Warning: Biotite failed to detect chain for {pdb_file}: {e}")
        return ""

def get_weights_path(repo_id, filename):
    """
    Tries to get the local path immediately. If not found, downloads it.
    """
    print(f"Looking for {filename} in {repo_id}...")
    try:
        # 1. FASTEST: Try loading entirely from local cache (no internet check)
        return hf_hub_download(
            repo_id=repo_id, 
            filename=filename, 
            local_files_only=True
        )
    except (utils.EntryNotFoundError, utils.LocalEntryNotFoundError, FileNotFoundError):
        # 2. FALLBACK: If not found locally, download it (cached for next time)
        print(f"Weights not found locally. Downloading from HF Hub...")
        return hf_hub_download(
            repo_id=repo_id, 
            filename=filename, 
            local_files_only=False
        )

def process_pdb_file(pdb_file, backbones, sequences, names):
    _name = pdb_file[:-4]
    _chain = get_first_chain_id(pdb_file)
    parsed_pdb = parse_PDB(pdb_file, name=_name, input_chain_list=[_chain])[0]
    backbone, sequence = parsed_pdb['coords_chain_{}'.format(_chain)], parsed_pdb['seq_chain_{}'.format(_chain)]
    if len(sequence) > 1023:
        print("Sequence length is greater than 1023, skipping {}".format(_name))
    else:
        backbones.append(backbone)
        sequences.append(sequence)
        names.append(_name)
    return backbones, sequences, names

@spaces.GPU
def flex_seq(input_seq, input_file):
    if not input_seq:
        input_seq = ""

    if not input_seq.strip() and not input_file:
        return None, "Provide a file/s or a input sequence/s"
    
    if input_file:
        if len(input_file) == 1:
            input_file = input_file[0]
            filename, suffix = os.path.splitext(input_file)
        else:
            suffix = ".pdb_list"

    default_name = '{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S'))
    output_name = default_name

    sequences = []
    names = []
    backbones = []
    flucts_list = []
    pdb_files = []

    datapoint_for_eval = 'all'

    if input_seq:
        suffix = ""
        proteins = input_seq.strip().split('\n')
        if len(proteins) % 2 != 0:
            raise ValueError("You must adhere to the .fasta format")
        for record in range(0, len(proteins), 2):
            if ">" in proteins[record]:
                name = proteins[record][1:]
                sequence = proteins[record+1]
            else:
                raise ValueError("You must adhere to the .fasta format")
            
            if datapoint_for_eval == 'all':
                names.append(name)
                sequences.append(sequence)
                backbones.append(None)

    elif suffix == ".fasta":
        for record in SeqIO.parse(input_file, "fasta"):
            name = record.name
            if datapoint_for_eval == 'all':
                names.append(name)
                sequences.append(str(record.seq))
                backbones.append(None)

    elif suffix == ".pdb":
        backbones, sequences, names = process_pdb_file(input_file, backbones, sequences, names)
        pdb_files.append(input_file)

    elif suffix == ".pdb_list":
        for i in input_file:
            backbones, sequences, names = process_pdb_file(i, backbones, sequences, names)
            pdb_files.append(i)
    
    env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader)
    # Set folder for huggingface cache
    os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME']
    # Set gpu device
    os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device']

    config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
    class_config=ClassConfig(config)
    class_config.adaptor_architecture = 'no-adaptor'
    config['inference_args']['device'] = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
    model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
    model.to(config['inference_args']['device'])
    repo_id = "Honzus24/Flexpert_weights"
    file_weights = config['inference_args']['seq_model_path']
    
    # Get path (instant if cached)
    weights_path = get_weights_path(repo_id, file_weights)
    
    # Load weights
    state_dict = torch.load(weights_path, map_location=config['inference_args']['device'])
    model.load_state_dict(state_dict, strict=False)
    model.eval()

    data_to_collate = []
    for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):
        #Ensure that the missing residues in the sequence are not represented as '-' but as 'X'
        sequence = sequence.replace('-', 'X') #due to the tokenizer vocabulary

        tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt')
        tokenized_seq, attention_mask = tokenizer_out['input_ids'].to(config['inference_args']['device']), tokenizer_out['attention_mask'].to(config['inference_args']['device'])
    
        data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:]})

    data_collator = DataCollatorForTokenRegression(tokenizer)
    batch = data_collator(data_to_collate)  # Wrap in list since collator expects batch
    batch.to(model.device)

    # Predict
    with torch.no_grad():
        output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size'])
        predictions = output_logits[:,:,0] #includes the prediction for the added token
        # subselect the predictions using the attention mask
    
    output_filename = Path(config['inference_args']['prediction_output_dir'].format(output_name, "seq"))
    output_filename.parent.mkdir(parents=True, exist_ok=True)
    output_files = []
    output_message = "Success"

    for prediction, mask, name, sequence in zip(predictions, batch['attention_mask'], names, sequences):
        output_filename_new = output_filename.with_stem("{}_".format(name.split("/")[-1]) + output_filename.stem)
        with open(output_filename_new.with_suffix('.txt'), 'w') as f:
            f.write("Residue Number\tResidue ID\tFlexibility\n")
            prediction = prediction[mask.bool()]
            if len(prediction) != len(sequence)+1:
                print("Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1))

            assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)

            p = prediction.tolist()[:-1]
            for i in range(len(p)):
                f.write(f"{i:<10}\t{sequence[i]:<20}\t{round(p[i], 4):<10}\n")
        output_files.append(str(output_filename_new.with_suffix('.txt')))

    if suffix == ".pdb" or suffix == ".pdb_list":
        for name, pdb_file, prediction in zip(names, pdb_files, predictions):
            _prediction = prediction[:-1].reshape(1,-1)
            _outname = output_filename.with_name('{}_'.format(name.split("/")[-1]) + output_filename.stem + '.pdb')
            print("Saving prediction to {}.".format(_outname))
            modify_bfactor_biotite(pdb_file, None, _outname, _prediction) #writing the prediction without the last token
            output_files.append(str(_outname))
    
    _outname = output_filename.with_name(name.split("/")[-1] + output_filename.stem + '.fasta')
    with open(_outname, 'w') as f:
        print("Saving fasta to {}.".format(_outname))
        for name, sequence in zip(names, sequences):
            f.write('>' + name + '\n')
            f.write(sequence + '\n')
        output_files.append(str(_outname))
    
    return output_files, output_message

@spaces.GPU
def flex_3d(input_file):        
    if not input_file:
        return None, "Provide a file or a input sequence"
    
    if len(input_file) == 1:
        input_file = input_file[0]
        filename, suffix = os.path.splitext(input_file)
    else:
        suffix = ".pdb_list"

    default_name = '{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S'))
    output_name = default_name

    sequences = []
    names = []
    backbones = []
    pdb_files = []
    flucts_list = []

    datapoint_for_eval = 'all'

    if suffix == ".pdb":
        backbones, sequences, names = process_pdb_file(input_file, backbones, sequences, names)
        pdb_files.append(input_file)

    elif suffix == ".jsonl":
        for line in open(input_file, 'r'):
            _dict = json.loads(line)

            if 'fluctuations' in _dict.keys():
                print("fluctuations are precomputed, using them")
                dot_separated_name = get_dot_separated_name(key='pdb_name', _dict=_dict)
                if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval:
            
                    names.append(_dict['pdb_name'])
                    backbones.append(None)
                    sequences.append(_dict['sequence'])

                    flucts_list.append(_dict['fluctuations']+[0.0]) #padding for end cls token
                continue
            
            dot_separated_name = get_dot_separated_name(key='name', _dict=_dict)
            
            if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval:
                backbones.append(_dict['coords'])
                sequences.append(_dict['seq'])
                names.append(dot_separated_name)
    
    elif suffix == ".pdb_list":
        for i in input_file:
            backbones, sequences, names = process_pdb_file(i, backbones, sequences, names)
            pdb_files.append(i)

    env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader)
    # Set folder for huggingface cache
    os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME']
    # Set gpu device
    os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device']

    config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
    class_config=ClassConfig(config)
    class_config.adaptor_architecture = 'conv'
    config['inference_args']['device'] = config['inference_args']['device'] if torch.cuda.is_available() else 'cpu'
    model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)

    model.to(config['inference_args']['device'])
    repo_id = "Honzus24/Flexpert_weights"
    print("Loading 3D model from {}".format(config['inference_args']['3d_model_path']))
    file_weights = config['inference_args']['3d_model_path']
    
    # Get path (instant if cached)
    weights_path = get_weights_path(repo_id, file_weights)
    
    # Load weights
    state_dict = torch.load(weights_path, map_location=config['inference_args']['device'])
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    
    data_to_collate = []
    for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):
        
        if backbone is not None:
            _dict = {'coords': backbone, 'seq': sequence}
            flucts, _ = get_fluctuation_for_json_dict(_dict, enm_type = config['inference_args']['enm_type'])
            flucts = flucts.tolist()
            flucts.append(0.0) #To match the special token for the sequence
            flucts = torch.tensor(flucts).to(config['inference_args']['device'])
        else:
            flucts = flucts_list[idx]

        #Ensure that the missing residues in the sequence are not represented as '-' but as 'X'
        sequence = sequence.replace('-', 'X') #due to the tokenizer vocabulary

        tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt')
        tokenized_seq, attention_mask = tokenizer_out['input_ids'].to(config['inference_args']['device']), tokenizer_out['attention_mask'].to(config['inference_args']['device'])
        
        data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:], 'enm_vals': flucts})

    # Use the data collator to process the input
    data_collator = DataCollatorForTokenRegression(tokenizer)

    batch = data_collator(data_to_collate)  # Wrap in list since collator expects batch
    batch.to(model.device)

    # Predict
    with torch.no_grad():
        output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size'])
        predictions = output_logits[:,:,0] #includes the prediction for the added token
        # subselect the predictions using the attention mask
    
    output_filename = Path(config['inference_args']['prediction_output_dir'].format(output_name, "3D"))
    output_filename.parent.mkdir(parents=True, exist_ok=True)
    output_files = []
    output_message = "Success"

    for prediction, mask, name, sequence in zip(predictions, batch['attention_mask'], names, sequences):
        output_filename_new = output_filename.with_stem("{}_".format(name.split("/")[-1]) + output_filename.stem)
        with open(output_filename_new.with_suffix('.txt'), 'w') as f:
            f.write("Residue Number\tResidue ID\tFlexibility\n")
            prediction = prediction[mask.bool()]
            if len(prediction) != len(sequence)+1:
                print("Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1))

            assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)

            p = prediction.tolist()[:-1]
            for i in range(len(p)):
                f.write(f"{i:<10}\t{sequence[i]:<20}\t{round(p[i], 4):<10}\n")
        output_files.append(str(output_filename_new.with_suffix('.txt')))

    output_files_enm = []

    for enm_prediction, name in zip(batch['enm_vals'], names):
        _outname_new = output_filename.with_name("{}".format(name.split("/")[-1]) + '_enm_' + output_filename.stem + '.txt')
        with open(_outname_new, 'w') as f:
            print("Saving ENM predictions to {}.".format(_outname_new))
            for enm_prediction, name in zip(batch['enm_vals'], names):
                f.write('>' + name + '\n')
                f.write(', '.join([str(p) for p in enm_prediction.tolist()[:-1]]) + '\n')
        output_files_enm.append(str(_outname_new))

    if suffix == ".pdb" or suffix == ".pdb_list":
        for name, pdb_file, prediction in zip(names, pdb_files, predictions):
            _prediction = prediction[:-1].reshape(1,-1)
            _outname = output_filename.with_name('{}_'.format(name.split("/")[-1]) + output_filename.stem + '.pdb')
            print("Saving prediction to {}.".format(_outname))
            modify_bfactor_biotite(pdb_file, None, _outname, _prediction) #writing the prediction without the last token
            output_files.append(str(_outname))

    _outname = output_filename.with_name(name.split("/")[-1] + output_filename.stem + '.fasta')
    with open(_outname, 'w') as f:
        print("Saving fasta to {}.".format(_outname))
        for name, sequence in zip(names, sequences):
            f.write('>' + name + '\n')
            f.write(sequence + '\n')
    output_files.append(str(_outname))
    
    if suffix == ".pdb" or suffix == ".pdb_list":
        for name, pdb_file, enm_vals_single in zip(names, pdb_files, batch['enm_vals']):
            _outname = output_filename.with_name('{}_enm_'.format(name.split("/")[-1]) + output_filename.stem + '.pdb')
            print("Saving ENM prediction to {}.".format(_outname))
            _enm_vals = enm_vals_single[:-1].reshape(1,-1)
            eps = 1e-6
            _enm_vals = torch.clip(_enm_vals, -100+eps, 1000-eps)
            _enm_vals = torch.nan_to_num(_enm_vals, nan=0.0)
            _enm_vals = torch.round(_enm_vals, decimals=2)
            modify_bfactor_biotite(pdb_file, None, _outname, _enm_vals) #writing the prediction without the last token
            output_files_enm.append(str(_outname))

    return output_files, output_message, output_files_enm

def rescale_bfactors(pdb_file):
    base, ext = os.path.splitext(pdb_file)
    # Create the new filename
    out_file = base + "-scaled" + ext

    atom_array = strucio.load_structure(pdb_file)
    sse = annotate_sse(atom_array)
    
    start = 0

    for i, item in enumerate(sse):
        if item == "a" or item == "b":
            start = i
            break

    sse = sse[::-1]
    end = 0

    for i, item in enumerate(sse):
        if item == "a" or item == "b":
            end = i
            break

    end = len(sse) - end - 1

    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("prot", pdb_file)

    # Collect all bfactors
    bfactors = [atom.bfactor for atom in structure.get_atoms()]
    
    res_starts = residues.get_residue_starts(atom_array)

    start = res_starts[start]
    end = res_starts[end]

    bfactors_start = bfactors[:start]
    bfactors_end = bfactors[end:]
    bfactors_struct = bfactors[start:end]

    min_b = min(bfactors_struct)
    max_b = max(bfactors_struct)

    bfactors_start = np.clip(a = bfactors_start, min = min_b, max = max_b)
    bfactors_end = np.clip(a = bfactors_end, min = min_b, max = max_b)

    bfactors = np.concatenate((bfactors_start, bfactors_struct, bfactors_end))

    def scale(b):
        if max_b == min_b:
            return 0.5  # arbitrary mid value
        return ((b - min_b) / (max_b - min_b))

    # Rescale all atoms
    for i, atom in enumerate(structure.get_atoms()):
        atom.set_bfactor(scale(bfactors[i]))

    # Save to the *new* file path
    io = PDBIO()
    io.set_structure(structure)
    io.save(out_file)
    
    return out_file

def clear_files():
    folder = 'prediction_results/'
    if not os.path.isdir(folder):
        os.makedirs(folder)
    
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        os.remove(file_path)

def handle_seq_prediction(input_seq, input_file):
    clear_files()

    main_files, message = flex_seq(input_seq, input_file)

    fasta_index = next(
        (i for i, filename in enumerate(main_files) if filename.endswith(".fasta"))
    )
    txt_index = next(
    (i for i in range(len(main_files) - 1, -1, -1) if main_files[i].endswith(".txt"))
    )

    pdb_files_for_viz = [str(f) for f in main_files[txt_index+1:fasta_index] if f.endswith(('.pdb'))]
    pdb_files_for_viz_scaled = [str(rescale_bfactors(f)) for f in main_files[txt_index+1:fasta_index] if f.endswith(('.pdb'))]
    main_files.extend(pdb_files_for_viz_scaled)
    pdb_files_for_viz.extend(pdb_files_for_viz_scaled)

    return main_files, message, pdb_files_for_viz


def handle_3d_prediction(input_file):
    clear_files()

    main_files, message, enm_files = flex_3d(input_file)

    fasta_index = next(
        (i for i, filename in enumerate(main_files) if filename.endswith(".fasta"))
    )
    txt_index = next(
        (i for i in range(len(main_files) - 1, -1, -1) if main_files[i].endswith(".txt"))
    )

    pdb_files_for_viz = [f for f in main_files if f.endswith(('.pdb'))]
    pdb_files_for_viz_scaled = [rescale_bfactors(f) for f in main_files[1:fasta_index] if f.endswith(('.pdb'))]
    pdb_files_for_viz.extend(pdb_files_for_viz_scaled)
    main_files.extend(pdb_files_for_viz_scaled)
    main_files.extend(enm_files)
        
    return main_files, message, pdb_files_for_viz

def clear_outputs():
    return "", None, None

PRIMARY = "primary"
SECONDARY = "secondary"

def switch_component_view(button_label):
    updates = {
        "text_visible": gr.update(visible=False),
        "file_visible": gr.update(visible=False),
        "text_clear": "",
        "file_clear": []
    }
    
    # Updates for button colors
    button_updates = {
        "text_variant": gr.update(variant=SECONDARY),
        "file_variant": gr.update(variant=SECONDARY)
    }

    if button_label == "Text Input":
        updates["text_visible"] = gr.update(visible=True)
        button_updates["text_variant"] = gr.update(variant=PRIMARY)

    elif button_label == "File Input":
        updates["file_visible"] = gr.update(visible=True)
        button_updates["file_variant"] = gr.update(variant=PRIMARY)
    
    return [
        updates["text_visible"], 
        updates["file_visible"], 
        updates["text_clear"], 
        updates["file_clear"],
        button_updates["text_variant"], 
        button_updates["file_variant"]
    ]

theme = gr.themes.Base(
    neutral_hue="gray",
    primary_hue="slate")

gr.set_static_paths(["prediction_results"])

with gr.Blocks(theme=theme) as demo:
    gr.Image("Flexpert_logo.png", show_label=False, interactive=False)
    gr.Markdown(value="""
        ## About Flexpert

        On the web-version of Flexpert you can calculate the per-residue flexibility of a protein by either inputting the protein as a string or through .pdb/.fasta files.

        ### Inputs:

        #### Flexpert-Seq:

        * **Text** - Enter one or more proteins according to the specified format.
        * **File** - Select either .fasta file containing one or more proteins, or one or more .pdb files with a single-chain protein in the file.
        * **Note:** You can only select either **Text** or **File** input options per a single prediction.

        #### Flexpert-3D:

        * **File** - Select one or more .pdb files with a single-chain protein in the file.

        ### Outputs:

        #### Files:

        * Depending on your input, different output files appear:
            * A **.txt file** with the per-residue flexibility for all proteins **always appears**.
            * A **.fasta file** appears with all the proteins.
            * If you input a **.pdb file**, two .pdb files per protein appear, one with **'true'** per-residue flexibilities and **'scaled'** per-residue flexibilities.
            * For Flexpert-3D, another **.pdb file** per protein also appears containing per-residue ENM values.

        #### Visualisations:

        * You will notice that there is a possibility of seeing a visualisation of the per-residue flexibility of the provided proteins. These visualisations can only appear if you predict the flexibility via a **.pdb file**.
        * We provide both the **'real'** (flexibilities predicted by Flexpert) and the **'scaled'** (flexibilities normalised according to the maximum flexibility) visualisations.
        * To toggle between visualisations, click the lower-most button on the side-panel (the brush) and then choose between files.

        """)
    with gr.Tabs() as tabs:
        with gr.Tab("Flexpert-Seq", id="tab_seq"):
            with gr.Row():
                text_button = gr.Button("Text Input", variant=PRIMARY)
                file_button = gr.Button("File Input", variant=SECONDARY)
    
            with gr.Column(visible=True) as col_text_input:
                input_seq = gr.Textbox(
                    label="Paste Protein Sequences (FASTA format)",
                    placeholder=">ProteinName1\nAGFASRGT...\n>ProteinName2\nQWERTY...",
                    lines=10,
                    scale=2
                )
    
            # Column for File Input (Default: Hidden)
            with gr.Column(visible=False) as col_file_input:
                input_file = gr.File(label="Select one or more .pdb files OR a .fasta file containing one or more proteins", file_count="multiple", file_types = ['.fasta', '.pdb'])
    
            predict_seq = gr.Button("Predict")
    
            all_outputs = [
                col_text_input, 
                col_file_input, 
                input_seq, 
                input_file, 
                text_button,
                file_button
            ]
    
            text_button.click(
                fn=switch_component_view,
                inputs=[text_button],
                outputs= all_outputs
            )
    
            file_button.click(
                fn=switch_component_view,
                inputs=[file_button],
                outputs= all_outputs
            )
    
    
        with gr.Tab("Flexpert-3D"):
            input_file_3d = gr.File(label="Select one or more .pdb files", file_count = "multiple", file_types = ['.pdb'])
    
            predict_3d = gr.Button("Predict")
    
        output_text = gr.Textbox(label = "Output message", placeholder="The output message statement will be displayed here")
    
        reps = [
            {
            "model": 0,
            "chain": "",
            "resname": "",
            "style": "cartoon",  # or "stick", "sphere", "surface"
            "color": "alphafold",  # This is the key - use alphafold color scheme
            "around": 0,
            "byres": False,
            "opacity": 1.0,
            }
        ]
    
        molecule_output = Molecule3D(label="Protein Structure", height=500, file_count = "multiple", reps = reps, confidenceLabel="Flexibility")
    
        output_files = gr.File(file_count="multiple", type = "filepath")
    
        clear_button = gr.ClearButton([input_seq, input_file, input_file_3d, output_text, molecule_output, output_files])
    
        with gr.Row():
            logos = gr.Image("logos.png", show_label=False, interactive=False)

        tabs.select(
            fn=clear_outputs,
            inputs=None,
            outputs=[output_text, molecule_output, output_files]
        )

        text_button.click(
            fn=clear_outputs, 
            inputs=None, 
            outputs=[output_text, molecule_output, output_files]
        )
        
        file_button.click(
            fn=clear_outputs, 
            inputs=None, 
            outputs=[output_text, molecule_output, output_files]
        )
        
        # Connect the buttons to their respective functions.
        predict_seq.click(handle_seq_prediction, inputs=[input_seq, input_file], outputs=[output_files, output_text, molecule_output])
        predict_3d.click(handle_3d_prediction, inputs=[input_file_3d], outputs=[output_files, output_text, molecule_output])

# Launch the interface
demo.launch(show_error=True)