import gradio as gr import torch from huggingface_hub import hf_hub_download from encoder import MutationEncoder from model import MutationPredictorCNN # Load model MODEL_PATH = hf_hub_download( repo_id="nileshhanotia/mutation-pathogenicity-predictor", filename="pytorch_model.pth" ) device = torch.device("cpu") model = MutationPredictorCNN().to(device) checkpoint = torch.load(MODEL_PATH, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() encoder = MutationEncoder() def generate_explainability(ref_seq, mut_seq, importance, encoded_tensor): """ Generate explainability visualization using the encoded tensor to match exactly what the model sees """ # Extract mutation position from the encoding (positions 990:1089) diff_mask = encoded_tensor[990:1089] mutation_pos = torch.argmax(diff_mask).item() # Check if mutation was detected if diff_mask[mutation_pos].item() == 0: return "No mutation detected in encoding" # Clean sequences ref_seq = ref_seq.strip().upper() mut_seq = mut_seq.strip().upper() # Create pointer aligned to mutation position pointer = " " * mutation_pos + "^" # Extract bases at mutation position if mutation_pos < len(ref_seq) and mutation_pos < len(mut_seq): ref_base = ref_seq[mutation_pos] mut_base = mut_seq[mutation_pos] substitution = f"{ref_base}>{mut_base}" else: substitution = "Unknown" # Format explainability output explainability_text = ( "Mutated sequence:\n" + mut_seq + "\n" + pointer + "\n\n" + f"Mutation position: {mutation_pos}\n" + f"Substitution: {substitution}\n" + f"Importance score: {importance:.4f}" ) return explainability_text def predict(ref_seq, mut_seq): """ Predict pathogenicity and generate explainability """ # Clean input sequences ref_seq = ref_seq.strip().upper() mut_seq = mut_seq.strip().upper() # Validate sequences if not ref_seq or not mut_seq: return "Error", 0.0, 0.0, "Please provide both reference and mutated sequences" if len(ref_seq) != len(mut_seq): return "Error", 0.0, 0.0, f"Sequences must be same length (ref: {len(ref_seq)}, mut: {len(mut_seq)})" try: # Encode mutation encoded = encoder.encode_mutation(ref_seq, mut_seq) # Add batch dimension tensor = encoded.unsqueeze(0).to(device) # Get model predictions with torch.no_grad(): logit, importance = model(tensor) probability = logit.item() # Model already outputs sigmoid importance_val = importance.item() # Determine label label = "Pathogenic" if probability >= 0.5 else "Benign" # Generate explainability using the encoded tensor explain = generate_explainability( ref_seq, mut_seq, importance_val, encoded ) return label, probability, importance_val, explain except Exception as e: error_msg = f"Error during prediction: {str(e)}" return "Error", 0.0, 0.0, error_msg # UI with gr.Blocks(title="DNA Mutation Pathogenicity Predictor") as demo: gr.Markdown(""" # 🧬 Explainable Mutation Pathogenicity Predictor Predict whether a DNA mutation is pathogenic or benign with explainability showing the mutation position and importance. """) with gr.Row(): with gr.Column(): ref_input = gr.Textbox( label="Reference sequence (99bp)", placeholder="Enter reference DNA sequence (A, T, G, C)", lines=3 ) mut_input = gr.Textbox( label="Mutated sequence (99bp)", placeholder="Enter mutated DNA sequence (A, T, G, C)", lines=3 ) with gr.Row(): clear_btn = gr.Button("Clear") submit = gr.Button("Predict", variant="primary") with gr.Column(): prediction = gr.Textbox( label="Prediction", interactive=False ) probability = gr.Number( label="Pathogenic Probability", interactive=False ) importance = gr.Number( label="Mutation Importance Score", interactive=False ) # Explainability visualization explainability = gr.Textbox( label="Explainability Visualization", lines=8, interactive=False ) # Examples gr.Markdown("### Examples") gr.Examples( examples=[ [ "AAAAAAAAAACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", "AAAAAAAAAATAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" ], [ "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA", "ATCGATCGATGGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA" ] ], inputs=[ref_input, mut_input], label="Click an example to load" ) # Button actions submit.click( fn=predict, inputs=[ref_input, mut_input], outputs=[prediction, probability, importance, explainability] ) clear_btn.click( fn=lambda: ("", "", "", 0.0, 0.0, ""), outputs=[ref_input, mut_input, prediction, probability, importance, explainability] ) if __name__ == "__main__": demo.launch()