import gradio as gr
import os
import json
import tempfile
import traceback
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Optional, Tuple, Dict, Any
import torch
import time
import io
import base64
import zipfile
from datetime import datetime
import gradio, fastapi, pydantic, starlette
print("gradio", gradio.__version__)
print("fastapi", fastapi.__version__)
print("pydantic", pydantic.__version__)
print("starlette", starlette.__version__)
# Set up paths and imports for different deployment environments
import sys
BASE_DIR = Path(__file__).parent
# Smart import handling for different environments
def setup_imports():
"""Smart import setup for different deployment environments"""
global AntigenChain, PROJECT_BASE_DIR
# Method 1: Try importing from src directory (local development)
if (BASE_DIR / "src").exists():
sys.path.insert(0, str(BASE_DIR))
try:
from src.bce.antigen.antigen import AntigenChain
from src.bce.utils.constants import BASE_DIR as PROJECT_BASE_DIR
print("✅ Successfully imported from src/ directory")
return True
except ImportError as e:
print(f"❌ Failed to import from src/: {e}")
# Method 2: Try adding src to path and direct import (Hugging Face Spaces)
src_path = BASE_DIR / "src"
if src_path.exists():
sys.path.insert(0, str(src_path))
try:
from bce.antigen.antigen import AntigenChain
from bce.utils.constants import BASE_DIR as PROJECT_BASE_DIR
print("✅ Successfully imported from src/ added to path")
return True
except ImportError as e:
print(f"❌ Failed to import with src/ in path: {e}")
# Method 3: Try direct import (if package is installed)
try:
from bce.antigen.antigen import AntigenChain
from bce.utils.constants import BASE_DIR as PROJECT_BASE_DIR
print("✅ Successfully imported from installed package")
return True
except ImportError as e:
print(f"❌ Failed to import from installed package: {e}")
# If all methods fail, use default settings
print("⚠️ All import methods failed, using fallback settings")
PROJECT_BASE_DIR = BASE_DIR
return False
# Execute import setup
import_success = setup_imports()
if not import_success:
print("❌ Critical: Could not import BCE modules. Please check the file structure.")
print("Expected structure:")
print("- src/bce/antigen/antigen.py")
print("- src/bce/utils/constants.py")
print("- src/bce/model/ReCEP.py")
print("- src/bce/data/utils.py")
sys.exit(1)
# Configuration
DEFAULT_MODEL_PATH = os.getenv("BCE_MODEL_PATH", str(PROJECT_BASE_DIR / "models" / "ReCEP" / "20250626_110438" / "best_mcc_model.bin"))
ESM_TOKEN = os.getenv("ESM_TOKEN", "1mzAo8l1uxaU8UfVcGgV7B")
# PDB data directory
PDB_DATA_DIR = PROJECT_BASE_DIR / "data" / "pdb"
PDB_DATA_DIR.mkdir(parents=True, exist_ok=True)
def validate_pdb_id(pdb_id: str) -> bool:
"""Validate PDB ID format"""
if not pdb_id or len(pdb_id) != 4:
return False
return pdb_id.isalnum()
def validate_chain_id(chain_id: str) -> bool:
"""Validate chain ID format"""
if not chain_id or len(chain_id) != 1:
return False
return chain_id.isalnum()
def create_pdb_visualization_html(pdb_data: str, predicted_epitopes: list,
predictions: dict, protein_id: str, top_k_regions: list = None) -> str:
"""Create HTML with 3Dmol.js visualization compatible with Gradio - enhanced version with more features"""
# Prepare data for JavaScript
epitope_residues = predicted_epitopes
# Process top_k_regions for visualization
processed_regions = []
if top_k_regions:
for i, region in enumerate(top_k_regions):
if isinstance(region, dict):
processed_regions.append({
'center_idx': region.get('center_idx', 0),
'center_residue': region.get('center_residue', region.get('center_idx', 0)),
'covered_residues': region.get('covered_residues', region.get('covered_indices', [])),
'radius': 18.0, # Default radius
'predicted_value': region.get('graph_pred', 0.0)
})
# Create a unique ID for this visualization to avoid conflicts
import uuid
viewer_id = f"viewer_{uuid.uuid4().hex[:8]}"
html_content = f"""
3D Structure Visualization - {protein_id}
Loading 3Dmol.js...
"""
return html_content
def predict_epitopes(pdb_id: str, pdb_file, chain_id: str, radius: float, k: int,
encoder: str, device_config: str, use_threshold: bool, threshold: float,
auto_cleanup: bool, progress: gr.Progress = None) -> Tuple[str, str, str, str, str, str]:
"""
Main prediction function that handles the epitope prediction workflow
"""
try:
# Input validation
if not pdb_file and not pdb_id:
return "Error: Please provide either a PDB ID or upload a PDB file", "", "", "", "", ""
if pdb_id and not validate_pdb_id(pdb_id):
return "Error: PDB ID must be exactly 4 characters (letters and numbers)", "", "", "", "", ""
if not validate_chain_id(chain_id):
return "Error: Chain ID must be exactly 1 character", "", "", "", "", ""
# Update progress
if progress:
progress(0.1, desc="Initializing prediction...")
# Process device configuration
device_id = -1 if device_config == "CPU Only" else int(device_config.split(" ")[1])
use_gpu = device_id >= 0
# Load protein structure
if progress:
progress(0.2, desc="Loading protein structure...")
antigen_chain = None
temp_file_path = None
try:
if pdb_file:
# Handle uploaded file
if progress:
progress(0.25, desc="Processing uploaded PDB file...")
# Debug: print type and attributes of pdb_file
print(f"🔍 Debug: pdb_file type = {type(pdb_file)}")
print(f"🔍 Debug: pdb_file attributes = {dir(pdb_file)}")
# Extract PDB ID from filename if not provided
if not pdb_id:
if hasattr(pdb_file, 'name'):
pdb_id = Path(pdb_file.name).stem.split('_')[0][:4]
else:
pdb_id = "UNKN" # Default fallback
# Save uploaded file to data/pdb/ directory with proper naming
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{pdb_id}_{chain_id}_{timestamp}.pdb"
temp_file_path = PDB_DATA_DIR / filename
# Properly read and write the uploaded file
try:
if hasattr(pdb_file, 'name') and os.path.isfile(pdb_file.name):
# pdb_file is a file object with .name attribute
print(f"📁 Processing file object: {pdb_file.name}")
with open(pdb_file.name, "rb") as src:
with open(temp_file_path, "wb") as dst:
dst.write(src.read())
elif hasattr(pdb_file, 'read'):
# pdb_file is a file-like object
print(f"📄 Processing file-like object")
with open(temp_file_path, "wb") as f:
f.write(pdb_file.read())
else:
# pdb_file is a string (file path)
print(f"📍 Processing file path: {pdb_file}")
with open(str(pdb_file), "rb") as src:
with open(temp_file_path, "wb") as dst:
dst.write(src.read())
print(f"✅ PDB file saved to: {temp_file_path}")
except Exception as file_error:
print(f"❌ Error processing uploaded file: {file_error}")
return f"Error processing uploaded file: {str(file_error)}", "", "", "", "", ""
antigen_chain = AntigenChain.from_pdb(
path=str(temp_file_path),
chain_id=chain_id,
id=pdb_id
)
else:
# Load from PDB ID
if progress:
progress(0.25, desc=f"Downloading PDB structure {pdb_id}...")
antigen_chain = AntigenChain.from_pdb(
chain_id=chain_id,
id=pdb_id
)
except Exception as e:
return f"Error loading protein structure: {str(e)}", "", "", "", "", ""
if antigen_chain is None:
return "Error: Failed to load protein structure", "", "", "", "", ""
# Run prediction
if progress:
progress(0.4, desc="Running epitope prediction...")
try:
# Use threshold only if checkbox is checked
final_threshold = threshold if use_threshold else None
predict_results = antigen_chain.predict(
model_path=DEFAULT_MODEL_PATH,
device_id=device_id,
radius=radius,
k=k,
threshold=final_threshold,
verbose=True,
encoder=encoder,
use_gpu=use_gpu,
auto_cleanup=auto_cleanup
)
except Exception as e:
error_msg = f"Error during prediction: {str(e)}"
print(f"Prediction error: {error_msg}")
import traceback
traceback.print_exc()
return error_msg, "", "", "", "", ""
if progress:
progress(0.8, desc="Processing results...")
# Process results
if not predict_results:
return "Error: No prediction results generated", "", "", "", "", ""
# Extract prediction data
predicted_epitopes = predict_results.get("predicted_epitopes", [])
predictions = predict_results.get("predictions", {})
top_k_centers = predict_results.get("top_k_centers", [])
top_k_region_residues = predict_results.get("top_k_region_residues", [])
top_k_regions = predict_results.get("top_k_regions", [])
# Calculate summary statistics
protein_length = len(antigen_chain.sequence)
epitope_count = len(predicted_epitopes)
region_count = len(top_k_regions)
top_k_region_residues_count = len(top_k_region_residues)
coverage_rate = (len(top_k_region_residues) / protein_length) * 100 if protein_length > 0 else 0
# Create summary text
summary_text = f"""
## Prediction Results for {pdb_id}_{chain_id}
### Protein Information
- **PDB ID**: {pdb_id}
- **Chain**: {chain_id}
- **Length**: {protein_length} residues
- **Sequence**:
{antigen_chain.sequence}
### Prediction Summary
- **Number of Predicted Epitope Residues**: {epitope_count}
- **Top-k Regions**: {region_count}
- **Number of Residues in Predicted Binding Regions**: {top_k_region_residues_count}
### Top-k Region Centers
{', '.join(map(str, top_k_centers))}
### Predicted Epitope Residues
{', '.join(map(str, predicted_epitopes))}
### Binding Region Residues (Top-k Union)
{', '.join(map(str, top_k_region_residues))}
"""
# Create epitope list text with residue names
epitope_text = f"Predicted Epitope Residues ({len(predicted_epitopes)}):\n"
epitope_lines = []
for res in predicted_epitopes:
# Get residue index from residue number
if res in antigen_chain.resnum_to_index:
res_idx = antigen_chain.resnum_to_index[res]
res_name = antigen_chain.sequence[res_idx]
epitope_lines.append(f"Residue {res} ({res_name})")
else:
epitope_lines.append(f"Residue {res}")
epitope_text += "\n".join(epitope_lines)
# Create binding region text with residue names
binding_text = f"Binding Region Residues ({len(top_k_region_residues)}):\n"
binding_lines = []
for res in top_k_region_residues:
# Get residue index from residue number
if res in antigen_chain.resnum_to_index:
res_idx = antigen_chain.resnum_to_index[res]
res_name = antigen_chain.sequence[res_idx]
binding_lines.append(f"Residue {res} ({res_name})")
else:
binding_lines.append(f"Residue {res}")
binding_text += "\n".join(binding_lines)
# Create downloadable files
if progress:
progress(0.9, desc="Preparing download files...")
# JSON file
json_data = {
"protein_info": {
"id": pdb_id,
"chain_id": chain_id,
"length": protein_length,
"sequence": antigen_chain.sequence
},
"prediction": {
"predicted_epitopes": predicted_epitopes,
"predictions": predictions,
"top_k_centers": top_k_centers,
"top_k_region_residues": top_k_region_residues,
"top_k_regions": [
{
"center_idx": region.get('center_idx', 0),
"graph_pred": region.get('graph_pred', 0),
"covered_indices": region.get('covered_indices', [])
}
for region in top_k_regions
],
"coverage_rate": coverage_rate,
"mean_region_value": 0 # No longer calculated
},
"parameters": {
"radius": radius,
"k": k,
"encoder": encoder,
"device_config": device_config,
"use_threshold": use_threshold,
"threshold": final_threshold,
"auto_cleanup": auto_cleanup
}
}
# Save JSON file
json_file_path = tempfile.mktemp(suffix=".json")
with open(json_file_path, "w") as f:
json.dump(json_data, f, indent=2)
# CSV file
csv_data = []
for i, residue_num in enumerate(antigen_chain.residue_index):
residue_num = int(residue_num)
csv_data.append({
"Residue_Number": residue_num,
"Residue_Type": antigen_chain.sequence[i],
"Prediction_Probability": predictions.get(residue_num, 0.0),
"Is_Predicted_Epitope": 1 if residue_num in predicted_epitopes else 0,
"Is_In_TopK_Regions": 1 if residue_num in top_k_region_residues else 0
})
csv_df = pd.DataFrame(csv_data)
csv_file_path = tempfile.mktemp(suffix=".csv")
csv_df.to_csv(csv_file_path, index=False)
# Create 3D visualization
if progress:
progress(0.95, desc="Creating 3D visualization...")
# Generate PDB string for visualization HTML file
html_file_path = None
try:
pdb_str = generate_pdb_string(antigen_chain)
html_content = create_pdb_visualization_html(
pdb_str, predicted_epitopes, predictions, f"{pdb_id}_{chain_id}", top_k_regions
)
# Save HTML file to data directory for download
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
html_filename = f"{pdb_id}_{chain_id}_visualization_{timestamp}.html"
html_file_path = PDB_DATA_DIR / html_filename
with open(html_file_path, "w", encoding='utf-8') as f:
f.write(html_content)
print(f"✅ 3D visualization HTML saved to: {html_file_path}")
except Exception as e:
html_file_path = None
print(f"Warning: Could not create 3D visualization: {str(e)}")
# Clean up temporary files if auto_cleanup is enabled
if auto_cleanup and temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path)
print(f"🧹 Cleaned up temporary file: {temp_file_path}")
elif temp_file_path and os.path.exists(temp_file_path):
print(f"📁 PDB file retained at: {temp_file_path}")
if progress:
progress(1.0, desc="Prediction completed!")
# Return all results including HTML file path for download
return (
summary_text,
epitope_text,
binding_text,
str(html_file_path) if html_file_path else None, # HTML file moved to 4th position
json_file_path,
csv_file_path
)
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
return error_msg, "", "", "", "", ""
def generate_pdb_string(antigen_chain) -> str:
"""Generate PDB string for 3D visualization"""
from esm.utils import residue_constants as RC
pdb_str = "MODEL 1\n"
atom_num = 1
for res_idx in range(len(antigen_chain.sequence)):
one_letter = antigen_chain.sequence[res_idx]
resname = antigen_chain.convert_letter_1to3(one_letter)
resnum = antigen_chain.residue_index[res_idx]
mask = antigen_chain.atom37_mask[res_idx]
coords = antigen_chain.atom37_positions[res_idx][mask]
atoms = [name for name, exists in zip(RC.atom_types, mask) if exists]
for atom_name, coord in zip(atoms, coords):
x, y, z = coord
pdb_str += (f"ATOM {atom_num:5d} {atom_name:<3s} {resname:>3s} {antigen_chain.chain_id:1s}{resnum:4d}"
f" {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00\n")
atom_num += 1
pdb_str += "ENDMDL\n"
return pdb_str
def create_interface():
"""Create the Gradio interface"""
with gr.Blocks(css="""
.container {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}
.header {
text-align: center;
margin-bottom: 30px;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 10px;
}
.header h1 {
font-size: 2.5em;
margin-bottom: 10px;
}
.form-row {
display: flex;
gap: 20px;
align-items: end;
}
.form-row > * {
flex: 1;
}
.section {
margin: 20px 0;
padding: 15px;
background: #f8f9fa;
border-radius: 8px;
border-left: 4px solid #007bff;
}
.section h2 {
color: #333;
margin-bottom: 15px;
}
.results-section {
margin-top: 30px;
padding: 20px;
background: #f0f8ff;
border-radius: 8px;
border: 1px solid #e0e8f0;
}
.download-section {
margin-top: 20px;
padding: 15px;
background: #f9f9f9;
border-radius: 8px;
}
.download-section h3 {
color: #333;
margin-bottom: 10px;
}
""") as interface:
# Header
gr.HTML("""
🧬 B-cell Epitope Prediction Server
Predict epitopes using the RoBep model
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("
📋 Input Protein Structure
")
input_method = gr.Radio(
choices=["Upload PDB File", "PDB ID"],
value="Upload PDB File",
label="Input Method"
)
pdb_file = gr.File(
label="Upload PDB File (Recommended)",
file_types=[".pdb", ".ent"],
visible=True
)
pdb_id = gr.Textbox(
label="PDB ID",
info="You can use defalt PDB ID 5i9q for demo. However, this function is not available for other PDB IDs now, since Hugging Face is not supported to fetch PDB files from website now",
placeholder="e.g., 5I9Q",
max_lines=1,
visible=False
)
chain_id = gr.Textbox(
label="Chain ID",
value="A",
max_lines=1
)
with gr.Accordion("🔧 Advanced Parameters", open=False):
radius = gr.Slider(
label="Radius (Å)",
minimum=1.0,
maximum=50.0,
step=0.1,
value=18.0
)
k = gr.Slider(
label="Top-k Regions",
minimum=1,
maximum=20,
step=1,
value=7
)
encoder = gr.Dropdown(
label="Encoder",
choices=["esmc", "esm2"],
value="esmc"
)
device_config = gr.Dropdown(
label="Device Configuration",
choices=["CPU Only", "GPU 0", "GPU 1", "GPU 2", "GPU 3"],
value="CPU Only"
)
use_threshold = gr.Checkbox(
label="Use Custom Threshold",
value=False
)
threshold = gr.Number(
label="Threshold Value",
value=0.366,
visible=False
)
auto_cleanup = gr.Checkbox(
label="Auto-cleanup Generated Data",
value=True
)
predict_btn = gr.Button("🧮 Predict Epitopes", variant="primary", size="lg")
with gr.Column(scale=2):
gr.HTML("
📊 Prediction Results
")
# 3D Visualization download (moved to top)
gr.HTML("
🧬 3D Visualization
You can download the HTML to visualize the prediction results and the spheres used.
Features: PDB ID/File support • 3D visualization • Multiple export formats
""")
return interface
if __name__ == "__main__":
# Create and launch the interface
try:
interface = create_interface()
# Check if running on Hugging Face Spaces
is_spaces = os.getenv("SPACE_ID") is not None
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=is_spaces, # Use share=True on Spaces, False locally
show_error=True,
max_threads=4 if is_spaces else 8
)
except Exception as e:
print(f"Error launching application: {e}")
print("Please ensure all dependencies are installed correctly.")
import traceback
traceback.print_exc()