Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -90,7 +90,7 @@ def calculate_shap_values(model, x_tensor):
|
|
| 90 |
# Baseline
|
| 91 |
baseline_output = model(x_tensor)
|
| 92 |
baseline_probs = torch.softmax(baseline_output, dim=1)
|
| 93 |
-
baseline_prob = baseline_probs[0, 1].item() # Probability of 'human'
|
| 94 |
|
| 95 |
# Zeroing each feature to measure impact
|
| 96 |
shap_values = []
|
|
@@ -100,7 +100,7 @@ def calculate_shap_values(model, x_tensor):
|
|
| 100 |
x_zeroed[0, i] = 0.0
|
| 101 |
output = model(x_zeroed)
|
| 102 |
probs = torch.softmax(output, dim=1)
|
| 103 |
-
prob = probs[0, 1].item()
|
| 104 |
impact = baseline_prob - prob
|
| 105 |
shap_values.append(impact)
|
| 106 |
x_zeroed[0, i] = original_val # restore
|
|
@@ -354,6 +354,7 @@ def analyze_sequence(file_path, top_k=10, fasta_text="", window_size=500):
|
|
| 354 |
with torch.no_grad():
|
| 355 |
output = model(x_tensor)
|
| 356 |
probs = torch.softmax(output, dim=1)
|
|
|
|
| 357 |
pred_human = probs[0, 1].item()
|
| 358 |
|
| 359 |
# Calculate SHAP values
|
|
@@ -814,4 +815,4 @@ if __name__ == "__main__":
|
|
| 814 |
server_port=7860, # Default Gradio port
|
| 815 |
show_api=False, # Hide API docs
|
| 816 |
debug=False # Set to True for debugging
|
| 817 |
-
)
|
|
|
|
| 90 |
# Baseline
|
| 91 |
baseline_output = model(x_tensor)
|
| 92 |
baseline_probs = torch.softmax(baseline_output, dim=1)
|
| 93 |
+
baseline_prob = baseline_probs[0, 1].item() # Probability of 'human'
|
| 94 |
|
| 95 |
# Zeroing each feature to measure impact
|
| 96 |
shap_values = []
|
|
|
|
| 100 |
x_zeroed[0, i] = 0.0
|
| 101 |
output = model(x_zeroed)
|
| 102 |
probs = torch.softmax(output, dim=1)
|
| 103 |
+
prob = probs[0, 1].item() # Probability of 'human'
|
| 104 |
impact = baseline_prob - prob
|
| 105 |
shap_values.append(impact)
|
| 106 |
x_zeroed[0, i] = original_val # restore
|
|
|
|
| 354 |
with torch.no_grad():
|
| 355 |
output = model(x_tensor)
|
| 356 |
probs = torch.softmax(output, dim=1)
|
| 357 |
+
# Using index 1 for probability of human
|
| 358 |
pred_human = probs[0, 1].item()
|
| 359 |
|
| 360 |
# Calculate SHAP values
|
|
|
|
| 815 |
server_port=7860, # Default Gradio port
|
| 816 |
show_api=False, # Hide API docs
|
| 817 |
debug=False # Set to True for debugging
|
| 818 |
+
)
|