File size: 19,217 Bytes
32c275c
 
178c45c
 
 
 
32c275c
 
 
 
f2a576a
7ce6f00
178c45c
32c275c
 
 
f2a576a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178c45c
f2a576a
 
 
178c45c
f2a576a
 
 
178c45c
f2a576a
 
 
 
 
 
178c45c
f2a576a
 
 
 
 
 
 
178c45c
f2a576a
 
 
 
178c45c
f2a576a
 
178c45c
f2a576a
 
7ce6f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32c275c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2a576a
32c275c
 
 
 
 
f2a576a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178c45c
f2a576a
178c45c
f2a576a
 
 
 
 
 
 
 
 
 
 
178c45c
f2a576a
32c275c
 
178c45c
 
f2a576a
32c275c
7ce6f00
32c275c
7ce6f00
32c275c
 
7ce6f00
 
 
32c275c
 
 
7ce6f00
 
 
 
 
 
178c45c
 
7ce6f00
178c45c
7ce6f00
32c275c
7ce6f00
32c275c
178c45c
32c275c
178c45c
 
 
 
 
32c275c
178c45c
 
 
32c275c
7ce6f00
178c45c
 
 
 
 
 
 
 
7ce6f00
178c45c
 
 
 
 
 
 
 
 
 
 
 
 
7ce6f00
178c45c
 
7ce6f00
178c45c
7ce6f00
178c45c
7ce6f00
f2a576a
178c45c
7ce6f00
178c45c
f2a576a
178c45c
f2a576a
32c275c
178c45c
 
 
ef184f7
178c45c
32c275c
178c45c
7ce6f00
 
 
 
32c275c
 
 
e497b51
 
 
 
7ce6f00
 
e497b51
 
 
 
 
7ce6f00
 
e497b51
7ce6f00
e497b51
 
 
 
 
 
 
 
7ce6f00
e497b51
 
7ce6f00
e497b51
 
 
 
 
7ce6f00
 
 
e497b51
 
 
 
178c45c
e497b51
 
32c275c
 
 
 
 
 
f2a576a
178c45c
 
f2a576a
 
178c45c
f2a576a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178c45c
f2a576a
 
 
 
 
 
 
 
 
 
178c45c
f2a576a
 
 
 
 
 
 
 
 
 
 
 
 
 
7ce6f00
 
 
 
 
 
 
 
 
 
 
 
 
 
32c275c
f2a576a
7ce6f00
32c275c
ef184f7
 
178c45c
7ce6f00
 
 
 
 
 
 
 
 
 
 
178c45c
e497b51
7ce6f00
32c275c
178c45c
 
 
 
 
 
 
ef184f7
32c275c
178c45c
2a3e7bf
7ce6f00
178c45c
 
7ce6f00
178c45c
 
7ce6f00
 
178c45c
7ce6f00
178c45c
7ce6f00
178c45c
7ce6f00
178c45c
7ce6f00
178c45c
7ce6f00
178c45c
7ce6f00
 
 
 
 
178c45c
7ce6f00
 
178c45c
7ce6f00
 
32c275c
 
f2a576a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
import gradio as gr # type: ignore
import os
from gradio_molecule3d import Molecule3D #type: ignore
from Bio.PDB import PDBParser #type: ignore

import time

# Import your custom modules from the /scripts folder
from scripts.download import download_and_clean_pdb
from scripts.generator import run_broteinshake_generator
from scripts.refine import polish_design, process_results
from scripts.visualize import create_design_plot
from scripts.foldprotein import fold_protein_sequence

# --- HELPER FUNCTIONS ---

def get_pdb_chains(pdb_file):
    """Extracts unique chain IDs from a PDB file."""
    if not pdb_file or not os.path.exists(pdb_file):
        return []
    try:
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure("temp", pdb_file)
        chains = [chain.id for model in structure for chain in model]
        return sorted(list(set(chains)))
    except Exception as e:
        print(f"Error extracting chains: {e}")
        return []

