File size: 21,785 Bytes
dd4330e
2334c10
 
cbdb785
0c3de7a
9c1eb20
cdb0d81
9c1eb20
cdb0d81
 
 
ccb91c3
cdb0d81
8e6f926
9c1eb20
5620fc5
 
ce47dc1
 
f1bba3b
 
 
19d3e69
 
 
 
e381e89
 
19d3e69
 
 
 
 
 
 
519edc2
ac925ab
9c2e8dd
00f906e
4d87697
 
 
 
19d3e69
 
 
db7f9c6
19d3e69
 
ff7b8e3
 
 
 
e413339
a9f4b7f
19d3e69
 
 
 
 
 
 
 
 
 
752caab
19d3e69
 
 
a9f4b7f
d825dd9
19d3e69
cd0b38f
a9f4b7f
19d3e69
 
2d9177f
19d3e69
 
 
db7f9c6
5926c76
 
19d3e69
f4a25aa
19d3e69
 
752caab
19d3e69
ff7b8e3
 
 
2a6013d
ff7b8e3
19d3e69
 
 
2aa5124
19d3e69
 
2aa5124
19d3e69
 
 
 
 
 
 
 
 
c7e53bf
19d3e69
 
 
 
7f2ddaa
19d3e69
 
 
 
 
 
 
2a6013d
19d3e69
 
 
 
 
 
 
 
519edc2
5620fc5
d63575c
ae5cd29
 
519edc2
 
cdb0d81
519edc2
 
 
 
cdb0d81
519edc2
 
 
fa9d40f
519edc2
 
 
 
d82df2e
519edc2
 
 
f18f4a7
5620fc5
f18f4a7
7362684
0c3de7a
519edc2
 
4f846da
0c1f112
5620fc5
0c3de7a
f4a25aa
0c3de7a
ae5cd29
f4a25aa
ae5cd29
f4a25aa
303be44
f4a25aa
303be44
 
0c3de7a
 
f4a25aa
0c3de7a
 
 
 
 
96c2d1e
6948058
efd4eb0
9526b25
b730525
d89da81
 
0c3de7a
5620fc5
f18f4a7
7362684
f18f4a7
7362684
4f846da
6ead239
4f846da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334cdb5
4f846da
6ead239
4f846da
d55ab1f
6ead239
 
4f846da
 
 
 
 
0c3de7a
c7d6740
5620fc5
f18f4a7
519edc2
 
 
 
9526b25
aa0a8ad
519edc2
aa0a8ad
f0e02fb
519edc2
 
 
c7d6740
 
ae5cd29
 
519edc2
 
fa9d40f
519edc2
 
aa0a8ad
519edc2
 
6ffebf7
06f1a01
519edc2
 
0c3de7a
519edc2
9161c6a
0c3de7a
 
4f846da
b596af3
4f846da
c7d6740
 
aa0a8ad
519edc2
3da800a
519edc2
5620fc5
f18f4a7
5539efa
f18f4a7
9526b25
519edc2
cdb0d81
287064c
cdb0d81
c7d6740
cdb0d81
 
 
 
 
 
 
287064c
 
cdb0d81
c7d6740
efd4eb0
cdb0d81
 
 
 
c7d6740
07768ea
 
 
31b25c9
97580aa
6ffebf7
97580aa
b730525
06ca46a
 
3730722
53edd52
ddfefab
871c9c4
53edd52
c7d6740
53edd52
 
 
 
 
 
 
7753a80
 
 
 
 
5539efa
ff7b8e3
 
 
 
63dab7e
4a62eed
 
4e5ecd0
19d3dce
b594185
19d3dce
c7d6740
 
4a62eed
033b440
4a62eed
ddfefab
 
4a62eed
c7d6740
6a8b48e
4a62eed
5d04074
f5eb425
4a62eed
19251cb
 
fb8afd6
19251cb
fb8afd6
53edd52
9769ada
 
 
fb8afd6
871c9c4
 
 
 
a982de1
92cd7a7
2334c10
cdb0d81
 
2334c10
 
 
 
 
 
 
 
 
 
 
 
 
d1f6cab
2334c10
 
d1f6cab
2334c10
 
 
3039822
2334c10
 
 
 
 
 
 
 
 
 
 
 
