| import gradio as gr |
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| import re |
| import time |
| import tempfile |
| import os |
|
|
| |
| model_path = "final_model_100k.pt" |
| tokenizer = AutoTokenizer.from_pretrained("armheb/DNA_bert_6") |
|
|
| model = AutoModelForSequenceClassification.from_pretrained("armheb/DNA_bert_6", num_labels=2) |
| raw_state_dict = torch.load(model_path, map_location=torch.device('cpu')) |
| new_state_dict = {} |
|
|
| for key, value in raw_state_dict.items(): |
| if key.startswith('backbone.'): |
| new_key = key.replace('backbone.', 'bert.') |
| elif key.startswith('task_classifier.'): |
| new_key = key.replace('task_classifier.', 'classifier.') |
| else: |
| new_key = key |
| new_state_dict[new_key] = value |
|
|
| model.load_state_dict(new_state_dict, strict=False) |
| model.eval() |
|
|
| |
| def seq2kmer(seq, k=6): |
| return " ".join([seq[i:i+k] for i in range(len(seq) - k + 1)]) |
|
|
| def find_drach_motifs(sequence): |
| pattern = r'[AGT][AG]AC[ACT]' |
| matches = [(m.start(), m.group()) for m in re.finditer(pattern, sequence)] |
| |
| highlighted_seq = sequence |
| for m in reversed(matches): |
| start, motif = m[0], m[1] |
| highlighted_seq = highlighted_seq[:start] + f"**<span style='color:#000000; background:#f3f4f6; padding:2px 4px; border-radius:4px; border:1px solid #d1d5db;'>{motif}</span>**" + highlighted_seq[start+5:] |
| |
| motifs_text = ", ".join([f"<span style='color:#111827;'>{m[1]} (Pos {m[0]})</span>" for m in matches]) if matches else "<span style='color:#111827;'>None detected.</span>" |
| return motifs_text, highlighted_seq |
|
|
| def calc_gc_content(sequence, window=15): |
| gc_vals = [] |
| half = window // 2 |
| for i in range(len(sequence)): |
| start = max(0, i - half) |
| end = min(len(sequence), i + half + 1) |
| sub = sequence[start:end] |
| gc_vals.append((sub.count('G') + sub.count('C')) / len(sub)) |
| return gc_vals |
|
|
| def run_ebcs_core(sequence, k_mask=6): |
| baseline_kmer = seq2kmer(sequence) |
| inputs = tokenizer([baseline_kmer], return_tensors="pt", padding="max_length", max_length=128) |
| |
| with torch.no_grad(): |
| base_out = model(**inputs).logits |
| base_prob = F.softmax(base_out, dim=1)[0][1].item() |
| |
| scores = np.zeros(41) |
| mutated_seqs = [] |
| |
| for i in range(41): |
| left = (sequence[:max(0, i-k_mask)] + "N"*k_mask + sequence[i:])[:41] |
| right = (sequence[:i] + "N"*k_mask + sequence[min(41, i+k_mask):])[:41] |
| mutated_seqs.extend([seq2kmer(left), seq2kmer(right)]) |
| |
| batch_inputs = tokenizer(mutated_seqs, return_tensors="pt", padding="max_length", max_length=128) |
| |
| with torch.no_grad(): |
| batch_out = model(**batch_inputs).logits |
| batch_probs = F.softmax(batch_out, dim=1)[:, 1].numpy() |
| |
| for i in range(41): |
| scores[i] = abs(batch_probs[i*2] - batch_probs[i*2 + 1]) / (base_prob + 1e-8) |
| |
| return scores, base_prob |
|
|
| |
| def process_search(sequence, k_mask=6): |
| yield None, "<h3 style='color:#111827; margin:0; font-family: sans-serif;'>⏳ Initializing request... (ETA: ~3s)</h3>", "" |
| |
| sequence = sequence.strip().upper() |
| if len(sequence) != 41: |
| yield None, f"<h3 style='color:#dc2626; margin:0; font-family: sans-serif;'>❌ Error: Sequence must be exactly 41 bp.</h3>", "" |
| return |
| |
| if not re.fullmatch(r'[ACGTUN]+', sequence): |
| yield None, "<h3 style='color:#dc2626; margin:0; font-family: sans-serif;'>❌ Error: Invalid sequence. Only standard nucleotides allowed.</h3>", "" |
| return |
|
|
| time.sleep(0.4) |
| yield None, "<h3 style='color:#111827; margin:0; font-family: sans-serif;'>🧬 Scanning sequence & computing EBCS... (ETA: ~2s)</h3>", "" |
| |
| motifs_text, highlighted_seq = find_drach_motifs(sequence) |
| gc_vals = calc_gc_content(sequence) |
| scores, base_prob = run_ebcs_core(sequence, k_mask) |
| |
| time.sleep(0.4) |
| yield None, "<h3 style='color:#111827; margin:0; font-family: sans-serif;'>📊 Rendering spatial maps... (ETA: ~1s)</h3>", "" |
| |
| fig, ax1 = plt.subplots(figsize=(9, 4.5)) |
| |
| |
| fig.patch.set_alpha(0.0) |
| ax1.patch.set_alpha(0.0) |
| |
| peak_idx = np.argmax(scores) |
| peak_score = scores[peak_idx] |
| peak_base = sequence[peak_idx] |
|
|
| ax1.set_title(f"EBCS Profile - Peak Boundary Detected at Position {peak_idx} ({peak_base})", fontweight='bold', color='#111827', pad=15) |
| ax1.plot(range(41), scores, color='#4f46e5', linewidth=2.5, marker='o', markersize=5, label='EBCS Delta') |
| ax1.fill_between(range(41), scores, color='#4f46e5', alpha=0.08) |
| |
| ax1.set_xticks(range(41)) |
| ax1.set_xticklabels(list(sequence), fontsize=9, color='#111827') |
| ax1.set_xlabel("Spatial Nucleotide Resolution", fontweight='600', color='#4b5563') |
| ax1.set_ylabel("Boundary Contrast Delta", color='#4f46e5', fontweight='600') |
| |
| ax1.tick_params(axis='y', labelcolor='#4f46e5', color='#d1d5db') |
| ax1.spines['top'].set_visible(False) |
| ax1.spines['right'].set_visible(False) |
| ax1.spines['left'].set_color('#d1d5db') |
| ax1.spines['bottom'].set_color('#d1d5db') |
| ax1.grid(True, linestyle='--', color='#f3f4f6') |
| |
| ax2 = ax1.twinx() |
| ax2.patch.set_alpha(0.0) |
| ax2.plot(range(41), gc_vals, color='#9ca3af', linestyle='-', linewidth=2, alpha=0.3, label='Local GC%') |
| ax2.set_ylabel("GC Content (Smoothed)", color='#9ca3af') |
| ax2.tick_params(axis='y', labelcolor='#9ca3af', color='#d1d5db') |
| ax2.spines['top'].set_visible(False) |
| ax2.spines['right'].set_visible(False) |
| ax2.spines['left'].set_visible(False) |
| ax2.spines['bottom'].set_visible(False) |
| |
| ax1.axvline(x=peak_idx, color='#e11d48', linestyle=':', linewidth=2, alpha=0.8) |
| fig.tight_layout() |
| |
| res = f""" |
| <div style="color: #111827; font-size: 1.05rem; font-family: sans-serif;"> |
| <h3 style="margin-top: 0; color: #111827;">🎯 Target: <span style="background:#e0e7ff; padding:2px 6px; border-radius:4px; color:#4f46e5;">{peak_base}</span> at Pos <b>{peak_idx}</b></h3> |
| <p style="margin: 8px 0; color: #111827;"><b>Max Contrast:</b> {peak_score:.4f} | <b>Baseline Confidence:</b> {base_prob:.4f}</p> |
| <p style="margin: 8px 0; color: #111827;"><b>Sequence Map:</b> {highlighted_seq}</p> |
| </div> |
| """ |
| mot = f"<div style='color: #111827; font-size: 1.05rem; font-family: sans-serif;'><p style='color: #111827; margin:0;'><b>Canonical DRACH Motifs:</b> {motifs_text}</p></div>" |
| |
| yield fig, res, mot |
|
|
| def process_batch(file_obj, k_mask=6): |
| if file_obj is None: |
| yield None, "<h3 style='color:#dc2626; margin:0; font-family: sans-serif;'>❌ Please upload a CSV or FASTA file.</h3>" |
| return |
| |
| yield None, "<h3 style='color:#111827; margin:0; font-family: sans-serif;'>⏳ Parsing uploaded file...</h3>" |
| sequences = [] |
| with open(file_obj, 'r') as f: |
| lines = f.readlines() |
| for line in lines: |
| line = line.strip().upper() |
| if len(line) == 41 and not line.startswith(">"): |
| sequences.append(line) |
| |
| if not sequences: |
| yield None, "<h3 style='color:#dc2626; margin:0; font-family: sans-serif;'>❌ No valid 41-bp sequences found.</h3>" |
| return |
| |
| results = [] |
| total = len(sequences) |
| |
| for idx, seq in enumerate(sequences): |
| if idx % 5 == 0 or idx < 3: |
| yield None, f"<h3 style='color:#111827; margin:0; font-family: sans-serif;'>🧬 Computing sequence {idx+1} of {total}...</h3>" |
| |
| if not re.fullmatch(r'[ACGTUN]+', seq): |
| continue |
| |
| scores, base_prob = run_ebcs_core(seq, k_mask) |
| peak_idx = np.argmax(scores) |
| motifs, _ = find_drach_motifs(seq) |
| results.append({ |
| "Sequence": seq, |
| "Baseline_Confidence": round(base_prob, 4), |
| "Peak_Position": peak_idx, |
| "Peak_Base": seq[peak_idx], |
| "Max_EBCS_Score": round(scores[peak_idx], 4), |
| "DRACH_Motifs": motifs |
| }) |
| |
| yield None, "<h3 style='color:#111827; margin:0; font-family: sans-serif;'>📊 Formatting output CSV...</h3>" |
| df = pd.DataFrame(results) |
| temp_dir = tempfile.mkdtemp() |
| out_path = os.path.join(temp_dir, "EpiRNA_Batch_Results.csv") |
| df.to_csv(out_path, index=False) |
| |
| yield out_path, f"<h3 style='color:#16a34a; margin:0; font-family: sans-serif;'>✅ Successfully processed {len(results)} sequences.</h3>" |
|
|
| |
| |
| glass_theme = gr.themes.Soft( |
| primary_hue="indigo", |
| neutral_hue="slate", |
| ).set( |
| body_background_fill="#f8fafc", |
| body_background_fill_dark="#f8fafc", |
| background_fill_primary="rgba(255, 255, 255, 0.7)", |
| background_fill_primary_dark="rgba(255, 255, 255, 0.7)", |
| background_fill_secondary="rgba(255, 255, 255, 0.4)", |
| background_fill_secondary_dark="rgba(255, 255, 255, 0.4)", |
| border_color_primary="rgba(203, 213, 225, 0.6)", |
| border_color_primary_dark="rgba(203, 213, 225, 0.6)", |
| block_background_fill="rgba(255, 255, 255, 0.6)", |
| block_background_fill_dark="rgba(255, 255, 255, 0.6)", |
| block_title_text_color="#111827", |
| block_title_text_color_dark="#111827", |
| block_label_text_color="#374151", |
| block_label_text_color_dark="#374151", |
| body_text_color="#111827", |
| body_text_color_dark="#111827", |
| input_background_fill="#ffffff", |
| input_background_fill_dark="#ffffff", |
| ) |
|
|
| custom_css = """ |
| /* Glassmorphism Structural Layout */ |
| body { background: linear-gradient(135deg, #f8fafc 0%, #e0e7ff 100%) !important; } |
| .main-container { max-width: 1400px; margin: 0 auto; padding: 40px 20px; } |
| .glass-panel { |
| background: rgba(255, 255, 255, 0.5) !important; |
| backdrop-filter: blur(12px) !important; |
| border-radius: 16px; |
| border: 1px solid rgba(255, 255, 255, 0.6) !important; |
| padding: 24px; |
| box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05); |
| } |
| |
| /* Hide Gradio Footer */ |
| footer { display: none !important; } |
| |
| /* Buttons & Inputs */ |
| textarea, input { background-color: #ffffff !important; color: #111827 !important; border: 1px solid #cbd5e1 !important; } |
| button.primary { background-color: #111827 !important; color: #ffffff !important; border-radius: 8px !important; transition: all 0.2s ease;} |
| button.primary:hover { background-color: #4f46e5 !important; transform: translateY(-1px);} |
| |
| /* Tabs Styling */ |
| .tabs { border: none !important; background: transparent !important; } |
| .tab-nav { border-bottom: 1px solid rgba(0,0,0,0.1) !important; } |
| .tab-nav button { color: #4b5563 !important; font-weight: 500 !important; background: transparent !important; } |
| .tab-nav button.selected { color: #4f46e5 !important; border-bottom: 2px solid #4f46e5 !important; } |
| |
| /* Pro Tooltips */ |
| .pro-tooltip { position: relative; display: inline-block; cursor: help; border-bottom: 1px dashed #4f46e5; font-weight: 600; color: #4f46e5;} |
| .pro-tooltip .tooltip-text { visibility: hidden; width: 280px; background-color: #111827; color: #ffffff !important; text-align: left; padding: 12px 16px; border-radius: 8px; position: absolute; z-index: 100; bottom: 130%; left: 50%; transform: translateX(-50%) translateY(10px); opacity: 0; transition: all 0.2s ease; font-size: 0.9rem; font-weight: 400; line-height: 1.5; pointer-events: none; } |
| .pro-tooltip:hover .tooltip-text { visibility: visible; opacity: 1; transform: translateX(-50%) translateY(0); } |
| """ |
|
|
| with gr.Blocks(theme=glass_theme, css=custom_css, title="EpiRNA") as app: |
| |
| with gr.Row(elem_classes="main-container"): |
| |
| |
| with gr.Column(scale=4, elem_classes="glass-panel"): |
| gr.HTML(""" |
| <div style="margin-bottom: 30px;"> |
| <h1 style="font-size: 3.5rem; font-weight: 800; margin: 0; letter-spacing: -1.5px; line-height: 1.1; color: #111827; font-family: sans-serif;"> |
| <span style='color: #4f46e5;'>Epi</span><span style='color: #e11d48;'>RNA</span> |
| </h1> |
| <p style="font-size: 1.1rem; font-weight: 400; color: #4b5563; margin-top: 8px; font-family: sans-serif;"> |
| Decoding RNA Catalytic Boundaries at Single-Nucleotide Resolution |
| </p> |
| </div> |
| """) |
| |
| d_in = gr.Textbox( |
| label="Target Sequence", |
| info="Must be exactly 41 nucleotides long.", |
| placeholder="Paste exactly 41-bp of RNA/DNA sequence...", |
| lines=3 |
| ) |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| d_k_slider = gr.Slider( |
| minimum=2, maximum=10, step=2, value=6, |
| label="EBCS Mask Resolution (k-mer)" |
| ) |
| |
| d_btn = gr.Button("Analyze Sequence", variant="primary") |
|
|
| |
| with gr.Column(scale=8, elem_classes="glass-panel"): |
| |
| with gr.Tabs(): |
| |
| |
| with gr.Tab("Spatial EBCS Map"): |
| out_res = gr.HTML("<h3 style='color:#111827; margin-top:0; font-family: sans-serif;'>⏳ Waiting for sequence input...</h3>") |
| out_plot = gr.Plot() |
| out_mot = gr.HTML() |
| |
| |
| with gr.Tab("Batch Processing Engine"): |
| gr.HTML(""" |
| <div style='color:#111827; font-family: sans-serif;'> |
| <h3 style='margin-top:0; color:#111827;'>Bulk Sequence Analysis</h3> |
| <p style='color:#111827;'>Upload a <code>.csv</code> or <code>.fasta</code> file containing sequences (one per line).</p> |
| </div> |
| """) |
| |
| batch_file = gr.File(label="Upload Dataset", file_types=[".csv", ".fasta", ".txt"]) |
| batch_k_slider = gr.Slider(minimum=2, maximum=10, step=2, value=6, label="Mask Resolution (k-mer)") |
| batch_btn = gr.Button("Run Batch EBCS", variant="primary") |
| gr.HTML("<hr style='border-color: #e5e7eb; margin: 20px 0;'>") |
| batch_status = gr.HTML("<h3 style='color:#111827; margin:0; font-family: sans-serif;'>Ready for upload.</h3>") |
| batch_download = gr.File(label="Download Processed Results", interactive=False) |
|
|
| |
| with gr.Tab("The Science"): |
| gr.HTML(""" |
| <div style="color: #111827; line-height: 1.6; font-size: 1.05rem; font-family: sans-serif;"> |
| <h3 style="margin-top: 0; color: #111827; font-weight: 600;">The "Clever Hans" Effect in Epitranscriptomics</h3> |
| <p style="margin-top: 5px; color: #374151;">Traditional deep learning models for RNA modifications overfit to lab-specific technical noise (like <span class="pro-tooltip">GC-content bias<span class="tooltip-text">A common laboratory artifact where sequencing machines preferentially read sequences rich in Guanine (G) and Cytosine (C), tricking AI models into correlating GC% with RNA modifications.</span></span>). They fail to generalize across unseen datasets.</p> |
| |
| <h3 style="margin-top: 25px; color: #111827; font-weight: 600;">The Zero-Shot Solution</h3> |
| <p style="margin-top: 5px; color: #374151;">EpiRNA leverages a <span class="pro-tooltip">DANN<span class="tooltip-text">Domain Adversarial Neural Network.</span></span> trained on <span class="pro-tooltip">SSB<span class="tooltip-text">Synthetic Sandbox Bootstrapping.</span></span>. By mathematically stripping away technical batch artifacts, it learns true causal biology.</p> |
| |
| <h3 style="margin-top: 25px; color: #111827; font-weight: 600;">What is EBCS?</h3> |
| <p style="margin-top: 5px; color: #374151;">Epitranscriptomic Boundary Contrast Scoring (<span class="pro-tooltip">EBCS<span class="tooltip-text">A zero-shot mathematical probe that calculates the exact single-nucleotide derivative of an AI model's confidence.</span></span>) slides a synthetic mask across the sequence to calculate the mathematical derivative of the model's confidence. The <span class="pro-tooltip">peak contrast delta<span class="tooltip-text">The highest point on the blue graph line.</span></span> reveals the exact single-nucleotide catalytic boundary the AI relies upon.</p> |
| </div> |
| """) |
|
|
| |
| d_btn.click(process_search, inputs=[d_in, d_k_slider], outputs=[out_plot, out_res, out_mot]) |
| batch_btn.click(process_batch, inputs=[batch_file, batch_k_slider], outputs=[batch_download, batch_status]) |
|
|
| app.queue().launch() |