def load_pdb_and_extract_chains(pdb_id):
    """Download PDB and extract chains for selection."""
    if not pdb_id or not pdb_id.strip():
        return gr.update(choices=[], value=[]), "⚠️ Please enter a PDB ID", gr.update(interactive=False), []
    
    try:
        # Download the PDB
        pdb_path = download_and_clean_pdb(pdb_id.strip(), data_dir="data")
        
        # Extract chains
        chains = get_pdb_chains(pdb_path)
        
        if not chains:
            return gr.update(choices=[], value=[]), f"⚠️ No chains found in {pdb_id.upper()}", gr.update(interactive=False), []
        
        # Single-chain proteins are supported (will use different ProteinMPNN command)
        # For single-chain, the only chain will be automatically selected for redesign
        if len(chains) == 1:
            status_msg = f"Loaded {pdb_id.upper()}: Single-chain protein - will redesign chain {chains[0]}"
            # Auto-select the chain for single-chain proteins
            return gr.update(choices=chains, value=chains), status_msg, gr.update(interactive=True), chains
        
        status_msg = f"Loaded {pdb_id.upper()}: Found {len(chains)} chain(s) - {', '.join(chains)}"
        # Initially disable button - user must select at least one chain
        return gr.update(choices=chains, value=chains), status_msg, gr.update(interactive=False), chains
    except Exception as e:
        error_msg = f"Error loading {pdb_id.upper()}: {str(e)}"
        print(error_msg)
        return gr.update(choices=[], value=[]), error_msg, gr.update(interactive=False), []

def validate_chain_selection(selected_chains, available_chains_state):
    """Validate that at least one chain is selected and at least one remains fixed (for multi-chain)."""
    if not selected_chains or len(selected_chains) == 0:
        warning = "Please select at least one chain to redesign"
        return gr.update(interactive=False), warning, available_chains_state
    
    # Get available chains from state
    available_chains = available_chains_state if available_chains_state else []
    
    # For single-chain proteins, allow selecting the only chain
    if len(available_chains) == 1:
        warning = f"Single-chain protein: Will redesign chain {available_chains[0]}"
        return gr.update(interactive=True), warning, available_chains_state
    
    # For multi-chain: Check if all chains are selected (would leave no fixed chains)
    if available_chains and len(selected_chains) >= len(available_chains):
        warning = f"Cannot select all chains - at least one chain must remain fixed. Selected: {', '.join(selected_chains)}"
        return gr.update(interactive=False), warning, available_chains_state
    
    warning = f"{len(selected_chains)} chain(s) selected for redesign: {', '.join(selected_chains)}"
    return gr.update(interactive=True), warning, available_chains_state

def get_all_sequences(fasta_file: str) -> str:
    """Get all designed sequences from FASTA file."""
    sequences = []
    with open(fasta_file, 'r') as f:
        lines = [line.strip() for line in f.readlines() if line.strip()]
    
    for i in range(0, len(lines), 2):
        if i + 1 >= len(lines):
            break
        header = lines[i]
        sequence = lines[i+1]
        
        # Skip the original native sequence (first entry)
        if "sample" not in header:
            continue
        
        sequences.append(f"{header}\n{sequence}")
    
    if sequences:
        return "\n\n".join(sequences)
    else:
        raise ValueError(f"No valid designs found in {fasta_file}")

def extract_best_sequence(fasta_file: str) -> str:
    """Extract the best sequence (lowest score) from FASTA file."""
    best_score = float('inf')
    best_header = ""
    best_seq = ""
    
    with open(fasta_file, 'r') as f:
        lines = [line.strip() for line in f.readlines() if line.strip()]
        
    for i in range(0, len(lines), 2):
        if i + 1 >= len(lines):
            break
        header = lines[i]
        sequence = lines[i+1]
        
        # Skip the original native sequence (first entry)
        if "sample" not in header:
            continue
            
        # Parse the score: "score=0.7647"
        try:
            score_part = [p for p in header.split(',') if 'score' in p][0]
            score = float(score_part.split('=')[1])
            
            if score < best_score:
                best_score = score
                best_header = header
                best_seq = sequence
        except (IndexError, ValueError):
            continue
    
    if best_seq:
        return f"{best_header}\n{best_seq}"
    else:
        raise ValueError(f"No valid designs found in {fasta_file}")