3039822
2334c10
 
9f3ddc0
fa9d40f
6fad7e1
c7d6740
2334c10
 
 
cdb0d81
 
2334c10
f7a164f
 
 
 
 
92cd7a7
f7a164f
2334c10
8074cb1
2334c10
 
 
 
 
 
cdb0d81
 
2334c10
 
 
 
 
 
 
 
 
 
 
 
 
db7f9c6
 
 
 
 
 
 
 
2334c10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdb0d81
2334c10
cdb0d81
c7d6740
8074cb1
 
 
86e36a2
8074cb1
 
c7d6740
86e36a2
2334c10
 
 
 
9fa463b
caecccc
2334c10
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
#ref: https://huggingface.co/blog/AmelieSchreiber/esm-interact
import gradio as gr

import os
from transformers import Trainer, TrainingArguments, AutoTokenizer, EsmForMaskedLM, AutoModelForMaskedLM, TrainerCallback, EsmForProteinFolding
import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.optim import AdamW
from torch.distributions import Categorical

import pandas as pd
from pathlib import Path

#import wandb
import numpy as np
from datetime import datetime
import time
from plot_pdb import plot_struc

import requests
import Bio.PDB

# Constants & Globals
HF_TOKEN = os.environ.get("HF_token")
print("HF_TOKEN:",HF_TOKEN)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

MODEL_OPTIONS = [
    "facebook/esm2_t6_8M_UR50D",
    "facebook/esm2_t12_35M_UR50D",
    "facebook/esm2_t33_650M_UR50D",
]  # models users can choose from

PEFT_MODEL_OPTIONS = [
    "wangjin2000/esm2_t6_8M_UR50D_PPI_2024-09-19_06",
    "wangjin2000/esm2_t12_35M_UR50D_PPI_2024-09-19_09",
    "wangjin2000/esm2_t33_650M_UR50D_PPI_2024-09-19_10",
    "wangjin2000/esm2_t6_8M_UR50D_PPI_2024-11-19_04",
    "facebook/esm2_t6_8M_UR50D",
    "facebook/esm2_t12_35M_UR50D",
    "facebook/esm2_t33_650M_UR50D",
]  # finetuned models, but the last three models are original fundation models

#build datasets
class ProteinDataset(Dataset):
    def __init__(self, file, tokenizer, peptide_length):
        data = pd.read_csv(file)
        self.tokenizer = tokenizer
        #self.proteins = data["Receptor Sequence"].tolist()
        #self.peptides = data["Binder"].tolist()
        self.proteins = data["P_Sequence"].tolist()  #header defined by Lin Qiao
        self.peptides = data["p_Sequence"].tolist()
        self.max_length_pm = 1000 + 2 + peptide_length  #assume the max length of protein is 1000
        
    def __len__(self):
        return len(self.proteins)

    def __getitem__(self, idx):
        protein_seq = self.proteins[idx]
        peptide_seq = self.peptides[idx]

        masked_peptide = '<mask>' * len(peptide_seq)
        complex_seq = protein_seq + masked_peptide
        # Tokenize and pad the complex sequence
        complex_input = self.tokenizer(complex_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True).to(device)

        input_ids = complex_input["input_ids"].squeeze()
        attention_mask = complex_input["attention_mask"].squeeze()
        
        # Create labels (tokens for ground truth AAs)
        label_seq = protein_seq + peptide_seq
        labels = self.tokenizer(label_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True)["input_ids"].to(device).squeeze()
        
        # Set non-masked positions in the labels tensor to -100
        labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)

        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

