| | import gradio as gr |
| | import torch |
| | from huggingface_hub import hf_hub_download |
| | from encoder import MutationEncoder |
| | from model import MutationPredictorCNN |
| |
|
| | |
| | MODEL_PATH = hf_hub_download( |
| | repo_id="nileshhanotia/mutation-pathogenicity-predictor", |
| | filename="pytorch_model.pth" |
| | ) |
| |
|
| | device = torch.device("cpu") |
| |
|
| | model = MutationPredictorCNN().to(device) |
| |
|
| | checkpoint = torch.load(MODEL_PATH, map_location=device) |
| |
|
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| |
|
| | model.eval() |
| |
|
| | encoder = MutationEncoder() |
| |
|
| |
|
| | def generate_explainability(ref_seq, mut_seq, importance, encoded_tensor): |
| | """ |
| | Generate explainability visualization using the encoded tensor |
| | to match exactly what the model sees |
| | """ |
| | |
| | diff_mask = encoded_tensor[990:1089] |
| | mutation_pos = torch.argmax(diff_mask).item() |
| | |
| | |
| | if diff_mask[mutation_pos].item() == 0: |
| | return "No mutation detected in encoding" |
| | |
| | |
| | ref_seq = ref_seq.strip().upper() |
| | mut_seq = mut_seq.strip().upper() |
| | |
| | |
| | pointer = " " * mutation_pos + "^" |
| | |
| | |
| | if mutation_pos < len(ref_seq) and mutation_pos < len(mut_seq): |
| | ref_base = ref_seq[mutation_pos] |
| | mut_base = mut_seq[mutation_pos] |
| | substitution = f"{ref_base}>{mut_base}" |
| | else: |
| | substitution = "Unknown" |
| | |
| | |
| | explainability_text = ( |
| | "Mutated sequence:\n" |
| | + mut_seq + "\n" |
| | + pointer + "\n\n" |
| | + f"Mutation position: {mutation_pos}\n" |
| | + f"Substitution: {substitution}\n" |
| | + f"Importance score: {importance:.4f}" |
| | ) |
| | |
| | return explainability_text |
| |
|
| |
|
| | def predict(ref_seq, mut_seq): |
| | """ |
| | Predict pathogenicity and generate explainability |
| | """ |
| | |
| | ref_seq = ref_seq.strip().upper() |
| | mut_seq = mut_seq.strip().upper() |
| | |
| | |
| | if not ref_seq or not mut_seq: |
| | return "Error", 0.0, 0.0, "Please provide both reference and mutated sequences" |
| | |
| | if len(ref_seq) != len(mut_seq): |
| | return "Error", 0.0, 0.0, f"Sequences must be same length (ref: {len(ref_seq)}, mut: {len(mut_seq)})" |
| | |
| | try: |
| | |
| | encoded = encoder.encode_mutation(ref_seq, mut_seq) |
| | |
| | |
| | tensor = encoded.unsqueeze(0).to(device) |
| | |
| | |
| | with torch.no_grad(): |
| | logit, importance = model(tensor) |
| | probability = logit.item() |
| | importance_val = importance.item() |
| | |
| | |
| | label = "Pathogenic" if probability >= 0.5 else "Benign" |
| | |
| | |
| | explain = generate_explainability( |
| | ref_seq, |
| | mut_seq, |
| | importance_val, |
| | encoded |
| | ) |
| | |
| | return label, probability, importance_val, explain |
| | |
| | except Exception as e: |
| | error_msg = f"Error during prediction: {str(e)}" |
| | return "Error", 0.0, 0.0, error_msg |
| |
|
| |
|
| | |
| | with gr.Blocks(title="DNA Mutation Pathogenicity Predictor") as demo: |
| | |
| | gr.Markdown(""" |
| | # 🧬 Explainable Mutation Pathogenicity Predictor |
| | |
| | Predict whether a DNA mutation is pathogenic or benign with explainability |
| | showing the mutation position and importance. |
| | """) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | ref_input = gr.Textbox( |
| | label="Reference sequence (99bp)", |
| | placeholder="Enter reference DNA sequence (A, T, G, C)", |
| | lines=3 |
| | ) |
| | |
| | mut_input = gr.Textbox( |
| | label="Mutated sequence (99bp)", |
| | placeholder="Enter mutated DNA sequence (A, T, G, C)", |
| | lines=3 |
| | ) |
| | |
| | with gr.Row(): |
| | clear_btn = gr.Button("Clear") |
| | submit = gr.Button("Predict", variant="primary") |
| | |
| | with gr.Column(): |
| | prediction = gr.Textbox( |
| | label="Prediction", |
| | interactive=False |
| | ) |
| | |
| | probability = gr.Number( |
| | label="Pathogenic Probability", |
| | interactive=False |
| | ) |
| | |
| | importance = gr.Number( |
| | label="Mutation Importance Score", |
| | interactive=False |
| | ) |
| | |
| | |
| | explainability = gr.Textbox( |
| | label="Explainability Visualization", |
| | lines=8, |
| | interactive=False |
| | ) |
| | |
| | |
| | gr.Markdown("### Examples") |
| | gr.Examples( |
| | examples=[ |
| | [ |
| | "AAAAAAAAAACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", |
| | "AAAAAAAAAATAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" |
| | ], |
| | [ |
| | "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA", |
| | "ATCGATCGATGGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA" |
| | ] |
| | ], |
| | inputs=[ref_input, mut_input], |
| | label="Click an example to load" |
| | ) |
| | |
| | |
| | submit.click( |
| | fn=predict, |
| | inputs=[ref_input, mut_input], |
| | outputs=[prediction, probability, importance, explainability] |
| | ) |
| | |
| | clear_btn.click( |
| | fn=lambda: ("", "", "", 0.0, 0.0, ""), |
| | outputs=[ref_input, mut_input, prediction, probability, importance, explainability] |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| | |