def run_part1(pdb_id, fixed_chains, variable_chains, temperature=0.1, selected_chains=None):
    """Downloads the PDB and runs ProteinMPNN design."""
    try:
        # Step 1: Secure the template
        pdb_path = download_and_clean_pdb(pdb_id, data_dir="data")
        
        # Handle chain selection logic
        # If chains are selected via checkbox, use those as variable chains
        # Otherwise, use the text input (backward compatibility)
        all_chains = get_pdb_chains(pdb_path)
        
        # Check if single-chain protein
        is_single_chain = len(all_chains) == 1
        
        if selected_chains and len(selected_chains) > 0:
            # Selected chains = variable chains, rest = fixed
            variable_chains = "".join(selected_chains)
            fixed_chains = "".join([c for c in all_chains if c not in selected_chains])
            
            # For single-chain: no fixed chains (will use different ProteinMPNN command)
            # For multi-chain: Validate must have at least one fixed chain
            if not is_single_chain and (not fixed_chains or len(fixed_chains) == 0):
                raise ValueError(f"Cannot redesign all chains - at least one chain must remain fixed. Selected: {', '.join(selected_chains)}, Available: {', '.join(all_chains)}")
            
            if is_single_chain:
                print(f"Single-chain mode: Redesigning chain {variable_chains}")
            else:
                print(f"Using chain selector: Fixed={fixed_chains}, Variable={variable_chains}")
        else:
            # If no chains selected, use text inputs (default behavior)
            # For single-chain, if variable_chains is empty, use the only chain
            if is_single_chain and not variable_chains:
                variable_chains = all_chains[0]
                fixed_chains = ""
            # For multi-chain: Validate text inputs don't select all chains
            elif not is_single_chain and fixed_chains and variable_chains:
                all_selected = set(fixed_chains + variable_chains)
                if len(all_selected) >= len(all_chains):
                    raise ValueError(f"Cannot redesign all chains - at least one chain must remain fixed.")
            print(f"Using text inputs: Fixed={fixed_chains}, Variable={variable_chains}")
        
        # Step 2: Generate Optimized Sequences
        # This creates the .fa files you need for the ESM Atlas
        print(f"Temperature: {temperature}")
        print(f"Parameters: Fixed chains={fixed_chains}, Variable chains={variable_chains}, Temp={temperature}")
        run_broteinshake_generator(pdb_path, fixed_chains, variable_chains, num_seqs=20, temp=temperature)
        
        # Get all sequences and the best one
        fa_file = os.path.join("generated", pdb_id.lower(), "seqs", f"{pdb_id.lower()}_clones.fa")
        all_sequences = get_all_sequences(fa_file)
        best_sequence = extract_best_sequence(fa_file)
        
        # Generate the dashboard plot
        evolution_plot = create_design_plot(fa_file)
        
        # Parse score from header for status message
        score_part = [p for p in best_sequence.split('\n')[0].split(',') if 'score' in p][0]
        best_score = float(score_part.split('=')[1])
        
        # Count number of designs
        num_designs = len([s for s in all_sequences.split('\n\n') if s.strip()])
        
        # Format status with best sequence
        status_message = (
            f"Design Complete! {num_designs} designs generated.\n\n"
            f"Lead Candidate (Best Score: {best_score:.4f}):\n"
            f"{best_sequence}\n\n"
            
        )
            
        return all_sequences, evolution_plot, status_message
    except Exception as e:
        return "", None, f"Error in Part 1: {str(e)}"