# fine-tuning function
def finetune(base_model_path, peptide_length):   #, train_dataset, test_dataset):
    
    #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #load base model
    base_model = EsmForMaskedLM.from_pretrained(base_model_path).to(device)
    
    # Tokenization
    tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")

    #train_dataset = ProteinDataset("./datasets/pepnn_train.csv", tokenizer, peptide_length)
    #test_dataset = ProteinDataset("./datasets/pepnn_test.csv", tokenizer, peptide_length)
    train_dataset = ProteinDataset("./datasets/peptide-protein-train.csv", tokenizer, peptide_length)
    eval_dataset = ProteinDataset("./datasets/peptide-protein-eval.csv", tokenizer, peptide_length)
    
    model_name_base = base_model_path.split("/")[1]
    timestamp = datetime.now().strftime('%Y-%m-%d_%H')
    lr = 0.0007984276816171436
    save_path = f"{model_name_base}_PPI_{timestamp}"
    
    training_args = TrainingArguments(
    output_dir=save_path,  #f"{model_name_base}_PPI_{timestamp}",
    num_train_epochs = 5,
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 16,
    warmup_steps = 501,
    logging_dir=None,
    logging_steps=10,
    evaluation_strategy="epoch",
    load_best_model_at_end=True,
    save_strategy='epoch',
    metric_for_best_model='eval_loss',
    save_total_limit = 5,
    gradient_accumulation_steps=2,
    push_to_hub=True,  #jw 20240918
    hub_token = HF_TOKEN, #jw 20240918
    dataloader_pin_memory=False, #jw 20241119 true for CPU
    )

    # Initialize Trainer
    trainer = Trainer(
        model=base_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        optimizers=(AdamW(base_model.parameters(), lr=lr), None),
    )

    # Train and Save Model
    trainer.train()

    return save_path

def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
    start = time.time()
    sequence = protein_seq + binder_seq 
    
    print("129:model.device in compute_pseudo_perplexity",model.device)
    original_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)
    length_of_binder = len(binder_seq)
    
    # Prepare a batch with each row having one masked token from the binder sequence
    masked_inputs = original_input.repeat(length_of_binder, 1)
    positions_to_mask = torch.arange(-length_of_binder - 1, -1, device=model.device)
    masked_inputs[torch.arange(length_of_binder), positions_to_mask] = tokenizer.mask_token_id
    
    # Prepare labels for the masked tokens
    labels = torch.full_like(masked_inputs, -100)
    labels[torch.arange(length_of_binder), positions_to_mask] = original_input[0, positions_to_mask]

    # Get model predictions and calculate loss
    with torch.no_grad():
        outputs = model(masked_inputs, labels=labels)
        loss = outputs.loss
             
    # Loss is already averaged by the model
    avg_loss = loss.item()
    pseudo_perplexity = np.exp(avg_loss)

    end = time.time()
    elapsed = end - start
    #print(f'compute_pseudo_perplexity time: {elapsed:.4f} seconds')
    
    return pseudo_perplexity

# compute pLLDT and iPMT from ESMFOLD model directly, very slow
def compute_plddt_iptm(protein_seq, binder_seq):
    start = time.time()
    # always the ESMFold model
    model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").to(device)
    tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")

    model.eval()
    # based on https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_folding.ipynb 
    #model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
    #model = model.cuda()
    #model.esm = model.esm.half()     # Uncomment to switch the stem to float16
    #torch.backends.cuda.matmul.allow_tf32 = True
    print("168:model.device in compute_plddt_iptm",model.device)
    
    sequence = protein_seq + binder_seq
    inputs = tokenizer(sequence, return_tensors='pt', add_special_tokens=False).to(device)

    # Get model predictions 
    with torch.no_grad():
        outputs = model(**inputs)
        plddt = outputs.plddt
        
        ptm = outputs.ptm.item()

        
    avg_plddt = plddt[0,:,1].mean().item()
    #iPTM = ptm
    #print("170: iPTM:",iPTM)
    
    end = time.time()
    elapsed = end - start
    #print(f'compute_plddt_iptm time: {elapsed:.4f} seconds')
    
    return avg_plddt, ptm

