| | import gradio as gr |
| | import os |
| | import json |
| | import tempfile |
| | import traceback |
| | import numpy as np |
| | import pandas as pd |
| | from pathlib import Path |
| | from typing import Optional, Tuple, Dict, Any |
| | import torch |
| | import time |
| | import io |
| | import base64 |
| | import zipfile |
| | from datetime import datetime |
| |
|
| | import gradio, fastapi, pydantic, starlette |
| | print("gradio", gradio.__version__) |
| | print("fastapi", fastapi.__version__) |
| | print("pydantic", pydantic.__version__) |
| | print("starlette", starlette.__version__) |
| |
|
| | |
| | import sys |
| | BASE_DIR = Path(__file__).parent |
| |
|
| | |
| | def setup_imports(): |
| | """Smart import setup for different deployment environments""" |
| | global AntigenChain, PROJECT_BASE_DIR |
| | |
| | |
| | if (BASE_DIR / "src").exists(): |
| | sys.path.insert(0, str(BASE_DIR)) |
| | try: |
| | from src.bce.antigen.antigen import AntigenChain |
| | from src.bce.utils.constants import BASE_DIR as PROJECT_BASE_DIR |
| | print("✅ Successfully imported from src/ directory") |
| | return True |
| | except ImportError as e: |
| | print(f"❌ Failed to import from src/: {e}") |
| | |
| | |
| | src_path = BASE_DIR / "src" |
| | if src_path.exists(): |
| | sys.path.insert(0, str(src_path)) |
| | try: |
| | from bce.antigen.antigen import AntigenChain |
| | from bce.utils.constants import BASE_DIR as PROJECT_BASE_DIR |
| | print("✅ Successfully imported from src/ added to path") |
| | return True |
| | except ImportError as e: |
| | print(f"❌ Failed to import with src/ in path: {e}") |
| | |
| | |
| | try: |
| | from bce.antigen.antigen import AntigenChain |
| | from bce.utils.constants import BASE_DIR as PROJECT_BASE_DIR |
| | print("✅ Successfully imported from installed package") |
| | return True |
| | except ImportError as e: |
| | print(f"❌ Failed to import from installed package: {e}") |
| | |
| | |
| | print("⚠️ All import methods failed, using fallback settings") |
| | PROJECT_BASE_DIR = BASE_DIR |
| | return False |
| |
|
| | |
| | import_success = setup_imports() |
| |
|
| | if not import_success: |
| | print("❌ Critical: Could not import BCE modules. Please check the file structure.") |
| | print("Expected structure:") |
| | print("- src/bce/antigen/antigen.py") |
| | print("- src/bce/utils/constants.py") |
| | print("- src/bce/model/ReCEP.py") |
| | print("- src/bce/data/utils.py") |
| | sys.exit(1) |
| |
|
| | |
| | DEFAULT_MODEL_PATH = os.getenv("BCE_MODEL_PATH", str(PROJECT_BASE_DIR / "models" / "ReCEP" / "20250626_110438" / "best_mcc_model.bin")) |
| | ESM_TOKEN = os.getenv("ESM_TOKEN", "1mzAo8l1uxaU8UfVcGgV7B") |
| |
|
| | |
| | PDB_DATA_DIR = PROJECT_BASE_DIR / "data" / "pdb" |
| | PDB_DATA_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| | def validate_pdb_id(pdb_id: str) -> bool: |
| | """Validate PDB ID format""" |
| | if not pdb_id or len(pdb_id) != 4: |
| | return False |
| | return pdb_id.isalnum() |
| |
|
| | def validate_chain_id(chain_id: str) -> bool: |
| | """Validate chain ID format""" |
| | if not chain_id or len(chain_id) != 1: |
| | return False |
| | return chain_id.isalnum() |
| |
|
| | def create_pdb_visualization_html(pdb_data: str, predicted_epitopes: list, |
| | predictions: dict, protein_id: str, top_k_regions: list = None) -> str: |
| | """Create HTML with 3Dmol.js visualization compatible with Gradio - enhanced version with more features""" |
| | |
| | |
| | epitope_residues = predicted_epitopes |
| | |
| | |
| | processed_regions = [] |
| | if top_k_regions: |
| | for i, region in enumerate(top_k_regions): |
| | if isinstance(region, dict): |
| | processed_regions.append({ |
| | 'center_idx': region.get('center_idx', 0), |
| | 'center_residue': region.get('center_residue', region.get('center_idx', 0)), |
| | 'covered_residues': region.get('covered_residues', region.get('covered_indices', [])), |
| | 'radius': 18.0, |
| | 'predicted_value': region.get('graph_pred', 0.0) |
| | }) |
| | |
| | |
| | import uuid |
| | viewer_id = f"viewer_{uuid.uuid4().hex[:8]}" |
| | |
| | html_content = f""" |
| | <div style="width: 100%; height: 600px; border: 1px solid #ddd; border-radius: 8px; overflow: hidden;"> |
| | <div style="padding: 10px; background: #f8f9fa; border-bottom: 1px solid #ddd;"> |
| | <h3 style="margin: 0 0 10px 0; color: #333;">3D Structure Visualization - {protein_id}</h3> |
| | <div style="display: flex; gap: 15px; align-items: center; flex-wrap: wrap;"> |
| | <div> |
| | <label style="font-weight: bold; margin-right: 5px;">Display Mode:</label> |
| | <select id="vizMode_{viewer_id}" onchange="updateVisualization_{viewer_id}()" style="padding: 4px;"> |
| | <option value="prediction">Predicted Epitopes</option> |
| | <option value="probability">Probability Gradient</option> |
| | <option value="regions">Top-k Regions</option> |
| | </select> |
| | </div> |
| | <div> |
| | <label style="font-weight: bold; margin-right: 5px;">Style:</label> |
| | <select id="vizStyle_{viewer_id}" onchange="updateVisualization_{viewer_id}()" style="padding: 4px;"> |
| | <option value="cartoon">Cartoon</option> |
| | <option value="surface">Surface</option> |
| | <option value="stick">Stick</option> |
| | <option value="sphere">Sphere</option> |
| | </select> |
| | </div> |
| | <div> |
| | <label style="font-weight: bold; margin-right: 5px;"> |
| | <input type="checkbox" id="showSpheres_{viewer_id}" onchange="updateVisualization_{viewer_id}()" style="margin-right: 3px;"> Show Spheres |
| | </label> |
| | </div> |
| | <div> |
| | <label style="font-weight: bold; margin-right: 5px;">Sphere Display:</label> |
| | <select id="sphereCount_{viewer_id}" onchange="handleSphereCountChange_{viewer_id}()" style="padding: 4px;"> |
| | <option value="1">Top 1</option> |
| | <option value="2">Top 2</option> |
| | <option value="3">Top 3</option> |
| | <option value="4">Top 4</option> |
| | <option value="5" selected>Top 5</option> |
| | <option value="6">Top 6</option> |
| | <option value="7">Top 7</option> |
| | <option value="all">All Spheres</option> |
| | <option value="custom">Custom Selection</option> |
| | </select> |
| | </div> |
| | <div id="customSphereSelection_{viewer_id}" style="display: none; margin-top: 10px; padding: 10px; background: #f9f9f9; border-radius: 5px; max-height: 120px; overflow-y: auto;"> |
| | <label style="font-weight: bold; margin-bottom: 5px; display: block;">Select Spheres to Display:</label> |
| | <div id="sphereCheckboxes_{viewer_id}" style="display: flex; flex-wrap: wrap; gap: 8px; max-height: 80px; overflow-y: auto;"> |
| | <!-- Checkboxes will be dynamically generated --> |
| | </div> |
| | </div> |
| | <div> |
| | <button onclick="resetView_{viewer_id}()" style="padding: 4px 8px; margin-right: 5px;">Reset View</button> |
| | <button onclick="saveImage_{viewer_id}()" style="padding: 4px 8px;">Save Image</button> |
| | </div> |
| | </div> |
| | </div> |
| | <div id="{viewer_id}" style="width: 100%; height: 520px; min-height: 400px; position: relative; background: #f0f0f0;"> |
| | <div style="position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%); text-align: center;"> |
| | <p id="status_{viewer_id}" style="color: #666;">Loading 3Dmol.js...</p> |
| | </div> |
| | </div> |
| | </div> |
| | |
| | <script src="https://unpkg.com/3dmol@2.0.4/build/3Dmol-min.js"></script> |
| | <script> |
| | // Global variables for this viewer instance |
| | window.viewer_{viewer_id} = null; |
| | window.pdbData_{viewer_id} = `{pdb_data}`; |
| | window.predictedEpitopes_{viewer_id} = {json.dumps(epitope_residues)}; |
| | window.predictions_{viewer_id} = {json.dumps(predictions)}; |
| | window.topKRegions_{viewer_id} = {json.dumps(processed_regions)}; |
| | |
| | // Wait for 3Dmol to be available with timeout |
| | function wait3Dmol_{viewer_id}(attempts = 0) {{ |
| | if (typeof $3Dmol !== 'undefined') {{ |
| | console.log('3Dmol.js loaded successfully for {viewer_id}'); |
| | document.getElementById('status_{viewer_id}').textContent = 'Initializing 3D viewer...'; |
| | setTimeout(() => initializeViewer_{viewer_id}(), 100); |
| | }} else if (attempts < 50) {{ // 5 second timeout |
| | console.log(`Waiting for 3Dmol.js... attempt ${{attempts + 1}}`); |
| | setTimeout(() => wait3Dmol_{viewer_id}(attempts + 1), 100); |
| | }} else {{ |
| | console.error('Failed to load 3Dmol.js after 5 seconds'); |
| | document.getElementById('status_{viewer_id}').textContent = 'Failed to load 3Dmol.js. Please refresh the page.'; |
| | document.getElementById('status_{viewer_id}').style.color = 'red'; |
| | }} |
| | }} |
| | |
| | function initializeViewer_{viewer_id}() {{ |
| | try {{ |
| | const element = document.getElementById('{viewer_id}'); |
| | if (!element) {{ |
| | console.error('Viewer element not found: {viewer_id}'); |
| | return; |
| | }} |
| | |
| | document.getElementById('status_{viewer_id}').textContent = 'Creating viewer...'; |
| | |
| | window.viewer_{viewer_id} = $3Dmol.createViewer(element, {{ |
| | defaultcolors: $3Dmol.rasmolElementColors |
| | }}); |
| | |
| | document.getElementById('status_{viewer_id}').textContent = 'Loading structure...'; |
| | |
| | window.viewer_{viewer_id}.addModel(window.pdbData_{viewer_id}, 'pdb'); |
| | |
| | // Hide status message |
| | const statusEl = document.getElementById('status_{viewer_id}'); |
| | if (statusEl) statusEl.style.display = 'none'; |
| | |
| | updateVisualization_{viewer_id}(); |
| | |
| | // Initialize sphere checkboxes if data is available |
| | if (window.topKRegions_{viewer_id} && window.topKRegions_{viewer_id}.length > 0) {{ |
| | generateSphereCheckboxes_{viewer_id}(); |
| | }} |
| | |
| | console.log('3D viewer initialized successfully for {viewer_id}'); |
| | }} catch (error) {{ |
| | console.error('Error initializing 3D viewer:', error); |
| | const statusEl = document.getElementById('status_{viewer_id}'); |
| | if (statusEl) {{ |
| | statusEl.textContent = 'Error loading 3D viewer: ' + error.message; |
| | statusEl.style.color = 'red'; |
| | }} |
| | }} |
| | }} |
| | |
| | function updateVisualization_{viewer_id}() {{ |
| | if (!window.viewer_{viewer_id}) return; |
| | |
| | try {{ |
| | const mode = document.getElementById('vizMode_{viewer_id}').value; |
| | const style = document.getElementById('vizStyle_{viewer_id}').value; |
| | const showSpheres = document.getElementById('showSpheres_{viewer_id}').checked; |
| | |
| | // Clear everything |
| | window.viewer_{viewer_id}.removeAllShapes(); |
| | window.viewer_{viewer_id}.removeAllSurfaces(); |
| | window.viewer_{viewer_id}.setStyle({{}}, {{}}); |
| | |
| | // Base style |
| | const baseStyle = {{}}; |
| | if (style === 'surface') {{ |
| | baseStyle['cartoon'] = {{ hidden: true }}; |
| | }} else {{ |
| | baseStyle[style] = {{ color: '#e6e6f7' }}; |
| | }} |
| | window.viewer_{viewer_id}.setStyle({{}}, baseStyle); |
| | |
| | if (mode === 'prediction') {{ |
| | // Highlight predicted epitopes |
| | if (window.predictedEpitopes_{viewer_id}.length > 0 && style !== 'surface') {{ |
| | const epitopeStyle = {{}}; |
| | epitopeStyle[style] = {{ color: '#9C6ADE' }}; |
| | window.viewer_{viewer_id}.setStyle({{ resi: window.predictedEpitopes_{viewer_id} }}, epitopeStyle); |
| | }} |
| | |
| | // Add surface for epitopes if surface mode |
| | if (style === 'surface') {{ |
| | window.viewer_{viewer_id}.addSurface($3Dmol.SurfaceType.VDW, {{ |
| | opacity: 1.0, |
| | color: '#e6e6f7' |
| | }}); |
| | |
| | if (window.predictedEpitopes_{viewer_id}.length > 0) {{ |
| | window.viewer_{viewer_id}.addSurface($3Dmol.SurfaceType.VDW, {{ |
| | opacity: 1.0, |
| | color: '#9C6ADE' |
| | }}, {{ resi: window.predictedEpitopes_{viewer_id} }}); |
| | }} |
| | }} |
| | }} else if (mode === 'probability') {{ |
| | // Color by probability scores |
| | if (window.predictions_{viewer_id} && Object.keys(window.predictions_{viewer_id}).length > 0) {{ |
| | const allProbs = Object.values(window.predictions_{viewer_id}).filter(p => p !== undefined); |
| | const minProb = Math.min(...allProbs); |
| | const maxProb = Math.max(...allProbs); |
| | |
| | Object.entries(window.predictions_{viewer_id}).forEach(([resnum, score]) => {{ |
| | const normalizedProb = maxProb > minProb ? (score - minProb) / (maxProb - minProb) : 0.5; |
| | const color = interpolateColor('#E6F3FF', '#DC143C', normalizedProb); |
| | const probStyle = {{}}; |
| | if (style !== 'surface') {{ |
| | probStyle[style] = {{ color: color }}; |
| | window.viewer_{viewer_id}.setStyle({{ resi: parseInt(resnum) }}, probStyle); |
| | }} |
| | }}); |
| | |
| | if (style === 'surface') {{ |
| | window.viewer_{viewer_id}.addSurface($3Dmol.SurfaceType.VDW, {{ |
| | opacity: 0.8, |
| | color: '#e6e6f7' |
| | }}); |
| | |
| | Object.entries(window.predictions_{viewer_id}).forEach(([resnum, score]) => {{ |
| | const normalizedProb = maxProb > minProb ? (score - minProb) / (maxProb - minProb) : 0.5; |
| | const color = interpolateColor('#E6F3FF', '#DC143C', normalizedProb); |
| | window.viewer_{viewer_id}.addSurface($3Dmol.SurfaceType.VDW, {{ |
| | opacity: 1.0, |
| | color: color |
| | }}, {{ resi: parseInt(resnum) }}); |
| | }}); |
| | }} |
| | }} |
| | }} else if (mode === 'regions') {{ |
| | // Color top-k regions |
| | const colors = ['#FF6B6B', '#96CEB4', '#4ECDC4', '#45B7D1', '#FFEAA7', '#DDA0DD', '#87CEEB']; |
| | |
| | if (window.topKRegions_{viewer_id} && window.topKRegions_{viewer_id}.length > 0) {{ |
| | window.topKRegions_{viewer_id}.forEach((region, index) => {{ |
| | const color = colors[index % colors.length]; |
| | const regionStyle = {{}}; |
| | if (style !== 'surface') {{ |
| | regionStyle[style] = {{ color: color }}; |
| | window.viewer_{viewer_id}.setStyle({{ resi: region.covered_residues }}, regionStyle); |
| | }} |
| | }}); |
| | |
| | if (style === 'surface') {{ |
| | window.viewer_{viewer_id}.addSurface($3Dmol.SurfaceType.VDW, {{ |
| | opacity: 0.8, |
| | color: '#e6e6f7' |
| | }}); |
| | |
| | window.topKRegions_{viewer_id}.forEach((region, index) => {{ |
| | const color = colors[index % colors.length]; |
| | window.viewer_{viewer_id}.addSurface($3Dmol.SurfaceType.VDW, {{ |
| | opacity: 1.0, |
| | color: color |
| | }}, {{ resi: region.covered_residues }}); |
| | }}); |
| | }} |
| | }} |
| | }} |
| | |
| | // Add spheres if requested |
| | if (showSpheres && window.topKRegions_{viewer_id} && window.topKRegions_{viewer_id}.length > 0) {{ |
| | const colors = ['#FF6B6B', '#96CEB4', '#4ECDC4', '#45B7D1', '#FFEAA7', '#DDA0DD', '#87CEEB']; |
| | const sphereCount = document.getElementById('sphereCount_{viewer_id}').value; |
| | |
| | // Determine which spheres to show |
| | let spheresToShow = []; |
| | if (sphereCount === 'custom') {{ |
| | const selectedIndices = getSelectedSphereIndices_{viewer_id}(); |
| | spheresToShow = selectedIndices.map(idx => ({{ region: window.topKRegions_{viewer_id}[idx], index: idx }})); |
| | }} else {{ |
| | let numSpheres = sphereCount === 'all' ? window.topKRegions_{viewer_id}.length : parseInt(sphereCount); |
| | numSpheres = Math.min(numSpheres, window.topKRegions_{viewer_id}.length); |
| | spheresToShow = window.topKRegions_{viewer_id}.slice(0, numSpheres).map((region, index) => ({{ region, index }})); |
| | }} |
| | |
| | spheresToShow.forEach(({{ region, index }}) => {{ |
| | const color = colors[index % colors.length]; |
| | const centerResidues = window.viewer_{viewer_id}.getModel(0).selectedAtoms({{ |
| | resi: region.center_residue, |
| | atom: 'CA' |
| | }}); |
| | |
| | if (centerResidues.length > 0) {{ |
| | const centerAtom = centerResidues[0]; |
| | const centerCoords = {{ x: centerAtom.x, y: centerAtom.y, z: centerAtom.z }}; |
| | |
| | // Add wireframe sphere |
| | window.viewer_{viewer_id}.addSphere({{ |
| | center: centerCoords, |
| | radius: region.radius, |
| | color: color, |
| | wireframe: true, |
| | linewidth: 2.0 |
| | }}); |
| | |
| | // Add center point |
| | window.viewer_{viewer_id}.addSphere({{ |
| | center: centerCoords, |
| | radius: 0.7, |
| | color: '#FFD700', |
| | wireframe: false |
| | }}); |
| | }} |
| | }}); |
| | }} |
| | |
| | window.viewer_{viewer_id}.zoomTo(); |
| | window.viewer_{viewer_id}.render(); |
| | }} catch (error) {{ |
| | console.error('Error updating visualization:', error); |
| | }} |
| | }} |
| | |
| | // Color interpolation helper functions |
| | function interpolateColor(color1, color2, factor) {{ |
| | const c1 = hexToRgb(color1); |
| | const c2 = hexToRgb(color2); |
| | |
| | const r = Math.round(c1.r + factor * (c2.r - c1.r)); |
| | const g = Math.round(c1.g + factor * (c2.g - c1.g)); |
| | const b = Math.round(c1.b + factor * (c2.b - c1.b)); |
| | |
| | return rgbToHex(r, g, b); |
| | }} |
| | |
| | function hexToRgb(hex) {{ |
| | const result = /^#?([a-f\d]{{2}})([a-f\d]{{2}})([a-f\d]{{2}})$/i.exec(hex); |
| | return result ? {{ |
| | r: parseInt(result[1], 16), |
| | g: parseInt(result[2], 16), |
| | b: parseInt(result[3], 16) |
| | }} : null; |
| | }} |
| | |
| | function rgbToHex(r, g, b) {{ |
| | return "#" + ((1 << 24) + (r << 16) + (g << 8) + b).toString(16).slice(1); |
| | }} |
| | |
| | function resetView_{viewer_id}() {{ |
| | if (window.viewer_{viewer_id}) {{ |
| | window.viewer_{viewer_id}.zoomTo(); |
| | window.viewer_{viewer_id}.render(); |
| | }} |
| | }} |
| | |
| | function saveImage_{viewer_id}() {{ |
| | if (window.viewer_{viewer_id}) {{ |
| | window.viewer_{viewer_id}.pngURI(function(uri) {{ |
| | const link = document.createElement('a'); |
| | link.href = uri; |
| | link.download = '{protein_id}_structure.png'; |
| | link.click(); |
| | }}); |
| | }} |
| | }} |
| | |
| | // Handle sphere count selection change |
| | function handleSphereCountChange_{viewer_id}() {{ |
| | const sphereCount = document.getElementById('sphereCount_{viewer_id}').value; |
| | const customSelectionDiv = document.getElementById('customSphereSelection_{viewer_id}'); |
| | |
| | if (sphereCount === 'custom') {{ |
| | customSelectionDiv.style.display = 'block'; |
| | generateSphereCheckboxes_{viewer_id}(); |
| | }} else {{ |
| | customSelectionDiv.style.display = 'none'; |
| | }} |
| | |
| | updateVisualization_{viewer_id}(); |
| | }} |
| | |
| | // Generate sphere checkboxes for custom selection |
| | function generateSphereCheckboxes_{viewer_id}() {{ |
| | if (!window.topKRegions_{viewer_id} || window.topKRegions_{viewer_id}.length === 0) {{ |
| | return; |
| | }} |
| | |
| | const regions = window.topKRegions_{viewer_id}; |
| | const container = document.getElementById('sphereCheckboxes_{viewer_id}'); |
| | container.innerHTML = ''; |
| | |
| | regions.forEach((region, index) => {{ |
| | const sphereNum = index + 1; |
| | const checkboxId = `sphere_{{sphereNum}}_{viewer_id}`; |
| | const colors = ['#FF6B6B', '#96CEB4', '#4ECDC4', '#45B7D1', '#FFEAA7', '#DDA0DD', '#87CEEB']; |
| | const sphereColor = colors[index % colors.length]; |
| | |
| | const checkboxContainer = document.createElement('div'); |
| | checkboxContainer.style.cssText = ` |
| | display: flex; |
| | align-items: center; |
| | padding: 5px 10px; |
| | border: 1px solid #ddd; |
| | border-radius: 4px; |
| | background: white; |
| | cursor: pointer; |
| | user-select: none; |
| | `; |
| | checkboxContainer.setAttribute('data-sphere', sphereNum); |
| | |
| | const checkbox = document.createElement('input'); |
| | checkbox.type = 'checkbox'; |
| | checkbox.id = checkboxId; |
| | checkbox.checked = sphereNum <= 5; // Default: show first 5 |
| | checkbox.style.marginRight = '5px'; |
| | |
| | const colorBox = document.createElement('div'); |
| | colorBox.style.cssText = ` |
| | width: 16px; |
| | height: 16px; |
| | background-color: ${{sphereColor}}; |
| | border: 1px solid #333; |
| | border-radius: 2px; |
| | margin-right: 5px; |
| | `; |
| | |
| | const label = document.createElement('label'); |
| | label.setAttribute('for', checkboxId); |
| | label.textContent = `Sphere ${{sphereNum}} (R${{region.center_residue}})`; |
| | label.style.cursor = 'pointer'; |
| | label.style.fontSize = '14px'; |
| | |
| | checkboxContainer.appendChild(checkbox); |
| | checkboxContainer.appendChild(colorBox); |
| | checkboxContainer.appendChild(label); |
| | container.appendChild(checkboxContainer); |
| | |
| | // Add click handler |
| | checkboxContainer.addEventListener('click', function(e) {{ |
| | if (e.target.type !== 'checkbox') {{ |
| | checkbox.checked = !checkbox.checked; |
| | }} |
| | |
| | if (checkbox.checked) {{ |
| | checkboxContainer.style.backgroundColor = '#f0f8ff'; |
| | checkboxContainer.style.borderColor = '#4a90e2'; |
| | }} else {{ |
| | checkboxContainer.style.backgroundColor = 'white'; |
| | checkboxContainer.style.borderColor = '#ddd'; |
| | }} |
| | |
| | updateVisualization_{viewer_id}(); |
| | }}); |
| | |
| | // Initialize visual state |
| | if (checkbox.checked) {{ |
| | checkboxContainer.style.backgroundColor = '#f0f8ff'; |
| | checkboxContainer.style.borderColor = '#4a90e2'; |
| | }} |
| | }}); |
| | }} |
| | |
| | // Get selected sphere indices for custom mode |
| | function getSelectedSphereIndices_{viewer_id}() {{ |
| | const selected = []; |
| | const checkboxes = document.querySelectorAll('#sphereCheckboxes_{viewer_id} input[type="checkbox"]:checked'); |
| | checkboxes.forEach(function(checkbox) {{ |
| | // Get sphere number from the data-sphere attribute of the container |
| | const container = checkbox.closest('[data-sphere]'); |
| | if (container) {{ |
| | const sphereNum = parseInt(container.getAttribute('data-sphere')); |
| | selected.push(sphereNum - 1); // Convert to 0-based index |
| | }} |
| | }}); |
| | return selected; |
| | }} |
| | |
| | // Start initialization |
| | wait3Dmol_{viewer_id}(); |
| | </script> |
| | """ |
| | |
| | return html_content |
| |
|
| | def predict_epitopes(pdb_id: str, pdb_file, chain_id: str, radius: float, k: int, |
| | encoder: str, device_config: str, use_threshold: bool, threshold: float, |
| | auto_cleanup: bool, progress: gr.Progress = None) -> Tuple[str, str, str, str, str, str]: |
| | """ |
| | Main prediction function that handles the epitope prediction workflow |
| | """ |
| | try: |
| | |
| | if not pdb_file and not pdb_id: |
| | return "Error: Please provide either a PDB ID or upload a PDB file", "", "", "", "", "" |
| | |
| | if pdb_id and not validate_pdb_id(pdb_id): |
| | return "Error: PDB ID must be exactly 4 characters (letters and numbers)", "", "", "", "", "" |
| | |
| | if not validate_chain_id(chain_id): |
| | return "Error: Chain ID must be exactly 1 character", "", "", "", "", "" |
| | |
| | |
| | if progress: |
| | progress(0.1, desc="Initializing prediction...") |
| | |
| | |
| | device_id = -1 if device_config == "CPU Only" else int(device_config.split(" ")[1]) |
| | use_gpu = device_id >= 0 |
| | |
| | |
| | if progress: |
| | progress(0.2, desc="Loading protein structure...") |
| | |
| | antigen_chain = None |
| | temp_file_path = None |
| | |
| | try: |
| | if pdb_file: |
| | |
| | if progress: |
| | progress(0.25, desc="Processing uploaded PDB file...") |
| | |
| | |
| | print(f"🔍 Debug: pdb_file type = {type(pdb_file)}") |
| | print(f"🔍 Debug: pdb_file attributes = {dir(pdb_file)}") |
| | |
| | |
| | if not pdb_id: |
| | if hasattr(pdb_file, 'name'): |
| | pdb_id = Path(pdb_file.name).stem.split('_')[0][:4] |
| | else: |
| | pdb_id = "UNKN" |
| | |
| | |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | filename = f"{pdb_id}_{chain_id}_{timestamp}.pdb" |
| | temp_file_path = PDB_DATA_DIR / filename |
| | |
| | |
| | try: |
| | if hasattr(pdb_file, 'name') and os.path.isfile(pdb_file.name): |
| | |
| | print(f"📁 Processing file object: {pdb_file.name}") |
| | with open(pdb_file.name, "rb") as src: |
| | with open(temp_file_path, "wb") as dst: |
| | dst.write(src.read()) |
| | elif hasattr(pdb_file, 'read'): |
| | |
| | print(f"📄 Processing file-like object") |
| | with open(temp_file_path, "wb") as f: |
| | f.write(pdb_file.read()) |
| | else: |
| | |
| | print(f"📍 Processing file path: {pdb_file}") |
| | with open(str(pdb_file), "rb") as src: |
| | with open(temp_file_path, "wb") as dst: |
| | dst.write(src.read()) |
| | |
| | print(f"✅ PDB file saved to: {temp_file_path}") |
| | |
| | except Exception as file_error: |
| | print(f"❌ Error processing uploaded file: {file_error}") |
| | return f"Error processing uploaded file: {str(file_error)}", "", "", "", "", "" |
| | |
| | antigen_chain = AntigenChain.from_pdb( |
| | path=str(temp_file_path), |
| | chain_id=chain_id, |
| | id=pdb_id |
| | ) |
| | else: |
| | |
| | if progress: |
| | progress(0.25, desc=f"Downloading PDB structure {pdb_id}...") |
| | |
| | antigen_chain = AntigenChain.from_pdb( |
| | chain_id=chain_id, |
| | id=pdb_id |
| | ) |
| | |
| | except Exception as e: |
| | return f"Error loading protein structure: {str(e)}", "", "", "", "", "" |
| | |
| | if antigen_chain is None: |
| | return "Error: Failed to load protein structure", "", "", "", "", "" |
| | |
| | |
| | if progress: |
| | progress(0.4, desc="Running epitope prediction...") |
| | |
| | try: |
| | |
| | final_threshold = threshold if use_threshold else None |
| | |
| | predict_results = antigen_chain.predict( |
| | model_path=DEFAULT_MODEL_PATH, |
| | device_id=device_id, |
| | radius=radius, |
| | k=k, |
| | threshold=final_threshold, |
| | verbose=True, |
| | encoder=encoder, |
| | use_gpu=use_gpu, |
| | auto_cleanup=auto_cleanup |
| | ) |
| | except Exception as e: |
| | error_msg = f"Error during prediction: {str(e)}" |
| | print(f"Prediction error: {error_msg}") |
| | import traceback |
| | traceback.print_exc() |
| | return error_msg, "", "", "", "", "" |
| | |
| | if progress: |
| | progress(0.8, desc="Processing results...") |
| | |
| | |
| | if not predict_results: |
| | return "Error: No prediction results generated", "", "", "", "", "" |
| | |
| | |
| | predicted_epitopes = predict_results.get("predicted_epitopes", []) |
| | predictions = predict_results.get("predictions", {}) |
| | top_k_centers = predict_results.get("top_k_centers", []) |
| | top_k_region_residues = predict_results.get("top_k_region_residues", []) |
| | top_k_regions = predict_results.get("top_k_regions", []) |
| | |
| | |
| | protein_length = len(antigen_chain.sequence) |
| | epitope_count = len(predicted_epitopes) |
| | region_count = len(top_k_regions) |
| | top_k_region_residues_count = len(top_k_region_residues) |
| | coverage_rate = (len(top_k_region_residues) / protein_length) * 100 if protein_length > 0 else 0 |
| | |
| | |
| | summary_text = f""" |
| | ## Prediction Results for {pdb_id}_{chain_id} |
| | |
| | ### Protein Information |
| | - **PDB ID**: {pdb_id} |
| | - **Chain**: {chain_id} |
| | - **Length**: {protein_length} residues |
| | - **Sequence**: <div style="word-wrap: break-word; word-break: break-all; white-space: pre-wrap; max-width: 100%; font-family: monospace; background: #f5f5f5; padding: 8px; border-radius: 4px; margin: 5px 0; display: inline-block;">{antigen_chain.sequence}</div> |
| | |
| | ### Prediction Summary |
| | - **Number of Predicted Epitope Residues**: {epitope_count} |
| | - **Top-k Regions**: {region_count} |
| | - **Number of Residues in Predicted Binding Regions**: {top_k_region_residues_count} |
| | |
| | ### Top-k Region Centers |
| | {', '.join(map(str, top_k_centers))} |
| | |
| | ### Predicted Epitope Residues |
| | {', '.join(map(str, predicted_epitopes))} |
| | |
| | ### Binding Region Residues (Top-k Union) |
| | {', '.join(map(str, top_k_region_residues))} |
| | """ |
| | |
| | |
| | epitope_text = f"Predicted Epitope Residues ({len(predicted_epitopes)}):\n" |
| | epitope_lines = [] |
| | for res in predicted_epitopes: |
| | |
| | if res in antigen_chain.resnum_to_index: |
| | res_idx = antigen_chain.resnum_to_index[res] |
| | res_name = antigen_chain.sequence[res_idx] |
| | epitope_lines.append(f"Residue {res} ({res_name})") |
| | else: |
| | epitope_lines.append(f"Residue {res}") |
| | epitope_text += "\n".join(epitope_lines) |
| | |
| | |
| | binding_text = f"Binding Region Residues ({len(top_k_region_residues)}):\n" |
| | binding_lines = [] |
| | for res in top_k_region_residues: |
| | |
| | if res in antigen_chain.resnum_to_index: |
| | res_idx = antigen_chain.resnum_to_index[res] |
| | res_name = antigen_chain.sequence[res_idx] |
| | binding_lines.append(f"Residue {res} ({res_name})") |
| | else: |
| | binding_lines.append(f"Residue {res}") |
| | binding_text += "\n".join(binding_lines) |
| | |
| | |
| | if progress: |
| | progress(0.9, desc="Preparing download files...") |
| | |
| | |
| | json_data = { |
| | "protein_info": { |
| | "id": pdb_id, |
| | "chain_id": chain_id, |
| | "length": protein_length, |
| | "sequence": antigen_chain.sequence |
| | }, |
| | "prediction": { |
| | "predicted_epitopes": predicted_epitopes, |
| | "predictions": predictions, |
| | "top_k_centers": top_k_centers, |
| | "top_k_region_residues": top_k_region_residues, |
| | "top_k_regions": [ |
| | { |
| | "center_idx": region.get('center_idx', 0), |
| | "graph_pred": region.get('graph_pred', 0), |
| | "covered_indices": region.get('covered_indices', []) |
| | } |
| | for region in top_k_regions |
| | ], |
| | "coverage_rate": coverage_rate, |
| | "mean_region_value": 0 |
| | }, |
| | "parameters": { |
| | "radius": radius, |
| | "k": k, |
| | "encoder": encoder, |
| | "device_config": device_config, |
| | "use_threshold": use_threshold, |
| | "threshold": final_threshold, |
| | "auto_cleanup": auto_cleanup |
| | } |
| | } |
| | |
| | |
| | json_file_path = tempfile.mktemp(suffix=".json") |
| | with open(json_file_path, "w") as f: |
| | json.dump(json_data, f, indent=2) |
| | |
| | |
| | csv_data = [] |
| | for i, residue_num in enumerate(antigen_chain.residue_index): |
| | residue_num = int(residue_num) |
| | csv_data.append({ |
| | "Residue_Number": residue_num, |
| | "Residue_Type": antigen_chain.sequence[i], |
| | "Prediction_Probability": predictions.get(residue_num, 0.0), |
| | "Is_Predicted_Epitope": 1 if residue_num in predicted_epitopes else 0, |
| | "Is_In_TopK_Regions": 1 if residue_num in top_k_region_residues else 0 |
| | }) |
| | |
| | csv_df = pd.DataFrame(csv_data) |
| | csv_file_path = tempfile.mktemp(suffix=".csv") |
| | csv_df.to_csv(csv_file_path, index=False) |
| | |
| | |
| | if progress: |
| | progress(0.95, desc="Creating 3D visualization...") |
| | |
| | |
| | html_file_path = None |
| | try: |
| | pdb_str = generate_pdb_string(antigen_chain) |
| | html_content = create_pdb_visualization_html( |
| | pdb_str, predicted_epitopes, predictions, f"{pdb_id}_{chain_id}", top_k_regions |
| | ) |
| | |
| | |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | html_filename = f"{pdb_id}_{chain_id}_visualization_{timestamp}.html" |
| | html_file_path = PDB_DATA_DIR / html_filename |
| | |
| | with open(html_file_path, "w", encoding='utf-8') as f: |
| | f.write(html_content) |
| | |
| | print(f"✅ 3D visualization HTML saved to: {html_file_path}") |
| | |
| | except Exception as e: |
| | html_file_path = None |
| | print(f"Warning: Could not create 3D visualization: {str(e)}") |
| | |
| | |
| | if auto_cleanup and temp_file_path and os.path.exists(temp_file_path): |
| | os.remove(temp_file_path) |
| | print(f"🧹 Cleaned up temporary file: {temp_file_path}") |
| | elif temp_file_path and os.path.exists(temp_file_path): |
| | print(f"📁 PDB file retained at: {temp_file_path}") |
| | |
| | if progress: |
| | progress(1.0, desc="Prediction completed!") |
| | |
| | |
| | return ( |
| | summary_text, |
| | epitope_text, |
| | binding_text, |
| | str(html_file_path) if html_file_path else None, |
| | json_file_path, |
| | csv_file_path |
| | ) |
| | |
| | except Exception as e: |
| | import traceback |
| | error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
| | return error_msg, "", "", "", "", "" |
| |
|
| | def generate_pdb_string(antigen_chain) -> str: |
| | """Generate PDB string for 3D visualization""" |
| | from esm.utils import residue_constants as RC |
| | |
| | pdb_str = "MODEL 1\n" |
| | atom_num = 1 |
| | |
| | for res_idx in range(len(antigen_chain.sequence)): |
| | one_letter = antigen_chain.sequence[res_idx] |
| | resname = antigen_chain.convert_letter_1to3(one_letter) |
| | resnum = antigen_chain.residue_index[res_idx] |
| | |
| | mask = antigen_chain.atom37_mask[res_idx] |
| | coords = antigen_chain.atom37_positions[res_idx][mask] |
| | atoms = [name for name, exists in zip(RC.atom_types, mask) if exists] |
| | |
| | for atom_name, coord in zip(atoms, coords): |
| | x, y, z = coord |
| | pdb_str += (f"ATOM {atom_num:5d} {atom_name:<3s} {resname:>3s} {antigen_chain.chain_id:1s}{resnum:4d}" |
| | f" {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00\n") |
| | atom_num += 1 |
| | |
| | pdb_str += "ENDMDL\n" |
| | return pdb_str |
| |
|
| | def create_interface(): |
| | """Create the Gradio interface""" |
| |
|
| | with gr.Blocks(css=""" |
| | .container { |
| | max-width: 1200px; |
| | margin: 0 auto; |
| | padding: 20px; |
| | } |
| | .header { |
| | text-align: center; |
| | margin-bottom: 30px; |
| | padding: 20px; |
| | background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
| | color: white; |
| | border-radius: 10px; |
| | } |
| | .header h1 { |
| | font-size: 2.5em; |
| | margin-bottom: 10px; |
| | } |
| | .form-row { |
| | display: flex; |
| | gap: 20px; |
| | align-items: end; |
| | } |
| | .form-row > * { |
| | flex: 1; |
| | } |
| | .section { |
| | margin: 20px 0; |
| | padding: 15px; |
| | background: #f8f9fa; |
| | border-radius: 8px; |
| | border-left: 4px solid #007bff; |
| | } |
| | .section h2 { |
| | color: #333; |
| | margin-bottom: 15px; |
| | } |
| | .results-section { |
| | margin-top: 30px; |
| | padding: 20px; |
| | background: #f0f8ff; |
| | border-radius: 8px; |
| | border: 1px solid #e0e8f0; |
| | } |
| | .download-section { |
| | margin-top: 20px; |
| | padding: 15px; |
| | background: #f9f9f9; |
| | border-radius: 8px; |
| | } |
| | .download-section h3 { |
| | color: #333; |
| | margin-bottom: 10px; |
| | } |
| | """) as interface: |
| | |
| | |
| | gr.HTML(""" |
| | <div class="header"> |
| | <h1>🧬 B-cell Epitope Prediction Server</h1> |
| | <p>Predict epitopes using the RoBep model</p> |
| | </div> |
| | """) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | gr.HTML("<div class='section'><h2>📋 Input Protein Structure</h2></div>") |
| |
|
| | input_method = gr.Radio( |
| | choices=["Upload PDB File", "PDB ID"], |
| | value="Upload PDB File", |
| | label="Input Method" |
| | ) |
| | |
| | pdb_file = gr.File( |
| | label="Upload PDB File (Recommended)", |
| | file_types=[".pdb", ".ent"], |
| | visible=True |
| | ) |
| | pdb_id = gr.Textbox( |
| | label="PDB ID", |
| | info="You can use defalt PDB ID 5i9q for demo. However, this function is not available for other PDB IDs now, since Hugging Face is not supported to fetch PDB files from website now", |
| | placeholder="e.g., 5I9Q", |
| | max_lines=1, |
| | visible=False |
| | ) |
| | chain_id = gr.Textbox( |
| | label="Chain ID", |
| | value="A", |
| | max_lines=1 |
| | ) |
| |
|
| | with gr.Accordion("🔧 Advanced Parameters", open=False): |
| | radius = gr.Slider( |
| | label="Radius (Å)", |
| | minimum=1.0, |
| | maximum=50.0, |
| | step=0.1, |
| | value=18.0 |
| | ) |
| | k = gr.Slider( |
| | label="Top-k Regions", |
| | minimum=1, |
| | maximum=20, |
| | step=1, |
| | value=7 |
| | ) |
| | encoder = gr.Dropdown( |
| | label="Encoder", |
| | choices=["esmc", "esm2"], |
| | value="esmc" |
| | ) |
| | device_config = gr.Dropdown( |
| | label="Device Configuration", |
| | choices=["CPU Only", "GPU 0", "GPU 1", "GPU 2", "GPU 3"], |
| | value="CPU Only" |
| | ) |
| | use_threshold = gr.Checkbox( |
| | label="Use Custom Threshold", |
| | value=False |
| | ) |
| | threshold = gr.Number( |
| | label="Threshold Value", |
| | value=0.366, |
| | visible=False |
| | ) |
| | auto_cleanup = gr.Checkbox( |
| | label="Auto-cleanup Generated Data", |
| | value=True |
| | ) |
| |
|
| | predict_btn = gr.Button("🧮 Predict Epitopes", variant="primary", size="lg") |
| |
|
| | with gr.Column(scale=2): |
| | gr.HTML("<div class='section'><h2>📊 Prediction Results</h2></div>") |
| | |
| | |
| | gr.HTML("<div style='margin: 15px 0; padding: 10px; background: #f0f8ff; border-left: 4px solid #4a90e2; border-radius: 5px;'><h3 style='margin: 0 0 8px 0; color: #333;'>🧬 3D Visualization</h3><p style='margin: 0; color: #666;'>You can download the HTML to visualize the prediction results and the spheres used.</p></div>") |
| | html_download = gr.File( |
| | label="Download Interactive 3D Visualization HTML", |
| | visible=True |
| | ) |
| | results_text = gr.Markdown(label="Prediction Summary", visible=True) |
| |
|
| | with gr.Row(): |
| | epitope_list = gr.Textbox( |
| | label="Predicted Epitope Residues", |
| | max_lines=10, |
| | interactive=False |
| | ) |
| | binding_regions = gr.Textbox( |
| | label="Binding Region Residues", |
| | max_lines=10, |
| | interactive=False |
| | ) |
| |
|
| | gr.HTML("<div class='download-section'><h3>📥 Download Data Results</h3></div>") |
| | with gr.Row(): |
| | json_download = gr.File( |
| | label="JSON Results", |
| | visible=True |
| | ) |
| | csv_download = gr.File( |
| | label="CSV Results", |
| | visible=True |
| | ) |
| |
|
| | def toggle_input_method(method): |
| | return (gr.update(visible=method == "PDB ID"), |
| | gr.update(visible=method == "Upload PDB File")) |
| |
|
| | def toggle_threshold(use_threshold): |
| | return gr.update(visible=use_threshold) |
| |
|
| | input_method.change(toggle_input_method, inputs=[input_method], outputs=[pdb_id, pdb_file]) |
| | use_threshold.change(toggle_threshold, inputs=[use_threshold], outputs=[threshold]) |
| |
|
| | predict_btn.click( |
| | predict_epitopes, |
| | inputs=[ |
| | pdb_id, pdb_file, chain_id, radius, k, encoder, |
| | device_config, use_threshold, threshold, auto_cleanup |
| | ], |
| | outputs=[ |
| | results_text, epitope_list, binding_regions, |
| | html_download, json_download, csv_download |
| | ], |
| | show_progress=True |
| | ) |
| |
|
| | gr.HTML(""" |
| | <div style="text-align: center; margin-top: 30px; padding: 20px; background: #f0f0f0; border-radius: 10px;"> |
| | <p>© 2024 B-cell Epitope Prediction Server | Powered by RoBep model</p> |
| | <p><strong>Features:</strong> PDB ID/File support • 3D visualization • Multiple export formats</p> |
| | </div> |
| | """) |
| |
|
| | return interface |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | try: |
| | interface = create_interface() |
| | |
| | |
| | is_spaces = os.getenv("SPACE_ID") is not None |
| | |
| | interface.launch( |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | share=is_spaces, |
| | show_error=True, |
| | max_threads=4 if is_spaces else 8 |
| | ) |
| | except Exception as e: |
| | print(f"Error launching application: {e}") |
| | print("Please ensure all dependencies are installed correctly.") |
| | import traceback |
| | traceback.print_exc() |
| |
|