EpiRNA / app.py
supzammy's picture
Upload 3 files
ce8c078 verified
import gradio as gr
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re
import time
import tempfile
import os
# --- 1. Load Model ---
model_path = "final_model_100k.pt"
tokenizer = AutoTokenizer.from_pretrained("armheb/DNA_bert_6")
model = AutoModelForSequenceClassification.from_pretrained("armheb/DNA_bert_6", num_labels=2)
raw_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
new_state_dict = {}
for key, value in raw_state_dict.items():
if key.startswith('backbone.'):
new_key = key.replace('backbone.', 'bert.')
elif key.startswith('task_classifier.'):
new_key = key.replace('task_classifier.', 'classifier.')
else:
new_key = key
new_state_dict[new_key] = value
model.load_state_dict(new_state_dict, strict=False)
model.eval()
# --- 2. Logic Functions ---
def seq2kmer(seq, k=6):
return " ".join([seq[i:i+k] for i in range(len(seq) - k + 1)])
def find_drach_motifs(sequence):
pattern = r'[AGT][AG]AC[ACT]'
matches = [(m.start(), m.group()) for m in re.finditer(pattern, sequence)]
highlighted_seq = sequence
for m in reversed(matches):
start, motif = m[0], m[1]
highlighted_seq = highlighted_seq[:start] + f"**<span style='color:#000000; background:#f3f4f6; padding:2px 4px; border-radius:4px; border:1px solid #d1d5db;'>{motif}</span>**" + highlighted_seq[start+5:]
motifs_text = ", ".join([f"<span style='color:#111827;'>{m[1]} (Pos {m[0]})</span>" for m in matches]) if matches else "<span style='color:#111827;'>None detected.</span>"
return motifs_text, highlighted_seq
def calc_gc_content(sequence, window=15):
gc_vals = []
half = window // 2
for i in range(len(sequence)):
start = max(0, i - half)
end = min(len(sequence), i + half + 1)
sub = sequence[start:end]
gc_vals.append((sub.count('G') + sub.count('C')) / len(sub))
return gc_vals
def run_ebcs_core(sequence, k_mask=6):
baseline_kmer = seq2kmer(sequence)
inputs = tokenizer([baseline_kmer], return_tensors="pt", padding="max_length", max_length=128)
with torch.no_grad():
base_out = model(**inputs).logits
base_prob = F.softmax(base_out, dim=1)[0][1].item()
scores = np.zeros(41)
mutated_seqs = []
for i in range(41):
left = (sequence[:max(0, i-k_mask)] + "N"*k_mask + sequence[i:])[:41]
right = (sequence[:i] + "N"*k_mask + sequence[min(41, i+k_mask):])[:41]
mutated_seqs.extend([seq2kmer(left), seq2kmer(right)])
batch_inputs = tokenizer(mutated_seqs, return_tensors="pt", padding="max_length", max_length=128)
with torch.no_grad():
batch_out = model(**batch_inputs).logits
batch_probs = F.softmax(batch_out, dim=1)[:, 1].numpy()
for i in range(41):
scores[i] = abs(batch_probs[i*2] - batch_probs[i*2 + 1]) / (base_prob + 1e-8)
return scores, base_prob
# --- 3. UI Handlers ---
def process_search(sequence, k_mask=6):
yield None, "<h3 style='color:#111827; margin:0; font-family: sans-serif;'>⏳ Initializing request... (ETA: ~3s)</h3>", ""
sequence = sequence.strip().upper()
if len(sequence) != 41:
yield None, f"<h3 style='color:#dc2626; margin:0; font-family: sans-serif;'>❌ Error: Sequence must be exactly 41 bp.</h3>", ""
return
if not re.fullmatch(r'[ACGTUN]+', sequence):
yield None, "<h3 style='color:#dc2626; margin:0; font-family: sans-serif;'>❌ Error: Invalid sequence. Only standard nucleotides allowed.</h3>", ""
return
time.sleep(0.4)
yield None, "<h3 style='color:#111827; margin:0; font-family: sans-serif;'>🧬 Scanning sequence & computing EBCS... (ETA: ~2s)</h3>", ""
motifs_text, highlighted_seq = find_drach_motifs(sequence)
gc_vals = calc_gc_content(sequence)
scores, base_prob = run_ebcs_core(sequence, k_mask)
time.sleep(0.4)
yield None, "<h3 style='color:#111827; margin:0; font-family: sans-serif;'>📊 Rendering spatial maps... (ETA: ~1s)</h3>", ""
fig, ax1 = plt.subplots(figsize=(9, 4.5))
# GLASSMORPHISM: Make Matplotlib background fully transparent
fig.patch.set_alpha(0.0)
ax1.patch.set_alpha(0.0)
peak_idx = np.argmax(scores)
peak_score = scores[peak_idx]
peak_base = sequence[peak_idx]
ax1.set_title(f"EBCS Profile - Peak Boundary Detected at Position {peak_idx} ({peak_base})", fontweight='bold', color='#111827', pad=15)
ax1.plot(range(41), scores, color='#4f46e5', linewidth=2.5, marker='o', markersize=5, label='EBCS Delta')
ax1.fill_between(range(41), scores, color='#4f46e5', alpha=0.08)
ax1.set_xticks(range(41))
ax1.set_xticklabels(list(sequence), fontsize=9, color='#111827')
ax1.set_xlabel("Spatial Nucleotide Resolution", fontweight='600', color='#4b5563')
ax1.set_ylabel("Boundary Contrast Delta", color='#4f46e5', fontweight='600')
ax1.tick_params(axis='y', labelcolor='#4f46e5', color='#d1d5db')
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.spines['left'].set_color('#d1d5db')
ax1.spines['bottom'].set_color('#d1d5db')
ax1.grid(True, linestyle='--', color='#f3f4f6')
ax2 = ax1.twinx()
ax2.patch.set_alpha(0.0) # Transparent secondary axis
ax2.plot(range(41), gc_vals, color='#9ca3af', linestyle='-', linewidth=2, alpha=0.3, label='Local GC%')
ax2.set_ylabel("GC Content (Smoothed)", color='#9ca3af')
ax2.tick_params(axis='y', labelcolor='#9ca3af', color='#d1d5db')
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.spines['left'].set_visible(False)
ax2.spines['bottom'].set_visible(False)
ax1.axvline(x=peak_idx, color='#e11d48', linestyle=':', linewidth=2, alpha=0.8)
fig.tight_layout()
res = f"""
<div style="color: #111827; font-size: 1.05rem; font-family: sans-serif;">
<h3 style="margin-top: 0; color: #111827;">🎯 Target: <span style="background:#e0e7ff; padding:2px 6px; border-radius:4px; color:#4f46e5;">{peak_base}</span> at Pos <b>{peak_idx}</b></h3>
<p style="margin: 8px 0; color: #111827;"><b>Max Contrast:</b> {peak_score:.4f} &nbsp;|&nbsp; <b>Baseline Confidence:</b> {base_prob:.4f}</p>
<p style="margin: 8px 0; color: #111827;"><b>Sequence Map:</b> {highlighted_seq}</p>
</div>
"""
mot = f"<div style='color: #111827; font-size: 1.05rem; font-family: sans-serif;'><p style='color: #111827; margin:0;'><b>Canonical DRACH Motifs:</b> {motifs_text}</p></div>"
yield fig, res, mot
def process_batch(file_obj, k_mask=6):
if file_obj is None:
yield None, "<h3 style='color:#dc2626; margin:0; font-family: sans-serif;'>❌ Please upload a CSV or FASTA file.</h3>"
return
yield None, "<h3 style='color:#111827; margin:0; font-family: sans-serif;'>⏳ Parsing uploaded file...</h3>"
sequences = []
with open(file_obj, 'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip().upper()
if len(line) == 41 and not line.startswith(">"):
sequences.append(line)
if not sequences:
yield None, "<h3 style='color:#dc2626; margin:0; font-family: sans-serif;'>❌ No valid 41-bp sequences found.</h3>"
return
results = []
total = len(sequences)
for idx, seq in enumerate(sequences):
if idx % 5 == 0 or idx < 3:
yield None, f"<h3 style='color:#111827; margin:0; font-family: sans-serif;'>🧬 Computing sequence {idx+1} of {total}...</h3>"
if not re.fullmatch(r'[ACGTUN]+', seq):
continue
scores, base_prob = run_ebcs_core(seq, k_mask)
peak_idx = np.argmax(scores)
motifs, _ = find_drach_motifs(seq)
results.append({
"Sequence": seq,
"Baseline_Confidence": round(base_prob, 4),
"Peak_Position": peak_idx,
"Peak_Base": seq[peak_idx],
"Max_EBCS_Score": round(scores[peak_idx], 4),
"DRACH_Motifs": motifs
})
yield None, "<h3 style='color:#111827; margin:0; font-family: sans-serif;'>📊 Formatting output CSV...</h3>"
df = pd.DataFrame(results)
temp_dir = tempfile.mkdtemp()
out_path = os.path.join(temp_dir, "EpiRNA_Batch_Results.csv")
df.to_csv(out_path, index=False)
yield out_path, f"<h3 style='color:#16a34a; margin:0; font-family: sans-serif;'>✅ Successfully processed {len(results)} sequences.</h3>"
# --- 4. THEME & CSS (Native Gradio Dark-Mode Override & Glassmorphism) ---
# This forces Dark Mode variables to be Light/Translucent
glass_theme = gr.themes.Soft(
primary_hue="indigo",
neutral_hue="slate",
).set(
body_background_fill="#f8fafc",
body_background_fill_dark="#f8fafc",
background_fill_primary="rgba(255, 255, 255, 0.7)",
background_fill_primary_dark="rgba(255, 255, 255, 0.7)",
background_fill_secondary="rgba(255, 255, 255, 0.4)",
background_fill_secondary_dark="rgba(255, 255, 255, 0.4)",
border_color_primary="rgba(203, 213, 225, 0.6)",
border_color_primary_dark="rgba(203, 213, 225, 0.6)",
block_background_fill="rgba(255, 255, 255, 0.6)",
block_background_fill_dark="rgba(255, 255, 255, 0.6)",
block_title_text_color="#111827",
block_title_text_color_dark="#111827",
block_label_text_color="#374151",
block_label_text_color_dark="#374151",
body_text_color="#111827",
body_text_color_dark="#111827",
input_background_fill="#ffffff",
input_background_fill_dark="#ffffff",
)
custom_css = """
/* Glassmorphism Structural Layout */
body { background: linear-gradient(135deg, #f8fafc 0%, #e0e7ff 100%) !important; }
.main-container { max-width: 1400px; margin: 0 auto; padding: 40px 20px; }
.glass-panel {
background: rgba(255, 255, 255, 0.5) !important;
backdrop-filter: blur(12px) !important;
border-radius: 16px;
border: 1px solid rgba(255, 255, 255, 0.6) !important;
padding: 24px;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05);
}
/* Hide Gradio Footer */
footer { display: none !important; }
/* Buttons & Inputs */
textarea, input { background-color: #ffffff !important; color: #111827 !important; border: 1px solid #cbd5e1 !important; }
button.primary { background-color: #111827 !important; color: #ffffff !important; border-radius: 8px !important; transition: all 0.2s ease;}
button.primary:hover { background-color: #4f46e5 !important; transform: translateY(-1px);}
/* Tabs Styling */
.tabs { border: none !important; background: transparent !important; }
.tab-nav { border-bottom: 1px solid rgba(0,0,0,0.1) !important; }
.tab-nav button { color: #4b5563 !important; font-weight: 500 !important; background: transparent !important; }
.tab-nav button.selected { color: #4f46e5 !important; border-bottom: 2px solid #4f46e5 !important; }
/* Pro Tooltips */
.pro-tooltip { position: relative; display: inline-block; cursor: help; border-bottom: 1px dashed #4f46e5; font-weight: 600; color: #4f46e5;}
.pro-tooltip .tooltip-text { visibility: hidden; width: 280px; background-color: #111827; color: #ffffff !important; text-align: left; padding: 12px 16px; border-radius: 8px; position: absolute; z-index: 100; bottom: 130%; left: 50%; transform: translateX(-50%) translateY(10px); opacity: 0; transition: all 0.2s ease; font-size: 0.9rem; font-weight: 400; line-height: 1.5; pointer-events: none; }
.pro-tooltip:hover .tooltip-text { visibility: visible; opacity: 1; transform: translateX(-50%) translateY(0); }
"""
with gr.Blocks(theme=glass_theme, css=custom_css, title="EpiRNA") as app:
with gr.Row(elem_classes="main-container"):
# ================= LEFT COLUMN: STATIC SEARCH ENGINE =================
with gr.Column(scale=4, elem_classes="glass-panel"):
gr.HTML("""
<div style="margin-bottom: 30px;">
<h1 style="font-size: 3.5rem; font-weight: 800; margin: 0; letter-spacing: -1.5px; line-height: 1.1; color: #111827; font-family: sans-serif;">
<span style='color: #4f46e5;'>Epi</span><span style='color: #e11d48;'>RNA</span>
</h1>
<p style="font-size: 1.1rem; font-weight: 400; color: #4b5563; margin-top: 8px; font-family: sans-serif;">
Decoding RNA Catalytic Boundaries at Single-Nucleotide Resolution
</p>
</div>
""")
d_in = gr.Textbox(
label="Target Sequence",
info="Must be exactly 41 nucleotides long.",
placeholder="Paste exactly 41-bp of RNA/DNA sequence...",
lines=3
)
with gr.Accordion("Advanced Settings", open=False):
d_k_slider = gr.Slider(
minimum=2, maximum=10, step=2, value=6,
label="EBCS Mask Resolution (k-mer)"
)
d_btn = gr.Button("Analyze Sequence", variant="primary")
# ================= RIGHT COLUMN: OUTPUTS & BATCH & ABOUT =================
with gr.Column(scale=8, elem_classes="glass-panel"):
with gr.Tabs():
# Right Tab 1: Single Search Graph Results
with gr.Tab("Spatial EBCS Map"):
out_res = gr.HTML("<h3 style='color:#111827; margin-top:0; font-family: sans-serif;'>⏳ Waiting for sequence input...</h3>")
out_plot = gr.Plot()
out_mot = gr.HTML()
# Right Tab 2: Batch Processing Upload
with gr.Tab("Batch Processing Engine"):
gr.HTML("""
<div style='color:#111827; font-family: sans-serif;'>
<h3 style='margin-top:0; color:#111827;'>Bulk Sequence Analysis</h3>
<p style='color:#111827;'>Upload a <code>.csv</code> or <code>.fasta</code> file containing sequences (one per line).</p>
</div>
""")
batch_file = gr.File(label="Upload Dataset", file_types=[".csv", ".fasta", ".txt"])
batch_k_slider = gr.Slider(minimum=2, maximum=10, step=2, value=6, label="Mask Resolution (k-mer)")
batch_btn = gr.Button("Run Batch EBCS", variant="primary")
gr.HTML("<hr style='border-color: #e5e7eb; margin: 20px 0;'>")
batch_status = gr.HTML("<h3 style='color:#111827; margin:0; font-family: sans-serif;'>Ready for upload.</h3>")
batch_download = gr.File(label="Download Processed Results", interactive=False)
# Right Tab 3: HTML Science / FAQ with Interactive Tooltips
with gr.Tab("The Science"):
gr.HTML("""
<div style="color: #111827; line-height: 1.6; font-size: 1.05rem; font-family: sans-serif;">
<h3 style="margin-top: 0; color: #111827; font-weight: 600;">The "Clever Hans" Effect in Epitranscriptomics</h3>
<p style="margin-top: 5px; color: #374151;">Traditional deep learning models for RNA modifications overfit to lab-specific technical noise (like <span class="pro-tooltip">GC-content bias<span class="tooltip-text">A common laboratory artifact where sequencing machines preferentially read sequences rich in Guanine (G) and Cytosine (C), tricking AI models into correlating GC% with RNA modifications.</span></span>). They fail to generalize across unseen datasets.</p>
<h3 style="margin-top: 25px; color: #111827; font-weight: 600;">The Zero-Shot Solution</h3>
<p style="margin-top: 5px; color: #374151;">EpiRNA leverages a <span class="pro-tooltip">DANN<span class="tooltip-text">Domain Adversarial Neural Network.</span></span> trained on <span class="pro-tooltip">SSB<span class="tooltip-text">Synthetic Sandbox Bootstrapping.</span></span>. By mathematically stripping away technical batch artifacts, it learns true causal biology.</p>
<h3 style="margin-top: 25px; color: #111827; font-weight: 600;">What is EBCS?</h3>
<p style="margin-top: 5px; color: #374151;">Epitranscriptomic Boundary Contrast Scoring (<span class="pro-tooltip">EBCS<span class="tooltip-text">A zero-shot mathematical probe that calculates the exact single-nucleotide derivative of an AI model's confidence.</span></span>) slides a synthetic mask across the sequence to calculate the mathematical derivative of the model's confidence. The <span class="pro-tooltip">peak contrast delta<span class="tooltip-text">The highest point on the blue graph line.</span></span> reveals the exact single-nucleotide catalytic boundary the AI relies upon.</p>
</div>
""")
# --- Event Wiring ---
d_btn.click(process_search, inputs=[d_in, d_k_slider], outputs=[out_plot, out_res, out_mot])
batch_btn.click(process_batch, inputs=[batch_file, batch_k_slider], outputs=[batch_download, batch_status])
app.queue().launch()