# call web API of ESMFOLD to get pLLDT
def get_plddt(protein_seq, binder_seq):
    start = time.time()
    sequence = protein_seq + binder_seq
    
    retries = 0
    pdb_str = None
    url = "https://api.esmatlas.com/foldSequence/v1/pdb/"
    while retries < 3 and pdb_str is None:
        response = requests.post(url, data=sequence, verify=False)
        pdb_str = response.text
        if pdb_str == "INTERNAL SERVER ERROR":
            retries += 1
            time.sleep(0.1)
            pdb = None #pdb = str = None

    #save a pdb format file
    name = sequence[:3] + sequence[-3:]  #combine the firt and last 3 AAs of sequence as a filename.
    outpath = (
        Path.cwd() / f"PDB-{name}.pdb")
    with open(outpath.name, "w") as f:
        f.write(pdb_str)
    outpath_str = str(outpath)

    #get pdb column values
    p = Bio.PDB.PDBParser()
    structure = p.get_structure('myStructureName', outpath_str)
    ids = [a.get_id() for a in structure.get_atoms()]
    pLDDTs = [a.get_bfactor() for a in structure.get_atoms()]     

    avg_plddt = np.mean(pLDDTs)
    
    ptm = 0 #place holder for iPTM
    end = time.time()
    elapsed = end - start
    print(f'get_plddt time: {elapsed:.4f} seconds')
    
    return avg_plddt, ptm
    
def generate_peptide_for_single_sequence(model, tokenizer, protein_seq, peptide_length = 15, top_k = 3, num_binders = 5, plddt_iptm_yes="no"):
    start = time.time()
    
    peptide_length = int(peptide_length)
    top_k = int(top_k)
    num_binders = int(num_binders)

    binders_with_ppl_plddt = []
    n = 0
    for _ in range(num_binders):
        n += 1
        #print("n in num_binders:", n)
        # Generate binder
        masked_peptide = '<mask>' * peptide_length
        input_sequence = protein_seq + masked_peptide
        inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device)
        #inputs = tokenizer(input_sequence, return_tensors="pt").to(device)
        print("198:model.device in generate_:",model.device)
        
        with torch.no_grad():
            logits = model(**inputs).logits

        mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
        logits_at_masks = logits[0, mask_token_indices]
        
        # Apply top-k sampling
        top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1)
        probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)                                
        predicted_indices = Categorical(probabilities).sample()
        predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
        generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '')
        
        # Compute PPL for the generated binder
        ppl = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder)

        # Get PLDDT from ESMFold model
        if  plddt_iptm_yes=="yes": 
            plddt, iptm = get_plddt(protein_seq, generated_binder)
            #plddt, iptm = compute_plddt_iptm(protein_seq, generated_binder) #too time-consuming
        else:
            plddt, iptm = [0, 0]
        
        # Add the generated binder and its PPL to the results list
        binders_with_ppl_plddt.append([generated_binder, ppl, plddt, iptm])

    end = time.time()
    elapsed = end - start
    print(f'generate_peptide_for_single_sequence: {elapsed:.4f} seconds')
    
    return binders_with_ppl_plddt

# Predict peptide binder with finetuned model
def predict_peptide(base_model_path, finetuned_model_path, input_seqs, peptide_length=15,  num_binders=4, plddt_iptm_yes="no"):
    # Load the model
    loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path)   #.to(device) inference use cpu

    # Ensure the model is in evaluation mode
    loaded_model.eval()
    
    # Tokenization
    tokenizer = AutoTokenizer.from_pretrained(base_model_path)

    # set top_k mutations for each AA position
    top_k=3
    if isinstance(input_seqs, str):  # Single sequence
        binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, input_seqs, peptide_length, top_k, num_binders, plddt_iptm_yes)
        results_df = pd.DataFrame(binders, columns=['Binder', 'PPL', 'pLDDT', 'iPTM'])

    elif isinstance(input_seqs, list):  # List of sequences
        results = []
        for seq in input_seqs:
            binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, seq, peptide_length, top_k, num_binders, plddt_iptm_yes)
            for binder, ppl, plddt, iptm in binders:
                results.append([seq, binder, ppl, plddt, iptm])
        results_df = pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'PPL', 'pLDDT', 'iPTM'])
    print(results_df)

    #combine target protein and predicted peptide with 20 G amino acids.
    separator = 'G' * 20
    peptide_lp = results_df['Binder'][results_df['PPL'].idxmin()] #Choosing the one with the lowest perplexity
    PPC = input_seqs + separator +  peptide_lp
    
    return results_df, PPC

