Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -330,76 +330,64 @@ def compute_shap_difference(shap1_norm, shap2_norm):
|
|
| 330 |
|
| 331 |
def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
|
| 332 |
"""
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
"""
|
| 341 |
-
|
| 342 |
-
# Prepare data for the heatmap
|
| 343 |
heatmap_data = shap_diff.reshape(1, -1)
|
| 344 |
-
extent = max(abs(
|
| 345 |
|
| 346 |
-
# Create figure and
|
| 347 |
fig, ax = plt.subplots(figsize=(12, 3))
|
|
|
|
|
|
|
| 348 |
|
| 349 |
-
#
|
| 350 |
-
cmap = get_zero_centered_cmap()
|
| 351 |
-
cax = ax.imshow(
|
| 352 |
-
heatmap_data,
|
| 353 |
-
aspect='auto',
|
| 354 |
-
cmap=cmap,
|
| 355 |
-
vmin=-extent,
|
| 356 |
-
vmax=extent
|
| 357 |
-
)
|
| 358 |
-
|
| 359 |
-
# Add a vertical colorbar on the right
|
| 360 |
-
cbar = plt.colorbar(cax, ax=ax, orientation='vertical', fraction=0.025, pad=0.03)
|
| 361 |
-
cbar.ax.tick_params(labelsize=9)
|
| 362 |
-
cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
|
| 363 |
-
|
| 364 |
-
# Configure the top axis for relative (%) positions
|
| 365 |
num_ticks = 5
|
| 366 |
tick_positions = np.linspace(0, shap_diff.shape[0] - 1, num_ticks)
|
| 367 |
-
tick_labels = [f"{int(x * 100)}%" for x in np.linspace(0, 1, num_ticks)]
|
| 368 |
ax.set_xticks(tick_positions)
|
| 369 |
-
ax.set_xticklabels(
|
| 370 |
-
|
| 371 |
-
ax.set_xlabel('Relative Position (%)', fontsize=10, labelpad=8)
|
| 372 |
-
ax.set_title(title, fontsize=12, pad=10)
|
| 373 |
-
ax.set_yticks([]) # Hide the y-axis ticks for a 1D heatmap
|
| 374 |
-
|
| 375 |
-
# Create a second (bottom) x-axis for actual positions
|
| 376 |
-
ax2 = ax.secondary_xaxis('bottom')
|
| 377 |
-
ax2.set_xlim(ax.get_xlim()) # Match the same data range as the top axis
|
| 378 |
|
| 379 |
-
#
|
|
|
|
|
|
|
| 380 |
seq1_positions = np.linspace(0, seq1_length, num_ticks)
|
| 381 |
seq2_positions = np.linspace(0, seq2_length, num_ticks)
|
| 382 |
|
| 383 |
-
# Format large numbers with 'K' or 'M'
|
| 384 |
def format_position(x):
|
| 385 |
if x >= 1e6:
|
| 386 |
-
return f"{x/1e6:.1f}M"
|
| 387 |
elif x >= 1e3:
|
| 388 |
-
return f"{x/1e3:.0f}K"
|
| 389 |
else:
|
| 390 |
-
return
|
| 391 |
|
| 392 |
seq1_labels = [format_position(x) for x in seq1_positions]
|
| 393 |
seq2_labels = [format_position(x) for x in seq2_positions]
|
| 394 |
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
-
# Use tight_layout to reduce overlap
|
| 402 |
-
plt.tight_layout()
|
| 403 |
return fig
|
| 404 |
|
| 405 |
def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
|
|
|
|
| 330 |
|
| 331 |
def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
|
| 332 |
"""
|
| 333 |
+
Plots a comparative heatmap of SHAP differences between two sequences.
|
| 334 |
+
- The bottom x-axis shows relative positions (%) across the normalized dimension.
|
| 335 |
+
- The top x-axis shows actual positions for each sequence (S1 and S2).
|
| 336 |
+
- A vertical colorbar is placed to the right.
|
| 337 |
+
- Negative (blue) indicates Seq1 is more human-like in that region,
|
| 338 |
+
positive (red) indicates Seq2 is more human-like,
|
| 339 |
+
white indicates no substantial difference.
|
| 340 |
"""
|
| 341 |
+
# Prepare data
|
|
|
|
| 342 |
heatmap_data = shap_diff.reshape(1, -1)
|
| 343 |
+
extent = max(abs(shap_diff.min()), abs(shap_diff.max()))
|
| 344 |
|
| 345 |
+
# Create figure and axis
|
| 346 |
fig, ax = plt.subplots(figsize=(12, 3))
|
| 347 |
+
cmap = get_zero_centered_cmap() # Ensure this function is defined above
|
| 348 |
+
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
|
| 349 |
|
| 350 |
+
# Bottom axis: percentage-based x-axis
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
num_ticks = 5
|
| 352 |
tick_positions = np.linspace(0, shap_diff.shape[0] - 1, num_ticks)
|
|
|
|
| 353 |
ax.set_xticks(tick_positions)
|
| 354 |
+
ax.set_xticklabels([f"{int(x * 100)}%" for x in np.linspace(0, 1, num_ticks)], fontsize=9)
|
| 355 |
+
ax.set_xlabel("Relative Position (%)", fontsize=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
+
# Top axis: actual sequence positions for both Seq1 and Seq2
|
| 358 |
+
ax_top = ax.twiny()
|
| 359 |
+
ax_top.set_xlim(ax.get_xlim())
|
| 360 |
seq1_positions = np.linspace(0, seq1_length, num_ticks)
|
| 361 |
seq2_positions = np.linspace(0, seq2_length, num_ticks)
|
| 362 |
|
|
|
|
| 363 |
def format_position(x):
|
| 364 |
if x >= 1e6:
|
| 365 |
+
return f"{x / 1e6:.1f}M"
|
| 366 |
elif x >= 1e3:
|
| 367 |
+
return f"{x / 1e3:.0f}K"
|
| 368 |
else:
|
| 369 |
+
return f"{int(x)}"
|
| 370 |
|
| 371 |
seq1_labels = [format_position(x) for x in seq1_positions]
|
| 372 |
seq2_labels = [format_position(x) for x in seq2_positions]
|
| 373 |
|
| 374 |
+
ax_top.set_xticks(tick_positions)
|
| 375 |
+
ax_top.set_xticklabels(
|
| 376 |
+
[f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)],
|
| 377 |
+
fontsize=8
|
| 378 |
+
)
|
| 379 |
+
ax_top.set_xlabel("Sequence Positions", fontsize=10, labelpad=15)
|
| 380 |
+
|
| 381 |
+
# Colorbar on the right
|
| 382 |
+
cbar = fig.colorbar(cax, ax=ax, orientation='vertical', fraction=0.02, pad=0.02)
|
| 383 |
+
cbar.set_label("SHAP Difference (Seq2 - Seq1)", fontsize=10)
|
| 384 |
+
|
| 385 |
+
# Aesthetics
|
| 386 |
+
ax.set_yticks([])
|
| 387 |
+
ax.set_title(title, fontsize=12, pad=10)
|
| 388 |
+
|
| 389 |
+
fig.tight_layout(rect=[0, 0, 0.88, 1])
|
| 390 |
|
|
|
|
|
|
|
| 391 |
return fig
|
| 392 |
|
| 393 |
def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
|