42Cummer commited on
Commit
f2a576a
Β·
verified Β·
1 Parent(s): 64ac21d

Upload 4 files

Browse files
Files changed (2) hide show
  1. app.py +162 -21
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,15 +1,81 @@
1
  import gradio as gr # type: ignore
2
  import os
3
  from gradio_molecule3d import Molecule3D
 
4
 
5
  # Import your custom modules from the /scripts folder
6
  from scripts.download import download_and_clean_pdb
7
  from scripts.generator import run_broteinshake_generator
8
- from scripts.refine import polish_design
9
  from scripts.visualize import create_design_plot
10
 
11
  # --- HELPER FUNCTIONS ---
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def get_all_sequences(fasta_file: str) -> str:
14
  """Get all designed sequences from FASTA file."""
15
  sequences = []
@@ -69,15 +135,52 @@ def extract_best_sequence(fasta_file: str) -> str:
69
  else:
70
  raise ValueError(f"No valid designs found in {fasta_file}")
71
 
72
- def run_part1(pdb_id, fixed_chains, variable_chains):
73
  """Downloads the PDB and runs ProteinMPNN design."""
74
  try:
75
  # Step 1: Secure the template
76
  pdb_path = download_and_clean_pdb(pdb_id, data_dir="data")
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # Step 2: Generate Optimized Sequences
79
  # This creates the .fa files you need for the ESM Atlas
80
- run_broteinshake_generator(pdb_path, fixed_chains, variable_chains)
 
 
81
 
82
  # Get all sequences and the best one
83
  fa_file = os.path.join("generated", pdb_id.lower(), "seqs", f"{pdb_id.lower()}_clones.fa")
@@ -123,23 +226,20 @@ def run_part2(pdb_id, uploaded_esm_file):
123
 
124
  # Call the new lightweight Biopython-only script
125
  print(f"πŸ” Starting alignment for PDB ID: {pdb_id}, File: {file_path}")
126
- final_pdb_path, rmsd_val = polish_design(pdb_id, file_path)
127
 
128
  # Validate output
129
  if not final_pdb_path or not os.path.exists(final_pdb_path):
130
  return None, f"❌ Error: Alignment failed - output file not created: {final_pdb_path}"
131
 
132
- if rmsd_val is None:
133
  return None, "❌ Error: Alignment failed - RMSD calculation returned None"
134
 
135
- print(f"βœ… Alignment successful: RMSD = {rmsd_val:.3f} Γ…, Output: {final_pdb_path}")
 
 
 
136
 
137
- # Update the report to focus on Backbone Alignment
138
- report = (
139
- f"βœ… Validation Successful!\n"
140
- f"🎯 RMSD: {rmsd_val:.3f} Γ…\n"
141
- f"🧬 Status: High-Precision Backbone Match"
142
- )
143
  return final_pdb_path, report
144
  except FileNotFoundError as e:
145
  error_msg = f"❌ File Error: {str(e)}"
@@ -205,12 +305,54 @@ with gr.Blocks(theme=dark_biohub, css=biohub_css) as demo:
205
  # TAB 1: GENERATIVE DESIGN
206
  with gr.Tab("1. Sequence Generation"):
207
  gr.Markdown("Enter a PDB ID to 'repaint' its binder interface using ProteinMPNN.")
208
- with gr.Row():
209
- pdb_input = gr.Textbox(label="Target PDB ID", placeholder="e.g., 3kas", value="3kas")
210
- f_chains = gr.Textbox(label="Fixed Chains (Lock)", value="A")
211
- v_chains = gr.Textbox(label="Variable Chains (Key)", value="B")
212
 
