nileshhanotia's picture
Create app.py
e2b8fa8 verified
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from encoder import MutationEncoder
from model import MutationPredictorCNN
# Load model
MODEL_PATH = hf_hub_download(
repo_id="nileshhanotia/mutation-pathogenicity-predictor",
filename="pytorch_model.pth"
)
device = torch.device("cpu")
model = MutationPredictorCNN().to(device)
checkpoint = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
encoder = MutationEncoder()
def generate_explainability(ref_seq, mut_seq, importance, encoded_tensor):
"""
Generate explainability visualization using the encoded tensor
to match exactly what the model sees
"""
# Extract mutation position from the encoding (positions 990:1089)
diff_mask = encoded_tensor[990:1089]
mutation_pos = torch.argmax(diff_mask).item()
# Check if mutation was detected
if diff_mask[mutation_pos].item() == 0:
return "No mutation detected in encoding"
# Clean sequences
ref_seq = ref_seq.strip().upper()
mut_seq = mut_seq.strip().upper()
# Create pointer aligned to mutation position
pointer = " " * mutation_pos + "^"
# Extract bases at mutation position
if mutation_pos < len(ref_seq) and mutation_pos < len(mut_seq):
ref_base = ref_seq[mutation_pos]
mut_base = mut_seq[mutation_pos]
substitution = f"{ref_base}>{mut_base}"
else:
substitution = "Unknown"
# Format explainability output
explainability_text = (
"Mutated sequence:\n"
+ mut_seq + "\n"
+ pointer + "\n\n"
+ f"Mutation position: {mutation_pos}\n"
+ f"Substitution: {substitution}\n"
+ f"Importance score: {importance:.4f}"
)
return explainability_text
def predict(ref_seq, mut_seq):
"""
Predict pathogenicity and generate explainability
"""
# Clean input sequences
ref_seq = ref_seq.strip().upper()
mut_seq = mut_seq.strip().upper()
# Validate sequences
if not ref_seq or not mut_seq:
return "Error", 0.0, 0.0, "Please provide both reference and mutated sequences"
if len(ref_seq) != len(mut_seq):
return "Error", 0.0, 0.0, f"Sequences must be same length (ref: {len(ref_seq)}, mut: {len(mut_seq)})"
try:
# Encode mutation
encoded = encoder.encode_mutation(ref_seq, mut_seq)
# Add batch dimension
tensor = encoded.unsqueeze(0).to(device)
# Get model predictions
with torch.no_grad():
logit, importance = model(tensor)
probability = logit.item() # Model already outputs sigmoid
importance_val = importance.item()
# Determine label
label = "Pathogenic" if probability >= 0.5 else "Benign"
# Generate explainability using the encoded tensor
explain = generate_explainability(
ref_seq,
mut_seq,
importance_val,
encoded
)
return label, probability, importance_val, explain
except Exception as e:
error_msg = f"Error during prediction: {str(e)}"
return "Error", 0.0, 0.0, error_msg
# UI
with gr.Blocks(title="DNA Mutation Pathogenicity Predictor") as demo:
gr.Markdown("""
# 🧬 Explainable Mutation Pathogenicity Predictor
Predict whether a DNA mutation is pathogenic or benign with explainability
showing the mutation position and importance.
""")
with gr.Row():
with gr.Column():
ref_input = gr.Textbox(
label="Reference sequence (99bp)",
placeholder="Enter reference DNA sequence (A, T, G, C)",
lines=3
)
mut_input = gr.Textbox(
label="Mutated sequence (99bp)",
placeholder="Enter mutated DNA sequence (A, T, G, C)",
lines=3
)
with gr.Row():
clear_btn = gr.Button("Clear")
submit = gr.Button("Predict", variant="primary")
with gr.Column():
prediction = gr.Textbox(
label="Prediction",
interactive=False
)
probability = gr.Number(
label="Pathogenic Probability",
interactive=False
)
importance = gr.Number(
label="Mutation Importance Score",
interactive=False
)
# Explainability visualization
explainability = gr.Textbox(
label="Explainability Visualization",
lines=8,
interactive=False
)
# Examples
gr.Markdown("### Examples")
gr.Examples(
examples=[
[
"AAAAAAAAAACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
"AAAAAAAAAATAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
],
[
"ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA",
"ATCGATCGATGGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA"
]
],
inputs=[ref_input, mut_input],
label="Click an example to load"
)
# Button actions
submit.click(
fn=predict,
inputs=[ref_input, mut_input],
outputs=[prediction, probability, importance, explainability]
)
clear_btn.click(
fn=lambda: ("", "", "", 0.0, 0.0, ""),
outputs=[ref_input, mut_input, prediction, probability, importance, explainability]
)
if __name__ == "__main__":
demo.launch()