#!/usr/bin/env python3 """ miRBind2 Gradio Interface Interactive web app for miRNA-mRNA binding prediction with explainability """ import sys import os from pathlib import Path # Add code directory to path CODE_DIR = Path(__file__).parent / "miRBind_2.0-main" / "code" / "pairwise_binding_site_model" sys.path.insert(0, str(CODE_DIR)) import torch import torch.nn.functional as F import numpy as np import pandas as pd import gradio as gr import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap from matplotlib.backends.backend_pdf import PdfPages import io from PIL import Image from datetime import datetime import tempfile from sklearn.metrics import average_precision_score from alignment_plot import plot_alignment_image from shared.models import load_model from shared.constants import get_pair_to_index, get_num_pairs, get_device, NUCLEOTIDE_COLORS from shared.encoding import pad_or_trim, encode_complementarity # Try to import Captum for SHAP, but make it optional try: from captum.attr import GradientShap CAPTUM_AVAILABLE = True except ImportError: CAPTUM_AVAILABLE = False print("Warning: Captum not installed. SHAP explainability will not be available.") print("Install with: pip install captum") # Global model variables MODEL = None MODEL_PARAMS = None DEVICE = None PAIR_TO_INDEX = None NUM_PAIRS = None # Global batch results cache BATCH_RESULTS = None BATCH_SHAP_CACHE = {} MODEL_FILENAME = "pairwise_onehot_model_20260105_200141.pt" MODEL_REPO_ID = os.getenv("MIRBIND2_MODEL_REPO", "dimostzim/mirbind2-weights") DEFAULT_MIRNA_SEQUENCE = "UAGCUUAUCAGACUGAUGUUGA" DEFAULT_TARGET_SEQUENCE = "GGGCACUUUUUCAACAUCAGUCUGAUAAGCUAAGUGUCUUCCAGGGAAUU" def resolve_model_path(): """Resolve model path locally, or download from Hugging Face Hub if missing.""" local_model_path = Path(__file__).parent / "miRBind_2.0-main" / "models" / MODEL_FILENAME if local_model_path.exists(): return local_model_path print(f"Local model not found at {local_model_path}") print(f"Downloading model from Hugging Face Hub repo: {MODEL_REPO_ID}") try: from huggingface_hub import hf_hub_download except ImportError as exc: raise FileNotFoundError( "Model file missing locally and huggingface_hub is not installed." ) from exc downloaded_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME, repo_type="model", ) return Path(downloaded_path) def load_pretrained_model(): """Load the pre-trained model once at startup.""" global MODEL, MODEL_PARAMS, DEVICE, PAIR_TO_INDEX, NUM_PAIRS model_path = resolve_model_path() DEVICE = get_device() PAIR_TO_INDEX = get_pair_to_index() NUM_PAIRS = get_num_pairs() print(f"Loading model from {model_path}") print(f"Using device: {DEVICE}") MODEL, checkpoint = load_model(str(model_path), "pairwise_onehot", DEVICE) MODEL_PARAMS = checkpoint['model_params'] print(f"Model loaded successfully!") print(f"Model parameters: {MODEL_PARAMS}") return True def validate_sequence(seq, seq_type="sequence"): """Validate nucleotide sequence.""" if not seq or len(seq.strip()) == 0: return False, f"{seq_type} cannot be empty" seq = seq.strip().upper() valid_nucleotides = set('ATCGUN') invalid = set(seq) - valid_nucleotides if invalid: return False, f"{seq_type} contains invalid characters: {invalid}. Only A, T, C, G, U, N are allowed." # Convert U to T for consistency seq = seq.replace('U', 'T') return True, seq def encode_sequence_pair(target_seq, mirna_seq): """Encode a target-miRNA sequence pair for the model.""" target_length = MODEL_PARAMS['target_length'] mirna_length = MODEL_PARAMS['mirna_length'] # Pad or trim sequences target_seq = pad_or_trim(target_seq, target_length) mirna_seq = pad_or_trim(mirna_seq, mirna_length) # Encode as integer indices indices = encode_complementarity( target_seq, mirna_seq, target_length, mirna_length, PAIR_TO_INDEX, NUM_PAIRS ) # Convert to one-hot encoding indices_tensor = torch.tensor(indices, dtype=torch.long) X_onehot = F.one_hot(indices_tensor, num_classes=NUM_PAIRS + 1).float() # Add batch dimension X_onehot = X_onehot.unsqueeze(0) return X_onehot, target_seq, mirna_seq def predict_binding(target_seq, mirna_seq): """Run binding prediction on a sequence pair.""" MODEL.eval() with torch.no_grad(): X, _, _ = encode_sequence_pair(target_seq, mirna_seq) X = X.to(DEVICE) output = MODEL(X) score = output.item() return score def compute_shap_values(target_seq, mirna_seq): """Compute SHAP attribution values for explainability.""" if not CAPTUM_AVAILABLE: return None MODEL.eval() X, _, _ = encode_sequence_pair(target_seq, mirna_seq) X = X.to(DEVICE) X.requires_grad = True # Create baseline (all zeros) baseline = torch.zeros_like(X) # Compute SHAP values explainer = GradientShap(MODEL) attributions = explainer.attribute(X, baselines=baseline, target=0) # Convert to numpy and reduce from 3D to 2D by summing over pair dimension shap_3d = attributions[0].cpu().detach().numpy() shap_2d = np.sum(shap_3d, axis=2) # Shape: [mirna_length, target_length] return shap_2d def plot_shap_heatmap(shap_2d, mirna_seq, target_seq, prediction_score): """Create SHAP heatmap visualization.""" fig_width = max(16, len(target_seq) * 0.32) fig_height = max(8, len(mirna_seq) * 0.3) fig, ax = plt.subplots(figsize=(fig_width, fig_height)) # Create custom colormap (red-white-blue) colors = ['#2166ac', '#4393c3', '#92c5de', '#d1e5f0', '#f7f7f7', '#fddbc7', '#f4a582', '#d6604d', '#b2182b'] n_bins = 256 cmap = LinearSegmentedColormap.from_list('shap', colors, N=n_bins) # Plot heatmap vmax = np.abs(shap_2d).max() im = ax.imshow(shap_2d, cmap=cmap, aspect='auto', vmin=-vmax, vmax=vmax) # Add colorbar cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) cbar.set_label('SHAP Attribution Value', rotation=270, labelpad=20, fontsize=11) # Set labels ax.set_xlabel('mRNA/Target Position', fontsize=12) ax.set_ylabel('miRNA Position', fontsize=12) ax.set_title(f'miRNA-mRNA Target Site Explainability (Prediction: {prediction_score:.3f})', fontsize=14, pad=20) # Add sequence labels on axes mirna_length = len(mirna_seq) target_length = len(target_seq) mirna_ticks = list(range(mirna_length)) target_ticks = list(range(target_length)) x_fontsize = max(5, min(8, 300 / max(target_length, 1))) y_fontsize = max(6, min(9, 240 / max(mirna_length, 1))) ax.set_yticks(mirna_ticks) ax.set_yticklabels([f"{i:>2} {mirna_seq[i]}" for i in mirna_ticks], fontsize=y_fontsize, fontfamily='monospace') ax.set_xticks(target_ticks) ax.set_xticklabels([f"{target_seq[i]}\n{i}" for i in target_ticks], fontsize=x_fontsize, rotation=0) ax.tick_params(axis='x', pad=6) ax.tick_params(axis='y', pad=6) # Add grid ax.grid(True, alpha=0.2, linewidth=0.5) plt.tight_layout() # Convert to image buf = io.BytesIO() plt.savefig(buf, format='png', dpi=120, bbox_inches='tight') buf.seek(0) img = Image.open(buf) plt.close() return img def plot_sequence_explainability(mirna_seq, target_seq, shap_2d=None): """Create sequence explainability visualization with SHAP importance bars.""" if shap_2d is None: # If no SHAP values, just show sequences fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 4)) # Plot miRNA sequence for i, nuc in enumerate(mirna_seq): color = NUCLEOTIDE_COLORS.get(nuc, 'gray') ax1.text(i, 0, nuc, ha='center', va='center', fontsize=9, bbox=dict(boxstyle='round,pad=0.5', facecolor=color, alpha=0.3)) ax1.text(-2, 0, 'miRNA:', ha='right', va='center', fontsize=11, fontweight='bold') ax1.set_xlim(-3, len(mirna_seq)) ax1.set_ylim(-0.5, 0.5) ax1.axis('off') ax1.set_title('miRNA Sequence', fontsize=11, pad=10) # Plot target sequence for i, nuc in enumerate(target_seq): color = NUCLEOTIDE_COLORS.get(nuc, 'gray') ax2.text(i, 0, nuc, ha='center', va='center', fontsize=9, bbox=dict(boxstyle='round,pad=0.5', facecolor=color, alpha=0.3)) ax2.text(-2, 0, 'mRNA:', ha='right', va='center', fontsize=11, fontweight='bold') ax2.set_xlim(-3, len(target_seq)) ax2.set_ylim(-0.5, 0.5) ax2.axis('off') ax2.set_title('mRNA/Target Sequence', fontsize=11, pad=10) else: # With SHAP values - show importance bars # Max SHAP per miRNA position (max across target positions - axis 1) mirna_importance = np.max(np.abs(shap_2d), axis=1) # Max SHAP per target position (max across miRNA positions - axis 0) target_importance = np.max(np.abs(shap_2d), axis=0) fig_width = max(16, len(target_seq) * 0.28) fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(fig_width, 6.5)) # --- miRNA subplot --- # Bar plot for importance _draw_sequence_importance_axis( ax1, mirna_seq, mirna_importance, 'steelblue', 'miRNA Position', 'miRNA Sequence Importance (Max SHAP per position)', label='Max SHAP Score' ) # --- Target subplot --- _draw_sequence_importance_axis( ax2, target_seq, target_importance, 'coral', 'mRNA/Target Position', 'mRNA/Target Sequence Importance (Max SHAP per position)', label='Max SHAP Score' ) plt.tight_layout() # Convert to image buf = io.BytesIO() plt.savefig(buf, format='png', dpi=120, bbox_inches='tight') buf.seek(0) img = Image.open(buf) plt.close() return img def _draw_sequence_importance_axis(ax, sequence, importance, bar_color, xlabel, title, label=None, title_fontsize=11, label_fontsize=9, nucleotide_fontsize=8, y_fontsize=10): """Render per-position importance with nucleotide boxes and separate position numbers.""" x_positions = np.arange(len(sequence)) if label is None: ax.bar(x_positions, importance, alpha=0.6, color=bar_color, width=0.8) else: ax.bar(x_positions, importance, alpha=0.6, color=bar_color, width=0.8, label=label) max_importance = float(np.max(importance)) if len(importance) else 0.0 y_top = max(max_importance * 1.15, 0.1) nucleotide_y = -0.08 * y_top ax.set_ylim(-0.22 * y_top, y_top) for i, nuc in enumerate(sequence): color = NUCLEOTIDE_COLORS.get(nuc, 'gray') ax.text(i, nucleotide_y, nuc, ha='center', va='top', fontsize=nucleotide_fontsize, fontweight='bold', bbox=dict(boxstyle='round,pad=0.3', facecolor=color, alpha=0.4)) ax.set_xlim(-0.5, len(sequence) - 0.5) ax.set_xticks(x_positions) ax.set_xticklabels([str(i) for i in x_positions], fontsize=max(5, min(8, 300 / max(len(sequence), 1)))) ax.tick_params(axis='x', length=0, pad=10) ax.set_xlabel(xlabel, fontsize=10, labelpad=12) ax.set_ylabel('Max SHAP Score', fontsize=y_fontsize) ax.set_title(title, fontsize=title_fontsize, pad=10) ax.grid(True, alpha=0.3, axis='y') if label is not None: ax.legend(loc='upper right', fontsize=label_fontsize) def create_shap_visualizations(mirna_seq, target_seq, score, shap_2d): """Build the explainability visualizations used across the app.""" explainability_img = plot_sequence_explainability(mirna_seq, target_seq, shap_2d) heatmap_img = None alignment_img = None if shap_2d is not None: heatmap_img = plot_shap_heatmap(shap_2d, mirna_seq, target_seq, score) try: alignment_img = plot_alignment_image(mirna_seq, target_seq, shap_2d) except Exception as alignment_error: print(f"Warning: alignment plot generation failed: {alignment_error}") return explainability_img, heatmap_img, alignment_img def generate_pdf_report(mirna_seq, target_seq, score, binding_prediction, confidence, shap_2d=None): """Generate a PDF report with all visualizations and results.""" # Create temporary file for PDF pdf_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') pdf_path = pdf_file.name pdf_file.close() # Pad sequences to model length mirna_seq_padded = pad_or_trim(mirna_seq, MODEL_PARAMS['mirna_length']) target_seq_padded = pad_or_trim(target_seq, MODEL_PARAMS['target_length']) with PdfPages(pdf_path) as pdf: # Page 1: Title and Summary fig = plt.figure(figsize=(8.5, 11)) ax = fig.add_subplot(111) ax.axis('off') # Title title_text = "miRBind2 Target Site Prediction Report" ax.text(0.5, 0.95, title_text, ha='center', va='top', fontsize=20, fontweight='bold', transform=ax.transAxes) # Date date_text = f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" ax.text(0.5, 0.91, date_text, ha='center', va='top', fontsize=10, color='gray', transform=ax.transAxes) # Divider ax.plot([0.1, 0.9], [0.89, 0.89], 'k-', linewidth=1, transform=ax.transAxes) # Results section y_pos = 0.85 ax.text(0.5, y_pos, "Prediction Results", ha='center', va='top', fontsize=16, fontweight='bold', transform=ax.transAxes) y_pos -= 0.06 ax.text(0.1, y_pos, f"Binding Score:", ha='left', va='top', fontsize=12, fontweight='bold', transform=ax.transAxes) ax.text(0.5, y_pos, f"{score:.4f}", ha='left', va='top', fontsize=12, transform=ax.transAxes) y_pos -= 0.04 ax.text(0.1, y_pos, f"Prediction:", ha='left', va='top', fontsize=12, fontweight='bold', transform=ax.transAxes) # Color code the prediction pred_color = 'green' if binding_prediction == "BINDING" else 'red' ax.text(0.5, y_pos, binding_prediction, ha='left', va='top', fontsize=12, color=pred_color, fontweight='bold', transform=ax.transAxes) y_pos -= 0.04 ax.text(0.1, y_pos, f"Confidence:", ha='left', va='top', fontsize=12, fontweight='bold', transform=ax.transAxes) ax.text(0.5, y_pos, f"{confidence:.1%}", ha='left', va='top', fontsize=12, transform=ax.transAxes) # Sequences section y_pos -= 0.08 ax.text(0.5, y_pos, "Input Sequences", ha='center', va='top', fontsize=16, fontweight='bold', transform=ax.transAxes) y_pos -= 0.06 ax.text(0.1, y_pos, "miRNA Sequence:", ha='left', va='top', fontsize=12, fontweight='bold', transform=ax.transAxes) y_pos -= 0.03 ax.text(0.1, y_pos, mirna_seq, ha='left', va='top', fontsize=10, family='monospace', transform=ax.transAxes) y_pos -= 0.03 ax.text(0.1, y_pos, f"(Length: {len(mirna_seq)} nt, padded to {MODEL_PARAMS['mirna_length']} nt)", ha='left', va='top', fontsize=9, color='gray', transform=ax.transAxes) y_pos -= 0.05 ax.text(0.1, y_pos, "mRNA/Target Sequence:", ha='left', va='top', fontsize=12, fontweight='bold', transform=ax.transAxes) y_pos -= 0.03 # Wrap long sequences target_wrapped = '\n'.join([target_seq[i:i+60] for i in range(0, len(target_seq), 60)]) ax.text(0.1, y_pos, target_wrapped, ha='left', va='top', fontsize=10, family='monospace', transform=ax.transAxes) lines_used = len(target_seq) // 60 + 1 y_pos -= 0.03 * lines_used ax.text(0.1, y_pos, f"(Length: {len(target_seq)} nt, padded to {MODEL_PARAMS['target_length']} nt)", ha='left', va='top', fontsize=9, color='gray', transform=ax.transAxes) # Model info y_pos -= 0.06 ax.text(0.5, y_pos, "Model Information", ha='center', va='top', fontsize=16, fontweight='bold', transform=ax.transAxes) y_pos -= 0.05 ax.text(0.1, y_pos, "Model: miRBind2 v1.0", ha='left', va='top', fontsize=10, transform=ax.transAxes) y_pos -= 0.03 ax.text(0.1, y_pos, "Training Data: AGO2 eCLIP (Manakov et al. 2022)", ha='left', va='top', fontsize=10, transform=ax.transAxes) pdf.savefig(fig, bbox_inches='tight') plt.close() # Page 2: Sequence Explainability (if SHAP available) if shap_2d is not None: # Create the sequence explainability plot mirna_importance = np.max(np.abs(shap_2d), axis=1) target_importance = np.max(np.abs(shap_2d), axis=0) fig, (ax1, ax2) = plt.subplots( 2, 1, figsize=(max(12, len(target_seq_padded) * 0.24), 10) ) # miRNA importance _draw_sequence_importance_axis( ax1, mirna_seq_padded, mirna_importance, 'steelblue', 'miRNA Position', 'miRNA Sequence Importance (Max SHAP per position)', title_fontsize=13, nucleotide_fontsize=7, y_fontsize=11 ) # mRNA importance _draw_sequence_importance_axis( ax2, target_seq_padded, target_importance, 'coral', 'mRNA/Target Position', 'mRNA/Target Sequence Importance (Max SHAP per position)', title_fontsize=13, nucleotide_fontsize=7, y_fontsize=11 ) plt.tight_layout() pdf.savefig(fig, bbox_inches='tight') plt.close() # Page 3: SHAP Heatmap fig, ax = plt.subplots( figsize=(max(12, len(target_seq_padded) * 0.24), max(10, len(mirna_seq_padded) * 0.3)) ) # Create colormap colors = ['#2166ac', '#4393c3', '#92c5de', '#d1e5f0', '#f7f7f7', '#fddbc7', '#f4a582', '#d6604d', '#b2182b'] cmap = LinearSegmentedColormap.from_list('shap', colors, N=256) vmax = np.abs(shap_2d).max() im = ax.imshow(shap_2d, cmap=cmap, aspect='auto', vmin=-vmax, vmax=vmax) cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) cbar.set_label('SHAP Attribution Value', rotation=270, labelpad=20, fontsize=11) ax.set_xlabel('mRNA/Target Position', fontsize=11) ax.set_ylabel('miRNA Position', fontsize=11) ax.set_title(f'SHAP Explainability Heatmap\nTarget site prediction: {score:.3f}', fontsize=13, fontweight='bold', pad=15) # Add sequence labels mirna_ticks = list(range(len(mirna_seq_padded))) target_ticks = list(range(len(target_seq_padded))) x_fontsize = max(5, min(7, 260 / max(len(target_seq_padded), 1))) y_fontsize = max(6, min(8, 220 / max(len(mirna_seq_padded), 1))) ax.set_yticks(mirna_ticks) ax.set_yticklabels([f"{i:>2} {mirna_seq_padded[i]}" for i in mirna_ticks], fontsize=y_fontsize, fontfamily='monospace') ax.set_xticks(target_ticks) ax.set_xticklabels([f"{target_seq_padded[i]}\n{i}" for i in target_ticks], fontsize=x_fontsize, rotation=0) ax.tick_params(axis='x', pad=6) ax.tick_params(axis='y', pad=6) ax.grid(True, alpha=0.2, linewidth=0.5) # Add explanation text explanation = ("Red regions: Positive contribution to binding\n" "Blue regions: Negative contribution to binding\n" "White regions: Minimal contribution") fig.text(0.5, 0.02, explanation, ha='center', va='bottom', fontsize=9, style='italic', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3)) plt.tight_layout() pdf.savefig(fig, bbox_inches='tight') plt.close() try: alignment_img = plot_alignment_image(mirna_seq_padded, target_seq_padded, shap_2d) fig, ax = plt.subplots(figsize=(max(12, alignment_img.width / 110), 4.5)) ax.imshow(alignment_img) ax.axis('off') ax.set_title('SHAP-Guided Target Site Alignment', fontsize=13, fontweight='bold', pad=12) plt.tight_layout() pdf.savefig(fig, bbox_inches='tight') plt.close() except Exception as alignment_error: print(f"Warning: alignment plot PDF page failed: {alignment_error}") return pdf_path def run_prediction(mirna_input, target_input, show_shap=True): """Main prediction function called by Gradio.""" # Validate inputs valid_mirna, mirna_result = validate_sequence(mirna_input, "miRNA sequence") if not valid_mirna: return mirna_result, None, None, None, None, None valid_target, target_result = validate_sequence(target_input, "mRNA sequence") if not valid_target: return target_result, None, None, None, None, None mirna_seq = mirna_result target_seq = target_result mirna_padded = pad_or_trim(mirna_seq, MODEL_PARAMS['mirna_length']) target_padded = pad_or_trim(target_seq, MODEL_PARAMS['target_length']) try: # Run prediction score = predict_binding(target_seq, mirna_seq) # Determine binding prediction binding = "BINDING" if score >= 0.5 else "NO BINDING" confidence = score if score >= 0.5 else (1 - score) # Format result text result_text = f""" ### Prediction Results **Binding Score:** {score:.4f} **Prediction:** {binding} (Confidence: {confidence:.1%}) **Model:** miRBind2 v1.0 **Sequences Used:** - miRNA length: {len(mirna_seq)} nt (padded to {MODEL_PARAMS['mirna_length']}) - mRNA length: {len(target_seq)} nt (padded to {MODEL_PARAMS['target_length']}) """ # Compute SHAP if requested shap_2d = None shap_img = None shap_text = "" explainability_img = None alignment_img = None if show_shap and CAPTUM_AVAILABLE: shap_text = "Computing SHAP values..." shap_2d = compute_shap_values(target_seq, mirna_seq) if shap_2d is not None: shap_text = """ ### SHAP Explainability **Position Pairs Heatmap** shows which miRNA-mRNA position combinations contribute most: - **Red regions:** Positive contribution to binding - **Blue regions:** Negative contribution to binding - **White regions:** Minimal contribution **Sequence Explainability** above shows per-nucleotide importance (max SHAP score across all positions). **Alignment View** shows a SHAP-guided target site alignment between the miRNA and target sequence. This helps identify the specific target sites and key nucleotides driving the prediction. """ elif show_shap and not CAPTUM_AVAILABLE: shap_text = "SHAP visualization requires Captum library. Install with: pip install captum" explainability_img, shap_img, alignment_img = create_shap_visualizations( mirna_padded, target_padded, score, shap_2d ) # Generate PDF report pdf_path = None if shap_2d is not None: try: pdf_path = generate_pdf_report( mirna_seq, target_seq, score, binding, confidence, shap_2d ) except Exception as pdf_error: print(f"Warning: PDF generation failed: {pdf_error}") pdf_path = None return result_text, explainability_img, shap_img, alignment_img, shap_text, pdf_path except Exception as e: return f"Error during prediction: {str(e)}", None, None, None, None, None # ============================================================================ # BATCH PREDICTION FUNCTIONS # ============================================================================ def parse_batch_file(file_path): """Parse uploaded TSV/CSV file with miRNA-target pairs.""" try: # Try reading as TSV first df = pd.read_csv(file_path, sep='\t') # If only one column, try CSV if len(df.columns) == 1: df = pd.read_csv(file_path) # Check for required columns if len(df.columns) < 2: return None, "File must have at least 2 columns (target, miRNA)" # Standardize column names col_names = df.columns.tolist() # Try to detect labels if 'label' in col_names or 'class' in col_names: has_labels = True else: has_labels = len(df.columns) >= 3 # Rename columns for consistency if has_labels: df.columns = ['target_seq', 'mirna_seq', 'label'] + col_names[3:] else: df.columns = ['target_seq', 'mirna_seq'] + col_names[2:] df['label'] = -1 # Unknown return df, None except Exception as e: return None, f"Error parsing file: {str(e)}" def run_batch_predictions(df, compute_shap=False, progress=gr.Progress()): """Run predictions on all pairs in the dataframe.""" global BATCH_RESULTS, BATCH_SHAP_CACHE results = [] BATCH_SHAP_CACHE = {} total = len(df) for idx, row in progress.tqdm(df.iterrows(), total=total, desc="Processing pairs"): target_seq = str(row['target_seq']) mirna_seq = str(row['mirna_seq']) true_label = row.get('label', -1) # Validate sequences valid_target, target_result = validate_sequence(target_seq, "Target") valid_mirna, mirna_result = validate_sequence(mirna_seq, "miRNA") if not valid_target or not valid_mirna: results.append({ 'index': idx, 'mirna_seq': mirna_seq, 'target_seq': target_seq, 'true_label': true_label, 'prediction_score': None, 'predicted_class': None, 'status': 'Invalid sequence' }) continue target_seq = target_result mirna_seq = mirna_result # Run prediction try: score = predict_binding(target_seq, mirna_seq) pred_class = 1 if score >= 0.5 else 0 results.append({ 'index': idx, 'mirna_seq': mirna_seq, 'target_seq': target_seq, 'true_label': true_label, 'prediction_score': score, 'predicted_class': pred_class, 'status': 'Success' }) # Compute SHAP if requested (cached for detail view) if compute_shap and CAPTUM_AVAILABLE: shap_2d = compute_shap_values(target_seq, mirna_seq) BATCH_SHAP_CACHE[idx] = shap_2d except Exception as e: results.append({ 'index': idx, 'mirna_seq': mirna_seq, 'target_seq': target_seq, 'true_label': true_label, 'prediction_score': None, 'predicted_class': None, 'status': f'Error: {str(e)}' }) BATCH_RESULTS = pd.DataFrame(results) return BATCH_RESULTS def format_results_table(results_df): """Format results for display in Gradio table.""" if results_df is None or len(results_df) == 0: return None display_df = results_df.copy() # Format columns display_df['Score'] = display_df['prediction_score'].apply( lambda x: f"{x:.4f}" if x is not None else "N/A" ) display_df['Prediction'] = display_df['predicted_class'].apply( lambda x: "BINDING" if x == 1 else "NO BINDING" if x == 0 else "N/A" ) display_df['True Label'] = display_df['true_label'].apply( lambda x: "BINDING" if x == 1 else "NO BINDING" if x == 0 else "Unknown" ) # Truncate sequences for display display_df['miRNA'] = display_df['mirna_seq'].apply( lambda x: x[:20] + "..." if len(x) > 20 else x ) display_df['Target'] = display_df['target_seq'].apply( lambda x: x[:30] + "..." if len(x) > 30 else x ) # Select and reorder columns display_cols = ['index', 'miRNA', 'Target', 'True Label', 'Score', 'Prediction', 'status'] return display_df[display_cols] def show_batch_detail_view(current_table_state, evt: gr.SelectData): """Show detailed view for selected row from batch results. Args: current_table_state: The current state of the displayed table (including any sorting) evt: Selection event containing row/column info """ global BATCH_RESULTS, BATCH_SHAP_CACHE if BATCH_RESULTS is None: return "No results available", None, None, None, None if current_table_state is None or len(current_table_state) == 0: return "No data in table", None, None, None, None # Get the visual row position in the currently displayed (possibly sorted) table visual_row_idx = evt.index[0] if visual_row_idx >= len(current_table_state): return "Invalid selection", None, None, None, None # Get the index from the DISPLAYED table at this position # This works correctly even if the table is sorted actual_idx = current_table_state.iloc[visual_row_idx]['index'] # Find the row in the original results using the actual index row = BATCH_RESULTS[BATCH_RESULTS['index'] == actual_idx].iloc[0] mirna_seq = row['mirna_seq'] target_seq = row['target_seq'] score = row['prediction_score'] pred_class = row['predicted_class'] true_label = row['true_label'] idx = row['index'] if score is None: return f"**Error:** {row['status']}", None, None, None, None # Format result text binding = "BINDING" if pred_class == 1 else "NO BINDING" confidence = score if pred_class == 1 else (1 - score) true_label_str = "BINDING" if true_label == 1 else "NO BINDING" if true_label == 0 else "Unknown" result_text = f""" ### Detailed Results for Pair #{idx} **Binding Score:** {score:.4f} **Prediction:** {binding} (Confidence: {confidence:.1%}) **True Label:** {true_label_str} **Sequences:** - **miRNA:** {mirna_seq} - **mRNA:** {target_seq} """ # Get or compute SHAP shap_2d = BATCH_SHAP_CACHE.get(idx) if shap_2d is not None: # Pad sequences mirna_padded = pad_or_trim(mirna_seq, MODEL_PARAMS['mirna_length']) target_padded = pad_or_trim(target_seq, MODEL_PARAMS['target_length']) explainability_img, heatmap_img, alignment_img = create_shap_visualizations( mirna_padded, target_padded, score, shap_2d ) explanation = ( "**SHAP values** computed during batch processing. " "The alignment view is derived from the same SHAP matrix." ) else: explainability_img = None heatmap_img = None alignment_img = None explanation = "⚠️ **SHAP not computed.** Re-run batch with 'Compute SHAP' enabled to see explainability." return result_text, explainability_img, heatmap_img, alignment_img, explanation def process_uploaded_file(file, compute_shap): """Process uploaded file and run predictions.""" if file is None: return "Please upload a file", None # Parse file df, error = parse_batch_file(file.name) if error: return error, None # Run predictions results_df = run_batch_predictions(df, compute_shap=compute_shap) # Format for display display_df = format_results_table(results_df) # Summary statistics total = len(results_df) successful = len(results_df[results_df['status'] == 'Success']) if 'true_label' in results_df.columns: known_labels = results_df[results_df['true_label'] != -1] if len(known_labels) > 0: correct = len(known_labels[known_labels['predicted_class'] == known_labels['true_label']]) accuracy = correct / len(known_labels) * 100 scored = known_labels.dropna(subset=['prediction_score']) aps_line = "" if len(scored) > 0: aps = average_precision_score(scored['true_label'], scored['prediction_score']) aps_line = f"\n**APS (Average Precision):** {aps:.4f}" summary = f""" ### Batch Processing Complete ✅ **Total pairs:** {total} **Successful:** {successful} **Failed:** {total - successful} **Accuracy:** {accuracy:.1f}% ({correct}/{len(known_labels)} correct){aps_line} 👇 Click any row in the table below to view detailed SHAP analysis. """ else: summary = f""" ### Batch Processing Complete ✅ **Total pairs:** {total} **Successful:** {successful} **Failed:** {total - successful} 👇 Click any row in the table below to view detailed results. """ else: summary = f""" ### Batch Processing Complete ✅ **Total pairs:** {total} **Successful:** {successful} **Failed:** {total - successful} 👇 Click any row in the table below to view detailed results. """ return summary, display_df def create_gradio_interface(): """Create the Gradio web interface with tabs.""" with gr.Blocks(title="miRBind2: miRNA-mRNA Target Site Predictor") as app: gr.Markdown(""" # miRBind2: miRNA-mRNA Target Site Predictor Predict miRNA-mRNA target sites using deep learning with explainability. **Model:** miRBind2 v1.0 """) with gr.Tabs(): # ================================================================ # TAB 1: SINGLE PREDICTION # ================================================================ with gr.Tab("Single Prediction"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Input Sequences") mirna_input = gr.Textbox( label="miRNA Sequence", placeholder=f"Enter miRNA sequence (e.g., {DEFAULT_MIRNA_SEQUENCE})", lines=3, value=DEFAULT_MIRNA_SEQUENCE ) target_input = gr.Textbox( label="mRNA/Target Sequence", placeholder="Enter mRNA target sequence", lines=5, # Demo target embeds a reverse-complement site for clearer default explainability. value=DEFAULT_TARGET_SEQUENCE ) show_shap = gr.Checkbox( label="Show SHAP Explainability (slower)", value=True ) predict_btn = gr.Button("Predict Binding", variant="primary", size="lg") gr.Markdown(""" **Instructions:** 1. Enter your miRNA sequence (RNA or DNA notation accepted) 2. Enter your mRNA target sequence 3. Check SHAP box for detailed explainability (recommended) 4. Click 'Predict Binding' **Sequence length:** miRNA is padded/trimmed to 28 nt, mRNA to 50 nt. Sequences longer than these limits are trimmed from the 3′ end; shorter sequences are padded with N. """) with gr.Column(scale=2): gr.Markdown("### Results") result_output = gr.Markdown(label="Prediction Results") gr.Markdown("#### Sequence Explainability") explainability_output = gr.Image(show_label=False, type="pil") gr.Markdown("#### SHAP Explainability Heatmap (Position Pairs)") shap_output = gr.Image(show_label=False, type="pil") gr.Markdown("#### SHAP-Guided Alignment View") alignment_view_output = gr.Image(show_label=False, type="pil") shap_text_output = gr.Markdown() # PDF download button with gr.Row(): pdf_download = gr.File(label="Download PDF Report", visible=True) # Connect prediction button predict_btn.click( fn=run_prediction, inputs=[mirna_input, target_input, show_shap], outputs=[ result_output, explainability_output, shap_output, alignment_view_output, shap_text_output, pdf_download, ] ) # ================================================================ # TAB 2: BATCH PREDICTIONS # ================================================================ with gr.Tab("Batch Predictions"): gr.Markdown(""" Upload a TSV/CSV file with multiple miRNA-mRNA pairs and browse through predictions interactively. """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📁 Upload File") file_upload = gr.File( label="Upload TSV/CSV File", file_types=['.tsv', '.csv', '.txt'] ) gr.Markdown(""" **Expected format (column order matters, headers optional):** - Column 1: mRNA/target sequence - Column 2: miRNA sequence - Column 3 (optional): Label (0 or 1) TSV (tab-separated) or CSV format. **Sequence length:** miRNA is padded/trimmed to 28 nt, mRNA to 50 nt. Sequences longer than these limits are trimmed from the 3′ end. See `example_batch.tsv` for reference. """) compute_shap_batch = gr.Checkbox( label="Compute SHAP for all pairs (slower)", value=True, info="Enable for detailed explainability. Increases processing time." ) process_btn = gr.Button("Process File", variant="primary", size="lg") with gr.Column(scale=2): gr.Markdown("### 📊 Results Summary") summary_output = gr.Markdown() gr.Markdown("### 📋 Browse Results") gr.Markdown("👇 Click any row to view detailed SHAP analysis") results_table = gr.Dataframe( label="Prediction Results", interactive=False, wrap=True ) gr.Markdown("### 🔍 Detailed View") with gr.Row(): with gr.Column(scale=1): detail_text = gr.Markdown() with gr.Column(scale=2): gr.Markdown("#### Sequence Explainability") detail_explainability = gr.Image(show_label=False, type="pil") gr.Markdown("#### SHAP Heatmap") detail_heatmap = gr.Image(show_label=False, type="pil") gr.Markdown("#### SHAP-Guided Alignment View") detail_alignment = gr.Image(show_label=False, type="pil") detail_explanation = gr.Markdown() # Connect batch events process_btn.click( fn=process_uploaded_file, inputs=[file_upload, compute_shap_batch], outputs=[summary_output, results_table] ) # Pass the table itself as input so we can see current state (including sorting) results_table.select( fn=show_batch_detail_view, inputs=[results_table], # Pass current table state outputs=[ detail_text, detail_explainability, detail_heatmap, detail_alignment, detail_explanation, ] ) # About section (outside tabs) gr.Markdown(""" --- ### About **miRBind2** uses a convolutional neural network trained on CLIP experimental data to predict miRNA-mRNA target sites. The model learns complementarity patterns between miRNA and target sequences. **GitHub:** [BioGeMT/miRBind_2.0](https://github.com/BioGeMT/miRBind_2.0) **Device:** {} """.format(DEVICE)) return app def main(): """Main entry point.""" print("=" * 60) print("miRBind2 Gradio Interface") print("=" * 60) # Load model try: load_pretrained_model() except Exception as e: print(f"Error loading model: {e}") sys.exit(1) # Create and launch interface app = create_gradio_interface() print("\n" + "=" * 60) print("Launching Gradio interface...") print("=" * 60) app.launch( share=False, # Set to True to create public link server_name="0.0.0.0", # Required for containerized hosting (HF Spaces) server_port=7860, show_error=True, theme=gr.themes.Soft() ) if __name__ == "__main__": main()