import gradio as gr import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForSequenceClassification import matplotlib.pyplot as plt import numpy as np import pandas as pd import re import time import tempfile import os # --- 1. Load Model --- model_path = "final_model_100k.pt" tokenizer = AutoTokenizer.from_pretrained("armheb/DNA_bert_6") model = AutoModelForSequenceClassification.from_pretrained("armheb/DNA_bert_6", num_labels=2) raw_state_dict = torch.load(model_path, map_location=torch.device('cpu')) new_state_dict = {} for key, value in raw_state_dict.items(): if key.startswith('backbone.'): new_key = key.replace('backbone.', 'bert.') elif key.startswith('task_classifier.'): new_key = key.replace('task_classifier.', 'classifier.') else: new_key = key new_state_dict[new_key] = value model.load_state_dict(new_state_dict, strict=False) model.eval() # --- 2. Logic Functions --- def seq2kmer(seq, k=6): return " ".join([seq[i:i+k] for i in range(len(seq) - k + 1)]) def find_drach_motifs(sequence): pattern = r'[AGT][AG]AC[ACT]' matches = [(m.start(), m.group()) for m in re.finditer(pattern, sequence)] highlighted_seq = sequence for m in reversed(matches): start, motif = m[0], m[1] highlighted_seq = highlighted_seq[:start] + f"**{motif}**" + highlighted_seq[start+5:] motifs_text = ", ".join([f"{m[1]} (Pos {m[0]})" for m in matches]) if matches else "None detected." return motifs_text, highlighted_seq def calc_gc_content(sequence, window=15): gc_vals = [] half = window // 2 for i in range(len(sequence)): start = max(0, i - half) end = min(len(sequence), i + half + 1) sub = sequence[start:end] gc_vals.append((sub.count('G') + sub.count('C')) / len(sub)) return gc_vals def run_ebcs_core(sequence, k_mask=6): baseline_kmer = seq2kmer(sequence) inputs = tokenizer([baseline_kmer], return_tensors="pt", padding="max_length", max_length=128) with torch.no_grad(): base_out = model(**inputs).logits base_prob = F.softmax(base_out, dim=1)[0][1].item() scores = np.zeros(41) mutated_seqs = [] for i in range(41): left = (sequence[:max(0, i-k_mask)] + "N"*k_mask + sequence[i:])[:41] right = (sequence[:i] + "N"*k_mask + sequence[min(41, i+k_mask):])[:41] mutated_seqs.extend([seq2kmer(left), seq2kmer(right)]) batch_inputs = tokenizer(mutated_seqs, return_tensors="pt", padding="max_length", max_length=128) with torch.no_grad(): batch_out = model(**batch_inputs).logits batch_probs = F.softmax(batch_out, dim=1)[:, 1].numpy() for i in range(41): scores[i] = abs(batch_probs[i*2] - batch_probs[i*2 + 1]) / (base_prob + 1e-8) return scores, base_prob # --- 3. UI Handlers --- def process_search(sequence, k_mask=6): yield None, "

⏳ Initializing request... (ETA: ~3s)

", "" sequence = sequence.strip().upper() if len(sequence) != 41: yield None, f"

❌ Error: Sequence must be exactly 41 bp.

", "" return if not re.fullmatch(r'[ACGTUN]+', sequence): yield None, "

❌ Error: Invalid sequence. Only standard nucleotides allowed.

", "" return time.sleep(0.4) yield None, "

🧬 Scanning sequence & computing EBCS... (ETA: ~2s)

", "" motifs_text, highlighted_seq = find_drach_motifs(sequence) gc_vals = calc_gc_content(sequence) scores, base_prob = run_ebcs_core(sequence, k_mask) time.sleep(0.4) yield None, "

📊 Rendering spatial maps... (ETA: ~1s)