def run_part2(pdb_id, sequence):
    """
    1. Folds the input sequence using ESM Atlas (API).
    2. Aligns the folded structure to the target PDB (polish_design).
    """
    try:
        # --- 1. Validate Inputs ---
        if not sequence or not sequence.strip():
            return None, "Error: Please enter a protein sequence."
        
        if not pdb_id or not pdb_id.strip():
            return None, "Error: PDB ID is required."
        
        print(f"Starting Pipeline for {pdb_id}...")
        print(f"   - Sequence Length: {len(sequence)} residues")

        # --- 2. Fold Sequence (Automated) ---
        # Calls the script in scripts/foldprotein.py
        pdb_content = fold_protein_sequence(sequence)
        
        if not pdb_content:
            return None, "Error: Protein folding failed. The API might be down or the sequence is invalid."

        # Save the raw folded structure to a temp file
        raw_fold_path = f"temp_fold_{pdb_id}_{int(time.time())}.pdb"
        with open(raw_fold_path, "w") as f:
            f.write(pdb_content)
        
        if not os.path.exists(raw_fold_path):
             return None, f"Error: Could not save folded PDB to {raw_fold_path}"

        # --- 3. Align & Polish (Existing Logic) ---
        print(f"Aligning folded structure to target {pdb_id}...")
        
        # Pass the generated file path to the existing polish_design function
        final_pdb_path, global_rmsd, core_rmsd, high_conf_rmsd = polish_design(pdb_id, raw_fold_path)
        
        # --- 4. Validate Alignment Output ---
        if not final_pdb_path or not os.path.exists(final_pdb_path):
            return None, f"Error: Alignment failed - output file not created: {final_pdb_path}"
        
        if high_conf_rmsd is None:
            return None, "Error: Alignment failed - RMSD calculation returned None"
        
        print(f"Success: Global RMSD={global_rmsd:.3f}A | Core RMSD={core_rmsd:.3f}A")
        
        # --- 5. Generate Report ---
        report = process_results(pdb_id, final_pdb_path, global_rmsd, high_conf_rmsd)
        
        # Clean up the raw unaligned fold to save space
        os.remove(raw_fold_path)
        
        return final_pdb_path, report

    except Exception as e:
        error_msg = f"Unexpected Error: {str(e)}"
        print(error_msg)
        import traceback
        traceback.print_exc()
        return None, error_msg

# --- GRADIO INTERFACE ---

# 1. Simple Dark Theme with Blue and Emerald Accents
dark_biohub = gr.themes.Base(
    primary_hue="blue",
    secondary_hue="emerald",
    neutral_hue="slate",
).set(
    body_background_fill="#0f172a",
    block_background_fill="#1e293b",
    body_text_color="#f1f5f9",
    button_primary_background_fill="#10b981",
    button_primary_text_color="#ffffff",
)

# 2. Targeted CSS for the 3D Viewer & Header
biohub_css = """
/* Remove the footer for a clean portfolio look */
footer {display: none !important;}

/* Fix the 3D viewer background to match the dark theme */
#molecule-viewer {
    background-color: #111827 !important;
    border: 1px solid #374151 !important;
    border-radius: 12px;
}

/* Header Styling */
#biohub-header {
    background: linear-gradient(135deg, #064e3b 0%, #1e40af 100%);
    padding: 1.5rem;
    border-radius: 12px;
    border: 1px solid #10b981;
    margin-bottom: 1rem;
}
"""

