dimostzim's picture
Revert "Rename HF model repo to miRBind2-weights"
2c7fcdf
#!/usr/bin/env python3
"""
miRBind2 Gradio Interface
Interactive web app for miRNA-mRNA binding prediction with explainability
"""
import sys
import os
from pathlib import Path
# Add code directory to path
CODE_DIR = Path(__file__).parent / "miRBind_2.0-main" / "code" / "pairwise_binding_site_model"
sys.path.insert(0, str(CODE_DIR))
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.backends.backend_pdf import PdfPages
import io
from PIL import Image
from datetime import datetime
import tempfile
from sklearn.metrics import average_precision_score
from alignment_plot import plot_alignment_image
from shared.models import load_model
from shared.constants import get_pair_to_index, get_num_pairs, get_device, NUCLEOTIDE_COLORS
from shared.encoding import pad_or_trim, encode_complementarity
# Try to import Captum for SHAP, but make it optional
try:
from captum.attr import GradientShap
CAPTUM_AVAILABLE = True
except ImportError:
CAPTUM_AVAILABLE = False
print("Warning: Captum not installed. SHAP explainability will not be available.")
print("Install with: pip install captum")
# Global model variables
MODEL = None
MODEL_PARAMS = None
DEVICE = None
PAIR_TO_INDEX = None
NUM_PAIRS = None
# Global batch results cache
BATCH_RESULTS = None
BATCH_SHAP_CACHE = {}
MODEL_FILENAME = "pairwise_onehot_model_20260105_200141.pt"
MODEL_REPO_ID = os.getenv("MIRBIND2_MODEL_REPO", "dimostzim/mirbind2-weights")
DEFAULT_MIRNA_SEQUENCE = "UAGCUUAUCAGACUGAUGUUGA"
DEFAULT_TARGET_SEQUENCE = "GGGCACUUUUUCAACAUCAGUCUGAUAAGCUAAGUGUCUUCCAGGGAAUU"
def resolve_model_path():
"""Resolve model path locally, or download from Hugging Face Hub if missing."""
local_model_path = Path(__file__).parent / "miRBind_2.0-main" / "models" / MODEL_FILENAME
if local_model_path.exists():
return local_model_path
print(f"Local model not found at {local_model_path}")
print(f"Downloading model from Hugging Face Hub repo: {MODEL_REPO_ID}")
try:
from huggingface_hub import hf_hub_download
except ImportError as exc:
raise FileNotFoundError(
"Model file missing locally and huggingface_hub is not installed."
) from exc
downloaded_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=MODEL_FILENAME,
repo_type="model",
)
return Path(downloaded_path)
def load_pretrained_model():
"""Load the pre-trained model once at startup."""
global MODEL, MODEL_PARAMS, DEVICE, PAIR_TO_INDEX, NUM_PAIRS
model_path = resolve_model_path()
DEVICE = get_device()
PAIR_TO_INDEX = get_pair_to_index()
NUM_PAIRS = get_num_pairs()
print(f"Loading model from {model_path}")
print(f"Using device: {DEVICE}")
MODEL, checkpoint = load_model(str(model_path), "pairwise_onehot", DEVICE)
MODEL_PARAMS = checkpoint['model_params']
print(f"Model loaded successfully!")
print(f"Model parameters: {MODEL_PARAMS}")
return True
def validate_sequence(seq, seq_type="sequence"):
"""Validate nucleotide sequence."""
if not seq or len(seq.strip()) == 0:
return False, f"{seq_type} cannot be empty"
seq = seq.strip().upper()
valid_nucleotides = set('ATCGUN')
invalid = set(seq) - valid_nucleotides
if invalid:
return False, f"{seq_type} contains invalid characters: {invalid}. Only A, T, C, G, U, N are allowed."
# Convert U to T for consistency
seq = seq.replace('U', 'T')
return True, seq
def encode_sequence_pair(target_seq, mirna_seq):
"""Encode a target-miRNA sequence pair for the model."""
target_length = MODEL_PARAMS['target_length']
mirna_length = MODEL_PARAMS['mirna_length']
# Pad or trim sequences
target_seq = pad_or_trim(target_seq, target_length)
mirna_seq = pad_or_trim(mirna_seq, mirna_length)
# Encode as integer indices
indices = encode_complementarity(
target_seq, mirna_seq, target_length, mirna_length,
PAIR_TO_INDEX, NUM_PAIRS
)
# Convert to one-hot encoding
indices_tensor = torch.tensor(indices, dtype=torch.long)
X_onehot = F.one_hot(indices_tensor, num_classes=NUM_PAIRS + 1).float()
# Add batch dimension
X_onehot = X_onehot.unsqueeze(0)
return X_onehot, target_seq, mirna_seq
def predict_binding(target_seq, mirna_seq):
"""Run binding prediction on a sequence pair."""
MODEL.eval()
with torch.no_grad():
X, _, _ = encode_sequence_pair(target_seq, mirna_seq)
X = X.to(DEVICE)
output = MODEL(X)
score = output.item()
return score
def compute_shap_values(target_seq, mirna_seq):
"""Compute SHAP attribution values for explainability."""
if not CAPTUM_AVAILABLE:
return None
MODEL.eval()
X, _, _ = encode_sequence_pair(target_seq, mirna_seq)
X = X.to(DEVICE)
X.requires_grad = True
# Create baseline (all zeros)
baseline = torch.zeros_like(X)
# Compute SHAP values
explainer = GradientShap(MODEL)
attributions = explainer.attribute(X, baselines=baseline, target=0)
# Convert to numpy and reduce from 3D to 2D by summing over pair dimension
shap_3d = attributions[0].cpu().detach().numpy()
shap_2d = np.sum(shap_3d, axis=2) # Shape: [mirna_length, target_length]
return shap_2d
def plot_shap_heatmap(shap_2d, mirna_seq, target_seq, prediction_score):
"""Create SHAP heatmap visualization."""
fig_width = max(16, len(target_seq) * 0.32)
fig_height = max(8, len(mirna_seq) * 0.3)
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
# Create custom colormap (red-white-blue)
colors = ['#2166ac', '#4393c3', '#92c5de', '#d1e5f0', '#f7f7f7',
'#fddbc7', '#f4a582', '#d6604d', '#b2182b']
n_bins = 256
cmap = LinearSegmentedColormap.from_list('shap', colors, N=n_bins)
# Plot heatmap
vmax = np.abs(shap_2d).max()
im = ax.imshow(shap_2d, cmap=cmap, aspect='auto', vmin=-vmax, vmax=vmax)
# Add colorbar
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('SHAP Attribution Value', rotation=270, labelpad=20, fontsize=11)
# Set labels
ax.set_xlabel('mRNA/Target Position', fontsize=12)
ax.set_ylabel('miRNA Position', fontsize=12)
ax.set_title(f'miRNA-mRNA Target Site Explainability (Prediction: {prediction_score:.3f})',
fontsize=14, pad=20)
# Add sequence labels on axes
mirna_length = len(mirna_seq)
target_length = len(target_seq)
mirna_ticks = list(range(mirna_length))
target_ticks = list(range(target_length))
x_fontsize = max(5, min(8, 300 / max(target_length, 1)))
y_fontsize = max(6, min(9, 240 / max(mirna_length, 1)))
ax.set_yticks(mirna_ticks)
ax.set_yticklabels([f"{i:>2} {mirna_seq[i]}" for i in mirna_ticks],
fontsize=y_fontsize, fontfamily='monospace')
ax.set_xticks(target_ticks)
ax.set_xticklabels([f"{target_seq[i]}\n{i}" for i in target_ticks],
fontsize=x_fontsize, rotation=0)
ax.tick_params(axis='x', pad=6)
ax.tick_params(axis='y', pad=6)
# Add grid
ax.grid(True, alpha=0.2, linewidth=0.5)
plt.tight_layout()
# Convert to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=120, bbox_inches='tight')
buf.seek(0)
img = Image.open(buf)
plt.close()
return img
def plot_sequence_explainability(mirna_seq, target_seq, shap_2d=None):
"""Create sequence explainability visualization with SHAP importance bars."""
if shap_2d is None:
# If no SHAP values, just show sequences
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 4))
# Plot miRNA sequence
for i, nuc in enumerate(mirna_seq):
color = NUCLEOTIDE_COLORS.get(nuc, 'gray')
ax1.text(i, 0, nuc, ha='center', va='center', fontsize=9,
bbox=dict(boxstyle='round,pad=0.5', facecolor=color, alpha=0.3))
ax1.text(-2, 0, 'miRNA:', ha='right', va='center', fontsize=11, fontweight='bold')
ax1.set_xlim(-3, len(mirna_seq))
ax1.set_ylim(-0.5, 0.5)
ax1.axis('off')
ax1.set_title('miRNA Sequence', fontsize=11, pad=10)
# Plot target sequence
for i, nuc in enumerate(target_seq):
color = NUCLEOTIDE_COLORS.get(nuc, 'gray')
ax2.text(i, 0, nuc, ha='center', va='center', fontsize=9,
bbox=dict(boxstyle='round,pad=0.5', facecolor=color, alpha=0.3))
ax2.text(-2, 0, 'mRNA:', ha='right', va='center', fontsize=11, fontweight='bold')
ax2.set_xlim(-3, len(target_seq))
ax2.set_ylim(-0.5, 0.5)
ax2.axis('off')
ax2.set_title('mRNA/Target Sequence', fontsize=11, pad=10)
else:
# With SHAP values - show importance bars
# Max SHAP per miRNA position (max across target positions - axis 1)
mirna_importance = np.max(np.abs(shap_2d), axis=1)
# Max SHAP per target position (max across miRNA positions - axis 0)
target_importance = np.max(np.abs(shap_2d), axis=0)
fig_width = max(16, len(target_seq) * 0.28)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(fig_width, 6.5))
# --- miRNA subplot ---
# Bar plot for importance
_draw_sequence_importance_axis(
ax1,
mirna_seq,
mirna_importance,
'steelblue',
'miRNA Position',
'miRNA Sequence Importance (Max SHAP per position)',
label='Max SHAP Score'
)
# --- Target subplot ---
_draw_sequence_importance_axis(
ax2,
target_seq,
target_importance,
'coral',
'mRNA/Target Position',
'mRNA/Target Sequence Importance (Max SHAP per position)',
label='Max SHAP Score'
)
plt.tight_layout()
# Convert to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=120, bbox_inches='tight')
buf.seek(0)
img = Image.open(buf)
plt.close()
return img
def _draw_sequence_importance_axis(ax, sequence, importance, bar_color, xlabel, title,
label=None, title_fontsize=11, label_fontsize=9,
nucleotide_fontsize=8, y_fontsize=10):
"""Render per-position importance with nucleotide boxes and separate position numbers."""
x_positions = np.arange(len(sequence))
if label is None:
ax.bar(x_positions, importance, alpha=0.6, color=bar_color, width=0.8)
else:
ax.bar(x_positions, importance, alpha=0.6, color=bar_color, width=0.8, label=label)
max_importance = float(np.max(importance)) if len(importance) else 0.0
y_top = max(max_importance * 1.15, 0.1)
nucleotide_y = -0.08 * y_top
ax.set_ylim(-0.22 * y_top, y_top)
for i, nuc in enumerate(sequence):
color = NUCLEOTIDE_COLORS.get(nuc, 'gray')
ax.text(i, nucleotide_y, nuc, ha='center', va='top',
fontsize=nucleotide_fontsize, fontweight='bold',
bbox=dict(boxstyle='round,pad=0.3', facecolor=color, alpha=0.4))
ax.set_xlim(-0.5, len(sequence) - 0.5)
ax.set_xticks(x_positions)
ax.set_xticklabels([str(i) for i in x_positions],
fontsize=max(5, min(8, 300 / max(len(sequence), 1))))
ax.tick_params(axis='x', length=0, pad=10)
ax.set_xlabel(xlabel, fontsize=10, labelpad=12)
ax.set_ylabel('Max SHAP Score', fontsize=y_fontsize)
ax.set_title(title, fontsize=title_fontsize, pad=10)
ax.grid(True, alpha=0.3, axis='y')
if label is not None:
ax.legend(loc='upper right', fontsize=label_fontsize)
def create_shap_visualizations(mirna_seq, target_seq, score, shap_2d):
"""Build the explainability visualizations used across the app."""
explainability_img = plot_sequence_explainability(mirna_seq, target_seq, shap_2d)
heatmap_img = None
alignment_img = None
if shap_2d is not None:
heatmap_img = plot_shap_heatmap(shap_2d, mirna_seq, target_seq, score)
try:
alignment_img = plot_alignment_image(mirna_seq, target_seq, shap_2d)
except Exception as alignment_error:
print(f"Warning: alignment plot generation failed: {alignment_error}")
return explainability_img, heatmap_img, alignment_img
def generate_pdf_report(mirna_seq, target_seq, score, binding_prediction, confidence, shap_2d=None):
"""Generate a PDF report with all visualizations and results."""
# Create temporary file for PDF
pdf_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf')
pdf_path = pdf_file.name
pdf_file.close()
# Pad sequences to model length
mirna_seq_padded = pad_or_trim(mirna_seq, MODEL_PARAMS['mirna_length'])
target_seq_padded = pad_or_trim(target_seq, MODEL_PARAMS['target_length'])
with PdfPages(pdf_path) as pdf:
# Page 1: Title and Summary
fig = plt.figure(figsize=(8.5, 11))
ax = fig.add_subplot(111)
ax.axis('off')
# Title
title_text = "miRBind2 Target Site Prediction Report"
ax.text(0.5, 0.95, title_text, ha='center', va='top',
fontsize=20, fontweight='bold', transform=ax.transAxes)
# Date
date_text = f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
ax.text(0.5, 0.91, date_text, ha='center', va='top',
fontsize=10, color='gray', transform=ax.transAxes)
# Divider
ax.plot([0.1, 0.9], [0.89, 0.89], 'k-', linewidth=1, transform=ax.transAxes)
# Results section
y_pos = 0.85
ax.text(0.5, y_pos, "Prediction Results", ha='center', va='top',
fontsize=16, fontweight='bold', transform=ax.transAxes)
y_pos -= 0.06
ax.text(0.1, y_pos, f"Binding Score:", ha='left', va='top',
fontsize=12, fontweight='bold', transform=ax.transAxes)
ax.text(0.5, y_pos, f"{score:.4f}", ha='left', va='top',
fontsize=12, transform=ax.transAxes)
y_pos -= 0.04
ax.text(0.1, y_pos, f"Prediction:", ha='left', va='top',
fontsize=12, fontweight='bold', transform=ax.transAxes)
# Color code the prediction
pred_color = 'green' if binding_prediction == "BINDING" else 'red'
ax.text(0.5, y_pos, binding_prediction, ha='left', va='top',
fontsize=12, color=pred_color, fontweight='bold', transform=ax.transAxes)
y_pos -= 0.04
ax.text(0.1, y_pos, f"Confidence:", ha='left', va='top',
fontsize=12, fontweight='bold', transform=ax.transAxes)
ax.text(0.5, y_pos, f"{confidence:.1%}", ha='left', va='top',
fontsize=12, transform=ax.transAxes)
# Sequences section
y_pos -= 0.08
ax.text(0.5, y_pos, "Input Sequences", ha='center', va='top',
fontsize=16, fontweight='bold', transform=ax.transAxes)
y_pos -= 0.06
ax.text(0.1, y_pos, "miRNA Sequence:", ha='left', va='top',
fontsize=12, fontweight='bold', transform=ax.transAxes)
y_pos -= 0.03
ax.text(0.1, y_pos, mirna_seq, ha='left', va='top',
fontsize=10, family='monospace', transform=ax.transAxes)
y_pos -= 0.03
ax.text(0.1, y_pos, f"(Length: {len(mirna_seq)} nt, padded to {MODEL_PARAMS['mirna_length']} nt)",
ha='left', va='top', fontsize=9, color='gray', transform=ax.transAxes)
y_pos -= 0.05
ax.text(0.1, y_pos, "mRNA/Target Sequence:", ha='left', va='top',
fontsize=12, fontweight='bold', transform=ax.transAxes)
y_pos -= 0.03
# Wrap long sequences
target_wrapped = '\n'.join([target_seq[i:i+60] for i in range(0, len(target_seq), 60)])
ax.text(0.1, y_pos, target_wrapped, ha='left', va='top',
fontsize=10, family='monospace', transform=ax.transAxes)
lines_used = len(target_seq) // 60 + 1
y_pos -= 0.03 * lines_used
ax.text(0.1, y_pos, f"(Length: {len(target_seq)} nt, padded to {MODEL_PARAMS['target_length']} nt)",
ha='left', va='top', fontsize=9, color='gray', transform=ax.transAxes)
# Model info
y_pos -= 0.06
ax.text(0.5, y_pos, "Model Information", ha='center', va='top',
fontsize=16, fontweight='bold', transform=ax.transAxes)
y_pos -= 0.05
ax.text(0.1, y_pos, "Model: miRBind2 v1.0", ha='left', va='top',
fontsize=10, transform=ax.transAxes)
y_pos -= 0.03
ax.text(0.1, y_pos, "Training Data: AGO2 eCLIP (Manakov et al. 2022)", ha='left', va='top',
fontsize=10, transform=ax.transAxes)
pdf.savefig(fig, bbox_inches='tight')
plt.close()
# Page 2: Sequence Explainability (if SHAP available)
if shap_2d is not None:
# Create the sequence explainability plot
mirna_importance = np.max(np.abs(shap_2d), axis=1)
target_importance = np.max(np.abs(shap_2d), axis=0)
fig, (ax1, ax2) = plt.subplots(
2, 1, figsize=(max(12, len(target_seq_padded) * 0.24), 10)
)
# miRNA importance
_draw_sequence_importance_axis(
ax1,
mirna_seq_padded,
mirna_importance,
'steelblue',
'miRNA Position',
'miRNA Sequence Importance (Max SHAP per position)',
title_fontsize=13,
nucleotide_fontsize=7,
y_fontsize=11
)
# mRNA importance
_draw_sequence_importance_axis(
ax2,
target_seq_padded,
target_importance,
'coral',
'mRNA/Target Position',
'mRNA/Target Sequence Importance (Max SHAP per position)',
title_fontsize=13,
nucleotide_fontsize=7,
y_fontsize=11
)
plt.tight_layout()
pdf.savefig(fig, bbox_inches='tight')
plt.close()
# Page 3: SHAP Heatmap
fig, ax = plt.subplots(
figsize=(max(12, len(target_seq_padded) * 0.24), max(10, len(mirna_seq_padded) * 0.3))
)
# Create colormap
colors = ['#2166ac', '#4393c3', '#92c5de', '#d1e5f0', '#f7f7f7',
'#fddbc7', '#f4a582', '#d6604d', '#b2182b']
cmap = LinearSegmentedColormap.from_list('shap', colors, N=256)
vmax = np.abs(shap_2d).max()
im = ax.imshow(shap_2d, cmap=cmap, aspect='auto', vmin=-vmax, vmax=vmax)
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('SHAP Attribution Value', rotation=270, labelpad=20, fontsize=11)
ax.set_xlabel('mRNA/Target Position', fontsize=11)
ax.set_ylabel('miRNA Position', fontsize=11)
ax.set_title(f'SHAP Explainability Heatmap\nTarget site prediction: {score:.3f}',
fontsize=13, fontweight='bold', pad=15)
# Add sequence labels
mirna_ticks = list(range(len(mirna_seq_padded)))
target_ticks = list(range(len(target_seq_padded)))
x_fontsize = max(5, min(7, 260 / max(len(target_seq_padded), 1)))
y_fontsize = max(6, min(8, 220 / max(len(mirna_seq_padded), 1)))
ax.set_yticks(mirna_ticks)
ax.set_yticklabels([f"{i:>2} {mirna_seq_padded[i]}" for i in mirna_ticks],
fontsize=y_fontsize, fontfamily='monospace')
ax.set_xticks(target_ticks)
ax.set_xticklabels([f"{target_seq_padded[i]}\n{i}" for i in target_ticks],
fontsize=x_fontsize, rotation=0)
ax.tick_params(axis='x', pad=6)
ax.tick_params(axis='y', pad=6)
ax.grid(True, alpha=0.2, linewidth=0.5)
# Add explanation text
explanation = ("Red regions: Positive contribution to binding\n"
"Blue regions: Negative contribution to binding\n"
"White regions: Minimal contribution")
fig.text(0.5, 0.02, explanation, ha='center', va='bottom',
fontsize=9, style='italic', bbox=dict(boxstyle='round',
facecolor='wheat', alpha=0.3))
plt.tight_layout()
pdf.savefig(fig, bbox_inches='tight')
plt.close()
try:
alignment_img = plot_alignment_image(mirna_seq_padded, target_seq_padded, shap_2d)
fig, ax = plt.subplots(figsize=(max(12, alignment_img.width / 110), 4.5))
ax.imshow(alignment_img)
ax.axis('off')
ax.set_title('SHAP-Guided Target Site Alignment', fontsize=13, fontweight='bold', pad=12)
plt.tight_layout()
pdf.savefig(fig, bbox_inches='tight')
plt.close()
except Exception as alignment_error:
print(f"Warning: alignment plot PDF page failed: {alignment_error}")
return pdf_path
def run_prediction(mirna_input, target_input, show_shap=True):
"""Main prediction function called by Gradio."""
# Validate inputs
valid_mirna, mirna_result = validate_sequence(mirna_input, "miRNA sequence")
if not valid_mirna:
return mirna_result, None, None, None, None, None
valid_target, target_result = validate_sequence(target_input, "mRNA sequence")
if not valid_target:
return target_result, None, None, None, None, None
mirna_seq = mirna_result
target_seq = target_result
mirna_padded = pad_or_trim(mirna_seq, MODEL_PARAMS['mirna_length'])
target_padded = pad_or_trim(target_seq, MODEL_PARAMS['target_length'])
try:
# Run prediction
score = predict_binding(target_seq, mirna_seq)
# Determine binding prediction
binding = "BINDING" if score >= 0.5 else "NO BINDING"
confidence = score if score >= 0.5 else (1 - score)
# Format result text
result_text = f"""
### Prediction Results
**Binding Score:** {score:.4f}
**Prediction:** {binding} (Confidence: {confidence:.1%})
**Model:** miRBind2 v1.0
**Sequences Used:**
- miRNA length: {len(mirna_seq)} nt (padded to {MODEL_PARAMS['mirna_length']})
- mRNA length: {len(target_seq)} nt (padded to {MODEL_PARAMS['target_length']})
"""
# Compute SHAP if requested
shap_2d = None
shap_img = None
shap_text = ""
explainability_img = None
alignment_img = None
if show_shap and CAPTUM_AVAILABLE:
shap_text = "Computing SHAP values..."
shap_2d = compute_shap_values(target_seq, mirna_seq)
if shap_2d is not None:
shap_text = """
### SHAP Explainability
**Position Pairs Heatmap** shows which miRNA-mRNA position combinations contribute most:
- **Red regions:** Positive contribution to binding
- **Blue regions:** Negative contribution to binding
- **White regions:** Minimal contribution
**Sequence Explainability** above shows per-nucleotide importance (max SHAP score across all positions).
**Alignment View** shows a SHAP-guided target site alignment between the miRNA and target sequence.
This helps identify the specific target sites and key nucleotides driving the prediction.
"""
elif show_shap and not CAPTUM_AVAILABLE:
shap_text = "SHAP visualization requires Captum library. Install with: pip install captum"
explainability_img, shap_img, alignment_img = create_shap_visualizations(
mirna_padded,
target_padded,
score,
shap_2d
)
# Generate PDF report
pdf_path = None
if shap_2d is not None:
try:
pdf_path = generate_pdf_report(
mirna_seq, target_seq, score, binding, confidence, shap_2d
)
except Exception as pdf_error:
print(f"Warning: PDF generation failed: {pdf_error}")
pdf_path = None
return result_text, explainability_img, shap_img, alignment_img, shap_text, pdf_path
except Exception as e:
return f"Error during prediction: {str(e)}", None, None, None, None, None
# ============================================================================
# BATCH PREDICTION FUNCTIONS
# ============================================================================
def parse_batch_file(file_path):
"""Parse uploaded TSV/CSV file with miRNA-target pairs."""
try:
# Try reading as TSV first
df = pd.read_csv(file_path, sep='\t')
# If only one column, try CSV
if len(df.columns) == 1:
df = pd.read_csv(file_path)
# Check for required columns
if len(df.columns) < 2:
return None, "File must have at least 2 columns (target, miRNA)"
# Standardize column names
col_names = df.columns.tolist()
# Try to detect labels
if 'label' in col_names or 'class' in col_names:
has_labels = True
else:
has_labels = len(df.columns) >= 3
# Rename columns for consistency
if has_labels:
df.columns = ['target_seq', 'mirna_seq', 'label'] + col_names[3:]
else:
df.columns = ['target_seq', 'mirna_seq'] + col_names[2:]
df['label'] = -1 # Unknown
return df, None
except Exception as e:
return None, f"Error parsing file: {str(e)}"
def run_batch_predictions(df, compute_shap=False, progress=gr.Progress()):
"""Run predictions on all pairs in the dataframe."""
global BATCH_RESULTS, BATCH_SHAP_CACHE
results = []
BATCH_SHAP_CACHE = {}
total = len(df)
for idx, row in progress.tqdm(df.iterrows(), total=total, desc="Processing pairs"):
target_seq = str(row['target_seq'])
mirna_seq = str(row['mirna_seq'])
true_label = row.get('label', -1)
# Validate sequences
valid_target, target_result = validate_sequence(target_seq, "Target")
valid_mirna, mirna_result = validate_sequence(mirna_seq, "miRNA")
if not valid_target or not valid_mirna:
results.append({
'index': idx,
'mirna_seq': mirna_seq,
'target_seq': target_seq,
'true_label': true_label,
'prediction_score': None,
'predicted_class': None,
'status': 'Invalid sequence'
})
continue
target_seq = target_result
mirna_seq = mirna_result
# Run prediction
try:
score = predict_binding(target_seq, mirna_seq)
pred_class = 1 if score >= 0.5 else 0
results.append({
'index': idx,
'mirna_seq': mirna_seq,
'target_seq': target_seq,
'true_label': true_label,
'prediction_score': score,
'predicted_class': pred_class,
'status': 'Success'
})
# Compute SHAP if requested (cached for detail view)
if compute_shap and CAPTUM_AVAILABLE:
shap_2d = compute_shap_values(target_seq, mirna_seq)
BATCH_SHAP_CACHE[idx] = shap_2d
except Exception as e:
results.append({
'index': idx,
'mirna_seq': mirna_seq,
'target_seq': target_seq,
'true_label': true_label,
'prediction_score': None,
'predicted_class': None,
'status': f'Error: {str(e)}'
})
BATCH_RESULTS = pd.DataFrame(results)
return BATCH_RESULTS
def format_results_table(results_df):
"""Format results for display in Gradio table."""
if results_df is None or len(results_df) == 0:
return None
display_df = results_df.copy()
# Format columns
display_df['Score'] = display_df['prediction_score'].apply(
lambda x: f"{x:.4f}" if x is not None else "N/A"
)
display_df['Prediction'] = display_df['predicted_class'].apply(
lambda x: "BINDING" if x == 1 else "NO BINDING" if x == 0 else "N/A"
)
display_df['True Label'] = display_df['true_label'].apply(
lambda x: "BINDING" if x == 1 else "NO BINDING" if x == 0 else "Unknown"
)
# Truncate sequences for display
display_df['miRNA'] = display_df['mirna_seq'].apply(
lambda x: x[:20] + "..." if len(x) > 20 else x
)
display_df['Target'] = display_df['target_seq'].apply(
lambda x: x[:30] + "..." if len(x) > 30 else x
)
# Select and reorder columns
display_cols = ['index', 'miRNA', 'Target', 'True Label', 'Score', 'Prediction', 'status']
return display_df[display_cols]
def show_batch_detail_view(current_table_state, evt: gr.SelectData):
"""Show detailed view for selected row from batch results.
Args:
current_table_state: The current state of the displayed table (including any sorting)
evt: Selection event containing row/column info
"""
global BATCH_RESULTS, BATCH_SHAP_CACHE
if BATCH_RESULTS is None:
return "No results available", None, None, None, None
if current_table_state is None or len(current_table_state) == 0:
return "No data in table", None, None, None, None
# Get the visual row position in the currently displayed (possibly sorted) table
visual_row_idx = evt.index[0]
if visual_row_idx >= len(current_table_state):
return "Invalid selection", None, None, None, None
# Get the index from the DISPLAYED table at this position
# This works correctly even if the table is sorted
actual_idx = current_table_state.iloc[visual_row_idx]['index']
# Find the row in the original results using the actual index
row = BATCH_RESULTS[BATCH_RESULTS['index'] == actual_idx].iloc[0]
mirna_seq = row['mirna_seq']
target_seq = row['target_seq']
score = row['prediction_score']
pred_class = row['predicted_class']
true_label = row['true_label']
idx = row['index']
if score is None:
return f"**Error:** {row['status']}", None, None, None, None
# Format result text
binding = "BINDING" if pred_class == 1 else "NO BINDING"
confidence = score if pred_class == 1 else (1 - score)
true_label_str = "BINDING" if true_label == 1 else "NO BINDING" if true_label == 0 else "Unknown"
result_text = f"""
### Detailed Results for Pair #{idx}
**Binding Score:** {score:.4f}
**Prediction:** {binding} (Confidence: {confidence:.1%})
**True Label:** {true_label_str}
**Sequences:**
- **miRNA:** {mirna_seq}
- **mRNA:** {target_seq}
"""
# Get or compute SHAP
shap_2d = BATCH_SHAP_CACHE.get(idx)
if shap_2d is not None:
# Pad sequences
mirna_padded = pad_or_trim(mirna_seq, MODEL_PARAMS['mirna_length'])
target_padded = pad_or_trim(target_seq, MODEL_PARAMS['target_length'])
explainability_img, heatmap_img, alignment_img = create_shap_visualizations(
mirna_padded,
target_padded,
score,
shap_2d
)
explanation = (
"**SHAP values** computed during batch processing. "
"The alignment view is derived from the same SHAP matrix."
)
else:
explainability_img = None
heatmap_img = None
alignment_img = None
explanation = "⚠️ **SHAP not computed.** Re-run batch with 'Compute SHAP' enabled to see explainability."
return result_text, explainability_img, heatmap_img, alignment_img, explanation
def process_uploaded_file(file, compute_shap):
"""Process uploaded file and run predictions."""
if file is None:
return "Please upload a file", None
# Parse file
df, error = parse_batch_file(file.name)
if error:
return error, None
# Run predictions
results_df = run_batch_predictions(df, compute_shap=compute_shap)
# Format for display
display_df = format_results_table(results_df)
# Summary statistics
total = len(results_df)
successful = len(results_df[results_df['status'] == 'Success'])
if 'true_label' in results_df.columns:
known_labels = results_df[results_df['true_label'] != -1]
if len(known_labels) > 0:
correct = len(known_labels[known_labels['predicted_class'] == known_labels['true_label']])
accuracy = correct / len(known_labels) * 100
scored = known_labels.dropna(subset=['prediction_score'])
aps_line = ""
if len(scored) > 0:
aps = average_precision_score(scored['true_label'], scored['prediction_score'])
aps_line = f"\n**APS (Average Precision):** {aps:.4f}"
summary = f"""
### Batch Processing Complete βœ…
**Total pairs:** {total}
**Successful:** {successful}
**Failed:** {total - successful}
**Accuracy:** {accuracy:.1f}% ({correct}/{len(known_labels)} correct){aps_line}
πŸ‘‡ Click any row in the table below to view detailed SHAP analysis.
"""
else:
summary = f"""
### Batch Processing Complete βœ…
**Total pairs:** {total}
**Successful:** {successful}
**Failed:** {total - successful}
πŸ‘‡ Click any row in the table below to view detailed results.
"""
else:
summary = f"""
### Batch Processing Complete βœ…
**Total pairs:** {total}
**Successful:** {successful}
**Failed:** {total - successful}
πŸ‘‡ Click any row in the table below to view detailed results.
"""
return summary, display_df
def create_gradio_interface():
"""Create the Gradio web interface with tabs."""
with gr.Blocks(title="miRBind2: miRNA-mRNA Target Site Predictor") as app:
gr.Markdown("""
# miRBind2: miRNA-mRNA Target Site Predictor
Predict miRNA-mRNA target sites using deep learning with explainability.
**Model:** miRBind2 v1.0
""")
with gr.Tabs():
# ================================================================
# TAB 1: SINGLE PREDICTION
# ================================================================
with gr.Tab("Single Prediction"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input Sequences")
mirna_input = gr.Textbox(
label="miRNA Sequence",
placeholder=f"Enter miRNA sequence (e.g., {DEFAULT_MIRNA_SEQUENCE})",
lines=3,
value=DEFAULT_MIRNA_SEQUENCE
)
target_input = gr.Textbox(
label="mRNA/Target Sequence",
placeholder="Enter mRNA target sequence",
lines=5,
# Demo target embeds a reverse-complement site for clearer default explainability.
value=DEFAULT_TARGET_SEQUENCE
)
show_shap = gr.Checkbox(
label="Show SHAP Explainability (slower)",
value=True
)
predict_btn = gr.Button("Predict Binding", variant="primary", size="lg")
gr.Markdown("""
**Instructions:**
1. Enter your miRNA sequence (RNA or DNA notation accepted)
2. Enter your mRNA target sequence
3. Check SHAP box for detailed explainability (recommended)
4. Click 'Predict Binding'
**Sequence length:** miRNA is padded/trimmed to 28 nt, mRNA to 50 nt. Sequences longer than these limits are trimmed from the 3β€² end; shorter sequences are padded with N.
""")
with gr.Column(scale=2):
gr.Markdown("### Results")
result_output = gr.Markdown(label="Prediction Results")
gr.Markdown("#### Sequence Explainability")
explainability_output = gr.Image(show_label=False, type="pil")
gr.Markdown("#### SHAP Explainability Heatmap (Position Pairs)")
shap_output = gr.Image(show_label=False, type="pil")
gr.Markdown("#### SHAP-Guided Alignment View")
alignment_view_output = gr.Image(show_label=False, type="pil")
shap_text_output = gr.Markdown()
# PDF download button
with gr.Row():
pdf_download = gr.File(label="Download PDF Report", visible=True)
# Connect prediction button
predict_btn.click(
fn=run_prediction,
inputs=[mirna_input, target_input, show_shap],
outputs=[
result_output,
explainability_output,
shap_output,
alignment_view_output,
shap_text_output,
pdf_download,
]
)
# ================================================================
# TAB 2: BATCH PREDICTIONS
# ================================================================
with gr.Tab("Batch Predictions"):
gr.Markdown("""
Upload a TSV/CSV file with multiple miRNA-mRNA pairs and browse through predictions interactively.
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“ Upload File")
file_upload = gr.File(
label="Upload TSV/CSV File",
file_types=['.tsv', '.csv', '.txt']
)
gr.Markdown("""
**Expected format (column order matters, headers optional):**
- Column 1: mRNA/target sequence
- Column 2: miRNA sequence
- Column 3 (optional): Label (0 or 1)
TSV (tab-separated) or CSV format.
**Sequence length:** miRNA is padded/trimmed to 28 nt, mRNA to 50 nt. Sequences longer than these limits are trimmed from the 3β€² end.
See `example_batch.tsv` for reference.
""")
compute_shap_batch = gr.Checkbox(
label="Compute SHAP for all pairs (slower)",
value=True,
info="Enable for detailed explainability. Increases processing time."
)
process_btn = gr.Button("Process File", variant="primary", size="lg")
with gr.Column(scale=2):
gr.Markdown("### πŸ“Š Results Summary")
summary_output = gr.Markdown()
gr.Markdown("### πŸ“‹ Browse Results")
gr.Markdown("πŸ‘‡ Click any row to view detailed SHAP analysis")
results_table = gr.Dataframe(
label="Prediction Results",
interactive=False,
wrap=True
)
gr.Markdown("### πŸ” Detailed View")
with gr.Row():
with gr.Column(scale=1):
detail_text = gr.Markdown()
with gr.Column(scale=2):
gr.Markdown("#### Sequence Explainability")
detail_explainability = gr.Image(show_label=False, type="pil")
gr.Markdown("#### SHAP Heatmap")
detail_heatmap = gr.Image(show_label=False, type="pil")
gr.Markdown("#### SHAP-Guided Alignment View")
detail_alignment = gr.Image(show_label=False, type="pil")
detail_explanation = gr.Markdown()
# Connect batch events
process_btn.click(
fn=process_uploaded_file,
inputs=[file_upload, compute_shap_batch],
outputs=[summary_output, results_table]
)
# Pass the table itself as input so we can see current state (including sorting)
results_table.select(
fn=show_batch_detail_view,
inputs=[results_table], # Pass current table state
outputs=[
detail_text,
detail_explainability,
detail_heatmap,
detail_alignment,
detail_explanation,
]
)
# About section (outside tabs)
gr.Markdown("""
---
### About
**miRBind2** uses a convolutional neural network trained on CLIP experimental data to predict
miRNA-mRNA target sites. The model learns complementarity patterns between miRNA and target sequences.
**GitHub:** [BioGeMT/miRBind_2.0](https://github.com/BioGeMT/miRBind_2.0)
**Device:** {}
""".format(DEVICE))
return app
def main():
"""Main entry point."""
print("=" * 60)
print("miRBind2 Gradio Interface")
print("=" * 60)
# Load model
try:
load_pretrained_model()
except Exception as e:
print(f"Error loading model: {e}")
sys.exit(1)
# Create and launch interface
app = create_gradio_interface()
print("\n" + "=" * 60)
print("Launching Gradio interface...")
print("=" * 60)
app.launch(
share=False, # Set to True to create public link
server_name="0.0.0.0", # Required for containerized hosting (HF Spaces)
server_port=7860,
show_error=True,
theme=gr.themes.Soft()
)
if __name__ == "__main__":
main()