Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| miRBind2 Gradio Interface | |
| Interactive web app for miRNA-mRNA binding prediction with explainability | |
| """ | |
| import sys | |
| import os | |
| from pathlib import Path | |
| # Add code directory to path | |
| CODE_DIR = Path(__file__).parent / "miRBind_2.0-main" / "code" / "pairwise_binding_site_model" | |
| sys.path.insert(0, str(CODE_DIR)) | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from matplotlib.colors import LinearSegmentedColormap | |
| from matplotlib.backends.backend_pdf import PdfPages | |
| import io | |
| from PIL import Image | |
| from datetime import datetime | |
| import tempfile | |
| from sklearn.metrics import average_precision_score | |
| from alignment_plot import plot_alignment_image | |
| from shared.models import load_model | |
| from shared.constants import get_pair_to_index, get_num_pairs, get_device, NUCLEOTIDE_COLORS | |
| from shared.encoding import pad_or_trim, encode_complementarity | |
| # Try to import Captum for SHAP, but make it optional | |
| try: | |
| from captum.attr import GradientShap | |
| CAPTUM_AVAILABLE = True | |
| except ImportError: | |
| CAPTUM_AVAILABLE = False | |
| print("Warning: Captum not installed. SHAP explainability will not be available.") | |
| print("Install with: pip install captum") | |
| # Global model variables | |
| MODEL = None | |
| MODEL_PARAMS = None | |
| DEVICE = None | |
| PAIR_TO_INDEX = None | |
| NUM_PAIRS = None | |
| # Global batch results cache | |
| BATCH_RESULTS = None | |
| BATCH_SHAP_CACHE = {} | |
| MODEL_FILENAME = "pairwise_onehot_model_20260105_200141.pt" | |
| MODEL_REPO_ID = os.getenv("MIRBIND2_MODEL_REPO", "dimostzim/mirbind2-weights") | |
| DEFAULT_MIRNA_SEQUENCE = "UAGCUUAUCAGACUGAUGUUGA" | |
| DEFAULT_TARGET_SEQUENCE = "GGGCACUUUUUCAACAUCAGUCUGAUAAGCUAAGUGUCUUCCAGGGAAUU" | |
| def resolve_model_path(): | |
| """Resolve model path locally, or download from Hugging Face Hub if missing.""" | |
| local_model_path = Path(__file__).parent / "miRBind_2.0-main" / "models" / MODEL_FILENAME | |
| if local_model_path.exists(): | |
| return local_model_path | |
| print(f"Local model not found at {local_model_path}") | |
| print(f"Downloading model from Hugging Face Hub repo: {MODEL_REPO_ID}") | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| except ImportError as exc: | |
| raise FileNotFoundError( | |
| "Model file missing locally and huggingface_hub is not installed." | |
| ) from exc | |
| downloaded_path = hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=MODEL_FILENAME, | |
| repo_type="model", | |
| ) | |
| return Path(downloaded_path) | |
| def load_pretrained_model(): | |
| """Load the pre-trained model once at startup.""" | |
| global MODEL, MODEL_PARAMS, DEVICE, PAIR_TO_INDEX, NUM_PAIRS | |
| model_path = resolve_model_path() | |
| DEVICE = get_device() | |
| PAIR_TO_INDEX = get_pair_to_index() | |
| NUM_PAIRS = get_num_pairs() | |
| print(f"Loading model from {model_path}") | |
| print(f"Using device: {DEVICE}") | |
| MODEL, checkpoint = load_model(str(model_path), "pairwise_onehot", DEVICE) | |
| MODEL_PARAMS = checkpoint['model_params'] | |
| print(f"Model loaded successfully!") | |
| print(f"Model parameters: {MODEL_PARAMS}") | |
| return True | |
| def validate_sequence(seq, seq_type="sequence"): | |
| """Validate nucleotide sequence.""" | |
| if not seq or len(seq.strip()) == 0: | |
| return False, f"{seq_type} cannot be empty" | |
| seq = seq.strip().upper() | |
| valid_nucleotides = set('ATCGUN') | |
| invalid = set(seq) - valid_nucleotides | |
| if invalid: | |
| return False, f"{seq_type} contains invalid characters: {invalid}. Only A, T, C, G, U, N are allowed." | |
| # Convert U to T for consistency | |
| seq = seq.replace('U', 'T') | |
| return True, seq | |
| def encode_sequence_pair(target_seq, mirna_seq): | |
| """Encode a target-miRNA sequence pair for the model.""" | |
| target_length = MODEL_PARAMS['target_length'] | |
| mirna_length = MODEL_PARAMS['mirna_length'] | |
| # Pad or trim sequences | |
| target_seq = pad_or_trim(target_seq, target_length) | |
| mirna_seq = pad_or_trim(mirna_seq, mirna_length) | |
| # Encode as integer indices | |
| indices = encode_complementarity( | |
| target_seq, mirna_seq, target_length, mirna_length, | |
| PAIR_TO_INDEX, NUM_PAIRS | |
| ) | |
| # Convert to one-hot encoding | |
| indices_tensor = torch.tensor(indices, dtype=torch.long) | |
| X_onehot = F.one_hot(indices_tensor, num_classes=NUM_PAIRS + 1).float() | |
| # Add batch dimension | |
| X_onehot = X_onehot.unsqueeze(0) | |
| return X_onehot, target_seq, mirna_seq | |
| def predict_binding(target_seq, mirna_seq): | |
| """Run binding prediction on a sequence pair.""" | |
| MODEL.eval() | |
| with torch.no_grad(): | |
| X, _, _ = encode_sequence_pair(target_seq, mirna_seq) | |
| X = X.to(DEVICE) | |
| output = MODEL(X) | |
| score = output.item() | |
| return score | |
| def compute_shap_values(target_seq, mirna_seq): | |
| """Compute SHAP attribution values for explainability.""" | |
| if not CAPTUM_AVAILABLE: | |
| return None | |
| MODEL.eval() | |
| X, _, _ = encode_sequence_pair(target_seq, mirna_seq) | |
| X = X.to(DEVICE) | |
| X.requires_grad = True | |
| # Create baseline (all zeros) | |
| baseline = torch.zeros_like(X) | |
| # Compute SHAP values | |
| explainer = GradientShap(MODEL) | |
| attributions = explainer.attribute(X, baselines=baseline, target=0) | |
| # Convert to numpy and reduce from 3D to 2D by summing over pair dimension | |
| shap_3d = attributions[0].cpu().detach().numpy() | |
| shap_2d = np.sum(shap_3d, axis=2) # Shape: [mirna_length, target_length] | |
| return shap_2d | |
| def plot_shap_heatmap(shap_2d, mirna_seq, target_seq, prediction_score): | |
| """Create SHAP heatmap visualization.""" | |
| fig_width = max(16, len(target_seq) * 0.32) | |
| fig_height = max(8, len(mirna_seq) * 0.3) | |
| fig, ax = plt.subplots(figsize=(fig_width, fig_height)) | |
| # Create custom colormap (red-white-blue) | |
| colors = ['#2166ac', '#4393c3', '#92c5de', '#d1e5f0', '#f7f7f7', | |
| '#fddbc7', '#f4a582', '#d6604d', '#b2182b'] | |
| n_bins = 256 | |
| cmap = LinearSegmentedColormap.from_list('shap', colors, N=n_bins) | |
| # Plot heatmap | |
| vmax = np.abs(shap_2d).max() | |
| im = ax.imshow(shap_2d, cmap=cmap, aspect='auto', vmin=-vmax, vmax=vmax) | |
| # Add colorbar | |
| cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) | |
| cbar.set_label('SHAP Attribution Value', rotation=270, labelpad=20, fontsize=11) | |
| # Set labels | |
| ax.set_xlabel('mRNA/Target Position', fontsize=12) | |
| ax.set_ylabel('miRNA Position', fontsize=12) | |
| ax.set_title(f'miRNA-mRNA Target Site Explainability (Prediction: {prediction_score:.3f})', | |
| fontsize=14, pad=20) | |
| # Add sequence labels on axes | |
| mirna_length = len(mirna_seq) | |
| target_length = len(target_seq) | |
| mirna_ticks = list(range(mirna_length)) | |
| target_ticks = list(range(target_length)) | |
| x_fontsize = max(5, min(8, 300 / max(target_length, 1))) | |
| y_fontsize = max(6, min(9, 240 / max(mirna_length, 1))) | |
| ax.set_yticks(mirna_ticks) | |
| ax.set_yticklabels([f"{i:>2} {mirna_seq[i]}" for i in mirna_ticks], | |
| fontsize=y_fontsize, fontfamily='monospace') | |
| ax.set_xticks(target_ticks) | |
| ax.set_xticklabels([f"{target_seq[i]}\n{i}" for i in target_ticks], | |
| fontsize=x_fontsize, rotation=0) | |
| ax.tick_params(axis='x', pad=6) | |
| ax.tick_params(axis='y', pad=6) | |
| # Add grid | |
| ax.grid(True, alpha=0.2, linewidth=0.5) | |
| plt.tight_layout() | |
| # Convert to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=120, bbox_inches='tight') | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| return img | |
| def plot_sequence_explainability(mirna_seq, target_seq, shap_2d=None): | |
| """Create sequence explainability visualization with SHAP importance bars.""" | |
| if shap_2d is None: | |
| # If no SHAP values, just show sequences | |
| fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 4)) | |
| # Plot miRNA sequence | |
| for i, nuc in enumerate(mirna_seq): | |
| color = NUCLEOTIDE_COLORS.get(nuc, 'gray') | |
| ax1.text(i, 0, nuc, ha='center', va='center', fontsize=9, | |
| bbox=dict(boxstyle='round,pad=0.5', facecolor=color, alpha=0.3)) | |
| ax1.text(-2, 0, 'miRNA:', ha='right', va='center', fontsize=11, fontweight='bold') | |
| ax1.set_xlim(-3, len(mirna_seq)) | |
| ax1.set_ylim(-0.5, 0.5) | |
| ax1.axis('off') | |
| ax1.set_title('miRNA Sequence', fontsize=11, pad=10) | |
| # Plot target sequence | |
| for i, nuc in enumerate(target_seq): | |
| color = NUCLEOTIDE_COLORS.get(nuc, 'gray') | |
| ax2.text(i, 0, nuc, ha='center', va='center', fontsize=9, | |
| bbox=dict(boxstyle='round,pad=0.5', facecolor=color, alpha=0.3)) | |
| ax2.text(-2, 0, 'mRNA:', ha='right', va='center', fontsize=11, fontweight='bold') | |
| ax2.set_xlim(-3, len(target_seq)) | |
| ax2.set_ylim(-0.5, 0.5) | |
| ax2.axis('off') | |
| ax2.set_title('mRNA/Target Sequence', fontsize=11, pad=10) | |
| else: | |
| # With SHAP values - show importance bars | |
| # Max SHAP per miRNA position (max across target positions - axis 1) | |
| mirna_importance = np.max(np.abs(shap_2d), axis=1) | |
| # Max SHAP per target position (max across miRNA positions - axis 0) | |
| target_importance = np.max(np.abs(shap_2d), axis=0) | |
| fig_width = max(16, len(target_seq) * 0.28) | |
| fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(fig_width, 6.5)) | |
| # --- miRNA subplot --- | |
| # Bar plot for importance | |
| _draw_sequence_importance_axis( | |
| ax1, | |
| mirna_seq, | |
| mirna_importance, | |
| 'steelblue', | |
| 'miRNA Position', | |
| 'miRNA Sequence Importance (Max SHAP per position)', | |
| label='Max SHAP Score' | |
| ) | |
| # --- Target subplot --- | |
| _draw_sequence_importance_axis( | |
| ax2, | |
| target_seq, | |
| target_importance, | |
| 'coral', | |
| 'mRNA/Target Position', | |
| 'mRNA/Target Sequence Importance (Max SHAP per position)', | |
| label='Max SHAP Score' | |
| ) | |
| plt.tight_layout() | |
| # Convert to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=120, bbox_inches='tight') | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| return img | |
| def _draw_sequence_importance_axis(ax, sequence, importance, bar_color, xlabel, title, | |
| label=None, title_fontsize=11, label_fontsize=9, | |
| nucleotide_fontsize=8, y_fontsize=10): | |
| """Render per-position importance with nucleotide boxes and separate position numbers.""" | |
| x_positions = np.arange(len(sequence)) | |
| if label is None: | |
| ax.bar(x_positions, importance, alpha=0.6, color=bar_color, width=0.8) | |
| else: | |
| ax.bar(x_positions, importance, alpha=0.6, color=bar_color, width=0.8, label=label) | |
| max_importance = float(np.max(importance)) if len(importance) else 0.0 | |
| y_top = max(max_importance * 1.15, 0.1) | |
| nucleotide_y = -0.08 * y_top | |
| ax.set_ylim(-0.22 * y_top, y_top) | |
| for i, nuc in enumerate(sequence): | |
| color = NUCLEOTIDE_COLORS.get(nuc, 'gray') | |
| ax.text(i, nucleotide_y, nuc, ha='center', va='top', | |
| fontsize=nucleotide_fontsize, fontweight='bold', | |
| bbox=dict(boxstyle='round,pad=0.3', facecolor=color, alpha=0.4)) | |
| ax.set_xlim(-0.5, len(sequence) - 0.5) | |
| ax.set_xticks(x_positions) | |
| ax.set_xticklabels([str(i) for i in x_positions], | |
| fontsize=max(5, min(8, 300 / max(len(sequence), 1)))) | |
| ax.tick_params(axis='x', length=0, pad=10) | |
| ax.set_xlabel(xlabel, fontsize=10, labelpad=12) | |
| ax.set_ylabel('Max SHAP Score', fontsize=y_fontsize) | |
| ax.set_title(title, fontsize=title_fontsize, pad=10) | |
| ax.grid(True, alpha=0.3, axis='y') | |
| if label is not None: | |
| ax.legend(loc='upper right', fontsize=label_fontsize) | |
| def create_shap_visualizations(mirna_seq, target_seq, score, shap_2d): | |
| """Build the explainability visualizations used across the app.""" | |
| explainability_img = plot_sequence_explainability(mirna_seq, target_seq, shap_2d) | |
| heatmap_img = None | |
| alignment_img = None | |
| if shap_2d is not None: | |
| heatmap_img = plot_shap_heatmap(shap_2d, mirna_seq, target_seq, score) | |
| try: | |
| alignment_img = plot_alignment_image(mirna_seq, target_seq, shap_2d) | |
| except Exception as alignment_error: | |
| print(f"Warning: alignment plot generation failed: {alignment_error}") | |
| return explainability_img, heatmap_img, alignment_img | |
| def generate_pdf_report(mirna_seq, target_seq, score, binding_prediction, confidence, shap_2d=None): | |
| """Generate a PDF report with all visualizations and results.""" | |
| # Create temporary file for PDF | |
| pdf_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') | |
| pdf_path = pdf_file.name | |
| pdf_file.close() | |
| # Pad sequences to model length | |
| mirna_seq_padded = pad_or_trim(mirna_seq, MODEL_PARAMS['mirna_length']) | |
| target_seq_padded = pad_or_trim(target_seq, MODEL_PARAMS['target_length']) | |
| with PdfPages(pdf_path) as pdf: | |
| # Page 1: Title and Summary | |
| fig = plt.figure(figsize=(8.5, 11)) | |
| ax = fig.add_subplot(111) | |
| ax.axis('off') | |
| # Title | |
| title_text = "miRBind2 Target Site Prediction Report" | |
| ax.text(0.5, 0.95, title_text, ha='center', va='top', | |
| fontsize=20, fontweight='bold', transform=ax.transAxes) | |
| # Date | |
| date_text = f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" | |
| ax.text(0.5, 0.91, date_text, ha='center', va='top', | |
| fontsize=10, color='gray', transform=ax.transAxes) | |
| # Divider | |
| ax.plot([0.1, 0.9], [0.89, 0.89], 'k-', linewidth=1, transform=ax.transAxes) | |
| # Results section | |
| y_pos = 0.85 | |
| ax.text(0.5, y_pos, "Prediction Results", ha='center', va='top', | |
| fontsize=16, fontweight='bold', transform=ax.transAxes) | |
| y_pos -= 0.06 | |
| ax.text(0.1, y_pos, f"Binding Score:", ha='left', va='top', | |
| fontsize=12, fontweight='bold', transform=ax.transAxes) | |
| ax.text(0.5, y_pos, f"{score:.4f}", ha='left', va='top', | |
| fontsize=12, transform=ax.transAxes) | |
| y_pos -= 0.04 | |
| ax.text(0.1, y_pos, f"Prediction:", ha='left', va='top', | |
| fontsize=12, fontweight='bold', transform=ax.transAxes) | |
| # Color code the prediction | |
| pred_color = 'green' if binding_prediction == "BINDING" else 'red' | |
| ax.text(0.5, y_pos, binding_prediction, ha='left', va='top', | |
| fontsize=12, color=pred_color, fontweight='bold', transform=ax.transAxes) | |
| y_pos -= 0.04 | |
| ax.text(0.1, y_pos, f"Confidence:", ha='left', va='top', | |
| fontsize=12, fontweight='bold', transform=ax.transAxes) | |
| ax.text(0.5, y_pos, f"{confidence:.1%}", ha='left', va='top', | |
| fontsize=12, transform=ax.transAxes) | |
| # Sequences section | |
| y_pos -= 0.08 | |
| ax.text(0.5, y_pos, "Input Sequences", ha='center', va='top', | |
| fontsize=16, fontweight='bold', transform=ax.transAxes) | |
| y_pos -= 0.06 | |
| ax.text(0.1, y_pos, "miRNA Sequence:", ha='left', va='top', | |
| fontsize=12, fontweight='bold', transform=ax.transAxes) | |
| y_pos -= 0.03 | |
| ax.text(0.1, y_pos, mirna_seq, ha='left', va='top', | |
| fontsize=10, family='monospace', transform=ax.transAxes) | |
| y_pos -= 0.03 | |
| ax.text(0.1, y_pos, f"(Length: {len(mirna_seq)} nt, padded to {MODEL_PARAMS['mirna_length']} nt)", | |
| ha='left', va='top', fontsize=9, color='gray', transform=ax.transAxes) | |
| y_pos -= 0.05 | |
| ax.text(0.1, y_pos, "mRNA/Target Sequence:", ha='left', va='top', | |
| fontsize=12, fontweight='bold', transform=ax.transAxes) | |
| y_pos -= 0.03 | |
| # Wrap long sequences | |
| target_wrapped = '\n'.join([target_seq[i:i+60] for i in range(0, len(target_seq), 60)]) | |
| ax.text(0.1, y_pos, target_wrapped, ha='left', va='top', | |
| fontsize=10, family='monospace', transform=ax.transAxes) | |
| lines_used = len(target_seq) // 60 + 1 | |
| y_pos -= 0.03 * lines_used | |
| ax.text(0.1, y_pos, f"(Length: {len(target_seq)} nt, padded to {MODEL_PARAMS['target_length']} nt)", | |
| ha='left', va='top', fontsize=9, color='gray', transform=ax.transAxes) | |
| # Model info | |
| y_pos -= 0.06 | |
| ax.text(0.5, y_pos, "Model Information", ha='center', va='top', | |
| fontsize=16, fontweight='bold', transform=ax.transAxes) | |
| y_pos -= 0.05 | |
| ax.text(0.1, y_pos, "Model: miRBind2 v1.0", ha='left', va='top', | |
| fontsize=10, transform=ax.transAxes) | |
| y_pos -= 0.03 | |
| ax.text(0.1, y_pos, "Training Data: AGO2 eCLIP (Manakov et al. 2022)", ha='left', va='top', | |
| fontsize=10, transform=ax.transAxes) | |
| pdf.savefig(fig, bbox_inches='tight') | |
| plt.close() | |
| # Page 2: Sequence Explainability (if SHAP available) | |
| if shap_2d is not None: | |
| # Create the sequence explainability plot | |
| mirna_importance = np.max(np.abs(shap_2d), axis=1) | |
| target_importance = np.max(np.abs(shap_2d), axis=0) | |
| fig, (ax1, ax2) = plt.subplots( | |
| 2, 1, figsize=(max(12, len(target_seq_padded) * 0.24), 10) | |
| ) | |
| # miRNA importance | |
| _draw_sequence_importance_axis( | |
| ax1, | |
| mirna_seq_padded, | |
| mirna_importance, | |
| 'steelblue', | |
| 'miRNA Position', | |
| 'miRNA Sequence Importance (Max SHAP per position)', | |
| title_fontsize=13, | |
| nucleotide_fontsize=7, | |
| y_fontsize=11 | |
| ) | |
| # mRNA importance | |
| _draw_sequence_importance_axis( | |
| ax2, | |
| target_seq_padded, | |
| target_importance, | |
| 'coral', | |
| 'mRNA/Target Position', | |
| 'mRNA/Target Sequence Importance (Max SHAP per position)', | |
| title_fontsize=13, | |
| nucleotide_fontsize=7, | |
| y_fontsize=11 | |
| ) | |
| plt.tight_layout() | |
| pdf.savefig(fig, bbox_inches='tight') | |
| plt.close() | |
| # Page 3: SHAP Heatmap | |
| fig, ax = plt.subplots( | |
| figsize=(max(12, len(target_seq_padded) * 0.24), max(10, len(mirna_seq_padded) * 0.3)) | |
| ) | |
| # Create colormap | |
| colors = ['#2166ac', '#4393c3', '#92c5de', '#d1e5f0', '#f7f7f7', | |
| '#fddbc7', '#f4a582', '#d6604d', '#b2182b'] | |
| cmap = LinearSegmentedColormap.from_list('shap', colors, N=256) | |
| vmax = np.abs(shap_2d).max() | |
| im = ax.imshow(shap_2d, cmap=cmap, aspect='auto', vmin=-vmax, vmax=vmax) | |
| cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) | |
| cbar.set_label('SHAP Attribution Value', rotation=270, labelpad=20, fontsize=11) | |
| ax.set_xlabel('mRNA/Target Position', fontsize=11) | |
| ax.set_ylabel('miRNA Position', fontsize=11) | |
| ax.set_title(f'SHAP Explainability Heatmap\nTarget site prediction: {score:.3f}', | |
| fontsize=13, fontweight='bold', pad=15) | |
| # Add sequence labels | |
| mirna_ticks = list(range(len(mirna_seq_padded))) | |
| target_ticks = list(range(len(target_seq_padded))) | |
| x_fontsize = max(5, min(7, 260 / max(len(target_seq_padded), 1))) | |
| y_fontsize = max(6, min(8, 220 / max(len(mirna_seq_padded), 1))) | |
| ax.set_yticks(mirna_ticks) | |
| ax.set_yticklabels([f"{i:>2} {mirna_seq_padded[i]}" for i in mirna_ticks], | |
| fontsize=y_fontsize, fontfamily='monospace') | |
| ax.set_xticks(target_ticks) | |
| ax.set_xticklabels([f"{target_seq_padded[i]}\n{i}" for i in target_ticks], | |
| fontsize=x_fontsize, rotation=0) | |
| ax.tick_params(axis='x', pad=6) | |
| ax.tick_params(axis='y', pad=6) | |
| ax.grid(True, alpha=0.2, linewidth=0.5) | |
| # Add explanation text | |
| explanation = ("Red regions: Positive contribution to binding\n" | |
| "Blue regions: Negative contribution to binding\n" | |
| "White regions: Minimal contribution") | |
| fig.text(0.5, 0.02, explanation, ha='center', va='bottom', | |
| fontsize=9, style='italic', bbox=dict(boxstyle='round', | |
| facecolor='wheat', alpha=0.3)) | |
| plt.tight_layout() | |
| pdf.savefig(fig, bbox_inches='tight') | |
| plt.close() | |
| try: | |
| alignment_img = plot_alignment_image(mirna_seq_padded, target_seq_padded, shap_2d) | |
| fig, ax = plt.subplots(figsize=(max(12, alignment_img.width / 110), 4.5)) | |
| ax.imshow(alignment_img) | |
| ax.axis('off') | |
| ax.set_title('SHAP-Guided Target Site Alignment', fontsize=13, fontweight='bold', pad=12) | |
| plt.tight_layout() | |
| pdf.savefig(fig, bbox_inches='tight') | |
| plt.close() | |
| except Exception as alignment_error: | |
| print(f"Warning: alignment plot PDF page failed: {alignment_error}") | |
| return pdf_path | |
| def run_prediction(mirna_input, target_input, show_shap=True): | |
| """Main prediction function called by Gradio.""" | |
| # Validate inputs | |
| valid_mirna, mirna_result = validate_sequence(mirna_input, "miRNA sequence") | |
| if not valid_mirna: | |
| return mirna_result, None, None, None, None, None | |
| valid_target, target_result = validate_sequence(target_input, "mRNA sequence") | |
| if not valid_target: | |
| return target_result, None, None, None, None, None | |
| mirna_seq = mirna_result | |
| target_seq = target_result | |
| mirna_padded = pad_or_trim(mirna_seq, MODEL_PARAMS['mirna_length']) | |
| target_padded = pad_or_trim(target_seq, MODEL_PARAMS['target_length']) | |
| try: | |
| # Run prediction | |
| score = predict_binding(target_seq, mirna_seq) | |
| # Determine binding prediction | |
| binding = "BINDING" if score >= 0.5 else "NO BINDING" | |
| confidence = score if score >= 0.5 else (1 - score) | |
| # Format result text | |
| result_text = f""" | |
| ### Prediction Results | |
| **Binding Score:** {score:.4f} | |
| **Prediction:** {binding} (Confidence: {confidence:.1%}) | |
| **Model:** miRBind2 v1.0 | |
| **Sequences Used:** | |
| - miRNA length: {len(mirna_seq)} nt (padded to {MODEL_PARAMS['mirna_length']}) | |
| - mRNA length: {len(target_seq)} nt (padded to {MODEL_PARAMS['target_length']}) | |
| """ | |
| # Compute SHAP if requested | |
| shap_2d = None | |
| shap_img = None | |
| shap_text = "" | |
| explainability_img = None | |
| alignment_img = None | |
| if show_shap and CAPTUM_AVAILABLE: | |
| shap_text = "Computing SHAP values..." | |
| shap_2d = compute_shap_values(target_seq, mirna_seq) | |
| if shap_2d is not None: | |
| shap_text = """ | |
| ### SHAP Explainability | |
| **Position Pairs Heatmap** shows which miRNA-mRNA position combinations contribute most: | |
| - **Red regions:** Positive contribution to binding | |
| - **Blue regions:** Negative contribution to binding | |
| - **White regions:** Minimal contribution | |
| **Sequence Explainability** above shows per-nucleotide importance (max SHAP score across all positions). | |
| **Alignment View** shows a SHAP-guided target site alignment between the miRNA and target sequence. | |
| This helps identify the specific target sites and key nucleotides driving the prediction. | |
| """ | |
| elif show_shap and not CAPTUM_AVAILABLE: | |
| shap_text = "SHAP visualization requires Captum library. Install with: pip install captum" | |
| explainability_img, shap_img, alignment_img = create_shap_visualizations( | |
| mirna_padded, | |
| target_padded, | |
| score, | |
| shap_2d | |
| ) | |
| # Generate PDF report | |
| pdf_path = None | |
| if shap_2d is not None: | |
| try: | |
| pdf_path = generate_pdf_report( | |
| mirna_seq, target_seq, score, binding, confidence, shap_2d | |
| ) | |
| except Exception as pdf_error: | |
| print(f"Warning: PDF generation failed: {pdf_error}") | |
| pdf_path = None | |
| return result_text, explainability_img, shap_img, alignment_img, shap_text, pdf_path | |
| except Exception as e: | |
| return f"Error during prediction: {str(e)}", None, None, None, None, None | |
| # ============================================================================ | |
| # BATCH PREDICTION FUNCTIONS | |
| # ============================================================================ | |
| def parse_batch_file(file_path): | |
| """Parse uploaded TSV/CSV file with miRNA-target pairs.""" | |
| try: | |
| # Try reading as TSV first | |
| df = pd.read_csv(file_path, sep='\t') | |
| # If only one column, try CSV | |
| if len(df.columns) == 1: | |
| df = pd.read_csv(file_path) | |
| # Check for required columns | |
| if len(df.columns) < 2: | |
| return None, "File must have at least 2 columns (target, miRNA)" | |
| # Standardize column names | |
| col_names = df.columns.tolist() | |
| # Try to detect labels | |
| if 'label' in col_names or 'class' in col_names: | |
| has_labels = True | |
| else: | |
| has_labels = len(df.columns) >= 3 | |
| # Rename columns for consistency | |
| if has_labels: | |
| df.columns = ['target_seq', 'mirna_seq', 'label'] + col_names[3:] | |
| else: | |
| df.columns = ['target_seq', 'mirna_seq'] + col_names[2:] | |
| df['label'] = -1 # Unknown | |
| return df, None | |
| except Exception as e: | |
| return None, f"Error parsing file: {str(e)}" | |
| def run_batch_predictions(df, compute_shap=False, progress=gr.Progress()): | |
| """Run predictions on all pairs in the dataframe.""" | |
| global BATCH_RESULTS, BATCH_SHAP_CACHE | |
| results = [] | |
| BATCH_SHAP_CACHE = {} | |
| total = len(df) | |
| for idx, row in progress.tqdm(df.iterrows(), total=total, desc="Processing pairs"): | |
| target_seq = str(row['target_seq']) | |
| mirna_seq = str(row['mirna_seq']) | |
| true_label = row.get('label', -1) | |
| # Validate sequences | |
| valid_target, target_result = validate_sequence(target_seq, "Target") | |
| valid_mirna, mirna_result = validate_sequence(mirna_seq, "miRNA") | |
| if not valid_target or not valid_mirna: | |
| results.append({ | |
| 'index': idx, | |
| 'mirna_seq': mirna_seq, | |
| 'target_seq': target_seq, | |
| 'true_label': true_label, | |
| 'prediction_score': None, | |
| 'predicted_class': None, | |
| 'status': 'Invalid sequence' | |
| }) | |
| continue | |
| target_seq = target_result | |
| mirna_seq = mirna_result | |
| # Run prediction | |
| try: | |
| score = predict_binding(target_seq, mirna_seq) | |
| pred_class = 1 if score >= 0.5 else 0 | |
| results.append({ | |
| 'index': idx, | |
| 'mirna_seq': mirna_seq, | |
| 'target_seq': target_seq, | |
| 'true_label': true_label, | |
| 'prediction_score': score, | |
| 'predicted_class': pred_class, | |
| 'status': 'Success' | |
| }) | |
| # Compute SHAP if requested (cached for detail view) | |
| if compute_shap and CAPTUM_AVAILABLE: | |
| shap_2d = compute_shap_values(target_seq, mirna_seq) | |
| BATCH_SHAP_CACHE[idx] = shap_2d | |
| except Exception as e: | |
| results.append({ | |
| 'index': idx, | |
| 'mirna_seq': mirna_seq, | |
| 'target_seq': target_seq, | |
| 'true_label': true_label, | |
| 'prediction_score': None, | |
| 'predicted_class': None, | |
| 'status': f'Error: {str(e)}' | |
| }) | |
| BATCH_RESULTS = pd.DataFrame(results) | |
| return BATCH_RESULTS | |
| def format_results_table(results_df): | |
| """Format results for display in Gradio table.""" | |
| if results_df is None or len(results_df) == 0: | |
| return None | |
| display_df = results_df.copy() | |
| # Format columns | |
| display_df['Score'] = display_df['prediction_score'].apply( | |
| lambda x: f"{x:.4f}" if x is not None else "N/A" | |
| ) | |
| display_df['Prediction'] = display_df['predicted_class'].apply( | |
| lambda x: "BINDING" if x == 1 else "NO BINDING" if x == 0 else "N/A" | |
| ) | |
| display_df['True Label'] = display_df['true_label'].apply( | |
| lambda x: "BINDING" if x == 1 else "NO BINDING" if x == 0 else "Unknown" | |
| ) | |
| # Truncate sequences for display | |
| display_df['miRNA'] = display_df['mirna_seq'].apply( | |
| lambda x: x[:20] + "..." if len(x) > 20 else x | |
| ) | |
| display_df['Target'] = display_df['target_seq'].apply( | |
| lambda x: x[:30] + "..." if len(x) > 30 else x | |
| ) | |
| # Select and reorder columns | |
| display_cols = ['index', 'miRNA', 'Target', 'True Label', 'Score', 'Prediction', 'status'] | |
| return display_df[display_cols] | |
| def show_batch_detail_view(current_table_state, evt: gr.SelectData): | |
| """Show detailed view for selected row from batch results. | |
| Args: | |
| current_table_state: The current state of the displayed table (including any sorting) | |
| evt: Selection event containing row/column info | |
| """ | |
| global BATCH_RESULTS, BATCH_SHAP_CACHE | |
| if BATCH_RESULTS is None: | |
| return "No results available", None, None, None, None | |
| if current_table_state is None or len(current_table_state) == 0: | |
| return "No data in table", None, None, None, None | |
| # Get the visual row position in the currently displayed (possibly sorted) table | |
| visual_row_idx = evt.index[0] | |
| if visual_row_idx >= len(current_table_state): | |
| return "Invalid selection", None, None, None, None | |
| # Get the index from the DISPLAYED table at this position | |
| # This works correctly even if the table is sorted | |
| actual_idx = current_table_state.iloc[visual_row_idx]['index'] | |
| # Find the row in the original results using the actual index | |
| row = BATCH_RESULTS[BATCH_RESULTS['index'] == actual_idx].iloc[0] | |
| mirna_seq = row['mirna_seq'] | |
| target_seq = row['target_seq'] | |
| score = row['prediction_score'] | |
| pred_class = row['predicted_class'] | |
| true_label = row['true_label'] | |
| idx = row['index'] | |
| if score is None: | |
| return f"**Error:** {row['status']}", None, None, None, None | |
| # Format result text | |
| binding = "BINDING" if pred_class == 1 else "NO BINDING" | |
| confidence = score if pred_class == 1 else (1 - score) | |
| true_label_str = "BINDING" if true_label == 1 else "NO BINDING" if true_label == 0 else "Unknown" | |
| result_text = f""" | |
| ### Detailed Results for Pair #{idx} | |
| **Binding Score:** {score:.4f} | |
| **Prediction:** {binding} (Confidence: {confidence:.1%}) | |
| **True Label:** {true_label_str} | |
| **Sequences:** | |
| - **miRNA:** {mirna_seq} | |
| - **mRNA:** {target_seq} | |
| """ | |
| # Get or compute SHAP | |
| shap_2d = BATCH_SHAP_CACHE.get(idx) | |
| if shap_2d is not None: | |
| # Pad sequences | |
| mirna_padded = pad_or_trim(mirna_seq, MODEL_PARAMS['mirna_length']) | |
| target_padded = pad_or_trim(target_seq, MODEL_PARAMS['target_length']) | |
| explainability_img, heatmap_img, alignment_img = create_shap_visualizations( | |
| mirna_padded, | |
| target_padded, | |
| score, | |
| shap_2d | |
| ) | |
| explanation = ( | |
| "**SHAP values** computed during batch processing. " | |
| "The alignment view is derived from the same SHAP matrix." | |
| ) | |
| else: | |
| explainability_img = None | |
| heatmap_img = None | |
| alignment_img = None | |
| explanation = "β οΈ **SHAP not computed.** Re-run batch with 'Compute SHAP' enabled to see explainability." | |
| return result_text, explainability_img, heatmap_img, alignment_img, explanation | |
| def process_uploaded_file(file, compute_shap): | |
| """Process uploaded file and run predictions.""" | |
| if file is None: | |
| return "Please upload a file", None | |
| # Parse file | |
| df, error = parse_batch_file(file.name) | |
| if error: | |
| return error, None | |
| # Run predictions | |
| results_df = run_batch_predictions(df, compute_shap=compute_shap) | |
| # Format for display | |
| display_df = format_results_table(results_df) | |
| # Summary statistics | |
| total = len(results_df) | |
| successful = len(results_df[results_df['status'] == 'Success']) | |
| if 'true_label' in results_df.columns: | |
| known_labels = results_df[results_df['true_label'] != -1] | |
| if len(known_labels) > 0: | |
| correct = len(known_labels[known_labels['predicted_class'] == known_labels['true_label']]) | |
| accuracy = correct / len(known_labels) * 100 | |
| scored = known_labels.dropna(subset=['prediction_score']) | |
| aps_line = "" | |
| if len(scored) > 0: | |
| aps = average_precision_score(scored['true_label'], scored['prediction_score']) | |
| aps_line = f"\n**APS (Average Precision):** {aps:.4f}" | |
| summary = f""" | |
| ### Batch Processing Complete β | |
| **Total pairs:** {total} | |
| **Successful:** {successful} | |
| **Failed:** {total - successful} | |
| **Accuracy:** {accuracy:.1f}% ({correct}/{len(known_labels)} correct){aps_line} | |
| π Click any row in the table below to view detailed SHAP analysis. | |
| """ | |
| else: | |
| summary = f""" | |
| ### Batch Processing Complete β | |
| **Total pairs:** {total} | |
| **Successful:** {successful} | |
| **Failed:** {total - successful} | |
| π Click any row in the table below to view detailed results. | |
| """ | |
| else: | |
| summary = f""" | |
| ### Batch Processing Complete β | |
| **Total pairs:** {total} | |
| **Successful:** {successful} | |
| **Failed:** {total - successful} | |
| π Click any row in the table below to view detailed results. | |
| """ | |
| return summary, display_df | |
| def create_gradio_interface(): | |
| """Create the Gradio web interface with tabs.""" | |
| with gr.Blocks(title="miRBind2: miRNA-mRNA Target Site Predictor") as app: | |
| gr.Markdown(""" | |
| # miRBind2: miRNA-mRNA Target Site Predictor | |
| Predict miRNA-mRNA target sites using deep learning with explainability. | |
| **Model:** miRBind2 v1.0 | |
| """) | |
| with gr.Tabs(): | |
| # ================================================================ | |
| # TAB 1: SINGLE PREDICTION | |
| # ================================================================ | |
| with gr.Tab("Single Prediction"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input Sequences") | |
| mirna_input = gr.Textbox( | |
| label="miRNA Sequence", | |
| placeholder=f"Enter miRNA sequence (e.g., {DEFAULT_MIRNA_SEQUENCE})", | |
| lines=3, | |
| value=DEFAULT_MIRNA_SEQUENCE | |
| ) | |
| target_input = gr.Textbox( | |
| label="mRNA/Target Sequence", | |
| placeholder="Enter mRNA target sequence", | |
| lines=5, | |
| # Demo target embeds a reverse-complement site for clearer default explainability. | |
| value=DEFAULT_TARGET_SEQUENCE | |
| ) | |
| show_shap = gr.Checkbox( | |
| label="Show SHAP Explainability (slower)", | |
| value=True | |
| ) | |
| predict_btn = gr.Button("Predict Binding", variant="primary", size="lg") | |
| gr.Markdown(""" | |
| **Instructions:** | |
| 1. Enter your miRNA sequence (RNA or DNA notation accepted) | |
| 2. Enter your mRNA target sequence | |
| 3. Check SHAP box for detailed explainability (recommended) | |
| 4. Click 'Predict Binding' | |
| **Sequence length:** miRNA is padded/trimmed to 28 nt, mRNA to 50 nt. Sequences longer than these limits are trimmed from the 3β² end; shorter sequences are padded with N. | |
| """) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Results") | |
| result_output = gr.Markdown(label="Prediction Results") | |
| gr.Markdown("#### Sequence Explainability") | |
| explainability_output = gr.Image(show_label=False, type="pil") | |
| gr.Markdown("#### SHAP Explainability Heatmap (Position Pairs)") | |
| shap_output = gr.Image(show_label=False, type="pil") | |
| gr.Markdown("#### SHAP-Guided Alignment View") | |
| alignment_view_output = gr.Image(show_label=False, type="pil") | |
| shap_text_output = gr.Markdown() | |
| # PDF download button | |
| with gr.Row(): | |
| pdf_download = gr.File(label="Download PDF Report", visible=True) | |
| # Connect prediction button | |
| predict_btn.click( | |
| fn=run_prediction, | |
| inputs=[mirna_input, target_input, show_shap], | |
| outputs=[ | |
| result_output, | |
| explainability_output, | |
| shap_output, | |
| alignment_view_output, | |
| shap_text_output, | |
| pdf_download, | |
| ] | |
| ) | |
| # ================================================================ | |
| # TAB 2: BATCH PREDICTIONS | |
| # ================================================================ | |
| with gr.Tab("Batch Predictions"): | |
| gr.Markdown(""" | |
| Upload a TSV/CSV file with multiple miRNA-mRNA pairs and browse through predictions interactively. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Upload File") | |
| file_upload = gr.File( | |
| label="Upload TSV/CSV File", | |
| file_types=['.tsv', '.csv', '.txt'] | |
| ) | |
| gr.Markdown(""" | |
| **Expected format (column order matters, headers optional):** | |
| - Column 1: mRNA/target sequence | |
| - Column 2: miRNA sequence | |
| - Column 3 (optional): Label (0 or 1) | |
| TSV (tab-separated) or CSV format. | |
| **Sequence length:** miRNA is padded/trimmed to 28 nt, mRNA to 50 nt. Sequences longer than these limits are trimmed from the 3β² end. | |
| See `example_batch.tsv` for reference. | |
| """) | |
| compute_shap_batch = gr.Checkbox( | |
| label="Compute SHAP for all pairs (slower)", | |
| value=True, | |
| info="Enable for detailed explainability. Increases processing time." | |
| ) | |
| process_btn = gr.Button("Process File", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π Results Summary") | |
| summary_output = gr.Markdown() | |
| gr.Markdown("### π Browse Results") | |
| gr.Markdown("π Click any row to view detailed SHAP analysis") | |
| results_table = gr.Dataframe( | |
| label="Prediction Results", | |
| interactive=False, | |
| wrap=True | |
| ) | |
| gr.Markdown("### π Detailed View") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| detail_text = gr.Markdown() | |
| with gr.Column(scale=2): | |
| gr.Markdown("#### Sequence Explainability") | |
| detail_explainability = gr.Image(show_label=False, type="pil") | |
| gr.Markdown("#### SHAP Heatmap") | |
| detail_heatmap = gr.Image(show_label=False, type="pil") | |
| gr.Markdown("#### SHAP-Guided Alignment View") | |
| detail_alignment = gr.Image(show_label=False, type="pil") | |
| detail_explanation = gr.Markdown() | |
| # Connect batch events | |
| process_btn.click( | |
| fn=process_uploaded_file, | |
| inputs=[file_upload, compute_shap_batch], | |
| outputs=[summary_output, results_table] | |
| ) | |
| # Pass the table itself as input so we can see current state (including sorting) | |
| results_table.select( | |
| fn=show_batch_detail_view, | |
| inputs=[results_table], # Pass current table state | |
| outputs=[ | |
| detail_text, | |
| detail_explainability, | |
| detail_heatmap, | |
| detail_alignment, | |
| detail_explanation, | |
| ] | |
| ) | |
| # About section (outside tabs) | |
| gr.Markdown(""" | |
| --- | |
| ### About | |
| **miRBind2** uses a convolutional neural network trained on CLIP experimental data to predict | |
| miRNA-mRNA target sites. The model learns complementarity patterns between miRNA and target sequences. | |
| **GitHub:** [BioGeMT/miRBind_2.0](https://github.com/BioGeMT/miRBind_2.0) | |
| **Device:** {} | |
| """.format(DEVICE)) | |
| return app | |
| def main(): | |
| """Main entry point.""" | |
| print("=" * 60) | |
| print("miRBind2 Gradio Interface") | |
| print("=" * 60) | |
| # Load model | |
| try: | |
| load_pretrained_model() | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| sys.exit(1) | |
| # Create and launch interface | |
| app = create_gradio_interface() | |
| print("\n" + "=" * 60) | |
| print("Launching Gradio interface...") | |
| print("=" * 60) | |
| app.launch( | |
| share=False, # Set to True to create public link | |
| server_name="0.0.0.0", # Required for containerized hosting (HF Spaces) | |
| server_port=7860, | |
| show_error=True, | |
| theme=gr.themes.Soft() | |
| ) | |
| if __name__ == "__main__": | |
| main() | |