def predict_peptide_from_file(base_model_path, finetuned_model_path, file_obj, max_peptide_length=15, num_binders=5, plddt_iptm_yes="no"):
    start = time.time()
    # Load the model
    loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path)  #.to(device)

    # Ensure the model is in evaluation mode
    loaded_model.eval()
    
    # Tokenization
    tokenizer = AutoTokenizer.from_pretrained(base_model_path)

    input = pd.read_csv(file_obj, header=0 )
    
    results = []
    
    for i, row in input.iterrows():
        print("sequence:", i)
        #protein_seq = row['Receptor Sequence']
        #peptide_seq = row['Peptide Sequence']
        protein_seq = row['P_Sequence']
        peptide_seq = row['p_Sequence']
        peptide_length = min([len(peptide_seq), max_peptide_length])  # use the same length of ground truth peptide length for prediction limited to max_peptide_length

        #get metrics for ground truth peptide
        ppl = compute_pseudo_perplexity(loaded_model, tokenizer, protein_seq, peptide_seq)
        if  plddt_iptm_yes=="yes":  
            plddt, iptm = get_plddt(protein_seq, peptide_seq)
            #plddt, iptm = compute_plddt_iptm(protein_seq, peptide_seq) #too time-consuming
        else:
            plddt, iptm = [0, 0]

        results.append([protein_seq, peptide_seq, ppl, plddt, iptm, 1]) # flag 1 for ground truth peptide

        # set top_k mutations for each AA position
        top_k=3
        #predict peptides
        binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, protein_seq, peptide_length, top_k, num_binders, plddt_iptm_yes)
        
        for binder, ppl, plddt, iptm in binders:    
            results.append([protein_seq, binder, ppl, plddt, iptm, 0]) # flag 0 for generated peptide
        
    results_df = pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'PPL', 'pLDDT', 'iPTM', 'GT_Flag'])
    
    timestamp = datetime.now().strftime('%Y-%m-%d_%H')
    outpath = (
        Path.cwd() / f"predicted_peptides_{timestamp}.csv"
    )
    
    str_outpath = str(outpath)  # work around as the latest gr.File needs a string_type input.

    results_df.to_csv(str_outpath, header=True, index=False)
    
    end = time.time()
    elapsed = end - start
    print(f'predict_peptide_from_file: {elapsed:.4f} seconds')
    
    return str_outpath   

def suggest(option):
    if option == "Protein:P63279":
        suggestion = "MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS"
    elif option == "Default protein":
        #suggestion = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
        suggestion = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" 
    elif option == "Antifreeze protein":
        suggestion = "QCTGGADCTSCTGACTGCGNCPNAVTCTNSQHCVKANTCTGSTDCNTAQTCTNSKDCFEANTCTDSTNCYKATACTNSSGCPGH"
    elif option == "AI Generated protein":
        suggestion = "MSGMKKLYEYTVTTLDEFLEKLKEFILNTSKDKIYKLTITNPKLIKDIGKAIAKAAEIADVDPKEIEEMIKAVEENELTKLVITIEQTDDKYVIKVELENEDGLVHSFEIYFKNKEEMEKFLELLEKLISKLSGS"
    elif option == "7-bladed propeller fold":
        suggestion = "VKLAGNSSLCPINGWAVYSKDNSIRIGSKGDVFVIREPFISCSHLECRTFFLTQGALLNDKHSNGTVKDRSPHRTLMSCPVGEAPSPYNSRFESVAWSASACHDGTSWLTIGISGPDNGAVAVLKYNGIITDTIKSWRNNILRTQESECACVNGSCFTVMTDGPSNGQASYKIFKMEKGKVVKSVELDAPNYHYEECSCYPNAGEITCVCRDNWHGSNRPWVSFNQNLEYQIGYICSGVFGDNPRPNDGTGSCGPVSSNGAYGVKGFSFKYGNGVWIGRTKSTNSRSGFEMIWDPNGWTETDSSFSVKQDIVAITDWSGYSGSFVQHPELTGLDCIRPCFWVELIRGRPKESTIWTSGSSISFCGVNSDTVGWSWPDGAELPFTIDK"
    else:
        suggestion = ""
    return suggestion

demo = gr.Blocks(title="ESM2 for Micropertide Protein Interaction (ESM2MPI)")