with gr.Blocks(theme=dark_biohub, css=biohub_css) as demo:
    # Header
    gr.HTML("""
        <div id='biohub-header'>
            <h1 style='color: white; margin: 0;'>BroteinShake</h1>
        </div>
    """)
    
    with gr.Tabs():
        # TAB 1: GENERATIVE DESIGN
        with gr.Tab("1. Sequence Generation"):
            gr.Markdown("Enter a PDB ID to 'repaint' its binder interface using ProteinMPNN.")
            
            pdb_input = gr.Textbox(label="Target PDB ID", placeholder="e.g., 3kas", value="")
            load_pdb_btn = gr.Button("Load PDB", variant="secondary")
            pdb_status = gr.Markdown("Enter a PDB ID and click 'Load PDB' to begin")
            
            with gr.Column():
                gr.Markdown("### Design Parameters")
                
                # Temperature (T) is the most critical knob for sequence recovery
                sampling_temp = gr.Slider(
                    minimum=0.05, maximum=1.0, value=0.1, step=0.05,
                    label="Sampling Temperature (T)",
                    info="T=0.1 for high-fidelity; T=0.3 for natural diversity"
                )

                # Dynamic Chain Handling
                chain_options = gr.CheckboxGroup(
                    choices=[], 
                    label="Chains to Redesign",
                    info="Identify which chains ProteinMPNN should modify (will populate after loading PDB)"
                )
                
                chain_warning = gr.Markdown("Select at least one chain to enable generation", visible=True)
                
                # Hidden state to track if we've successfully parsed the PDB
                pdb_state = gr.State()
                
                # Legacy text inputs (hidden but kept for backward compatibility)
                with gr.Row(visible=False):
                    f_chains = gr.Textbox(label="Fixed Chains (Lock)", value="A")
                    v_chains = gr.Textbox(label="Variable Chains (Key)", value="B")
            
            # Generate button (initially disabled)
            gen_btn = gr.Button("Generate Optimized Sequences", variant="primary", interactive=False)
            
            # Load PDB and extract chains when button is clicked
            load_pdb_btn.click(
                fn=load_pdb_and_extract_chains,
                inputs=[pdb_input],
                outputs=[chain_options, pdb_status, gen_btn, pdb_state]
            )
            
            # Validate chain selection and update button state
            chain_options.change(
                fn=validate_chain_selection,
                inputs=[chain_options, pdb_state],
                outputs=[gen_btn, chain_warning, pdb_state]
            )
            
            # Stack components vertically
            fa_output = gr.Code(
                label="Designed Sequences", 
                language="markdown",
                lines=10,
                max_lines=20  # Prevents infinite growth when showing all 20 candidates
            )
            plot_output = gr.Plot(
                label="Design Evolution Dashboard"
            )
            
            gr.Markdown("### System Status")
            status1 = gr.Markdown()
            
            gen_btn.click(run_part1, inputs=[pdb_input, f_chains, v_chains, sampling_temp, chain_options], 
                         outputs=[fa_output, plot_output, status1])

        # TAB 2: STRUCTURAL VALIDATION
        with gr.Tab("2. Structural Validation"):
            gr.Markdown("### Final Structure Preview")
            
            # Updated REPS for local dev visibility
            REPS = [
                {
                    "model": 0,
                    "style": "cartoon",
                    "color": "spectrum",
                    "opacity": 1.0
                }
            ]
            
            # 3D Viewer component
            protein_view = Molecule3D(label="3D Structure Viewer (Refined Shuttle)", reps=REPS, elem_id="molecule-viewer")
            
            with gr.Row():
                # REPLACEMENT: Text Area instead of File Upload
                sequence_input = gr.Textbox(
                    label="Paste Protein Sequence", 
                    placeholder="Paste the sequence generated in Tab 1 here (e.g., MKTII...)",
                    lines=4,
                    max_lines=8
                )
                refined_download = gr.File(label="Download Aligned Lead (.pdb)")
            
            validate_btn = gr.Button("Run Structural Alignment", variant="primary")
            status2 = gr.Textbox(label="Validation Report", interactive=False, lines=5)

            # Wrapper function now accepts 'sequence' instead of 'file'
            def run_validation_with_view(pdb_id, sequence):
                try:
                    # Call the updated run_part2 logic (API Fold -> Align)
                    final_pdb, report = run_part2(pdb_id, sequence)
                    
                    if final_pdb is not None and os.path.exists(final_pdb):
                        # Verify we're using the refined shuttle
                        if "Refined_Shuttle.pdb" in final_pdb or os.path.basename(final_pdb) == "Refined_Shuttle.pdb":
                            print(f"Visualizing refined shuttle: {final_pdb}")
                        else:
                            print(f"Warning: Expected Refined_Shuttle.pdb but got: {final_pdb}")
                    
                    # Return path for download, report text, and path for 3D viewer
                    return final_pdb, report, final_pdb
                    
                except Exception as e:
                    error_msg = f"Error generating 3D view: {str(e)}"
                    print(error_msg)
                    import traceback
                    traceback.print_exc()
                    return None, error_msg, None

            # Updated inputs to include sequence_input
            validate_btn.click(
                run_validation_with_view, 
                inputs=[pdb_input, sequence_input], 
                outputs=[refined_download, status2, protein_view]
            )
# Launch the app
if __name__ == "__main__":
    # Docker deployment for HuggingFace Spaces
    demo.launch(server_name="0.0.0.0", server_port=7860)