Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -328,32 +328,70 @@ def compute_shap_difference(shap1_norm, shap2_norm):
|
|
| 328 |
"""Compute the SHAP difference between normalized sequences"""
|
| 329 |
return shap2_norm - shap1_norm
|
| 330 |
|
| 331 |
-
def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
|
| 332 |
"""
|
| 333 |
-
Plot heatmap using relative positions (0-100%)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
"""
|
| 335 |
heatmap_data = shap_diff.reshape(1, -1)
|
| 336 |
extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
|
| 337 |
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
| 339 |
cmap = get_zero_centered_cmap()
|
| 340 |
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
|
| 341 |
|
| 342 |
-
# Create percentage-based x-axis ticks
|
| 343 |
num_ticks = 5
|
| 344 |
tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
|
| 345 |
tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
|
| 346 |
ax.set_xticks(tick_positions)
|
| 347 |
ax.set_xticklabels(tick_labels)
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
|
| 350 |
cbar.ax.tick_params(labelsize=8)
|
| 351 |
cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
|
| 352 |
|
|
|
|
| 353 |
ax.set_yticks([])
|
| 354 |
-
ax.set_xlabel('Relative Position
|
|
|
|
| 355 |
ax.set_title(title, pad=10)
|
| 356 |
-
|
|
|
|
|
|
|
| 357 |
|
| 358 |
return fig
|
| 359 |
|
|
|
|
| 328 |
"""Compute the SHAP difference between normalized sequences"""
|
| 329 |
return shap2_norm - shap1_norm
|
| 330 |
|
| 331 |
+
def plot_comparative_heatmap(shap_diff, seq1_length, seq2_length, title="SHAP Difference Heatmap"):
|
| 332 |
"""
|
| 333 |
+
Plot heatmap using both relative positions (0-100%) and actual sequence positions
|
| 334 |
+
|
| 335 |
+
Parameters:
|
| 336 |
+
shap_diff: numpy array of SHAP differences
|
| 337 |
+
seq1_length: length of sequence 1
|
| 338 |
+
seq2_length: length of sequence 2
|
| 339 |
+
title: plot title
|
| 340 |
"""
|
| 341 |
heatmap_data = shap_diff.reshape(1, -1)
|
| 342 |
extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
|
| 343 |
|
| 344 |
+
# Create figure with additional space for the second x-axis
|
| 345 |
+
fig, ax = plt.subplots(figsize=(12, 2.4))
|
| 346 |
+
|
| 347 |
+
# Plot main heatmap
|
| 348 |
cmap = get_zero_centered_cmap()
|
| 349 |
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
|
| 350 |
|
| 351 |
+
# Create percentage-based x-axis ticks (top)
|
| 352 |
num_ticks = 5
|
| 353 |
tick_positions = np.linspace(0, shap_diff.shape[0]-1, num_ticks)
|
| 354 |
tick_labels = [f"{int(x*100)}%" for x in np.linspace(0, 1, num_ticks)]
|
| 355 |
ax.set_xticks(tick_positions)
|
| 356 |
ax.set_xticklabels(tick_labels)
|
| 357 |
|
| 358 |
+
# Create second x-axis for actual positions (bottom)
|
| 359 |
+
ax2 = ax.twiny()
|
| 360 |
+
ax2.set_xlim(ax.get_xlim())
|
| 361 |
+
|
| 362 |
+
# Calculate actual positions for both sequences
|
| 363 |
+
seq1_positions = np.linspace(0, seq1_length, num_ticks)
|
| 364 |
+
seq2_positions = np.linspace(0, seq2_length, num_ticks)
|
| 365 |
+
|
| 366 |
+
# Format position labels with appropriate scaling
|
| 367 |
+
def format_position(x):
|
| 368 |
+
if x >= 1e6:
|
| 369 |
+
return f"{x/1e6:.1f}M"
|
| 370 |
+
elif x >= 1e3:
|
| 371 |
+
return f"{x/1e3:.0f}K"
|
| 372 |
+
else:
|
| 373 |
+
return f"{int(x)}"
|
| 374 |
+
|
| 375 |
+
seq1_labels = [format_position(x) for x in seq1_positions]
|
| 376 |
+
seq2_labels = [format_position(x) for x in seq2_positions]
|
| 377 |
+
|
| 378 |
+
# Set positions for bottom axis
|
| 379 |
+
ax2.set_xticks(tick_positions)
|
| 380 |
+
ax2.set_xticklabels([f"S1: {s1}\nS2: {s2}" for s1, s2 in zip(seq1_labels, seq2_labels)])
|
| 381 |
+
|
| 382 |
+
# Add colorbar
|
| 383 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
|
| 384 |
cbar.ax.tick_params(labelsize=8)
|
| 385 |
cbar.set_label('SHAP Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
|
| 386 |
|
| 387 |
+
# Adjust labels and layout
|
| 388 |
ax.set_yticks([])
|
| 389 |
+
ax.set_xlabel('Relative Position (%)', fontsize=10)
|
| 390 |
+
ax2.set_xlabel('Sequence Positions', fontsize=10)
|
| 391 |
ax.set_title(title, pad=10)
|
| 392 |
+
|
| 393 |
+
# Adjust layout to prevent label overlap
|
| 394 |
+
plt.subplots_adjust(bottom=0.35, left=0.05, right=0.95, top=0.85)
|
| 395 |
|
| 396 |
return fig
|
| 397 |
|