Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,13 +2,15 @@ import gradio as gr
|
|
| 2 |
import torch
|
| 3 |
import joblib
|
| 4 |
import numpy as np
|
|
|
|
| 5 |
import torch.nn as nn
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
import io
|
| 8 |
from PIL import Image
|
| 9 |
-
from itertools import product
|
| 10 |
|
| 11 |
-
#
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class VirusClassifier(nn.Module):
|
| 14 |
def __init__(self, input_shape: int):
|
|
@@ -29,46 +31,20 @@ class VirusClassifier(nn.Module):
|
|
| 29 |
|
| 30 |
def forward(self, x):
|
| 31 |
return self.network(x)
|
| 32 |
-
|
| 33 |
-
def get_gradient_importance(self, x, class_index=1):
|
| 34 |
-
"""
|
| 35 |
-
Calculate gradient-based importance for each input feature.
|
| 36 |
-
By default, we compute the gradient wrt the 'human' class (index=1).
|
| 37 |
-
This method is akin to a raw gradient or 'saliency' approach.
|
| 38 |
-
"""
|
| 39 |
-
x = x.clone().detach().requires_grad_(True)
|
| 40 |
-
output = self.network(x)
|
| 41 |
-
probs = torch.softmax(output, dim=1)
|
| 42 |
-
|
| 43 |
-
# Probability of the specified class
|
| 44 |
-
target_prob = probs[..., class_index]
|
| 45 |
-
|
| 46 |
-
# Zero existing gradients if any
|
| 47 |
-
if x.grad is not None:
|
| 48 |
-
x.grad.zero_()
|
| 49 |
-
|
| 50 |
-
# Backprop on that probability
|
| 51 |
-
target_prob.backward()
|
| 52 |
-
|
| 53 |
-
# Raw gradient is now in x.grad
|
| 54 |
-
importance = x.grad.detach()
|
| 55 |
-
|
| 56 |
-
# Optional: Multiply by input to get a more "integrated gradients"-like measure
|
| 57 |
-
# importance = importance * x.detach()
|
| 58 |
-
|
| 59 |
-
return importance, float(target_prob)
|
| 60 |
|
| 61 |
-
#
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
def parse_fasta(text
|
| 64 |
"""
|
| 65 |
-
|
| 66 |
"""
|
| 67 |
sequences = []
|
| 68 |
current_header = None
|
| 69 |
current_sequence = []
|
| 70 |
|
| 71 |
-
for line in text.split('\n'):
|
| 72 |
line = line.strip()
|
| 73 |
if not line:
|
| 74 |
continue
|
|
@@ -85,10 +61,8 @@ def parse_fasta(text: str):
|
|
| 85 |
|
| 86 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
| 87 |
"""
|
| 88 |
-
Convert a
|
| 89 |
-
Defaults to k=4.
|
| 90 |
"""
|
| 91 |
-
# Generate all possible k-mers
|
| 92 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
| 93 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
| 94 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
|
@@ -104,385 +78,355 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
| 104 |
|
| 105 |
return vec
|
| 106 |
|
| 107 |
-
def
|
| 108 |
-
"""
|
| 109 |
-
Compute various statistics for a given sequence:
|
| 110 |
-
- Length
|
| 111 |
-
- GC content (%)
|
| 112 |
-
- A/C/G/T counts
|
| 113 |
"""
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
}
|
| 121 |
-
|
| 122 |
-
counts = {
|
| 123 |
-
'A': sequence.count('A'),
|
| 124 |
-
'C': sequence.count('C'),
|
| 125 |
-
'G': sequence.count('G'),
|
| 126 |
-
'T': sequence.count('T')
|
| 127 |
-
}
|
| 128 |
-
gc_content = (counts['G'] + counts['C']) / length * 100.0
|
| 129 |
-
|
| 130 |
-
return {
|
| 131 |
-
'length': length,
|
| 132 |
-
'gc_content': gc_content,
|
| 133 |
-
'counts': counts
|
| 134 |
-
}
|
| 135 |
-
|
| 136 |
-
# --------------- Visualization Functions ---------------
|
| 137 |
-
|
| 138 |
-
def plot_shap_like_bars(kmers, importance_values, top_k=10):
|
| 139 |
"""
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
#
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
-
|
| 155 |
-
fig, ax = plt.subplots(figsize=(8, 6))
|
| 156 |
-
colors = ['green' if val > 0 else 'red' for val in top_importances]
|
| 157 |
-
ax.barh(range(len(top_kmers)), np.abs(top_importances), color=colors)
|
| 158 |
-
ax.set_yticks(range(len(top_kmers)))
|
| 159 |
-
ax.set_yticklabels(top_kmers)
|
| 160 |
-
ax.invert_yaxis() # So that the highest value is at the top
|
| 161 |
-
ax.set_xlabel("Feature Importance (Gradient Magnitude)")
|
| 162 |
-
ax.set_title(f"Top-{top_k} SHAP-like Feature Importances")
|
| 163 |
-
plt.tight_layout()
|
| 164 |
-
return fig
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
(Optional if you want a quick distribution overview)
|
| 170 |
-
"""
|
| 171 |
-
fig, ax = plt.subplots(figsize=(10, 4))
|
| 172 |
-
ax.bar(range(len(kmer_freq_vector)), kmer_freq_vector, color='blue', alpha=0.6)
|
| 173 |
-
ax.set_xlabel("K-mer Index")
|
| 174 |
-
ax.set_ylabel("Frequency")
|
| 175 |
-
ax.set_title("K-mer Frequency Distribution")
|
| 176 |
-
ax.set_xticks([])
|
| 177 |
-
plt.tight_layout()
|
| 178 |
-
return fig
|
| 179 |
|
| 180 |
-
def
|
| 181 |
"""
|
| 182 |
-
|
| 183 |
-
|
| 184 |
"""
|
| 185 |
-
fig = plt.figure(figsize=(
|
| 186 |
-
|
|
|
|
|
|
|
| 187 |
|
| 188 |
-
#
|
|
|
|
| 189 |
current_prob = 0.5
|
| 190 |
steps = [('Start', current_prob, 0)]
|
| 191 |
|
| 192 |
-
for
|
| 193 |
-
change =
|
| 194 |
current_prob += change
|
| 195 |
-
steps.append((
|
| 196 |
-
|
| 197 |
-
x_vals = range(len(steps))
|
| 198 |
-
y_vals = [s[1] for s in steps]
|
| 199 |
-
|
| 200 |
-
ax.step(x_vals, y_vals, 'b-', where='post', label='Probability', linewidth=2)
|
| 201 |
-
ax.plot(x_vals, y_vals, 'b.', markersize=10)
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
for i, (kmer, prob, change) in enumerate(steps):
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
| 217 |
|
|
|
|
| 218 |
if i > 0:
|
| 219 |
change_text = f'{change:+.3f}'
|
| 220 |
color = 'green' if change > 0 else 'red'
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
def plot_kmer_freq_and_sigma(important_kmers):
|
| 233 |
-
"""
|
| 234 |
-
Plot frequencies vs. sigma from mean for the top k-mers.
|
| 235 |
-
This reuses logic from the original create_visualization second subplot,
|
| 236 |
-
but as its own function for clarity.
|
| 237 |
-
"""
|
| 238 |
-
fig, ax = plt.subplots(figsize=(8, 5))
|
| 239 |
|
| 240 |
# Prepare data
|
| 241 |
kmers = [k['kmer'] for k in important_kmers]
|
| 242 |
frequencies = [k['occurrence'] for k in important_kmers]
|
| 243 |
sigmas = [k['sigma'] for k in important_kmers]
|
| 244 |
-
colors = ['green' if k['direction'] == 'human' else 'red' for k in important_kmers]
|
| 245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
x = np.arange(len(kmers))
|
| 247 |
width = 0.35
|
| 248 |
|
| 249 |
-
|
| 250 |
-
ax.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
|
| 251 |
|
| 252 |
-
#
|
| 253 |
-
|
| 254 |
-
#
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
ax2.set_ylabel('
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
| 267 |
|
| 268 |
plt.tight_layout()
|
| 269 |
return fig
|
| 270 |
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
def predict_fasta(
|
| 274 |
-
file_obj,
|
| 275 |
-
k_size=4,
|
| 276 |
-
top_k=10,
|
| 277 |
-
advanced_analysis=False
|
| 278 |
-
):
|
| 279 |
"""
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
"""
|
| 285 |
-
#
|
| 286 |
-
|
| 287 |
-
|
| 288 |
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
text = file_obj
|
| 292 |
-
else:
|
| 293 |
-
text = file_obj.decode('utf-8', errors='replace')
|
| 294 |
-
except Exception as e:
|
| 295 |
-
return f"Error reading file: {str(e)}", []
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
try:
|
| 304 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 305 |
-
model = VirusClassifier(input_shape=(4 ** k_size)).to(device)
|
| 306 |
state_dict = torch.load('model.pt', map_location=device)
|
| 307 |
model.load_state_dict(state_dict)
|
| 308 |
-
model.eval()
|
| 309 |
-
|
| 310 |
scaler = joblib.load('scaler.pkl')
|
| 311 |
except Exception as e:
|
| 312 |
-
return f"Error loading model
|
| 313 |
-
|
| 314 |
-
#
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
non_human_prob = float(probs[0][0])
|
| 340 |
-
confidence = float(torch.max(probs[0]).item())
|
| 341 |
-
|
| 342 |
-
# Compute gradient-based importance
|
| 343 |
-
importance, target_prob = model.get_gradient_importance(X_tensor, class_index=1)
|
| 344 |
-
importance = importance[0].cpu().numpy() # shape: (num_features,)
|
| 345 |
-
|
| 346 |
-
# Identify top-k features (by absolute gradient)
|
| 347 |
-
abs_importance = np.abs(importance)
|
| 348 |
-
sorted_indices = np.argsort(abs_importance)[::-1]
|
| 349 |
-
top_indices = sorted_indices[:top_k]
|
| 350 |
-
|
| 351 |
-
# Build a list of top k-mers
|
| 352 |
-
top_kmers_info = []
|
| 353 |
-
for i in top_indices:
|
| 354 |
-
kmer_name = all_kmers[i]
|
| 355 |
-
imp_val = float(importance[i])
|
| 356 |
-
direction = 'human' if imp_val > 0 else 'non-human'
|
| 357 |
-
freq_perc = float(raw_kmer_freq[i] * 100.0) # in percent
|
| 358 |
-
sigma = float(scaled_kmer_freq[0][i]) # This is the scaled value (stdev from mean if the scaler is StandardScaler)
|
| 359 |
-
|
| 360 |
-
top_kmers_info.append({
|
| 361 |
-
'kmer': kmer_name,
|
| 362 |
-
'impact': abs(imp_val),
|
| 363 |
-
'direction': direction,
|
| 364 |
-
'occurrence': freq_perc,
|
| 365 |
-
'sigma': sigma
|
| 366 |
-
})
|
| 367 |
-
|
| 368 |
-
# Text summary for this sequence
|
| 369 |
-
seq_report = []
|
| 370 |
-
seq_report.append(f"=== Sequence {idx} ===")
|
| 371 |
-
seq_report.append(f"Header: {header}")
|
| 372 |
-
seq_report.append(f"Length: {seq_stats['length']}")
|
| 373 |
-
seq_report.append(f"GC Content: {seq_stats['gc_content']:.2f}%")
|
| 374 |
-
seq_report.append(f"A: {seq_stats['counts']['A']}, C: {seq_stats['counts']['C']}, G: {seq_stats['counts']['G']}, T: {seq_stats['counts']['T']}")
|
| 375 |
-
seq_report.append(f"Prediction: {pred_label} (Confidence: {confidence:.4f})")
|
| 376 |
-
seq_report.append(f" Human Probability: {human_prob:.4f}")
|
| 377 |
-
seq_report.append(f" Non-human Probability: {non_human_prob:.4f}")
|
| 378 |
-
seq_report.append(f"\nTop-{top_k} Influential k-mers (by gradient magnitude):")
|
| 379 |
-
for tkm in top_kmers_info:
|
| 380 |
-
seq_report.append(
|
| 381 |
-
f" {tkm['kmer']}: pushes towards {tkm['direction']} "
|
| 382 |
-
f"(impact={tkm['impact']:.4f}), occurrence={tkm['occurrence']:.2f}%, "
|
| 383 |
-
f"sigma={tkm['sigma']:.2f}"
|
| 384 |
-
)
|
| 385 |
-
|
| 386 |
-
final_text_report.append("\n".join(seq_report))
|
| 387 |
-
|
| 388 |
-
# 6. Generate Plots (for each sequence)
|
| 389 |
-
if advanced_analysis:
|
| 390 |
-
# 6A. SHAP-like bar chart
|
| 391 |
-
fig_shap = plot_shap_like_bars(
|
| 392 |
-
kmers=all_kmers,
|
| 393 |
-
importance_values=importance,
|
| 394 |
-
top_k=top_k
|
| 395 |
-
)
|
| 396 |
-
buf_shap = io.BytesIO()
|
| 397 |
-
fig_shap.savefig(buf_shap, format='png', bbox_inches='tight', dpi=150)
|
| 398 |
-
buf_shap.seek(0)
|
| 399 |
-
plots.append(Image.open(buf_shap))
|
| 400 |
-
plt.close(fig_shap)
|
| 401 |
-
|
| 402 |
-
# 6B. k-mer distribution histogram
|
| 403 |
-
fig_kmer_dist = plot_kmer_distribution(raw_kmer_freq, all_kmers)
|
| 404 |
-
buf_dist = io.BytesIO()
|
| 405 |
-
fig_kmer_dist.savefig(buf_dist, format='png', bbox_inches='tight', dpi=150)
|
| 406 |
-
buf_dist.seek(0)
|
| 407 |
-
plots.append(Image.open(buf_dist))
|
| 408 |
-
plt.close(fig_kmer_dist)
|
| 409 |
-
|
| 410 |
-
# 6C. Original step visualization for top k k-mers
|
| 411 |
-
# Sort by actual 'impact' to preserve that step logic
|
| 412 |
-
# (largest absolute impact first)
|
| 413 |
-
top_kmers_info_step = sorted(top_kmers_info, key=lambda x: x['impact'], reverse=True)
|
| 414 |
-
fig_step = create_step_visualization(top_kmers_info_step, human_prob)
|
| 415 |
-
buf_step = io.BytesIO()
|
| 416 |
-
fig_step.savefig(buf_step, format='png', bbox_inches='tight', dpi=150)
|
| 417 |
-
buf_step.seek(0)
|
| 418 |
-
plots.append(Image.open(buf_step))
|
| 419 |
-
plt.close(fig_step)
|
| 420 |
-
|
| 421 |
-
# 6D. Frequency vs. sigma bar chart
|
| 422 |
-
fig_freq_sigma = plot_kmer_freq_and_sigma(top_kmers_info_step)
|
| 423 |
-
buf_freq_sigma = io.BytesIO()
|
| 424 |
-
fig_freq_sigma.savefig(buf_freq_sigma, format='png', bbox_inches='tight', dpi=150)
|
| 425 |
-
buf_freq_sigma.seek(0)
|
| 426 |
-
plots.append(Image.open(buf_freq_sigma))
|
| 427 |
-
plt.close(fig_freq_sigma)
|
| 428 |
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
):
|
| 441 |
-
"""
|
| 442 |
-
Wrapper for Gradio to handle the outputs in (text, List[Image]) form.
|
| 443 |
-
"""
|
| 444 |
-
text_output, pil_images = predict_fasta(
|
| 445 |
-
file_obj=file_obj,
|
| 446 |
-
k_size=k_size,
|
| 447 |
-
top_k=top_k,
|
| 448 |
-
advanced_analysis=advanced_analysis
|
| 449 |
-
)
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
return text_output, pil_images
|
| 453 |
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
)
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
if __name__ == "__main__":
|
| 487 |
-
|
| 488 |
-
|
|
|
|
| 2 |
import torch
|
| 3 |
import joblib
|
| 4 |
import numpy as np
|
| 5 |
+
from itertools import product
|
| 6 |
import torch.nn as nn
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import io
|
| 9 |
from PIL import Image
|
|
|
|
| 10 |
|
| 11 |
+
##############################################################################
|
| 12 |
+
# MODEL DEFINITION
|
| 13 |
+
##############################################################################
|
| 14 |
|
| 15 |
class VirusClassifier(nn.Module):
|
| 16 |
def __init__(self, input_shape: int):
|
|
|
|
| 31 |
|
| 32 |
def forward(self, x):
|
| 33 |
return self.network(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
##############################################################################
|
| 36 |
+
# UTILITIES
|
| 37 |
+
##############################################################################
|
| 38 |
|
| 39 |
+
def parse_fasta(text):
|
| 40 |
"""
|
| 41 |
+
Parses FASTA formatted text into a list of (header, sequence).
|
| 42 |
"""
|
| 43 |
sequences = []
|
| 44 |
current_header = None
|
| 45 |
current_sequence = []
|
| 46 |
|
| 47 |
+
for line in text.strip().split('\n'):
|
| 48 |
line = line.strip()
|
| 49 |
if not line:
|
| 50 |
continue
|
|
|
|
| 61 |
|
| 62 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
| 63 |
"""
|
| 64 |
+
Convert a sequence to a k-mer frequency vector of size len(ACGT^k).
|
|
|
|
| 65 |
"""
|
|
|
|
| 66 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
| 67 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
| 68 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
|
|
|
| 78 |
|
| 79 |
return vec
|
| 80 |
|
| 81 |
+
def ablation_importance(model, x_tensor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
"""
|
| 83 |
+
Calculates a simple ablation-based importance measure for each feature:
|
| 84 |
+
1. Compute baseline human probability p_base.
|
| 85 |
+
2. For each feature i, set x[i] = 0, re-run inference, compute new p, and
|
| 86 |
+
measure delta = p_base - p.
|
| 87 |
+
3. Return array of deltas (positive means that removing that feature
|
| 88 |
+
*decreases* the probability => that feature was pushing it higher).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
"""
|
| 90 |
+
model.eval()
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
# Baseline probability
|
| 93 |
+
output = model(x_tensor)
|
| 94 |
+
probs = torch.softmax(output, dim=1)
|
| 95 |
+
p_base = probs[0, 1].item()
|
| 96 |
+
|
| 97 |
+
# Store the delta importances
|
| 98 |
+
importances = np.zeros(x_tensor.shape[1], dtype=np.float32)
|
| 99 |
+
|
| 100 |
+
# For efficiency, we do ablation one feature at a time
|
| 101 |
+
for i in range(x_tensor.shape[1]):
|
| 102 |
+
x_copy = x_tensor.clone()
|
| 103 |
+
x_copy[0, i] = 0.0 # Ablate this feature
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
output_ablation = model(x_copy)
|
| 106 |
+
probs_ablation = torch.softmax(output_ablation, dim=1)
|
| 107 |
+
p_ablation = probs_ablation[0, 1].item()
|
| 108 |
+
# Delta
|
| 109 |
+
importances[i] = p_base - p_ablation
|
| 110 |
|
| 111 |
+
return importances, p_base
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
+
##############################################################################
|
| 114 |
+
# PLOTTING
|
| 115 |
+
##############################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
+
def create_step_and_frequency_plot(important_kmers, human_prob, title):
|
| 118 |
"""
|
| 119 |
+
Creates a combined step plot (showing how each k-mer modifies the probability)
|
| 120 |
+
and a frequency vs. sigma bar chart.
|
| 121 |
"""
|
| 122 |
+
fig = plt.figure(figsize=(15, 10))
|
| 123 |
+
|
| 124 |
+
# Create grid for subplots
|
| 125 |
+
gs = plt.GridSpec(2, 1, height_ratios=[1.5, 1], hspace=0.3)
|
| 126 |
|
| 127 |
+
# 1. Probability Step Plot
|
| 128 |
+
ax1 = plt.subplot(gs[0])
|
| 129 |
current_prob = 0.5
|
| 130 |
steps = [('Start', current_prob, 0)]
|
| 131 |
|
| 132 |
+
for kmer_info in important_kmers:
|
| 133 |
+
change = kmer_info['impact'] # positive => pushes up, negative => pushes down
|
| 134 |
current_prob += change
|
| 135 |
+
steps.append((kmer_info['kmer'], current_prob, change))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
x = range(len(steps))
|
| 138 |
+
y = [step[1] for step in steps]
|
| 139 |
+
|
| 140 |
+
# Plot steps
|
| 141 |
+
ax1.step(x, y, 'b-', where='post', label='Probability', linewidth=2)
|
| 142 |
+
ax1.plot(x, y, 'b.', markersize=10)
|
| 143 |
+
|
| 144 |
+
# Add reference line
|
| 145 |
+
ax1.axhline(y=0.5, color='r', linestyle='--', label='Neutral (0.5)')
|
| 146 |
+
|
| 147 |
+
# Customize plot
|
| 148 |
+
ax1.grid(True, linestyle='--', alpha=0.7)
|
| 149 |
+
ax1.set_ylim(0, 1)
|
| 150 |
+
ax1.set_ylabel('Human Probability')
|
| 151 |
+
ax1.set_title(f'K-mer Contributions to Prediction (final prob: {human_prob:.3f})')
|
| 152 |
+
|
| 153 |
+
# Add labels for each point
|
| 154 |
for i, (kmer, prob, change) in enumerate(steps):
|
| 155 |
+
# Add k-mer label
|
| 156 |
+
ax1.annotate(kmer,
|
| 157 |
+
(i, prob),
|
| 158 |
+
xytext=(0, 10 if i % 2 == 0 else -20),
|
| 159 |
+
textcoords='offset points',
|
| 160 |
+
ha='center',
|
| 161 |
+
rotation=45)
|
| 162 |
|
| 163 |
+
# Add change value
|
| 164 |
if i > 0:
|
| 165 |
change_text = f'{change:+.3f}'
|
| 166 |
color = 'green' if change > 0 else 'red'
|
| 167 |
+
ax1.annotate(change_text,
|
| 168 |
+
(i, prob),
|
| 169 |
+
xytext=(0, -20 if i % 2 == 0 else 10),
|
| 170 |
+
textcoords='offset points',
|
| 171 |
+
ha='center',
|
| 172 |
+
color=color)
|
| 173 |
+
|
| 174 |
+
ax1.legend()
|
| 175 |
+
|
| 176 |
+
# 2. K-mer Frequency and Sigma Plot
|
| 177 |
+
ax2 = plt.subplot(gs[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
# Prepare data
|
| 180 |
kmers = [k['kmer'] for k in important_kmers]
|
| 181 |
frequencies = [k['occurrence'] for k in important_kmers]
|
| 182 |
sigmas = [k['sigma'] for k in important_kmers]
|
|
|
|
| 183 |
|
| 184 |
+
# Color the bars: if impact>0 => green, else red
|
| 185 |
+
colors = ['g' if k['impact'] > 0 else 'r' for k in important_kmers]
|
| 186 |
+
|
| 187 |
+
# Create bar plot for frequencies
|
| 188 |
x = np.arange(len(kmers))
|
| 189 |
width = 0.35
|
| 190 |
|
| 191 |
+
ax2.bar(x - width/2, frequencies, width, label='Frequency (%)', color=colors, alpha=0.6)
|
|
|
|
| 192 |
|
| 193 |
+
# Twin axis for sigma
|
| 194 |
+
ax2_twin = ax2.twinx()
|
| 195 |
+
# To highlight positive or negative sigma, pick color accordingly
|
| 196 |
+
sigma_colors = []
|
| 197 |
+
for s, c in zip(sigmas, colors):
|
| 198 |
+
if s >= 0:
|
| 199 |
+
sigma_colors.append('blue') # above average
|
| 200 |
+
else:
|
| 201 |
+
sigma_colors.append('gray') # below average
|
| 202 |
+
|
| 203 |
+
ax2_twin.bar(x + width/2, sigmas, width, label='σ from Mean', color=sigma_colors, alpha=0.3)
|
| 204 |
|
| 205 |
+
# Customize plot
|
| 206 |
+
ax2.set_xticks(x)
|
| 207 |
+
ax2.set_xticklabels(kmers, rotation=45)
|
| 208 |
+
ax2.set_ylabel('Frequency (%)')
|
| 209 |
+
ax2_twin.set_ylabel('Standard Deviations (σ) from Mean')
|
| 210 |
+
ax2.set_title('K-mer Frequencies and Statistical Significance')
|
| 211 |
+
|
| 212 |
+
# Add legends
|
| 213 |
+
lines1, labels1 = ax2.get_legend_handles_labels()
|
| 214 |
+
lines2, labels2 = ax2_twin.get_legend_handles_labels()
|
| 215 |
+
ax2.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
|
| 216 |
|
| 217 |
plt.tight_layout()
|
| 218 |
return fig
|
| 219 |
|
| 220 |
+
def create_shap_like_bar_plot(impact_values, kmer_list, top_k):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
"""
|
| 222 |
+
Creates a horizontal bar plot showing the top_k features by absolute impact.
|
| 223 |
+
impact_values: array of float (length=256).
|
| 224 |
+
kmer_list: list of all k=4 kmers in order.
|
| 225 |
+
top_k: integer, how many top features to display.
|
| 226 |
"""
|
| 227 |
+
# Sort by absolute impact
|
| 228 |
+
indices_sorted = np.argsort(np.abs(impact_values))[::-1]
|
| 229 |
+
top_indices = indices_sorted[:top_k]
|
| 230 |
|
| 231 |
+
top_impacts = impact_values[top_indices]
|
| 232 |
+
top_kmers = [kmer_list[i] for i in top_indices]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
+
fig = plt.figure(figsize=(8, 6))
|
| 235 |
+
plt.barh(range(len(top_impacts)), top_impacts, color=['green' if i > 0 else 'red' for i in top_impacts])
|
| 236 |
+
plt.yticks(range(len(top_impacts)), top_kmers)
|
| 237 |
+
plt.xlabel("Impact on Human Probability (Ablation)")
|
| 238 |
+
plt.title(f"Top {top_k} K-mers by Absolute Impact")
|
| 239 |
+
plt.gca().invert_yaxis() # Highest at top
|
| 240 |
+
plt.tight_layout()
|
| 241 |
+
return fig
|
| 242 |
+
|
| 243 |
+
def create_global_bar_plot(impact_values, kmer_list):
|
| 244 |
+
"""
|
| 245 |
+
Creates a bar plot for ALL features (256) to see the global distribution.
|
| 246 |
+
"""
|
| 247 |
+
fig = plt.figure(figsize=(12, 6))
|
| 248 |
+
indices_sorted = np.argsort(np.abs(impact_values))[::-1]
|
| 249 |
+
sorted_impacts = impact_values[indices_sorted]
|
| 250 |
+
sorted_kmers = [kmer_list[i] for i in indices_sorted]
|
| 251 |
|
| 252 |
+
plt.bar(range(len(sorted_impacts)), sorted_impacts,
|
| 253 |
+
color=['green' if i > 0 else 'red' for i in sorted_impacts])
|
| 254 |
+
plt.title("Global Impact of All 256 K-mers (Ablation Method)")
|
| 255 |
+
plt.xlabel("K-mer (sorted by |impact|)")
|
| 256 |
+
plt.ylabel("Impact on Human Probability")
|
| 257 |
+
# Optionally, we can skip labeling all 256 on x-axis.
|
| 258 |
+
# But we can show only the top/bottom or none for clarity.
|
| 259 |
+
plt.tight_layout()
|
| 260 |
+
return fig
|
| 261 |
+
|
| 262 |
+
##############################################################################
|
| 263 |
+
# MAIN PREDICTION FUNCTION
|
| 264 |
+
##############################################################################
|
| 265 |
+
|
| 266 |
+
def predict(file_obj, top_kmers=10, advanced_plots=False, fasta_text=""):
|
| 267 |
+
"""
|
| 268 |
+
Main prediction function called by Gradio.
|
| 269 |
+
- file_obj: optional uploaded FASTA file
|
| 270 |
+
- top_kmers: number of top k-mers to display in the main SHAP-like plot
|
| 271 |
+
- advanced_plots: bool, whether to return global bar plots
|
| 272 |
+
- fasta_text: optional direct-pasted FASTA text
|
| 273 |
+
"""
|
| 274 |
+
# Priority: If user pasted text, use that; otherwise use uploaded file.
|
| 275 |
+
if fasta_text.strip():
|
| 276 |
+
text = fasta_text.strip()
|
| 277 |
+
else:
|
| 278 |
+
if file_obj is None:
|
| 279 |
+
return "No FASTA input provided", None, None, None
|
| 280 |
+
try:
|
| 281 |
+
if isinstance(file_obj, str):
|
| 282 |
+
text = file_obj
|
| 283 |
+
else:
|
| 284 |
+
text = file_obj.decode('utf-8')
|
| 285 |
+
except Exception as e:
|
| 286 |
+
return f"Error reading file: {str(e)}", None, None, None
|
| 287 |
+
|
| 288 |
+
# Parse FASTA
|
| 289 |
+
sequences = parse_fasta(text)
|
| 290 |
+
if len(sequences) == 0:
|
| 291 |
+
return "No valid FASTA sequences found", None, None, None
|
| 292 |
+
header, seq = sequences[0]
|
| 293 |
+
|
| 294 |
+
# Load model + scaler
|
| 295 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 296 |
+
model = VirusClassifier(256).to(device)
|
| 297 |
try:
|
|
|
|
|
|
|
| 298 |
state_dict = torch.load('model.pt', map_location=device)
|
| 299 |
model.load_state_dict(state_dict)
|
|
|
|
|
|
|
| 300 |
scaler = joblib.load('scaler.pkl')
|
| 301 |
except Exception as e:
|
| 302 |
+
return f"Error loading model or scaler: {str(e)}", None, None, None
|
| 303 |
+
|
| 304 |
+
# Prepare the vector
|
| 305 |
+
raw_freq_vector = sequence_to_kmer_vector(seq, k=4)
|
| 306 |
+
scaled_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
| 307 |
+
X_tensor = torch.FloatTensor(scaled_vector).to(device)
|
| 308 |
+
|
| 309 |
+
# Compute ablation-based importances
|
| 310 |
+
importances, p_base = ablation_importance(model, X_tensor)
|
| 311 |
+
# p_base is baseline human probability
|
| 312 |
+
|
| 313 |
+
# We also want frequency in % and sigma from mean
|
| 314 |
+
# If your scaler is e.g. StandardScaler, then "scaled_vector[0][i]" is
|
| 315 |
+
# how many std devs from the mean that feature is.
|
| 316 |
+
# We'll gather info in a list of dicts for each k-mer.
|
| 317 |
+
kmers_4 = [''.join(p) for p in product("ACGT", repeat=4)]
|
| 318 |
+
kmer_dict = {km: i for i, km in enumerate(kmers_4)}
|
| 319 |
+
|
| 320 |
+
# We'll sort by absolute impact to get the top 10 by default.
|
| 321 |
+
abs_sorted_idx = np.argsort(np.abs(importances))[::-1]
|
| 322 |
+
# But for the final step/frequency plot we only show top_kmers
|
| 323 |
+
top_indices = abs_sorted_idx[:top_kmers]
|
| 324 |
+
|
| 325 |
+
# Build a list of the top k-mers
|
| 326 |
+
important_kmers = []
|
| 327 |
+
for idx in top_indices:
|
| 328 |
+
# "impact" is how much that feature changed the probability
|
| 329 |
+
impact = importances[idx]
|
| 330 |
+
# raw frequency => raw_freq_vector[idx] * 100 for %
|
| 331 |
+
freq_pct = float(raw_freq_vector[idx] * 100.0)
|
| 332 |
+
# sigma => scaled_vector[0][idx]
|
| 333 |
+
sigma_val = float(scaled_vector[0][idx])
|
| 334 |
+
|
| 335 |
+
important_kmers.append({
|
| 336 |
+
'kmer': kmers_4[idx],
|
| 337 |
+
'impact': impact,
|
| 338 |
+
'occurrence': freq_pct,
|
| 339 |
+
'sigma': sigma_val
|
| 340 |
+
})
|
| 341 |
|
| 342 |
+
# For text output
|
| 343 |
+
# We decide final class based on model's direct output
|
| 344 |
+
with torch.no_grad():
|
| 345 |
+
output = model(X_tensor)
|
| 346 |
+
probs = torch.softmax(output, dim=1)
|
| 347 |
+
pred_class = 1 if probs[0,1] > probs[0,0] else 0
|
| 348 |
+
pred_label = 'human' if pred_class == 1 else 'non-human'
|
| 349 |
+
human_prob = probs[0,1].item()
|
| 350 |
+
nonhuman_prob = probs[0,0].item()
|
| 351 |
+
confidence = max(human_prob, nonhuman_prob)
|
| 352 |
+
|
| 353 |
+
results_text = (f"Sequence: {header}\n"
|
| 354 |
+
f"Prediction: {pred_label}\n"
|
| 355 |
+
f"Confidence: {confidence:.4f}\n"
|
| 356 |
+
f"Human probability: {human_prob:.4f}\n"
|
| 357 |
+
f"Non-human probability: {nonhuman_prob:.4f}\n"
|
| 358 |
+
f"Most influential k-mers (by ablation impact):\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
+
for kmer_info in important_kmers:
|
| 361 |
+
# sign => if impact>0 => removing it lowers p(human), so it was pushing p(human) up
|
| 362 |
+
direction = "UP (toward human)" if kmer_info['impact'] > 0 else "DOWN (toward non-human)"
|
| 363 |
+
results_text += (
|
| 364 |
+
f" {kmer_info['kmer']}: {direction}, "
|
| 365 |
+
f"Impact={kmer_info['impact']:.4f}, "
|
| 366 |
+
f"Occ={kmer_info['occurrence']:.2f}% of seq, "
|
| 367 |
+
f"{abs(kmer_info['sigma']):.2f}σ "
|
| 368 |
+
+ ("above" if kmer_info['sigma']>0 else "below")
|
| 369 |
+
+ " mean\n"
|
| 370 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
+
# PLOT 1: A SHAP-like bar plot for the top K features
|
| 373 |
+
shap_fig = create_shap_like_bar_plot(importances, kmers_4, top_kmers)
|
| 374 |
+
|
| 375 |
+
# PLOT 2: Step + frequency plot for the top K features
|
| 376 |
+
step_fig = create_step_and_frequency_plot(important_kmers, human_prob, header)
|
| 377 |
+
|
| 378 |
+
# PLOT 3 (optional advanced): global bar plot of all 256 features
|
| 379 |
+
global_fig = None
|
| 380 |
+
if advanced_plots:
|
| 381 |
+
global_fig = create_global_bar_plot(importances, kmers_4)
|
| 382 |
+
|
| 383 |
+
# Convert figures to PIL Images
|
| 384 |
+
def fig_to_image(fig):
|
| 385 |
+
buf = io.BytesIO()
|
| 386 |
+
fig.savefig(buf, format='png', bbox_inches='tight', dpi=200)
|
| 387 |
+
buf.seek(0)
|
| 388 |
+
im = Image.open(buf)
|
| 389 |
+
plt.close(fig)
|
| 390 |
+
return im
|
| 391 |
+
|
| 392 |
+
shap_img = fig_to_image(shap_fig)
|
| 393 |
+
step_img = fig_to_image(step_fig)
|
| 394 |
+
if global_fig is not None:
|
| 395 |
+
global_img = fig_to_image(global_fig)
|
| 396 |
+
else:
|
| 397 |
+
global_img = None
|
| 398 |
+
|
| 399 |
+
return results_text, shap_img, step_img, global_img
|
| 400 |
+
|
| 401 |
+
##############################################################################
|
| 402 |
+
# GRADIO INTERFACE
|
| 403 |
+
##############################################################################
|
| 404 |
+
|
| 405 |
+
title_text = "Virus Host Classifier"
|
| 406 |
+
description_text = """
|
| 407 |
+
Upload or paste a FASTA sequence to predict if it's likely **human** or **non-human** origin.
|
| 408 |
+
- **k=4** k-mers are used as features.
|
| 409 |
+
- We display ablation-based feature importance for interpretability.
|
| 410 |
+
- Advanced plots can be toggled to see the global distribution of all 256 k-mer impacts.
|
| 411 |
+
"""
|
| 412 |
+
|
| 413 |
+
iface = gr.Interface(
|
| 414 |
+
fn=predict,
|
| 415 |
+
inputs=[
|
| 416 |
+
gr.File(label="Upload FASTA file", type="binary", optional=True),
|
| 417 |
+
gr.Slider(label="Number of top k-mers to show", minimum=1, maximum=50, value=10, step=1),
|
| 418 |
+
gr.Checkbox(label="Show advanced (global) plots?", value=False),
|
| 419 |
+
gr.Textbox(label="Or paste FASTA text here", lines=5, placeholder=">header\nACGTACGT...")
|
| 420 |
+
],
|
| 421 |
+
outputs=[
|
| 422 |
+
gr.Textbox(label="Results", lines=10),
|
| 423 |
+
gr.Image(label="SHAP-like Top-k K-mer Bar Plot"),
|
| 424 |
+
gr.Image(label="Step & Frequency Plot (Top-k)"),
|
| 425 |
+
gr.Image(label="Global 256-K-mer Plot (advanced)", optional=True)
|
| 426 |
+
],
|
| 427 |
+
title=title_text,
|
| 428 |
+
description=description_text
|
| 429 |
+
)
|
| 430 |
|
| 431 |
if __name__ == "__main__":
|
| 432 |
+
iface.launch(share=True)
|
|
|