shafiqvai's picture
Update app.py
c73bcb1 verified
import torch
import joblib
import gradio as gr
import numpy as np
import pandas as pd
from transformers import T5Tokenizer, T5EncoderModel
import os
# --- Configuration ---
DEVICE = torch.device("cpu")
MODEL_PATH = "dna_binding_DNAstackBP_model.pkl"
SCALER_PATH = "dna_binding_scaler.pkl"
# --- Load DNAstackBP Assets ---
print("Initializing DNAstackBP Dashboard...")
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50", torch_dtype=torch.float16).to(DEVICE).eval()
svm_model = joblib.load(MODEL_PATH)
scaler = joblib.load(SCALER_PATH)
def predict_and_report(sequence):
if not sequence or len(sequence.strip()) < 5:
return "Invalid Sequence", 0.0, None, "Please enter a valid protein sequence."
clean_seq = sequence.upper().strip()
seq_fmt = " ".join(list(clean_seq))
ids = tokenizer(seq_fmt, return_tensors="pt", padding=True)
with torch.no_grad():
embedding = model(input_ids=ids['input_ids'].to(DEVICE), attention_mask=ids['attention_mask'].to(DEVICE))
last_hidden = embedding.last_hidden_state.squeeze(0)
mask = ids['attention_mask'].squeeze(0).unsqueeze(-1)
mean_embedding = torch.sum(last_hidden * mask, dim=0) / torch.sum(mask)
feat = scaler.transform(mean_embedding.cpu().numpy().reshape(1, -1))
label = svm_model.predict(feat)[0]
prob = svm_model.predict_proba(feat)[0]
result_text = "DNA-Binding Protein" if label == 1 else "Non-Binding Protein"
conf_score = float(prob[label])
conf_pct = f"{conf_score:.2%}"
report_file = "DNAstackBP_Analysis_Report.csv"
report_df = pd.DataFrame([{
"Serial No.": 1,
"Tool": "DNAstackBP",
"Input Sequence": clean_seq,
"Sequence Length": len(clean_seq),
"Result": result_text,
"Confidence": conf_pct
}])
report_df.to_csv(report_file, index=False)
status_msg = f"Analysis Complete! Confidence: {conf_pct}."
return result_text, conf_score, report_file, status_msg
# --- Custom Styling ---
# We use a Slate/Blue theme for a professional "Lab" look
theme = gr.themes.Soft(
primary_hue="cyan",
secondary_hue="slate",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
).set(
body_background_fill="*neutral_50",
block_background_fill="white",
block_border_width="1px",
button_primary_background_fill="*primary_600",
button_primary_background_fill_hover="*primary_500",
)
with gr.Blocks(theme=theme) as demo:
# Realistic Header with DNA Graphic
with gr.Row():
gr.HTML(
"""
<div style="display: flex; align-items: center; gap: 20px; padding: 20px; background: white; border-bottom: 3px solid #0891b2; border-radius: 8px;">
<img src="https://cdn-icons-png.flaticon.com/512/3011/3011406.png" width="80" height="80" style="filter: drop-shadow(2px 4px 6px #ccc);">
<div>
<h1 style="margin: 0; color: #0f172a; font-size: 2.5rem;">DNAstackBP</h1>
<p style="margin: 0; color: #64748b; font-size: 1.1rem;">High-Precision DNA-Binding Protein Prediction System</p>
</div>
</div>
"""
)
with gr.Tab("Analysis Dashboard"):
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### 🧬 Input Sequence")
input_text = gr.Textbox(
label="Paste Protein Sequence (Raw or FASTA)",
lines=12,
placeholder="Example: MASG...",
)
with gr.Row():
clear_btn = gr.Button("🗑️ Clear", variant="secondary")
submit_btn = gr.Button("🔍 Run Prediction", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### 📝 Results")
output_status = gr.Markdown("*System Ready*")
output_text = gr.Label(label="Prediction Result")
output_num = gr.Number(label="Confidence Level", interactive=False)
output_file = gr.File(label="📥 Export CSV Report")
with gr.Tab("About DNAstackBP"):
gr.Markdown(
"""
### Technical Methodology
**DNAstackBP** utilizes a multi-stage deep learning pipeline:
1. **Embedding Stage:** Uses **ProtT5-XL-UniRef50** for high-dimensional feature extraction.
2. **Baseline Dataset:** Generates probability values from curated baseline models.
3. **Meta-Classification:** Final analysis via a **CNN-1D Meta-Model** architecture.
### Developer Profile
**Md. Shafiqul Islam** **ID: 221-15-4656**
*Bioinformatics Analysis Tool - v1.0*
"""
)
submit_btn.click(
fn=predict_and_report,
inputs=input_text,
outputs=[output_text, output_num, output_file, output_status]
)
clear_btn.click(lambda: ["", None, None, "*System Ready*"], outputs=[input_text, output_text, output_file, output_status])
if __name__ == "__main__":
demo.launch()