", "" fig, ax1 = plt.subplots(figsize=(9, 4.5)) # GLASSMORPHISM: Make Matplotlib background fully transparent fig.patch.set_alpha(0.0) ax1.patch.set_alpha(0.0) peak_idx = np.argmax(scores) peak_score = scores[peak_idx] peak_base = sequence[peak_idx] ax1.set_title(f"EBCS Profile - Peak Boundary Detected at Position {peak_idx} ({peak_base})", fontweight='bold', color='#111827', pad=15) ax1.plot(range(41), scores, color='#4f46e5', linewidth=2.5, marker='o', markersize=5, label='EBCS Delta') ax1.fill_between(range(41), scores, color='#4f46e5', alpha=0.08) ax1.set_xticks(range(41)) ax1.set_xticklabels(list(sequence), fontsize=9, color='#111827') ax1.set_xlabel("Spatial Nucleotide Resolution", fontweight='600', color='#4b5563') ax1.set_ylabel("Boundary Contrast Delta", color='#4f46e5', fontweight='600') ax1.tick_params(axis='y', labelcolor='#4f46e5', color='#d1d5db') ax1.spines['top'].set_visible(False) ax1.spines['right'].set_visible(False) ax1.spines['left'].set_color('#d1d5db') ax1.spines['bottom'].set_color('#d1d5db') ax1.grid(True, linestyle='--', color='#f3f4f6') ax2 = ax1.twinx() ax2.patch.set_alpha(0.0) # Transparent secondary axis ax2.plot(range(41), gc_vals, color='#9ca3af', linestyle='-', linewidth=2, alpha=0.3, label='Local GC%') ax2.set_ylabel("GC Content (Smoothed)", color='#9ca3af') ax2.tick_params(axis='y', labelcolor='#9ca3af', color='#d1d5db') ax2.spines['top'].set_visible(False) ax2.spines['right'].set_visible(False) ax2.spines['left'].set_visible(False) ax2.spines['bottom'].set_visible(False) ax1.axvline(x=peak_idx, color='#e11d48', linestyle=':', linewidth=2, alpha=0.8) fig.tight_layout() res = f"""

🎯 Target: {peak_base} at Pos {peak_idx}

Max Contrast: {peak_score:.4f}  |  Baseline Confidence: {base_prob:.4f}

Sequence Map: {highlighted_seq}

