Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -328,106 +328,35 @@ def compute_shap_difference(shap1_norm, shap2_norm):
|
|
| 328 |
"""Compute the SHAP difference between normalized sequences"""
|
| 329 |
return shap2_norm - shap1_norm
|
| 330 |
|
| 331 |
-
def plot_comparative_heatmap(shap_diff,
|
| 332 |
"""
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
Parameters
|
| 336 |
-
----------
|
| 337 |
-
shap_diff : 1D array of differences (Seq2 SHAP - Seq1 SHAP).
|
| 338 |
-
Negative = Seq1 more human-like, Positive = Seq2 more human-like.
|
| 339 |
-
seq1_length : int, length of sequence 1 (for labeling).
|
| 340 |
-
seq2_length : int, length of sequence 2 (for labeling).
|
| 341 |
-
title : str, figure title.
|
| 342 |
-
|
| 343 |
-
Figure Layout
|
| 344 |
-
-------------
|
| 345 |
-
- Bottom X-Axis: Relative positions in percent (0% to 100%).
|
| 346 |
-
- Top X-Axis : Actual positions for both sequences (S1 and S2).
|
| 347 |
-
- Y-Axis : Hidden (this is effectively a 1D heatmap).
|
| 348 |
-
- Colorbar : Vertical, placed on the right.
|
| 349 |
-
|
| 350 |
-
The final layout uses `tight_layout` with a rectangular constraint to ensure
|
| 351 |
-
nothing overlaps while still providing clear labeling.
|
| 352 |
"""
|
| 353 |
-
# Reshape the 1D differences into a 1 x N image
|
| 354 |
heatmap_data = shap_diff.reshape(1, -1)
|
| 355 |
-
|
| 356 |
-
extent = max(abs(shap_diff.min()), abs(shap_diff.max()))
|
| 357 |
-
|
| 358 |
-
# Create the figure (width x height in inches)
|
| 359 |
-
fig, ax = plt.subplots(figsize=(10, 3))
|
| 360 |
|
| 361 |
-
|
| 362 |
cmap = get_zero_centered_cmap()
|
| 363 |
-
cax = ax.imshow(
|
| 364 |
-
heatmap_data,
|
| 365 |
-
aspect='auto',
|
| 366 |
-
cmap=cmap,
|
| 367 |
-
vmin=-extent,
|
| 368 |
-
vmax=extent
|
| 369 |
-
)
|
| 370 |
|
| 371 |
-
#
|
| 372 |
num_ticks = 5
|
| 373 |
-
tick_positions = np.linspace(0, shap_diff.shape[0]
|
| 374 |
-
|
| 375 |
-
# ----------------- Bottom Axis: Percentage ----------------- #
|
| 376 |
ax.set_xticks(tick_positions)
|
| 377 |
-
ax.set_xticklabels(
|
| 378 |
-
[f"{int(x * 100)}%" for x in np.linspace(0, 1, num_ticks)],
|
| 379 |
-
fontsize=9
|
| 380 |
-
)
|
| 381 |
-
ax.set_xlabel("Relative Position (%)", fontsize=10, labelpad=10)
|
| 382 |
-
|
| 383 |
-
# ----------------- Top Axis: Actual Positions ----------------- #
|
| 384 |
-
ax_top = ax.twiny()
|
| 385 |
-
ax_top.set_xlim(ax.get_xlim()) # Match the bottom axis
|
| 386 |
-
|
| 387 |
-
# Create position arrays for both sequences
|
| 388 |
-
seq1_positions = np.linspace(0, seq1_length, num_ticks)
|
| 389 |
-
seq2_positions = np.linspace(0, seq2_length, num_ticks)
|
| 390 |
-
|
| 391 |
-
# Helper function to format large positions nicely
|
| 392 |
-
def format_position(x):
|
| 393 |
-
if x >= 1e6:
|
| 394 |
-
return f"{x / 1e6:.1f}M"
|
| 395 |
-
elif x >= 1e3:
|
| 396 |
-
return f"{int(x / 1e3)}K"
|
| 397 |
-
else:
|
| 398 |
-
return f"{int(x)}"
|
| 399 |
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
ax_top.set_xticks(tick_positions)
|
| 404 |
-
# Each tick label shows the corresponding position in Seq1 and Seq2
|
| 405 |
-
ax_top.set_xticklabels(
|
| 406 |
-
[f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)],
|
| 407 |
-
fontsize=9
|
| 408 |
-
)
|
| 409 |
-
ax_top.set_xlabel("Sequence Positions", fontsize=10, labelpad=15)
|
| 410 |
-
|
| 411 |
-
# ----------------- Colorbar (Vertical, on the right) ----------------- #
|
| 412 |
-
# 'fraction' = thickness of colorbar, 'pad' = gap from the right edge
|
| 413 |
-
cbar = fig.colorbar(
|
| 414 |
-
cax, ax=ax, orientation='vertical', fraction=0.03, pad=0.07
|
| 415 |
-
)
|
| 416 |
-
cbar.set_label("SHAP Difference\n(Seq2 - Seq1)", fontsize=10, labelpad=5)
|
| 417 |
|
| 418 |
-
# Hide the y-axis (not needed in a 1D heatmap)
|
| 419 |
ax.set_yticks([])
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
# Adjust layout so everything fits without overlapping
|
| 425 |
-
# The rect parameter leaves space on the right for the colorbar
|
| 426 |
-
fig.tight_layout(rect=[0, 0, 0.9, 1]) # Adjust as necessary
|
| 427 |
|
| 428 |
return fig
|
| 429 |
|
| 430 |
-
|
| 431 |
def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
|
| 432 |
"""
|
| 433 |
Plot histogram of SHAP values with configurable number of bins
|
|
@@ -608,8 +537,6 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
| 608 |
# Generate visualizations
|
| 609 |
heatmap_fig = plot_comparative_heatmap(
|
| 610 |
shap_diff,
|
| 611 |
-
seq1_length=len1,
|
| 612 |
-
seq2_length=len2,
|
| 613 |
title=f"SHAP Difference Heatmap (window: {smooth_window})"
|
| 614 |
)
|
| 615 |
heatmap_img = fig_to_image(heatmap_fig)
|
|
|
|
| 328 |
"""Compute the SHAP difference between normalized sequences"""
|
| 329 |
return shap2_norm - shap1_norm
|
| 330 |
|
| 331 |
+
def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
|
| 332 |
"""
|
| 333 |
+
Plot heatmap using relative positions (0-100%)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
"""
|
|
|
|
| 335 |
heatmap_data = shap_diff.reshape(1, -1)
|
| 336 |
+
extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
+
fig, ax = plt.subplots(figsize=(12, 1.8))
|
| 339 |
cmap = get_zero_centered_cmap()
|
| 340 |
+
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
+
# Create percentage-based x-axis ticks
|
| 343 |
num_ticks = 5
|
| 344 |
+
tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
|
| 345 |
+
tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
|
|
|
|
| 346 |
ax.set_xticks(tick_positions)
|
| 347 |
+
ax.set_xticklabels(tick_labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
+
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
|
| 350 |
+
cbar.ax.tick_params(labelsize=8)
|
| 351 |
+
cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
|
|
|
|
| 353 |
ax.set_yticks([])
|
| 354 |
+
ax.set_xlabel('Relative Position in Sequence', fontsize=10)
|
| 355 |
+
ax.set_title(title, pad=10)
|
| 356 |
+
plt.subplots_adjust(bottom=0.25, left=0.05, right=0.95)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
|
| 358 |
return fig
|
| 359 |
|
|
|
|
| 360 |
def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
|
| 361 |
"""
|
| 362 |
Plot histogram of SHAP values with configurable number of bins
|
|
|
|
| 537 |
# Generate visualizations
|
| 538 |
heatmap_fig = plot_comparative_heatmap(
|
| 539 |
shap_diff,
|
|
|
|
|
|
|
| 540 |
title=f"SHAP Difference Heatmap (window: {smooth_window})"
|
| 541 |
)
|
| 542 |
heatmap_img = fig_to_image(heatmap_fig)
|