213
- gen_btn = gr.Button("πŸš€ Generate Optimized Sequences", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  # Stack components vertically
216
  fa_output = gr.Code(
@@ -226,7 +368,7 @@ with gr.Blocks(theme=dark_biohub, css=biohub_css) as demo:
226
  gr.Markdown("### System Status")
227
  status1 = gr.Markdown()
228
 
229
- gen_btn.click(run_part1, inputs=[pdb_input, f_chains, v_chains],
230
  outputs=[fa_output, plot_output, status1])
231
 
232
  # TAB 2: STRUCTURAL VALIDATION
@@ -284,6 +426,5 @@ with gr.Blocks(theme=dark_biohub, css=biohub_css) as demo:
284
 
285
  # Launch the app
286
  if __name__ == "__main__":
287
- # FORCE DARK MODE: Add the theme toggle directly to the URL
288
- demo.launch(server_port=7860)
289
- # Visit: http://127.0.0.1:7860/?__theme=dark
 
1
  import gradio as gr # type: ignore
2
  import os
3
  from gradio_molecule3d import Molecule3D
4
+ from Bio.PDB import PDBParser
5
 
6
  # Import your custom modules from the /scripts folder
7
  from scripts.download import download_and_clean_pdb
8
  from scripts.generator import run_broteinshake_generator
9
+ from scripts.refine import polish_design, process_results
10
  from scripts.visualize import create_design_plot
11
 
12
  # --- HELPER FUNCTIONS ---
13
 
14
+ def get_pdb_chains(pdb_file):
15
+ """Extracts unique chain IDs from a PDB file."""
16
+ if not pdb_file or not os.path.exists(pdb_file):
17
+ return []
18
+ try:
19
+ parser = PDBParser(QUIET=True)
20
+ structure = parser.get_structure("temp", pdb_file)
21
+ chains = [chain.id for model in structure for chain in model]
22
+ return sorted(list(set(chains)))
23
+ except Exception as e:
24
+ print(f"Error extracting chains: {e}")
25
+ return []
26
+
27
+ def load_pdb_and_extract_chains(pdb_id):
28
+ """Download PDB and extract chains for selection."""
29
+ if not pdb_id or not pdb_id.strip():
30
+ return gr.update(choices=[], value=[]), "⚠️ Please enter a PDB ID", gr.update(interactive=False), []
31
+
32
+ try:
33
+ # Download the PDB
34
+ pdb_path = download_and_clean_pdb(pdb_id.strip(), data_dir="data")
35
+
36
+ # Extract chains
37
+ chains = get_pdb_chains(pdb_path)
38
+
39
+ if not chains:
40
+ return gr.update(choices=[], value=[]), f"⚠️ No chains found in {pdb_id.upper()}", gr.update(interactive=False), []
41
+
42
+ # Single-chain proteins are supported (will use different ProteinMPNN command)
43
+ # For single-chain, the only chain will be automatically selected for redesign
44
+ if len(chains) == 1:
45
+ status_msg = f"βœ… Loaded {pdb_id.upper()}: Single-chain protein - will redesign chain {chains[0]}"
46
+ # Auto-select the chain for single-chain proteins
47
+ return gr.update(choices=chains, value=chains), status_msg, gr.update(interactive=True), chains
48
+
49
+ status_msg = f"βœ… Loaded {pdb_id.upper()}: Found {len(chains)} chain(s) - {', '.join(chains)}"
50
+ # Initially disable button - user must select at least one chain
51
+ return gr.update(choices=chains, value=chains), status_msg, gr.update(interactive=False), chains
52
+ except Exception as e:
53
+ error_msg = f"❌ Error loading {pdb_id.upper()}: {str(e)}"
54
+ print(error_msg)
55
+ return gr.update(choices=[], value=[]), error_msg, gr.update(interactive=False), []
56
+
57
+ def validate_chain_selection(selected_chains, available_chains_state):
58
+ """Validate that at least one chain is selected and at least one remains fixed (for multi-chain)."""
59
+ if not selected_chains or len(selected_chains) == 0:
60
+ warning = "⚠️ Please select at least one chain to redesign"
61
+ return gr.update(interactive=False), warning, available_chains_state
62
+
63
+ # Get available chains from state
64
+ available_chains = available_chains_state if available_chains_state else []
65
+
66
+ # For single-chain proteins, allow selecting the only chain
67
+ if len(available_chains) == 1:
68
+ warning = f"βœ… Single-chain protein: Will redesign chain {available_chains[0]}"
69
+ return gr.update(interactive=True), warning, available_chains_state
70
+
71
+ # For multi-chain: Check if all chains are selected (would leave no fixed chains)
72
+ if available_chains and len(selected_chains) >= len(available_chains):
73
+ warning = f"⚠️ Cannot select all chains - at least one chain must remain fixed. Selected: {', '.join(selected_chains)}"
74
+ return gr.update(interactive=False), warning, available_chains_state
75
+
76
+ warning = f"βœ… {len(selected_chains)} chain(s) selected for redesign: {', '.join(selected_chains)}"
77
+ return gr.update(interactive=True), warning, available_chains_state
78
+
79
  def get_all_sequences(fasta_file: str) -> str:
80
  """Get all designed sequences from FASTA file."""
81
  sequences = []
 
135
  else:
136
  raise ValueError(f"No valid designs found in {fasta_file}")
137
 
138
+ def run_part1(pdb_id, fixed_chains, variable_chains, temperature=0.1, selected_chains=None):
139
  """Downloads the PDB and runs ProteinMPNN design."""
140
  try:
141
  # Step 1: Secure the template
142
  pdb_path = download_and_clean_pdb(pdb_id, data_dir="data")
143
 
144
+ # Handle chain selection logic
145
+ # If chains are selected via checkbox, use those as variable chains
146
+ # Otherwise, use the text input (backward compatibility)
147
+ all_chains = get_pdb_chains(pdb_path)
148
+
149
+ # Check if single-chain protein
150
+ is_single_chain = len(all_chains) == 1
151
+
152
+ if selected_chains and len(selected_chains) > 0:
153
+ # Selected chains = variable chains, rest = fixed
154
+ variable_chains = "".join(selected_chains)
155
+ fixed_chains = "".join([c for c in all_chains if c not in selected_chains])
156
+
157
+ # For single-chain: no fixed chains (will use different ProteinMPNN command)
158
+ # For multi-chain: Validate must have at least one fixed chain
159
+ if not is_single_chain and (not fixed_chains or len(fixed_chains) == 0):
160
+ raise ValueError(f"Cannot redesign all chains - at least one chain must remain fixed. Selected: {', '.join(selected_chains)}, Available: {', '.join(all_chains)}")
161
+
162
+ if is_single_chain:
163
+ print(f"πŸ“‹ Single-chain mode: Redesigning chain {variable_chains}")
164
+ else:
165
+ print(f"πŸ“‹ Using chain selector: Fixed={fixed_chains}, Variable={variable_chains}")
166
+ else:
167
+ # If no chains selected, use text inputs (default behavior)
168
+ # For single-chain, if variable_chains is empty, use the only chain
169
+ if is_single_chain and not variable_chains:
170
+ variable_chains = all_chains[0]
171
+ fixed_chains = ""
172
+ # For multi-chain: Validate text inputs don't select all chains
173
+ elif not is_single_chain and fixed_chains and variable_chains:
174
+ all_selected = set(fixed_chains + variable_chains)
175
+ if len(all_selected) >= len(all_chains):
176
+ raise ValueError(f"Cannot redesign all chains - at least one chain must remain fixed.")
177
+ print(f"πŸ“‹ Using text inputs: Fixed={fixed_chains}, Variable={variable_chains}")
178
+
179
  # Step 2: Generate Optimized Sequences
180
  # This creates the .fa files you need for the ESM Atlas
181
+ print(f"🌑️ Temperature: {temperature}")
182
+ print(f"πŸ”§ Parameters: Fixed chains={fixed_chains}, Variable chains={variable_chains}, Temp={temperature}")
183
+ run_broteinshake_generator(pdb_path, fixed_chains, variable_chains, num_seqs=20, temp=temperature)
184
 
185
  # Get all sequences and the best one
186
  fa_file = os.path.join("generated", pdb_id.lower(), "seqs", f"{pdb_id.lower()}_clones.fa")
 
226
 
227
  # Call the new lightweight Biopython-only script
228
  print(f"πŸ” Starting alignment for PDB ID: {pdb_id}, File: {file_path}")
229
+ final_pdb_path, global_rmsd, core_rmsd, high_conf_rmsd = polish_design(pdb_id, file_path)
230
 
231
  # Validate output
232
  if not final_pdb_path or not os.path.exists(final_pdb_path):
233
  return None, f"❌ Error: Alignment failed - output file not created: {final_pdb_path}"
234
 
235
+ if high_conf_rmsd is None:
236
  return None, "❌ Error: Alignment failed - RMSD calculation returned None"
237
 
238
+ print(f"βœ… Alignment successful: Global RMSD = {global_rmsd:.3f} Γ…, Core RMSD = {core_rmsd:.3f} Γ…, High-Conf RMSD = {high_conf_rmsd:.3f} Γ…")
239
+
240
+ # Generate detailed validation report
241
+ report = process_results(pdb_id, final_pdb_path, global_rmsd, high_conf_rmsd)
242
 
 
 
 
 
 
 
243
  return final_pdb_path, report
244
  except FileNotFoundError as e:
245
  error_msg = f"❌ File Error: {str(e)}"
 
305
  # TAB 1: GENERATIVE DESIGN
306
  with gr.Tab("1. Sequence Generation"):
307
  gr.Markdown("Enter a PDB ID to 'repaint' its binder interface using ProteinMPNN.")
 
 
 
 
308
 
309
+ pdb_input = gr.Textbox(label="Target PDB ID", placeholder="e.g., 3kas", value="")
310
+ load_pdb_btn = gr.Button("πŸ“₯ Load PDB", variant="secondary")
311
+ pdb_status = gr.Markdown("πŸ’‘ Enter a PDB ID and click 'Load PDB' to begin")
312
+
313
+ with gr.Column():
314
+ gr.Markdown("### βš™οΈ Design Parameters")
315
+
316
+ # Temperature (T) is the most critical knob for sequence recovery
317
+ sampling_temp = gr.Slider(
318
+ minimum=0.05, maximum=1.0, value=0.1, step=0.05,
319
+ label="Sampling Temperature (T)",
320
+ info="T=0.1 for high-fidelity; T=0.3 for natural diversity"
321
+ )
322
+
323
+ # Dynamic Chain Handling
324
+ chain_options = gr.CheckboxGroup(
325
+ choices=[],
326
+ label="Chains to Redesign",
327
+ info="Identify which chains ProteinMPNN should modify (will populate after loading PDB)"
328
+ )
329
+
330
+ chain_warning = gr.Markdown("πŸ’‘ Select at least one chain to enable generation", visible=True)
331
+
332
+ # Hidden state to track if we've successfully parsed the PDB
333
+ pdb_state = gr.State()
334
+
335
+ # Legacy text inputs (hidden but kept for backward compatibility)
336
+ with gr.Row(visible=False):
337
+ f_chains = gr.Textbox(label="Fixed Chains (Lock)", value="A")
338
+ v_chains = gr.Textbox(label="Variable Chains (Key)", value="B")
339
+
340
+ # Generate button (initially disabled)
341
+ gen_btn = gr.Button("πŸš€ Generate Optimized Sequences", variant="primary", interactive=False)
342
+
343
+ # Load PDB and extract chains when button is clicked
344
+ load_pdb_btn.click(
345
+ fn=load_pdb_and_extract_chains,
346
+ inputs=[pdb_input],
347
+ outputs=[chain_options, pdb_status, gen_btn, pdb_state]
348
+ )
349
+
350
+ # Validate chain selection and update button state
351
+ chain_options.change(
352
+ fn=validate_chain_selection,
353
+ inputs=[chain_options, pdb_state],
354
+ outputs=[gen_btn, chain_warning, pdb_state]
355
+ )
356
 
357
  # Stack components vertically
358
  fa_output = gr.Code(
 
368
  gr.Markdown("### System Status")
369
  status1 = gr.Markdown()
370
 
371
+ gen_btn.click(run_part1, inputs=[pdb_input, f_chains, v_chains, sampling_temp, chain_options],
372
  outputs=[fa_output, plot_output, status1])
373
 
374
  # TAB 2: STRUCTURAL VALIDATION
 
426
 
427
  # Launch the app
428
  if __name__ == "__main__":
429
+ # Docker deployment for HuggingFace Spaces
430
+ demo.launch(server_name="0.0.0.0", server_port=7860)
 
requirements.txt CHANGED
@@ -33,7 +33,7 @@ requests
33
  safetensors==0.7.0
34
  sympy==1.13.1
35
  tokenizers==0.22.1
36
- torch==2.5.1
37
  torchvision==0.20.1
38
  tqdm==4.67.1
39
  transformers==4.57.3
 
33
  safetensors==0.7.0
34
  sympy==1.13.1
35
  tokenizers==0.22.1
36
+ torch==2.8.0
37
  torchvision==0.20.1
38
  tqdm==4.67.1
39
  transformers==4.57.3