""" mot = f"

Canonical DRACH Motifs: {motifs_text}

" yield fig, res, mot def process_batch(file_obj, k_mask=6): if file_obj is None: yield None, "

❌ Please upload a CSV or FASTA file.

" return yield None, "

⏳ Parsing uploaded file...

" sequences = [] with open(file_obj, 'r') as f: lines = f.readlines() for line in lines: line = line.strip().upper() if len(line) == 41 and not line.startswith(">"): sequences.append(line) if not sequences: yield None, "

❌ No valid 41-bp sequences found.

" return results = [] total = len(sequences) for idx, seq in enumerate(sequences): if idx % 5 == 0 or idx < 3: yield None, f"

🧬 Computing sequence {idx+1} of {total}...

" if not re.fullmatch(r'[ACGTUN]+', seq): continue scores, base_prob = run_ebcs_core(seq, k_mask) peak_idx = np.argmax(scores) motifs, _ = find_drach_motifs(seq) results.append({ "Sequence": seq, "Baseline_Confidence": round(base_prob, 4), "Peak_Position": peak_idx, "Peak_Base": seq[peak_idx], "Max_EBCS_Score": round(scores[peak_idx], 4), "DRACH_Motifs": motifs }) yield None, "

📊 Formatting output CSV...

" df = pd.DataFrame(results) temp_dir = tempfile.mkdtemp() out_path = os.path.join(temp_dir, "EpiRNA_Batch_Results.csv") df.to_csv(out_path, index=False) yield out_path, f"

✅ Successfully processed {len(results)} sequences.

" # --- 4. THEME & CSS (Native Gradio Dark-Mode Override & Glassmorphism) --- # This forces Dark Mode variables to be Light/Translucent glass_theme = gr.themes.Soft( primary_hue="indigo", neutral_hue="slate", ).set( body_background_fill="#f8fafc", body_background_fill_dark="#f8fafc", background_fill_primary="rgba(255, 255, 255, 0.7)", background_fill_primary_dark="rgba(255, 255, 255, 0.7)", background_fill_secondary="rgba(255, 255, 255, 0.4)", background_fill_secondary_dark="rgba(255, 255, 255, 0.4)", border_color_primary="rgba(203, 213, 225, 0.6)", border_color_primary_dark="rgba(203, 213, 225, 0.6)", block_background_fill="rgba(255, 255, 255, 0.6)", block_background_fill_dark="rgba(255, 255, 255, 0.6)", block_title_text_color="#111827", block_title_text_color_dark="#111827", block_label_text_color="#374151", block_label_text_color_dark="#374151", body_text_color="#111827", body_text_color_dark="#111827", input_background_fill="#ffffff", input_background_fill_dark="#ffffff", ) custom_css = """ /* Glassmorphism Structural Layout */ body { background: linear-gradient(135deg, #f8fafc 0%, #e0e7ff 100%) !important; } .main-container { max-width: 1400px; margin: 0 auto; padding: 40px 20px; } .glass-panel { background: rgba(255, 255, 255, 0.5) !important; backdrop-filter: blur(12px) !important; border-radius: 16px; border: 1px solid rgba(255, 255, 255, 0.6) !important; padding: 24px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05); } /* Hide Gradio Footer */ footer { display: none !important; } /* Buttons & Inputs */ textarea, input { background-color: #ffffff !important; color: #111827 !important; border: 1px solid #cbd5e1 !important; } button.primary { background-color: #111827 !important; color: #ffffff !important; border-radius: 8px !important; transition: all 0.2s ease;} button.primary:hover { background-color: #4f46e5 !important; transform: translateY(-1px);} /* Tabs Styling */ .tabs { border: none !important; background: transparent !important; } .tab-nav { border-bottom: 1px solid rgba(0,0,0,0.1) !important; } .tab-nav button { color: #4b5563 !important; font-weight: 500 !important; background: transparent !important; } .tab-nav button.selected { color: #4f46e5 !important; border-bottom: 2px solid #4f46e5 !important; } /* Pro Tooltips */ .pro-tooltip { position: relative; display: inline-block; cursor: help; border-bottom: 1px dashed #4f46e5; font-weight: 600; color: #4f46e5;} .pro-tooltip .tooltip-text { visibility: hidden; width: 280px; background-color: #111827; color: #ffffff !important; text-align: left; padding: 12px 16px; border-radius: 8px; position: absolute; z-index: 100; bottom: 130%; left: 50%; transform: translateX(-50%) translateY(10px); opacity: 0; transition: all 0.2s ease; font-size: 0.9rem; font-weight: 400; line-height: 1.5; pointer-events: none; } .pro-tooltip:hover .tooltip-text { visibility: visible; opacity: 1; transform: translateX(-50%) translateY(0); } """ with gr.Blocks(theme=glass_theme, css=custom_css, title="EpiRNA") as app: with gr.Row(elem_classes="main-container"): # ================= LEFT COLUMN: STATIC SEARCH ENGINE ================= with gr.Column(scale=4, elem_classes="glass-panel"): gr.HTML("""

EpiRNA

Decoding RNA Catalytic Boundaries at Single-Nucleotide Resolution

""") d_in = gr.Textbox( label="Target Sequence", info="Must be exactly 41 nucleotides long.", placeholder="Paste exactly 41-bp of RNA/DNA sequence...", lines=3 ) with gr.Accordion("Advanced Settings", open=False): d_k_slider = gr.Slider( minimum=2, maximum=10, step=2, value=6, label="EBCS Mask Resolution (k-mer)" ) d_btn = gr.Button("Analyze Sequence", variant="primary") # ================= RIGHT COLUMN: OUTPUTS & BATCH & ABOUT ================= with gr.Column(scale=8, elem_classes="glass-panel"): with gr.Tabs(): # Right Tab 1: Single Search Graph Results with gr.Tab("Spatial EBCS Map"): out_res = gr.HTML("

⏳ Waiting for sequence input...

") out_plot = gr.Plot() out_mot = gr.HTML() # Right Tab 2: Batch Processing Upload with gr.Tab("Batch Processing Engine"): gr.HTML("""

Bulk Sequence Analysis

Upload a .csv or .fasta file containing sequences (one per line).

""") batch_file = gr.File(label="Upload Dataset", file_types=[".csv", ".fasta", ".txt"]) batch_k_slider = gr.Slider(minimum=2, maximum=10, step=2, value=6, label="Mask Resolution (k-mer)") batch_btn = gr.Button("Run Batch EBCS", variant="primary") gr.HTML("
") batch_status = gr.HTML("

Ready for upload.

") batch_download = gr.File(label="Download Processed Results", interactive=False) # Right Tab 3: HTML Science / FAQ with Interactive Tooltips with gr.Tab("The Science"): gr.HTML("""

The "Clever Hans" Effect in Epitranscriptomics

Traditional deep learning models for RNA modifications overfit to lab-specific technical noise (like GC-content biasA common laboratory artifact where sequencing machines preferentially read sequences rich in Guanine (G) and Cytosine (C), tricking AI models into correlating GC% with RNA modifications.). They fail to generalize across unseen datasets.

The Zero-Shot Solution

EpiRNA leverages a DANNDomain Adversarial Neural Network. trained on SSBSynthetic Sandbox Bootstrapping.. By mathematically stripping away technical batch artifacts, it learns true causal biology.

What is EBCS?

Epitranscriptomic Boundary Contrast Scoring (EBCSA zero-shot mathematical probe that calculates the exact single-nucleotide derivative of an AI model's confidence.) slides a synthetic mask across the sequence to calculate the mathematical derivative of the model's confidence. The peak contrast deltaThe highest point on the blue graph line. reveals the exact single-nucleotide catalytic boundary the AI relies upon.

""") # --- Event Wiring --- d_btn.click(process_search, inputs=[d_in, d_k_slider], outputs=[out_plot, out_res, out_mot]) batch_btn.click(process_batch, inputs=[batch_file, batch_k_slider], outputs=[batch_download, batch_status]) app.queue().launch()