with demo:
    gr.Markdown("# ESM2 for Micropertide Protein Interaction (ESM2MPI)")
    #gr.Textbox(dubug_result)
    
    with gr.Column():
        gr.Markdown("## Select a base model and a corresponding finetuned model")

        with gr.Row():   
            with gr.Column(scale=5, variant="compact"):
                base_model_name = gr.Dropdown(
                    choices=MODEL_OPTIONS,
                    value=MODEL_OPTIONS[0],
                    label="Base Model Name",
                    interactive = True,
                )  
                PEFT_model_name = gr.Dropdown(
                    choices=PEFT_MODEL_OPTIONS,
                    value=PEFT_MODEL_OPTIONS[0],
                    label="Finetuned Model Name",
                    interactive = True,
                )
                with gr.Row():
                    peptide_length=gr.Slider(minimum=10, maximum=100, step=1, label="Peptide Maximum Length", value=15)
                    num_pred_peptides=gr.Slider(minimum=1, maximum=10, step=1, label="Number of Predicted Peptides", value=5)
                    plddt_iptm_yes=gr.Radio(["yes", "no"],label="Compute pLDDT and iPTM (slow!)", value="no")
            with gr.Column(scale=5, variant="compact"): 
                    name = gr.Dropdown(
                        label="Choose a Sample Protein", 
                        value="Protein:P63279", 
                        choices=["Default protein", "Antifreeze protein", "Protein:P63279", "AI Generated protein", "7-bladed propeller fold", "custom"]
                    )
                    uploaded_file = gr.File(
                        label="Local File Upload",
                        file_count="single",
                        file_types=[".tsv", ".csv"],
                        type="filepath",
                        height=40,
                    )
        gr.Markdown(
                "## Predict peptide sequence:"
                )
        with gr.Row():
            with gr.Column(variant="compact", scale = 8):  
                input_seq = gr.Textbox(
                    lines=1,
                    max_lines=12,
                    label="Protein:P63279 to be predicted:",
                    value="MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS",
                    placeholder="Paste your protein sequence here...",
                    interactive = True,
                ) 
                text_pos = gr.Textbox(
                    lines=1,
                    max_lines=12,
                    label="Sequency Position:",
                    placeholder=
                    "012345678911234567892123456789312345678941234567895123456789612345678971234567898123456789912345678901234567891123456789",
                    interactive=False,
                )
            with gr.Column(variant="compact", scale = 2):
                    predict_btn = gr.Button(
                        value="Predict peptide sequence from a protein sequence", 
                        interactive=True, 
                        variant="primary",
                    )
                    plot_struc_btn = gr.Button(value = "Plot ESMFold predicted structure ", variant="primary")
                
                    predict_file_btn = gr.Button(
                        value="Predict peptide from a local file", 
                        interactive=True, 
                        variant="primary",
                    )
        with gr.Row():  
            with gr.Column(variant="compact", scale = 5):
                output_text = gr.Textbox(
                    lines=1,
                    max_lines=12,
                    label="Output",
                    placeholder="Output",
                )   
            with gr.Column(variant="compact", scale = 5):  
                finetune_button = gr.Button(
                    value="Finetune Pre-trained Model", 
                    interactive=True, 
                    variant="primary",
                )
        with gr.Row(): 
            output_viewer = gr.HTML()
            output_file = gr.File(
                label="Download as Text File",
                file_count="single",
                type="filepath",  
                interactive=False,
            )  
            
    # select protein sample
    name.change(fn=suggest, inputs=name, outputs=input_seq)   

    # "Predict peptide sequence" actions
    predict_btn.click(
        fn = predict_peptide,
        inputs=[base_model_name,PEFT_model_name,input_seq,peptide_length,num_pred_peptides,plddt_iptm_yes], 
        outputs = [output_text, input_seq], 
    )

    # "Predict peptide from a local file" actions
    predict_file_btn.click(
        fn = predict_peptide_from_file,
        inputs=[base_model_name,PEFT_model_name,uploaded_file,peptide_length,num_pred_peptides,plddt_iptm_yes], 
        outputs = [output_file], 
    )
    
    # "Finetune Pre-trained Model" actions
    finetune_button.click(
        fn = finetune,
        inputs=[base_model_name,peptide_length],
        outputs = [output_text],
    )
           
    # plot protein structure
    plot_struc_btn.click(fn=plot_struc, inputs=input_seq, outputs=[output_file, output_viewer])

    
demo.launch()