import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re
import time
import tempfile
import os
# --- 1. Load Model ---
model_path = "final_model_100k.pt"
tokenizer = AutoTokenizer.from_pretrained("armheb/DNA_bert_6")
model = AutoModelForSequenceClassification.from_pretrained("armheb/DNA_bert_6", num_labels=2)
raw_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
new_state_dict = {}
for key, value in raw_state_dict.items():
if key.startswith('backbone.'):
new_key = key.replace('backbone.', 'bert.')
elif key.startswith('task_classifier.'):
new_key = key.replace('task_classifier.', 'classifier.')
else:
new_key = key
new_state_dict[new_key] = value
model.load_state_dict(new_state_dict, strict=False)
model.eval()
# --- 2. Logic Functions ---
def seq2kmer(seq, k=6):
return " ".join([seq[i:i+k] for i in range(len(seq) - k + 1)])
def find_drach_motifs(sequence):
pattern = r'[AGT][AG]AC[ACT]'
matches = [(m.start(), m.group()) for m in re.finditer(pattern, sequence)]
highlighted_seq = sequence
for m in reversed(matches):
start, motif = m[0], m[1]
highlighted_seq = highlighted_seq[:start] + f"**{motif}**" + highlighted_seq[start+5:]
motifs_text = ", ".join([f"{m[1]} (Pos {m[0]})" for m in matches]) if matches else "None detected."
return motifs_text, highlighted_seq
def calc_gc_content(sequence, window=15):
gc_vals = []
half = window // 2
for i in range(len(sequence)):
start = max(0, i - half)
end = min(len(sequence), i + half + 1)
sub = sequence[start:end]
gc_vals.append((sub.count('G') + sub.count('C')) / len(sub))
return gc_vals
def run_ebcs_core(sequence, k_mask=6):
baseline_kmer = seq2kmer(sequence)
inputs = tokenizer([baseline_kmer], return_tensors="pt", padding="max_length", max_length=128)
with torch.no_grad():
base_out = model(**inputs).logits
base_prob = F.softmax(base_out, dim=1)[0][1].item()
scores = np.zeros(41)
mutated_seqs = []
for i in range(41):
left = (sequence[:max(0, i-k_mask)] + "N"*k_mask + sequence[i:])[:41]
right = (sequence[:i] + "N"*k_mask + sequence[min(41, i+k_mask):])[:41]
mutated_seqs.extend([seq2kmer(left), seq2kmer(right)])
batch_inputs = tokenizer(mutated_seqs, return_tensors="pt", padding="max_length", max_length=128)
with torch.no_grad():
batch_out = model(**batch_inputs).logits
batch_probs = F.softmax(batch_out, dim=1)[:, 1].numpy()
for i in range(41):
scores[i] = abs(batch_probs[i*2] - batch_probs[i*2 + 1]) / (base_prob + 1e-8)
return scores, base_prob
# --- 3. UI Handlers ---
def process_search(sequence, k_mask=6):
yield None, "
⏳ Initializing request... (ETA: ~3s)
", ""
sequence = sequence.strip().upper()
if len(sequence) != 41:
yield None, f"❌ Error: Sequence must be exactly 41 bp.
", ""
return
if not re.fullmatch(r'[ACGTUN]+', sequence):
yield None, "❌ Error: Invalid sequence. Only standard nucleotides allowed.
", ""
return
time.sleep(0.4)
yield None, "🧬 Scanning sequence & computing EBCS... (ETA: ~2s)
", ""
motifs_text, highlighted_seq = find_drach_motifs(sequence)
gc_vals = calc_gc_content(sequence)
scores, base_prob = run_ebcs_core(sequence, k_mask)
time.sleep(0.4)
yield None, "📊 Rendering spatial maps... (ETA: ~1s)
", ""
fig, ax1 = plt.subplots(figsize=(9, 4.5))
# GLASSMORPHISM: Make Matplotlib background fully transparent
fig.patch.set_alpha(0.0)
ax1.patch.set_alpha(0.0)
peak_idx = np.argmax(scores)
peak_score = scores[peak_idx]
peak_base = sequence[peak_idx]
ax1.set_title(f"EBCS Profile - Peak Boundary Detected at Position {peak_idx} ({peak_base})", fontweight='bold', color='#111827', pad=15)
ax1.plot(range(41), scores, color='#4f46e5', linewidth=2.5, marker='o', markersize=5, label='EBCS Delta')
ax1.fill_between(range(41), scores, color='#4f46e5', alpha=0.08)
ax1.set_xticks(range(41))
ax1.set_xticklabels(list(sequence), fontsize=9, color='#111827')
ax1.set_xlabel("Spatial Nucleotide Resolution", fontweight='600', color='#4b5563')
ax1.set_ylabel("Boundary Contrast Delta", color='#4f46e5', fontweight='600')
ax1.tick_params(axis='y', labelcolor='#4f46e5', color='#d1d5db')
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.spines['left'].set_color('#d1d5db')
ax1.spines['bottom'].set_color('#d1d5db')
ax1.grid(True, linestyle='--', color='#f3f4f6')
ax2 = ax1.twinx()
ax2.patch.set_alpha(0.0) # Transparent secondary axis
ax2.plot(range(41), gc_vals, color='#9ca3af', linestyle='-', linewidth=2, alpha=0.3, label='Local GC%')
ax2.set_ylabel("GC Content (Smoothed)", color='#9ca3af')
ax2.tick_params(axis='y', labelcolor='#9ca3af', color='#d1d5db')
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.spines['left'].set_visible(False)
ax2.spines['bottom'].set_visible(False)
ax1.axvline(x=peak_idx, color='#e11d48', linestyle=':', linewidth=2, alpha=0.8)
fig.tight_layout()
res = f"""
🎯 Target: {peak_base} at Pos {peak_idx}
Max Contrast: {peak_score:.4f} | Baseline Confidence: {base_prob:.4f}
Sequence Map: {highlighted_seq}
"""
mot = f"Canonical DRACH Motifs: {motifs_text}
"
yield fig, res, mot
def process_batch(file_obj, k_mask=6):
if file_obj is None:
yield None, "❌ Please upload a CSV or FASTA file.
"
return
yield None, "⏳ Parsing uploaded file...
"
sequences = []
with open(file_obj, 'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip().upper()
if len(line) == 41 and not line.startswith(">"):
sequences.append(line)
if not sequences:
yield None, "❌ No valid 41-bp sequences found.
"
return
results = []
total = len(sequences)
for idx, seq in enumerate(sequences):
if idx % 5 == 0 or idx < 3:
yield None, f"🧬 Computing sequence {idx+1} of {total}...
"
if not re.fullmatch(r'[ACGTUN]+', seq):
continue
scores, base_prob = run_ebcs_core(seq, k_mask)
peak_idx = np.argmax(scores)
motifs, _ = find_drach_motifs(seq)
results.append({
"Sequence": seq,
"Baseline_Confidence": round(base_prob, 4),
"Peak_Position": peak_idx,
"Peak_Base": seq[peak_idx],
"Max_EBCS_Score": round(scores[peak_idx], 4),
"DRACH_Motifs": motifs
})
yield None, "📊 Formatting output CSV...
"
df = pd.DataFrame(results)
temp_dir = tempfile.mkdtemp()
out_path = os.path.join(temp_dir, "EpiRNA_Batch_Results.csv")
df.to_csv(out_path, index=False)
yield out_path, f"✅ Successfully processed {len(results)} sequences.
"
# --- 4. THEME & CSS (Native Gradio Dark-Mode Override & Glassmorphism) ---
# This forces Dark Mode variables to be Light/Translucent
glass_theme = gr.themes.Soft(
primary_hue="indigo",
neutral_hue="slate",
).set(
body_background_fill="#f8fafc",
body_background_fill_dark="#f8fafc",
background_fill_primary="rgba(255, 255, 255, 0.7)",
background_fill_primary_dark="rgba(255, 255, 255, 0.7)",
background_fill_secondary="rgba(255, 255, 255, 0.4)",
background_fill_secondary_dark="rgba(255, 255, 255, 0.4)",
border_color_primary="rgba(203, 213, 225, 0.6)",
border_color_primary_dark="rgba(203, 213, 225, 0.6)",
block_background_fill="rgba(255, 255, 255, 0.6)",
block_background_fill_dark="rgba(255, 255, 255, 0.6)",
block_title_text_color="#111827",
block_title_text_color_dark="#111827",
block_label_text_color="#374151",
block_label_text_color_dark="#374151",
body_text_color="#111827",
body_text_color_dark="#111827",
input_background_fill="#ffffff",
input_background_fill_dark="#ffffff",
)
custom_css = """
/* Glassmorphism Structural Layout */
body { background: linear-gradient(135deg, #f8fafc 0%, #e0e7ff 100%) !important; }
.main-container { max-width: 1400px; margin: 0 auto; padding: 40px 20px; }
.glass-panel {
background: rgba(255, 255, 255, 0.5) !important;
backdrop-filter: blur(12px) !important;
border-radius: 16px;
border: 1px solid rgba(255, 255, 255, 0.6) !important;
padding: 24px;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05);
}
/* Hide Gradio Footer */
footer { display: none !important; }
/* Buttons & Inputs */
textarea, input { background-color: #ffffff !important; color: #111827 !important; border: 1px solid #cbd5e1 !important; }
button.primary { background-color: #111827 !important; color: #ffffff !important; border-radius: 8px !important; transition: all 0.2s ease;}
button.primary:hover { background-color: #4f46e5 !important; transform: translateY(-1px);}
/* Tabs Styling */
.tabs { border: none !important; background: transparent !important; }
.tab-nav { border-bottom: 1px solid rgba(0,0,0,0.1) !important; }
.tab-nav button { color: #4b5563 !important; font-weight: 500 !important; background: transparent !important; }
.tab-nav button.selected { color: #4f46e5 !important; border-bottom: 2px solid #4f46e5 !important; }
/* Pro Tooltips */
.pro-tooltip { position: relative; display: inline-block; cursor: help; border-bottom: 1px dashed #4f46e5; font-weight: 600; color: #4f46e5;}
.pro-tooltip .tooltip-text { visibility: hidden; width: 280px; background-color: #111827; color: #ffffff !important; text-align: left; padding: 12px 16px; border-radius: 8px; position: absolute; z-index: 100; bottom: 130%; left: 50%; transform: translateX(-50%) translateY(10px); opacity: 0; transition: all 0.2s ease; font-size: 0.9rem; font-weight: 400; line-height: 1.5; pointer-events: none; }
.pro-tooltip:hover .tooltip-text { visibility: visible; opacity: 1; transform: translateX(-50%) translateY(0); }
"""
with gr.Blocks(theme=glass_theme, css=custom_css, title="EpiRNA") as app:
with gr.Row(elem_classes="main-container"):
# ================= LEFT COLUMN: STATIC SEARCH ENGINE =================
with gr.Column(scale=4, elem_classes="glass-panel"):
gr.HTML("""
EpiRNA
Decoding RNA Catalytic Boundaries at Single-Nucleotide Resolution
""")
d_in = gr.Textbox(
label="Target Sequence",
info="Must be exactly 41 nucleotides long.",
placeholder="Paste exactly 41-bp of RNA/DNA sequence...",
lines=3
)
with gr.Accordion("Advanced Settings", open=False):
d_k_slider = gr.Slider(
minimum=2, maximum=10, step=2, value=6,
label="EBCS Mask Resolution (k-mer)"
)
d_btn = gr.Button("Analyze Sequence", variant="primary")
# ================= RIGHT COLUMN: OUTPUTS & BATCH & ABOUT =================
with gr.Column(scale=8, elem_classes="glass-panel"):
with gr.Tabs():
# Right Tab 1: Single Search Graph Results
with gr.Tab("Spatial EBCS Map"):
out_res = gr.HTML("⏳ Waiting for sequence input...
")
out_plot = gr.Plot()
out_mot = gr.HTML()
# Right Tab 2: Batch Processing Upload
with gr.Tab("Batch Processing Engine"):
gr.HTML("""
Bulk Sequence Analysis
Upload a .csv or .fasta file containing sequences (one per line).
""")
batch_file = gr.File(label="Upload Dataset", file_types=[".csv", ".fasta", ".txt"])
batch_k_slider = gr.Slider(minimum=2, maximum=10, step=2, value=6, label="Mask Resolution (k-mer)")
batch_btn = gr.Button("Run Batch EBCS", variant="primary")
gr.HTML("
")
batch_status = gr.HTML("Ready for upload.
")
batch_download = gr.File(label="Download Processed Results", interactive=False)
# Right Tab 3: HTML Science / FAQ with Interactive Tooltips
with gr.Tab("The Science"):
gr.HTML("""
The "Clever Hans" Effect in Epitranscriptomics
Traditional deep learning models for RNA modifications overfit to lab-specific technical noise (like GC-content biasA common laboratory artifact where sequencing machines preferentially read sequences rich in Guanine (G) and Cytosine (C), tricking AI models into correlating GC% with RNA modifications.). They fail to generalize across unseen datasets.
The Zero-Shot Solution
EpiRNA leverages a DANNDomain Adversarial Neural Network. trained on SSBSynthetic Sandbox Bootstrapping.. By mathematically stripping away technical batch artifacts, it learns true causal biology.
What is EBCS?
Epitranscriptomic Boundary Contrast Scoring (EBCSA zero-shot mathematical probe that calculates the exact single-nucleotide derivative of an AI model's confidence.) slides a synthetic mask across the sequence to calculate the mathematical derivative of the model's confidence. The peak contrast deltaThe highest point on the blue graph line. reveals the exact single-nucleotide catalytic boundary the AI relies upon.
""")
# --- Event Wiring ---
d_btn.click(process_search, inputs=[d_in, d_k_slider], outputs=[out_plot, out_res, out_mot])
batch_btn.click(process_batch, inputs=[batch_file, batch_k_slider], outputs=[batch_download, batch_status])
app.queue().launch()