ThorbenFroehlking commited on
Commit
af13564
·
0 Parent(s):
Files changed (4) hide show
  1. README.md +14 -0
  2. app.py +653 -0
  3. model_loader.py +640 -0
  4. requirements.txt +14 -0
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Test Webpage
3
+ emoji: 🐢
4
+ colorFrom: blue
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.7.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: test_webpage
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import gradio as gr
3
+ import requests
4
+ from Bio.PDB import PDBParser, MMCIFParser, PDBIO, Select
5
+ from Bio.PDB.Polypeptide import is_aa
6
+ from Bio.SeqUtils import seq1
7
+ from typing import Optional, Tuple
8
+ import numpy as np
9
+ import os
10
+ from gradio_molecule3d import Molecule3D
11
+
12
+ from model_loader import load_model
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import DataLoader
18
+
19
+ import re
20
+ import pandas as pd
21
+ import copy
22
+
23
+ import transformers
24
+ from transformers import AutoTokenizer, DataCollatorForTokenClassification
25
+
26
+ from datasets import Dataset
27
+
28
+ from scipy.special import expit
29
+
30
+ # Load model and move to device
31
+ #checkpoint = 'ThorbenF/prot_t5_xl_uniref50'
32
+ #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_cryptic'
33
+ #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_database'
34
+ #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_full'
35
+ #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_0925'
36
+ #checkpoint = 'ThorbenF/prot_t5_xl_uniref50_0925_v2'
37
+ checkpoint = 'ThorbenF/prot_t5_xl_uniref50_full_v2'
38
+ max_length = 1500
39
+ model, tokenizer = load_model(checkpoint, max_length)
40
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
+ model.to(device)
42
+ model.eval()
43
+
44
+ def normalize_scores(scores):
45
+ min_score = np.min(scores)
46
+ max_score = np.max(scores)
47
+ return (scores - min_score) / (max_score - min_score) if max_score > min_score else scores
48
+
49
+ def read_mol(pdb_path):
50
+ """Read PDB file and return its content as a string"""
51
+ with open(pdb_path, 'r') as f:
52
+ return f.read()
53
+
54
+ def fetch_structure(pdb_id: str, output_dir: str = ".") -> str:
55
+ """
56
+ Fetch the structure file for a given PDB ID. Prioritizes CIF files.
57
+ If a structure file already exists locally, it uses that.
58
+ """
59
+ file_path = download_structure(pdb_id, output_dir)
60
+ return file_path
61
+
62
+ def download_structure(pdb_id: str, output_dir: str) -> str:
63
+ """
64
+ Attempt to download the structure file in CIF or PDB format.
65
+ Returns the path to the downloaded file.
66
+ """
67
+ for ext in ['.cif', '.pdb']:
68
+ file_path = os.path.join(output_dir, f"{pdb_id}{ext}")
69
+ if os.path.exists(file_path):
70
+ return file_path
71
+ url = f"https://files.rcsb.org/download/{pdb_id}{ext}"
72
+ response = requests.get(url, timeout=10)
73
+ if response.status_code == 200:
74
+ with open(file_path, 'wb') as f:
75
+ f.write(response.content)
76
+ return file_path
77
+ return None
78
+
79
+ def convert_cif_to_pdb(cif_path: str, output_dir: str = ".") -> str:
80
+ """
81
+ Convert a CIF file to PDB format using BioPython and return the PDB file path.
82
+ """
83
+ pdb_path = os.path.join(output_dir, os.path.basename(cif_path).replace('.cif', '.pdb'))
84
+ parser = MMCIFParser(QUIET=True)
85
+ structure = parser.get_structure('protein', cif_path)
86
+ io = PDBIO()
87
+ io.set_structure(structure)
88
+ io.save(pdb_path)
89
+ return pdb_path
90
+
91
+ def fetch_pdb(pdb_id):
92
+ pdb_path = fetch_structure(pdb_id)
93
+ _, ext = os.path.splitext(pdb_path)
94
+ if ext == '.cif':
95
+ pdb_path = convert_cif_to_pdb(pdb_path)
96
+ return pdb_path
97
+
98
+ def create_chain_specific_pdb(input_pdb: str, chain_id: str, residue_scores: list, protein_residues: list) -> str:
99
+ """
100
+ Create a PDB file with only the selected chain and residues, replacing B-factor with prediction scores
101
+ """
102
+ parser = PDBParser(QUIET=True)
103
+ structure = parser.get_structure('protein', input_pdb)
104
+
105
+ output_pdb = f"{os.path.splitext(input_pdb)[0]}_{chain_id}_predictions_scores.pdb"
106
+
107
+ # Create scores dictionary for easy lookup
108
+ scores_dict = {resi: score for resi, score in residue_scores}
109
+
110
+ # Create a custom Select class
111
+ class ResidueSelector(Select):
112
+ def __init__(self, chain_id, selected_residues, scores_dict):
113
+ self.chain_id = chain_id
114
+ self.selected_residues = selected_residues
115
+ self.scores_dict = scores_dict
116
+
117
+ def accept_chain(self, chain):
118
+ return chain.id == self.chain_id
119
+
120
+ def accept_residue(self, residue):
121
+ return residue.id[1] in self.selected_residues
122
+
123
+ def accept_atom(self, atom):
124
+ if atom.parent.id[1] in self.scores_dict:
125
+ atom.bfactor = np.absolute(1-self.scores_dict[atom.parent.id[1]]) * 100
126
+ return True
127
+
128
+ # Prepare output PDB with selected chain and residues, modified B-factors
129
+ io = PDBIO()
130
+ selector = ResidueSelector(chain_id, [res.id[1] for res in protein_residues], scores_dict)
131
+
132
+ io.set_structure(structure[0])
133
+ io.save(output_pdb, selector)
134
+
135
+ return output_pdb
136
+
137
+ def generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, score_type):
138
+ """Generate PyMOL commands based on score type"""
139
+ pymol_commands = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n"
140
+
141
+ pymol_commands += f"""
142
+ # PyMOL Visualization Commands
143
+ fetch {pdb_id}, protein
144
+ hide everything, all
145
+ show cartoon, chain {segment}
146
+ color white, chain {segment}
147
+ """
148
+
149
+ # Define colors for each score bracket
150
+ bracket_colors = {
151
+ "0.0-0.2": "white",
152
+ "0.2-0.4": "lightorange",
153
+ "0.4-0.6": "yelloworange",
154
+ "0.6-0.8": "orange",
155
+ "0.8-1.0": "red"
156
+ }
157
+
158
+ # Add PyMOL commands for each score bracket
159
+ for bracket, residues in residues_by_bracket.items():
160
+ if residues: # Only add commands if there are residues in this bracket
161
+ color = bracket_colors[bracket]
162
+ resi_list = '+'.join(map(str, residues))
163
+ pymol_commands += f"""
164
+ select bracket_{bracket.replace('.', '').replace('-', '_')}, resi {resi_list} and chain {segment}
165
+ show sticks, bracket_{bracket.replace('.', '').replace('-', '_')}
166
+ color {color}, bracket_{bracket.replace('.', '').replace('-', '_')}
167
+ """
168
+ return pymol_commands
169
+
170
+ def generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence, scores, current_time, score_type):
171
+ """Generate results text based on score type"""
172
+ result_str = f"Prediction for PDB: {pdb_id}, Chain: {segment}\nDate: {current_time}\nScore Type: {score_type}\n\n"
173
+ result_str += "Residues by Score Brackets:\n\n"
174
+
175
+ # Add residues for each bracket
176
+ for bracket, residues in residues_by_bracket.items():
177
+ result_str += f"Bracket {bracket}:\n"
178
+ result_str += f"Columns: Residue Name, Residue Number, One-letter Code, {score_type} Score\n"
179
+ result_str += "\n".join([
180
+ f"{res.resname} {res.id[1]} {sequence[i]} {scores[i]:.2f}"
181
+ for i, res in enumerate(protein_residues) if res.id[1] in residues
182
+ ])
183
+ result_str += "\n\n"
184
+
185
+ return result_str
186
+
187
+
188
+
189
+
190
+ def process_pdb(pdb_id_or_file, segment, score_type='normalized'):
191
+ # Determine if input is a PDB ID or file path
192
+ if pdb_id_or_file.endswith('.pdb'):
193
+ pdb_path = pdb_id_or_file
194
+ pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]
195
+ else:
196
+ pdb_id = pdb_id_or_file
197
+ pdb_path = fetch_pdb(pdb_id)
198
+
199
+ # Determine the file format and choose the appropriate parser
200
+ _, ext = os.path.splitext(pdb_path)
201
+ parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
202
+
203
+ # Parse the structure file
204
+ structure = parser.get_structure('protein', pdb_path)
205
+
206
+ # Extract the specified chain
207
+ chain = structure[0][segment]
208
+
209
+ protein_residues = [res for res in chain if is_aa(res)]
210
+ sequence = "".join(seq1(res.resname) for res in protein_residues)
211
+ sequence_id = [res.id[1] for res in protein_residues]
212
+
213
+ input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
214
+ with torch.no_grad():
215
+ outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
216
+
217
+ # Calculate scores and normalize them
218
+ raw_scores = expit(outputs[:, 1] - outputs[:, 0])
219
+ normalized_scores = normalize_scores(raw_scores)
220
+
221
+ # Choose which scores to use based on score_type
222
+ display_scores = normalized_scores if score_type == 'normalized' else raw_scores
223
+
224
+ # Zip residues with scores to track the residue ID and score
225
+ residue_scores = [(resi, score) for resi, score in zip(sequence_id, display_scores)]
226
+
227
+ # Also save both score types for later use
228
+ raw_residue_scores = [(resi, score) for resi, score in zip(sequence_id, raw_scores)]
229
+ norm_residue_scores = [(resi, score) for resi, score in zip(sequence_id, normalized_scores)]
230
+
231
+ # Define the score brackets
232
+ score_brackets = {
233
+ "0.0-0.2": (0.0, 0.2),
234
+ "0.2-0.4": (0.2, 0.4),
235
+ "0.4-0.6": (0.4, 0.6),
236
+ "0.6-0.8": (0.6, 0.8),
237
+ "0.8-1.0": (0.8, 1.0)
238
+ }
239
+
240
+ # Initialize a dictionary to store residues by bracket
241
+ residues_by_bracket = {bracket: [] for bracket in score_brackets}
242
+
243
+ # Categorize residues into brackets
244
+ for resi, score in residue_scores:
245
+ for bracket, (lower, upper) in score_brackets.items():
246
+ if lower <= score < upper:
247
+ residues_by_bracket[bracket].append(resi)
248
+ break
249
+
250
+ # Generate timestamp
251
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
252
+
253
+ # Generate result text and PyMOL commands based on score type
254
+ display_score_type = "Normalized" if score_type == 'normalized' else "Raw"
255
+ result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence,
256
+ display_scores, current_time, display_score_type)
257
+ pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type)
258
+
259
+ # Create chain-specific PDB with scores in B-factor
260
+ scored_pdb = create_chain_specific_pdb(pdb_path, segment, residue_scores, protein_residues)
261
+
262
+ # Molecule visualization with updated script with color mapping
263
+ mol_vis = molecule(pdb_path, residue_scores, segment)
264
+
265
+ # Create prediction file
266
+ prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt"
267
+ with open(prediction_file, "w") as f:
268
+ f.write(result_str)
269
+
270
+ scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb"
271
+ os.rename(scored_pdb, scored_pdb_name)
272
+
273
+ return pymol_commands, mol_vis, [prediction_file, scored_pdb_name], raw_residue_scores, norm_residue_scores, pdb_id, segment
274
+
275
+ def molecule(input_pdb, residue_scores=None, segment='A'):
276
+ # Read PDB file content
277
+ mol = read_mol(input_pdb)
278
+
279
+ # Prepare high-scoring residues script if scores are provided
280
+ high_score_script = ""
281
+ if residue_scores is not None:
282
+ # Filter residues based on their scores
283
+ class1_score_residues = [resi for resi, score in residue_scores if 0.0 < score <= 0.2]
284
+ class2_score_residues = [resi for resi, score in residue_scores if 0.2 < score <= 0.4]
285
+ class3_score_residues = [resi for resi, score in residue_scores if 0.4 < score <= 0.6]
286
+ class4_score_residues = [resi for resi, score in residue_scores if 0.6 < score <= 0.8]
287
+ class5_score_residues = [resi for resi, score in residue_scores if 0.8 < score <= 1.0]
288
+
289
+ high_score_script = """
290
+ // Load the original model and apply white cartoon style
291
+ let chainModel = viewer.addModel(pdb, "pdb");
292
+ chainModel.setStyle({}, {});
293
+ chainModel.setStyle(
294
+ {"chain": "%s"},
295
+ {"cartoon": {"color": "white"}}
296
+ );
297
+
298
+ // Create a new model for high-scoring residues and apply red sticks style
299
+ let class1Model = viewer.addModel(pdb, "pdb");
300
+ class1Model.setStyle({}, {});
301
+ class1Model.setStyle(
302
+ {"chain": "%s", "resi": [%s]},
303
+ {"stick": {"color": "0xFFFFFF", "opacity": 0.5}}
304
+ );
305
+
306
+ // Create a new model for high-scoring residues and apply red sticks style
307
+ let class2Model = viewer.addModel(pdb, "pdb");
308
+ class2Model.setStyle({}, {});
309
+ class2Model.setStyle(
310
+ {"chain": "%s", "resi": [%s]},
311
+ {"stick": {"color": "0xFFD580", "opacity": 0.7}}
312
+ );
313
+
314
+ // Create a new model for high-scoring residues and apply red sticks style
315
+ let class3Model = viewer.addModel(pdb, "pdb");
316
+ class3Model.setStyle({}, {});
317
+ class3Model.setStyle(
318
+ {"chain": "%s", "resi": [%s]},
319
+ {"stick": {"color": "0xFFA500", "opacity": 1}}
320
+ );
321
+
322
+ // Create a new model for high-scoring residues and apply red sticks style
323
+ let class4Model = viewer.addModel(pdb, "pdb");
324
+ class4Model.setStyle({}, {});
325
+ class4Model.setStyle(
326
+ {"chain": "%s", "resi": [%s]},
327
+ {"stick": {"color": "0xFF4500", "opacity": 1}}
328
+ );
329
+
330
+ // Create a new model for high-scoring residues and apply red sticks style
331
+ let class5Model = viewer.addModel(pdb, "pdb");
332
+ class5Model.setStyle({}, {});
333
+ class5Model.setStyle(
334
+ {"chain": "%s", "resi": [%s]},
335
+ {"stick": {"color": "0xFF0000", "alpha": 1}}
336
+ );
337
+
338
+ """ % (
339
+ segment,
340
+ segment,
341
+ ", ".join(str(resi) for resi in class1_score_residues),
342
+ segment,
343
+ ", ".join(str(resi) for resi in class2_score_residues),
344
+ segment,
345
+ ", ".join(str(resi) for resi in class3_score_residues),
346
+ segment,
347
+ ", ".join(str(resi) for resi in class4_score_residues),
348
+ segment,
349
+ ", ".join(str(resi) for resi in class5_score_residues)
350
+ )
351
+
352
+ # Generate the full HTML content
353
+ html_content = f"""
354
+ <!DOCTYPE html>
355
+ <html>
356
+ <head>
357
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
358
+ <style>
359
+ .mol-container {{
360
+ width: 100%;
361
+ height: 700px;
362
+ position: relative;
363
+ }}
364
+ </style>
365
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js"></script>
366
+ <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
367
+ </head>
368
+ <body>
369
+ <div id="container" class="mol-container"></div>
370
+ <script>
371
+ let pdb = `{mol}`; // Use template literal to properly escape PDB content
372
+ $(document).ready(function () {{
373
+ let element = $("#container");
374
+ let config = {{ backgroundColor: "white" }};
375
+ let viewer = $3Dmol.createViewer(element, config);
376
+
377
+ {high_score_script}
378
+
379
+ // Add hover functionality
380
+ viewer.setHoverable(
381
+ {{}},
382
+ true,
383
+ function(atom, viewer, event, container) {{
384
+ if (!atom.label) {{
385
+ atom.label = viewer.addLabel(
386
+ atom.resn + ":" +atom.resi + ":" + atom.atom,
387
+ {{
388
+ position: atom,
389
+ backgroundColor: 'mintcream',
390
+ fontColor: 'black',
391
+ fontSize: 18,
392
+ padding: 4
393
+ }}
394
+ );
395
+ }}
396
+ }},
397
+ function(atom, viewer) {{
398
+ if (atom.label) {{
399
+ viewer.removeLabel(atom.label);
400
+ delete atom.label;
401
+ }}
402
+ }}
403
+ );
404
+
405
+ viewer.zoomTo();
406
+ viewer.render();
407
+ viewer.zoom(0.8, 2000);
408
+ }});
409
+ </script>
410
+ </body>
411
+ </html>
412
+ """
413
+
414
+ # Return the HTML content within an iframe safely encoded for special characters
415
+ return f'<iframe width="100%" height="700" srcdoc="{html_content.replace(chr(34), "&quot;").replace(chr(39), "&#39;")}"></iframe>'
416
+
417
+ with gr.Blocks(css="""
418
+ /* Customize Gradio button colors */
419
+ #visualize-btn, #predict-btn {
420
+ background-color: #FF7300; /* Deep orange */
421
+ color: white;
422
+ border-radius: 5px;
423
+ padding: 10px;
424
+ font-weight: bold;
425
+ }
426
+ #visualize-btn:hover, #predict-btn:hover {
427
+ background-color: #CC5C00; /* Darkened orange on hover */
428
+ }
429
+ """) as demo:
430
+ gr.Markdown("# Protein Binding Site Prediction")
431
+
432
+ # Mode selection
433
+ mode = gr.Radio(
434
+ choices=["PDB ID", "Upload File"],
435
+ value="PDB ID",
436
+ label="Input Mode",
437
+ info="Choose whether to input a PDB ID or upload a PDB/CIF file."
438
+ )
439
+
440
+ # Input components based on mode
441
+ pdb_input = gr.Textbox(value="2F6V", label="PDB ID", placeholder="Enter PDB ID here...")
442
+ pdb_file = gr.File(label="Upload PDB/CIF File", visible=False)
443
+ visualize_btn = gr.Button("Visualize Structure", elem_id="visualize-btn")
444
+
445
+ molecule_output2 = Molecule3D(label="Protein Structure", reps=[
446
+ {
447
+ "model": 0,
448
+ "style": "cartoon",
449
+ "color": "whiteCarbon",
450
+ "residue_range": "",
451
+ "around": 0,
452
+ "byres": False,
453
+ }
454
+ ])
455
+
456
+ with gr.Row():
457
+ segment_input = gr.Textbox(value="A", label="Chain ID (protein)", placeholder="Enter Chain ID here...",
458
+ info="Choose in which chain to predict binding sites.")
459
+ prediction_btn = gr.Button("Predict Binding Site", elem_id="predict-btn")
460
+
461
+ # Add score type selector
462
+ score_type = gr.Radio(
463
+ choices=["Normalized Scores", "Raw Scores"],
464
+ value="Normalized Scores",
465
+ label="Score Visualization Type",
466
+ info="Choose which score type to visualize"
467
+ )
468
+
469
+ molecule_output = gr.HTML(label="Protein Structure")
470
+ explanation_vis = gr.Markdown("""
471
+ Score dependent colorcoding:
472
+ - 0.0-0.2: white
473
+ - 0.2–0.4: light orange
474
+ - 0.4–0.6: yellow orange
475
+ - 0.6–0.8: orange
476
+ - 0.8–1.0: red
477
+ """)
478
+ predictions_output = gr.Textbox(label="Visualize Prediction with PyMol")
479
+ gr.Markdown("### Download:\n- List of predicted binding site residues\n- PDB with score in beta factor column")
480
+ download_output = gr.File(label="Download Files", file_count="multiple")
481
+
482
+ # Store these as state variables so we can switch between them
483
+ raw_scores_state = gr.State(None)
484
+ norm_scores_state = gr.State(None)
485
+ last_pdb_path = gr.State(None)
486
+ last_segment = gr.State(None)
487
+ last_pdb_id = gr.State(None)
488
+
489
+ def process_interface(mode, pdb_id, pdb_file, chain_id, score_type_val):
490
+ selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
491
+
492
+ # First get the actual PDB file path
493
+ if mode == "PDB ID":
494
+ pdb_path = fetch_pdb(pdb_id) # Get the actual file path
495
+
496
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
497
+ # Store the actual file path, not just the PDB ID
498
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
499
+ elif mode == "Upload File":
500
+ _, ext = os.path.splitext(pdb_file.name)
501
+ file_path = os.path.join('./', f"{_}{ext}")
502
+ if ext == '.cif':
503
+ pdb_path = convert_cif_to_pdb(file_path)
504
+ else:
505
+ pdb_path = file_path
506
+
507
+ pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_id_result, segment = process_pdb(pdb_path, chain_id, selected_score_type)
508
+ return pymol_cmd, mol_vis, files, raw_scores, norm_scores, pdb_path, chain_id, pdb_id_result
509
+
510
+ def update_visualization_and_files(score_type_val, raw_scores, norm_scores, pdb_path, segment, pdb_id):
511
+ if raw_scores is None or norm_scores is None or pdb_path is None or segment is None or pdb_id is None:
512
+ return None, None, None
513
+
514
+ # Choose scores based on radio button selection
515
+ selected_score_type = 'normalized' if score_type_val == "Normalized Scores" else 'raw'
516
+ selected_scores = norm_scores if selected_score_type == 'normalized' else raw_scores
517
+
518
+ # Generate visualization with selected scores
519
+ mol_vis = molecule(pdb_path, selected_scores, segment)
520
+
521
+ # Generate PyMOL commands and downloadable files
522
+ # Get structure for residue info
523
+ _, ext = os.path.splitext(pdb_path)
524
+ parser = MMCIFParser(QUIET=True) if ext == '.cif' else PDBParser(QUIET=True)
525
+ structure = parser.get_structure('protein', pdb_path)
526
+ chain = structure[0][segment]
527
+ protein_residues = [res for res in chain if is_aa(res)]
528
+ sequence = "".join(seq1(res.resname) for res in protein_residues)
529
+
530
+ # Define score brackets
531
+ score_brackets = {
532
+ "0.0-0.2": (0.0, 0.2),
533
+ "0.2-0.4": (0.2, 0.4),
534
+ "0.4-0.6": (0.4, 0.6),
535
+ "0.6-0.8": (0.6, 0.8),
536
+ "0.8-1.0": (0.8, 1.0)
537
+ }
538
+
539
+ # Initialize a dictionary to store residues by bracket
540
+ residues_by_bracket = {bracket: [] for bracket in score_brackets}
541
+
542
+ # Categorize residues into brackets
543
+ for resi, score in selected_scores:
544
+ for bracket, (lower, upper) in score_brackets.items():
545
+ if lower <= score < upper:
546
+ residues_by_bracket[bracket].append(resi)
547
+ break
548
+
549
+ # Generate timestamp
550
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
551
+
552
+ # Generate result text and PyMOL commands based on score type
553
+ display_score_type = "Normalized" if selected_score_type == 'normalized' else "Raw"
554
+ scores_array = [score for _, score in selected_scores]
555
+ result_str = generate_results_text(pdb_id, segment, residues_by_bracket, protein_residues, sequence,
556
+ scores_array, current_time, display_score_type)
557
+ pymol_commands = generate_pymol_commands(pdb_id, segment, residues_by_bracket, current_time, display_score_type)
558
+
559
+ # Create chain-specific PDB with scores in B-factor
560
+ scored_pdb = create_chain_specific_pdb(pdb_path, segment, selected_scores, protein_residues)
561
+
562
+ # Create prediction file
563
+ prediction_file = f"{pdb_id}_{display_score_type.lower()}_binding_site_residues.txt"
564
+ with open(prediction_file, "w") as f:
565
+ f.write(result_str)
566
+
567
+ scored_pdb_name = f"{pdb_id}_{segment}_{display_score_type.lower()}_predictions_scores.pdb"
568
+ os.rename(scored_pdb, scored_pdb_name)
569
+
570
+ return mol_vis, pymol_commands, [prediction_file, scored_pdb_name]
571
+
572
+ def fetch_interface(mode, pdb_id, pdb_file):
573
+ if mode == "PDB ID":
574
+ return fetch_pdb(pdb_id)
575
+ elif mode == "Upload File":
576
+ _, ext = os.path.splitext(pdb_file.name)
577
+ file_path = os.path.join('./', f"{_}{ext}")
578
+ if ext == '.cif':
579
+ pdb_path = convert_cif_to_pdb(file_path)
580
+ else:
581
+ pdb_path= file_path
582
+ return pdb_path
583
+
584
+ def toggle_mode(selected_mode):
585
+ if selected_mode == "PDB ID":
586
+ return gr.update(visible=True), gr.update(visible=False)
587
+ else:
588
+ return gr.update(visible=False), gr.update(visible=True)
589
+
590
+
591
+
592
+ mode.change(
593
+ toggle_mode,
594
+ inputs=[mode],
595
+ outputs=[pdb_input, pdb_file]
596
+ )
597
+
598
+ prediction_btn.click(
599
+ process_interface,
600
+ inputs=[mode, pdb_input, pdb_file, segment_input, score_type],
601
+ outputs=[predictions_output, molecule_output, download_output,
602
+ raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id]
603
+ )
604
+
605
+ # Update visualization, PyMOL commands, and files when score type changes
606
+ score_type.change(
607
+ update_visualization_and_files,
608
+ inputs=[score_type, raw_scores_state, norm_scores_state, last_pdb_path, last_segment, last_pdb_id],
609
+ outputs=[molecule_output, predictions_output, download_output]
610
+ )
611
+
612
+ visualize_btn.click(
613
+ fetch_interface,
614
+ inputs=[mode, pdb_input, pdb_file],
615
+ outputs=molecule_output2
616
+ )
617
+
618
+ gr.Markdown("## Examples")
619
+ gr.Examples(
620
+ examples=[
621
+ ["7RPZ", "A"],
622
+ ["2IWI", "B"],
623
+ ["7LCJ", "R"],
624
+ ["4OBE", "A"]
625
+ ],
626
+ inputs=[pdb_input, segment_input],
627
+ outputs=[predictions_output, molecule_output, download_output]
628
+ )
629
+
630
+ def predict_utils(sequence):
631
+ input_ids = tokenizer(" ".join(sequence), return_tensors="pt").input_ids.to(device)
632
+ with torch.no_grad():
633
+ outputs = model(input_ids).logits.detach().cpu().numpy().squeeze()
634
+
635
+ raw_scores = expit(outputs[:, 1] - outputs[:, 0])
636
+ normalized_scores = normalize_scores(raw_scores)
637
+
638
+ return {
639
+ "raw_scores": raw_scores.tolist(),
640
+ "normalized_scores": normalized_scores.tolist()
641
+ }
642
+
643
+ dummy_input = gr.Textbox(visible=False)
644
+ dummy_output = gr.Textbox(visible=False)
645
+
646
+ dummy_btn = gr.Button("Predict Sequence", visible=False)
647
+ dummy_btn.click(
648
+ predict_utils,
649
+ inputs=[dummy_input],
650
+ outputs=[dummy_output]
651
+ )
652
+
653
+ demo.launch(share=True)
model_loader.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
+ from torch.utils.data import DataLoader
8
+
9
+ import re
10
+ import numpy as np
11
+ import os
12
+ import pandas as pd
13
+ import copy
14
+
15
+ import transformers, datasets
16
+ from transformers.modeling_outputs import TokenClassifierOutput
17
+ from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
18
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
19
+ from transformers import T5EncoderModel, T5Tokenizer
20
+ from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel
21
+ from transformers import AutoTokenizer
22
+ from transformers import TrainingArguments, Trainer, set_seed
23
+ from transformers import DataCollatorForTokenClassification
24
+
25
+ from dataclasses import dataclass
26
+ from typing import Dict, List, Optional, Tuple, Union
27
+
28
+ # for custom DataCollator
29
+ from transformers.data.data_collator import DataCollatorMixin
30
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
31
+ from transformers.utils import PaddingStrategy
32
+
33
+ from datasets import Dataset
34
+
35
+ from scipy.special import expit
36
+
37
+ #import peft
38
+ #from peft import get_peft_config, PeftModel, PeftConfig, inject_adapter_in_model, LoraConfig
39
+
40
+ cnn_head=True #False set True for Rostlab/prot_t5_xl_half_uniref50-enc
41
+ ffn_head=False #False
42
+ transformer_head=False
43
+ custom_lora=True #False #only true for Rostlab/prot_t5_xl_half_uniref50-enc
44
+
45
+ class ClassConfig:
46
+ def __init__(self, dropout=0.2, num_labels=3):
47
+ self.dropout_rate = dropout
48
+ self.num_labels = num_labels
49
+
50
+ class T5EncoderForTokenClassification(T5PreTrainedModel):
51
+
52
+ def __init__(self, config: T5Config, class_config: ClassConfig):
53
+ super().__init__(config)
54
+ self.num_labels = class_config.num_labels
55
+ self.config = config
56
+
57
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
58
+
59
+ encoder_config = copy.deepcopy(config)
60
+ encoder_config.use_cache = False
61
+ encoder_config.is_encoder_decoder = False
62
+ self.encoder = T5Stack(encoder_config, self.shared)
63
+
64
+ self.dropout = nn.Dropout(class_config.dropout_rate)
65
+
66
+ # Initialize different heads based on class_config
67
+ if cnn_head:
68
+ self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1)
69
+ self.classifier = nn.Linear(512, class_config.num_labels)
70
+ elif ffn_head:
71
+ # Multi-layer feed-forward network (FFN) head
72
+ self.ffn = nn.Sequential(
73
+ nn.Linear(config.hidden_size, 512),
74
+ nn.ReLU(),
75
+ nn.Linear(512, 256),
76
+ nn.ReLU(),
77
+ nn.Linear(256, class_config.num_labels)
78
+ )
79
+ elif transformer_head:
80
+ # Transformer layer head
81
+ encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8)
82
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
83
+ self.classifier = nn.Linear(config.hidden_size, class_config.num_labels)
84
+ else:
85
+ # Default classification head
86
+ self.classifier = nn.Linear(config.hidden_size, class_config.num_labels)
87
+
88
+ self.post_init()
89
+
90
+ # Model parallel
91
+ self.model_parallel = False
92
+ self.device_map = None
93
+
94
+ def parallelize(self, device_map=None):
95
+ self.device_map = (
96
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
97
+ if device_map is None
98
+ else device_map
99
+ )
100
+ assert_device_map(self.device_map, len(self.encoder.block))
101
+ self.encoder.parallelize(self.device_map)
102
+ self.classifier = self.classifier.to(self.encoder.first_device)
103
+ self.model_parallel = True
104
+
105
+ def deparallelize(self):
106
+ self.encoder.deparallelize()
107
+ self.encoder = self.encoder.to("cpu")
108
+ self.model_parallel = False
109
+ self.device_map = None
110
+ torch.cuda.empty_cache()
111
+
112
+ def get_input_embeddings(self):
113
+ return self.shared
114
+
115
+ def set_input_embeddings(self, new_embeddings):
116
+ self.shared = new_embeddings
117
+ self.encoder.set_input_embeddings(new_embeddings)
118
+
119
+ def get_encoder(self):
120
+ return self.encoder
121
+
122
+ def _prune_heads(self, heads_to_prune):
123
+ for layer, heads in heads_to_prune.items():
124
+ self.encoder.layer[layer].attention.prune_heads(heads)
125
+
126
+ def forward(
127
+ self,
128
+ input_ids=None,
129
+ attention_mask=None,
130
+ head_mask=None,
131
+ inputs_embeds=None,
132
+ labels=None,
133
+ output_attentions=None,
134
+ output_hidden_states=None,
135
+ return_dict=None,
136
+ ):
137
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
138
+
139
+ outputs = self.encoder(
140
+ input_ids=input_ids,
141
+ attention_mask=attention_mask,
142
+ inputs_embeds=inputs_embeds,
143
+ head_mask=head_mask,
144
+ output_attentions=output_attentions,
145
+ output_hidden_states=output_hidden_states,
146
+ return_dict=return_dict,
147
+ )
148
+
149
+ sequence_output = outputs[0]
150
+ sequence_output = self.dropout(sequence_output)
151
+
152
+ # Forward pass through the selected head
153
+ if cnn_head:
154
+ # CNN head
155
+ sequence_output = sequence_output.permute(0, 2, 1) # Prepare shape for CNN
156
+ cnn_output = self.cnn(sequence_output)
157
+ cnn_output = F.relu(cnn_output)
158
+ cnn_output = cnn_output.permute(0, 2, 1) # Shape back for classifier
159
+ logits = self.classifier(cnn_output)
160
+ elif ffn_head:
161
+ # FFN head
162
+ logits = self.ffn(sequence_output)
163
+ elif transformer_head:
164
+ # Transformer head
165
+ transformer_output = self.transformer_encoder(sequence_output)
166
+ logits = self.classifier(transformer_output)
167
+ else:
168
+ # Default classification head
169
+ logits = self.classifier(sequence_output)
170
+
171
+ loss = None
172
+ if labels is not None:
173
+ loss_fct = CrossEntropyLoss()
174
+ active_loss = attention_mask.view(-1) == 1
175
+ active_logits = logits.view(-1, self.num_labels)
176
+ active_labels = torch.where(
177
+ active_loss, labels.view(-1), torch.tensor(-100).type_as(labels)
178
+ )
179
+ valid_logits = active_logits[active_labels != -100]
180
+ valid_labels = active_labels[active_labels != -100]
181
+ valid_labels = valid_labels.to(valid_logits.device)
182
+ valid_labels = valid_labels.long()
183
+ loss = loss_fct(valid_logits, valid_labels)
184
+
185
+ if not return_dict:
186
+ output = (logits,) + outputs[2:]
187
+ return ((loss,) + output) if loss is not None else output
188
+
189
+ return TokenClassifierOutput(
190
+ loss=loss,
191
+ logits=logits,
192
+ hidden_states=outputs.hidden_states,
193
+ attentions=outputs.attentions,
194
+ )
195
+
196
+ # Modifies an existing transformer and introduce the LoRA layers
197
+
198
+ class CustomLoRAConfig:
199
+ def __init__(self):
200
+ self.lora_rank = 4
201
+ self.lora_init_scale = 0.01
202
+ self.lora_modules = ".*SelfAttention|.*EncDecAttention"
203
+ self.lora_layers = "q|k|v|o"
204
+ self.trainable_param_names = ".*layer_norm.*|.*lora_[ab].*"
205
+ self.lora_scaling_rank = 1
206
+ # lora_modules and lora_layers are speicified with regular expressions
207
+ # see https://www.w3schools.com/python/python_regex.asp for reference
208
+
209
+ class LoRALinear(nn.Module):
210
+ def __init__(self, linear_layer, rank, scaling_rank, init_scale):
211
+ super().__init__()
212
+ self.in_features = linear_layer.in_features
213
+ self.out_features = linear_layer.out_features
214
+ self.rank = rank
215
+ self.scaling_rank = scaling_rank
216
+ self.weight = linear_layer.weight
217
+ self.bias = linear_layer.bias
218
+ if self.rank > 0:
219
+ self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale)
220
+ if init_scale < 0:
221
+ self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale)
222
+ else:
223
+ self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank))
224
+ if self.scaling_rank:
225
+ self.multi_lora_a = nn.Parameter(
226
+ torch.ones(self.scaling_rank, linear_layer.in_features)
227
+ + torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale
228
+ )
229
+ if init_scale < 0:
230
+ self.multi_lora_b = nn.Parameter(
231
+ torch.ones(linear_layer.out_features, self.scaling_rank)
232
+ + torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale
233
+ )
234
+ else:
235
+ self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank))
236
+
237
+ def forward(self, input):
238
+ if self.scaling_rank == 1 and self.rank == 0:
239
+ # parsimonious implementation for ia3 and lora scaling
240
+ if self.multi_lora_a.requires_grad:
241
+ hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias)
242
+ else:
243
+ hidden = F.linear(input, self.weight, self.bias)
244
+ if self.multi_lora_b.requires_grad:
245
+ hidden = hidden * self.multi_lora_b.flatten()
246
+ return hidden
247
+ else:
248
+ # general implementation for lora (adding and scaling)
249
+ weight = self.weight
250
+ if self.scaling_rank:
251
+ weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank
252
+ if self.rank:
253
+ weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank
254
+ return F.linear(input, weight, self.bias)
255
+
256
+ def extra_repr(self):
257
+ return "in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}".format(
258
+ self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank
259
+ )
260
+
261
+
262
+ def modify_with_lora(transformer, config):
263
+ for m_name, module in dict(transformer.named_modules()).items():
264
+ if re.fullmatch(config.lora_modules, m_name):
265
+ for c_name, layer in dict(module.named_children()).items():
266
+ if re.fullmatch(config.lora_layers, c_name):
267
+ assert isinstance(
268
+ layer, nn.Linear
269
+ ), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}."
270
+ setattr(
271
+ module,
272
+ c_name,
273
+ LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale),
274
+ )
275
+ return transformer
276
+
277
+
278
+ def load_T5_model_classification(checkpoint, num_labels, half_precision, full = False, deepspeed=True):
279
+ # Load model and tokenizer
280
+
281
+ if "ankh" in checkpoint :
282
+ model = T5EncoderModel.from_pretrained(checkpoint,resume_download=True)
283
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint,resume_download=True)
284
+
285
+ elif "prot_t5" in checkpoint:
286
+ # possible to load the half precision model (thanks to @pawel-rezo for pointing that out)
287
+ if half_precision and deepspeed:
288
+ #tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
289
+ #model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16)#.to(torch.device('cuda')
290
+ tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False,resume_download=True)
291
+ model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda'),resume_download=True)
292
+ else:
293
+ model = T5EncoderModel.from_pretrained(checkpoint)
294
+ tokenizer = T5Tokenizer.from_pretrained(checkpoint)
295
+
296
+ elif "ProstT5" in checkpoint:
297
+ if half_precision and deepspeed:
298
+ tokenizer = T5Tokenizer.from_pretrained(checkpoint, do_lower_case=False,resume_download=True)
299
+ model = T5EncoderModel.from_pretrained(checkpoint, torch_dtype=torch.float16).to(torch.device('cuda'),resume_download=True)
300
+ else:
301
+ model = T5EncoderModel.from_pretrained(checkpoint,resume_download=True)
302
+ tokenizer = T5Tokenizer.from_pretrained(checkpoint,resume_download=True)
303
+
304
+ # Create new Classifier model with PT5 dimensions
305
+ class_config=ClassConfig(num_labels=num_labels)
306
+ class_model=T5EncoderForTokenClassification(model.config,class_config)
307
+
308
+ # Set encoder and embedding weights to checkpoint weights
309
+ class_model.shared=model.shared
310
+ class_model.encoder=model.encoder
311
+
312
+ # Delete the checkpoint model
313
+ model=class_model
314
+ del class_model
315
+
316
+ if full == True:
317
+ return model, tokenizer
318
+
319
+ # Print number of trainable parameters
320
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
321
+ params = sum([np.prod(p.size()) for p in model_parameters])
322
+ print("T5_Classfier\nTrainable Parameter: "+ str(params))
323
+
324
+ if custom_lora:
325
+ #the linear CustomLoRAConfig allows better quality predictions, but more memory is needed
326
+ # Add model modification lora
327
+ config = CustomLoRAConfig()
328
+
329
+ # Add LoRA layers
330
+ model = modify_with_lora(model, config)
331
+
332
+ # Freeze Embeddings and Encoder (except LoRA)
333
+ for (param_name, param) in model.shared.named_parameters():
334
+ param.requires_grad = False
335
+ for (param_name, param) in model.encoder.named_parameters():
336
+ param.requires_grad = False
337
+
338
+ for (param_name, param) in model.named_parameters():
339
+ if re.fullmatch(config.trainable_param_names, param_name):
340
+ param.requires_grad = True
341
+
342
+ else:
343
+ # lora modification
344
+ peft_config = LoraConfig(
345
+ r=4, lora_alpha=1, bias="all", target_modules=["q","k","v","o"]
346
+ )
347
+
348
+ model = inject_adapter_in_model(peft_config, model)
349
+
350
+ # Unfreeze the prediction head
351
+ for (param_name, param) in model.classifier.named_parameters():
352
+ param.requires_grad = True
353
+
354
+ # Print trainable Parameter
355
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
356
+ params = sum([np.prod(p.size()) for p in model_parameters])
357
+ print("T5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n")
358
+
359
+ return model, tokenizer
360
+
361
+ class EsmForTokenClassificationCustom(EsmPreTrainedModel):
362
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
363
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"cnn", r"ffn", r"transformer"]
364
+
365
+ def __init__(self, config):
366
+ super().__init__(config)
367
+ self.num_labels = config.num_labels
368
+ self.esm = EsmModel(config, add_pooling_layer=False)
369
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
370
+
371
+ if cnn_head:
372
+ self.cnn = nn.Conv1d(config.hidden_size, 512, kernel_size=3, padding=1)
373
+ self.classifier = nn.Linear(512, config.num_labels)
374
+ elif ffn_head:
375
+ # Multi-layer feed-forward network (FFN) as an alternative head
376
+ self.ffn = nn.Sequential(
377
+ nn.Linear(config.hidden_size, 512),
378
+ nn.ReLU(),
379
+ nn.Linear(512, 256),
380
+ nn.ReLU(),
381
+ nn.Linear(256, config.num_labels)
382
+ )
383
+ elif transformer_head:
384
+ # Transformer layer as an alternative head
385
+ encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=8)
386
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
387
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
388
+ else:
389
+ # Default classification head
390
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
391
+
392
+ self.init_weights()
393
+
394
+ def forward(
395
+ self,
396
+ input_ids: Optional[torch.LongTensor] = None,
397
+ attention_mask: Optional[torch.Tensor] = None,
398
+ position_ids: Optional[torch.LongTensor] = None,
399
+ head_mask: Optional[torch.Tensor] = None,
400
+ inputs_embeds: Optional[torch.FloatTensor] = None,
401
+ labels: Optional[torch.LongTensor] = None,
402
+ output_attentions: Optional[bool] = None,
403
+ output_hidden_states: Optional[bool] = None,
404
+ return_dict: Optional[bool] = None,
405
+ ) -> Union[Tuple, TokenClassifierOutput]:
406
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
407
+ outputs = self.esm(
408
+ input_ids,
409
+ attention_mask=attention_mask,
410
+ position_ids=position_ids,
411
+ head_mask=head_mask,
412
+ inputs_embeds=inputs_embeds,
413
+ output_attentions=output_attentions,
414
+ output_hidden_states=output_hidden_states,
415
+ return_dict=return_dict,
416
+ )
417
+
418
+ sequence_output = outputs[0]
419
+ sequence_output = self.dropout(sequence_output)
420
+
421
+ if cnn_head:
422
+ sequence_output = sequence_output.transpose(1, 2)
423
+ sequence_output = self.cnn(sequence_output)
424
+ sequence_output = sequence_output.transpose(1, 2)
425
+ logits = self.classifier(sequence_output)
426
+ elif ffn_head:
427
+ logits = self.ffn(sequence_output)
428
+ elif transformer_head:
429
+ # Apply transformer encoder for the transformer head
430
+ sequence_output = self.transformer_encoder(sequence_output)
431
+ logits = self.classifier(sequence_output)
432
+ else:
433
+ logits = self.classifier(sequence_output)
434
+
435
+ loss = None
436
+ if labels is not None:
437
+ loss_fct = CrossEntropyLoss()
438
+ active_loss = attention_mask.view(-1) == 1
439
+ active_logits = logits.view(-1, self.num_labels)
440
+ active_labels = torch.where(
441
+ active_loss, labels.view(-1), torch.tensor(-100).type_as(labels)
442
+ )
443
+ valid_logits = active_logits[active_labels != -100]
444
+ valid_labels = active_labels[active_labels != -100]
445
+ valid_labels = valid_labels.type(torch.LongTensor).to('cuda:0')
446
+ loss = loss_fct(valid_logits, valid_labels)
447
+
448
+ if not return_dict:
449
+ output = (logits,) + outputs[2:]
450
+ return ((loss,) + output) if loss is not None else output
451
+
452
+ return TokenClassifierOutput(
453
+ loss=loss,
454
+ logits=logits,
455
+ hidden_states=outputs.hidden_states,
456
+ attentions=outputs.attentions,
457
+ )
458
+
459
+ def _init_weights(self, module):
460
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d):
461
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
462
+ if module.bias is not None:
463
+ module.bias.data.zero_()
464
+
465
+ # based on transformers DataCollatorForTokenClassification
466
+ @dataclass
467
+ class DataCollatorForTokenClassificationESM(DataCollatorMixin):
468
+ """
469
+ Data collator that will dynamically pad the inputs received, as well as the labels.
470
+ Args:
471
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
472
+ The tokenizer used for encoding the data.
473
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
474
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
475
+ among:
476
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
477
+ sequence is provided).
478
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
479
+ acceptable input length for the model if that argument is not provided.
480
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
481
+ max_length (`int`, *optional*):
482
+ Maximum length of the returned list and optionally padding length (see above).
483
+ pad_to_multiple_of (`int`, *optional*):
484
+ If set will pad the sequence to a multiple of the provided value.
485
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
486
+ 7.5 (Volta).
487
+ label_pad_token_id (`int`, *optional*, defaults to -100):
488
+ The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
489
+ return_tensors (`str`):
490
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
491
+ """
492
+
493
+ tokenizer: PreTrainedTokenizerBase
494
+ padding: Union[bool, str, PaddingStrategy] = True
495
+ max_length: Optional[int] = None
496
+ pad_to_multiple_of: Optional[int] = None
497
+ label_pad_token_id: int = -100
498
+ return_tensors: str = "pt"
499
+
500
+ def torch_call(self, features):
501
+ import torch
502
+
503
+ label_name = "label" if "label" in features[0].keys() else "labels"
504
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
505
+
506
+ no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
507
+
508
+ batch = self.tokenizer.pad(
509
+ no_labels_features,
510
+ padding=self.padding,
511
+ max_length=self.max_length,
512
+ pad_to_multiple_of=self.pad_to_multiple_of,
513
+ return_tensors="pt",
514
+ )
515
+
516
+ if labels is None:
517
+ return batch
518
+
519
+ sequence_length = batch["input_ids"].shape[1]
520
+ padding_side = self.tokenizer.padding_side
521
+
522
+ def to_list(tensor_or_iterable):
523
+ if isinstance(tensor_or_iterable, torch.Tensor):
524
+ return tensor_or_iterable.tolist()
525
+ return list(tensor_or_iterable)
526
+
527
+ if padding_side == "right":
528
+ batch[label_name] = [
529
+ # to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
530
+ # changed to pad the special tokens at the beginning and end of the sequence
531
+ [self.label_pad_token_id] + to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)-1) for label in labels
532
+ ]
533
+ else:
534
+ batch[label_name] = [
535
+ [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
536
+ ]
537
+
538
+ batch[label_name] = torch.tensor(batch[label_name], dtype=torch.float)
539
+ return batch
540
+
541
+ def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
542
+ """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
543
+ import torch
544
+
545
+ # Tensorize if necessary.
546
+ if isinstance(examples[0], (list, tuple, np.ndarray)):
547
+ examples = [torch.tensor(e, dtype=torch.long) for e in examples]
548
+
549
+ length_of_first = examples[0].size(0)
550
+
551
+ # Check if padding is necessary.
552
+
553
+ are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
554
+ if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
555
+ return torch.stack(examples, dim=0)
556
+
557
+ # If yes, check if we have a `pad_token`.
558
+ if tokenizer._pad_token is None:
559
+ raise ValueError(
560
+ "You are attempting to pad samples but the tokenizer you are using"
561
+ f" ({tokenizer.__class__.__name__}) does not have a pad token."
562
+ )
563
+
564
+ # Creating the full tensor and filling it with our data.
565
+ max_length = max(x.size(0) for x in examples)
566
+ if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
567
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
568
+ result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
569
+ for i, example in enumerate(examples):
570
+ if tokenizer.padding_side == "right":
571
+ result[i, : example.shape[0]] = example
572
+ else:
573
+ result[i, -example.shape[0] :] = example
574
+ return result
575
+
576
+ def tolist(x):
577
+ if isinstance(x, list):
578
+ return x
579
+ elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
580
+ x = x.numpy()
581
+ return x.tolist()
582
+
583
+ #load ESM2 models
584
+ def load_esm_model_classification(checkpoint, num_labels, half_precision, full=False, deepspeed=True):
585
+
586
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
587
+
588
+
589
+ if half_precision and deepspeed:
590
+ model = EsmForTokenClassificationCustom.from_pretrained(checkpoint,
591
+ num_labels = num_labels,
592
+ ignore_mismatched_sizes=True,
593
+ torch_dtype = torch.float16)
594
+ else:
595
+ model = EsmForTokenClassificationCustom.from_pretrained(checkpoint,
596
+ num_labels = num_labels,
597
+ ignore_mismatched_sizes=True)
598
+
599
+ if full == True:
600
+ return model, tokenizer
601
+
602
+ peft_config = LoraConfig(
603
+ r=4, lora_alpha=1, bias="all", target_modules=["query","key","value","dense"]
604
+ )
605
+
606
+ model = inject_adapter_in_model(peft_config, model)
607
+
608
+ #model.gradient_checkpointing_enable()
609
+
610
+ # Unfreeze the prediction head
611
+ for (param_name, param) in model.classifier.named_parameters():
612
+ param.requires_grad = True
613
+
614
+ return model, tokenizer
615
+
616
+ def load_model(checkpoint,max_length):
617
+ #checkpoint='ThorbenF/prot_t5_xl_uniref50'
618
+ #best_model_path='ThorbenF/prot_t5_xl_uniref50/cpt.pth'
619
+ full=False
620
+ deepspeed=False
621
+ mixed=False
622
+ num_labels=2
623
+
624
+ print(checkpoint, num_labels, mixed, full, deepspeed)
625
+
626
+ # Determine model type and load accordingly
627
+ if "esm" in checkpoint:
628
+ model, tokenizer = load_esm_model_classification(checkpoint, num_labels, mixed, full, deepspeed)
629
+ else:
630
+ model, tokenizer = load_T5_model_classification(checkpoint, num_labels, mixed, full, deepspeed)
631
+
632
+
633
+ # Download the file
634
+ local_file = hf_hub_download(repo_id=checkpoint, filename="cpt.pth")
635
+
636
+ # Load the best model state
637
+ state_dict = torch.load(local_file, map_location=torch.device('cpu'), weights_only=True)
638
+ model.load_state_dict(state_dict)
639
+
640
+ return model, tokenizer
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.13.0
2
+ transformers>=4.30.0
3
+ datasets>=2.9.0
4
+ peft>=0.0.7
5
+ scipy>=1.7.0
6
+ pandas>=1.1.0
7
+ numpy>=1.19.0
8
+ scikit-learn>=0.24.0
9
+ sentencepiece
10
+ huggingface_hub>=0.15.0
11
+ requests
12
+ gradio_molecule3d
13
+ biopython>=1.81
14
+ pydantic==2.1.1