nileshhanotia commited on
Commit
e2b8fa8
·
verified ·
1 Parent(s): 8a9fd31

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from huggingface_hub import hf_hub_download
4
+ from encoder import MutationEncoder
5
+ from model import MutationPredictorCNN
6
+
7
+ # Load model
8
+ MODEL_PATH = hf_hub_download(
9
+ repo_id="nileshhanotia/mutation-pathogenicity-predictor",
10
+ filename="pytorch_model.pth"
11
+ )
12
+
13
+ device = torch.device("cpu")
14
+
15
+ model = MutationPredictorCNN().to(device)
16
+
17
+ checkpoint = torch.load(MODEL_PATH, map_location=device)
18
+
19
+ model.load_state_dict(checkpoint["model_state_dict"])
20
+
21
+ model.eval()
22
+
23
+ encoder = MutationEncoder()
24
+
25
+
26
+ def generate_explainability(ref_seq, mut_seq, importance, encoded_tensor):
27
+ """
28
+ Generate explainability visualization using the encoded tensor
29
+ to match exactly what the model sees
30
+ """
31
+ # Extract mutation position from the encoding (positions 990:1089)
32
+ diff_mask = encoded_tensor[990:1089]
33
+ mutation_pos = torch.argmax(diff_mask).item()
34
+
35
+ # Check if mutation was detected
36
+ if diff_mask[mutation_pos].item() == 0:
37
+ return "No mutation detected in encoding"
38
+
39
+ # Clean sequences
40
+ ref_seq = ref_seq.strip().upper()
41
+ mut_seq = mut_seq.strip().upper()
42
+
43
+ # Create pointer aligned to mutation position
44
+ pointer = " " * mutation_pos + "^"
45
+
46
+ # Extract bases at mutation position
47
+ if mutation_pos < len(ref_seq) and mutation_pos < len(mut_seq):
48
+ ref_base = ref_seq[mutation_pos]
49
+ mut_base = mut_seq[mutation_pos]
50
+ substitution = f"{ref_base}>{mut_base}"
51
+ else:
52
+ substitution = "Unknown"
53
+
54
+ # Format explainability output
55
+ explainability_text = (
56
+ "Mutated sequence:\n"
57
+ + mut_seq + "\n"
58
+ + pointer + "\n\n"
59
+ + f"Mutation position: {mutation_pos}\n"
60
+ + f"Substitution: {substitution}\n"
61
+ + f"Importance score: {importance:.4f}"
62
+ )
63
+
64
+ return explainability_text
65
+
66
+
67
+ def predict(ref_seq, mut_seq):
68
+ """
69
+ Predict pathogenicity and generate explainability
70
+ """
71
+ # Clean input sequences
72
+ ref_seq = ref_seq.strip().upper()
73
+ mut_seq = mut_seq.strip().upper()
74
+
75
+ # Validate sequences
76
+ if not ref_seq or not mut_seq:
77
+ return "Error", 0.0, 0.0, "Please provide both reference and mutated sequences"
78
+
79
+ if len(ref_seq) != len(mut_seq):
80
+ return "Error", 0.0, 0.0, f"Sequences must be same length (ref: {len(ref_seq)}, mut: {len(mut_seq)})"
81
+
82
+ try:
83
+ # Encode mutation
84
+ encoded = encoder.encode_mutation(ref_seq, mut_seq)
85
+
86
+ # Add batch dimension
87
+ tensor = encoded.unsqueeze(0).to(device)
88
+
89
+ # Get model predictions
90
+ with torch.no_grad():
91
+ logit, importance = model(tensor)
92
+ probability = logit.item() # Model already outputs sigmoid
93
+ importance_val = importance.item()
94
+
95
+ # Determine label
96
+ label = "Pathogenic" if probability >= 0.5 else "Benign"
97
+
98
+ # Generate explainability using the encoded tensor
99
+ explain = generate_explainability(
100
+ ref_seq,
101
+ mut_seq,
102
+ importance_val,
103
+ encoded
104
+ )
105
+
106
+ return label, probability, importance_val, explain
107
+
108
+ except Exception as e:
109
+ error_msg = f"Error during prediction: {str(e)}"
110
+ return "Error", 0.0, 0.0, error_msg
111
+
112
+
113
+ # UI
114
+ with gr.Blocks(title="DNA Mutation Pathogenicity Predictor") as demo:
115
+
116
+ gr.Markdown("""
117
+ # 🧬 Explainable Mutation Pathogenicity Predictor
118
+
119
+ Predict whether a DNA mutation is pathogenic or benign with explainability
120
+ showing the mutation position and importance.
121
+ """)
122
+
123
+ with gr.Row():
124
+ with gr.Column():
125
+ ref_input = gr.Textbox(
126
+ label="Reference sequence (99bp)",
127
+ placeholder="Enter reference DNA sequence (A, T, G, C)",
128
+ lines=3
129
+ )
130
+
131
+ mut_input = gr.Textbox(
132
+ label="Mutated sequence (99bp)",
133
+ placeholder="Enter mutated DNA sequence (A, T, G, C)",
134
+ lines=3
135
+ )
136
+
137
+ with gr.Row():
138
+ clear_btn = gr.Button("Clear")
139
+ submit = gr.Button("Predict", variant="primary")
140
+
141
+ with gr.Column():
142
+ prediction = gr.Textbox(
143
+ label="Prediction",
144
+ interactive=False
145
+ )
146
+
147
+ probability = gr.Number(
148
+ label="Pathogenic Probability",
149
+ interactive=False
150
+ )
151
+
152
+ importance = gr.Number(
153
+ label="Mutation Importance Score",
154
+ interactive=False
155
+ )
156
+
157
+ # Explainability visualization
158
+ explainability = gr.Textbox(
159
+ label="Explainability Visualization",
160
+ lines=8,
161
+ interactive=False
162
+ )
163
+
164
+ # Examples
165
+ gr.Markdown("### Examples")
166
+ gr.Examples(
167
+ examples=[
168
+ [
169
+ "AAAAAAAAAACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
170
+ "AAAAAAAAAATAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
171
+ ],
172
+ [
173
+ "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA",
174
+ "ATCGATCGATGGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA"
175
+ ]
176
+ ],
177
+ inputs=[ref_input, mut_input],
178
+ label="Click an example to load"
179
+ )
180
+
181
+ # Button actions
182
+ submit.click(
183
+ fn=predict,
184
+ inputs=[ref_input, mut_input],
185
+ outputs=[prediction, probability, importance, explainability]
186
+ )
187
+
188
+ clear_btn.click(
189
+ fn=lambda: ("", "", "", 0.0, 0.0, ""),
190
+ outputs=[ref_input, mut_input, prediction, probability, importance, explainability]
191
+ )
192
+
193
+
194
+ if __name__ == "__main__":
195
+ demo.launch()
196
+