BioGeMT-miRBind2 / alignment_plot.py
dimostzim's picture
fix orientation
9c47d54
"""SHAP-guided target site alignment visualization helpers."""
from __future__ import annotations
import io
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
COLOR_PALETTE = [
'#034da1', '#3359a7', '#4c66ad', '#6174b3', '#7482b9',
'#8590be', '#979ec4', '#a8adca', '#b9bbcf', '#c9cbd5',
'#dadada', '#dac4c4', '#daaeae', '#da9999', '#da8383',
'#da6d6d', '#d95757', '#d94141', '#d92c2c', '#d91616', '#d90000',
]
COMPLEMENT = {
'A': 'T',
'T': 'A',
'C': 'G',
'G': 'C',
'-': '',
'N': '',
}
def _plot_t(ax, base, left_edge, height, color):
ax.add_patch(patches.Rectangle(
xy=[left_edge + 0.4, base], width=0.2, height=height,
facecolor=color, edgecolor=color, fill=True
))
ax.add_patch(patches.Rectangle(
xy=[left_edge, base + 0.85 * height], width=1.0, height=0.15 * height,
facecolor=color, edgecolor=color, fill=True
))
def _plot_g(ax, base, left_edge, height, color):
ax.add_patch(patches.Ellipse(
xy=[left_edge + 0.65, base + 0.5 * height], width=1.3, height=height,
facecolor=color, edgecolor=color
))
ax.add_patch(patches.Ellipse(
xy=[left_edge + 0.65, base + 0.5 * height], width=0.91, height=0.7 * height,
facecolor='white', edgecolor='white'
))
ax.add_patch(patches.Rectangle(
xy=[left_edge + 1, base], width=1.0, height=height,
facecolor='white', edgecolor='white', fill=True
))
ax.add_patch(patches.Rectangle(
xy=[left_edge + 0.825, base + 0.085 * height], width=0.174, height=0.415 * height,
facecolor=color, edgecolor=color, fill=True
))
ax.add_patch(patches.Rectangle(
xy=[left_edge + 0.625, base + 0.35 * height], width=0.374, height=0.15 * height,
facecolor=color, edgecolor=color, fill=True
))
def _plot_a(ax, base, left_edge, height, color):
polygons = [
np.array([[0.0, 0.0], [0.5, 1.0], [0.5, 0.8], [0.2, 0.0]]),
np.array([[1.0, 0.0], [0.5, 1.0], [0.5, 0.8], [0.8, 0.0]]),
np.array([[0.225, 0.45], [0.775, 0.45], [0.85, 0.3], [0.15, 0.3]]),
]
scale = np.array([1, height])[None, :]
offset = np.array([left_edge, base])[None, :]
for polygon in polygons:
ax.add_patch(patches.Polygon(polygon * scale + offset, facecolor=color, edgecolor=color))
def _plot_c(ax, base, left_edge, height, color):
ax.add_patch(patches.Ellipse(
xy=[left_edge + 0.65, base + 0.5 * height], width=1.3, height=height,
facecolor=color, edgecolor=color
))
ax.add_patch(patches.Ellipse(
xy=[left_edge + 0.65, base + 0.5 * height], width=0.91, height=0.7 * height,
facecolor='white', edgecolor='white'
))
ax.add_patch(patches.Rectangle(
xy=[left_edge + 1, base], width=1.0, height=height,
facecolor='white', edgecolor='white', fill=True
))
def _plot_n(ax, base, left_edge, height, color):
ax.add_patch(patches.Rectangle(
xy=[left_edge, base], width=0.2, height=height,
facecolor=color, edgecolor=color, fill=True
))
ax.add_patch(patches.Rectangle(
xy=[left_edge + 0.8, base], width=0.15, height=height,
facecolor=color, edgecolor=color, fill=True, angle=45
))
ax.add_patch(patches.Rectangle(
xy=[left_edge + 0.8, base], width=0.2, height=height,
facecolor=color, edgecolor=color, fill=True
))
def _plot_dash(ax, base, left_edge, height, color):
ax.add_patch(patches.Rectangle(
xy=[left_edge + 0.2, base + 0.3 * height], width=0.6, height=0.2 * height,
facecolor=color, edgecolor=color, fill=True
))
def _plot_line(ax, base, left_edge, height, color):
ax.add_patch(patches.Rectangle(
xy=[left_edge + 0.45, base + 0.2 * height], width=0.1, height=0.6 * height,
facecolor=color, edgecolor=color, fill=True
))
def _plot_dot(ax, base, left_edge, height, color):
ax.add_patch(patches.Ellipse(
xy=[left_edge + 0.5, base + 0.5 * height], width=0.3, height=0.2 * height,
facecolor=color, edgecolor=color
))
BASE_PLOTTERS = {
'A': _plot_a,
'C': _plot_c,
'T': _plot_t,
'G': _plot_g,
'N': _plot_n,
'-': _plot_dash,
}
class _ShapAlignment:
"""Needleman-Wunsch-style traceback with SHAP-informed scores."""
RIGHT = "right"
DOWN = "down"
DIAGONAL = "diag"
def align(self, mirna_seq, target_seq, score_matrix,
opening_percentile=99, elongating_percentile=90):
processed_scores, opening_gap, elongating_gap = self._preprocess_score(
score_matrix, opening_percentile, elongating_percentile
)
grid = self._forward_pass(mirna_seq, target_seq, processed_scores, opening_gap, elongating_gap)
return self._backward_pass(mirna_seq, target_seq, grid, processed_scores)
def _score(self, grid, i, j, gap_penalty, score):
directions = [self.RIGHT, self.DIAGONAL, self.DOWN]
values = [
grid[i, j - 1][1] - gap_penalty,
grid[i - 1, j - 1][1] + score,
grid[i - 1, j][1] - gap_penalty,
]
best_index = int(np.argmax(values))
return directions[best_index], values[best_index]
def _forward_pass(self, mirna_seq, target_seq, score_matrix, opening_gap, elongating_gap):
target_bases = [""] + list(target_seq)
mirna_bases = [""] + list(mirna_seq)
grid = np.empty((len(target_bases), len(mirna_bases)), dtype=object)
for i in range(len(target_bases)):
grid[i, 0] = (self.DOWN, 0.0)
for j in range(len(mirna_bases)):
grid[0, j] = (self.RIGHT, 0.0)
is_opening = False
for i in range(1, len(target_bases)):
for j in range(1, len(mirna_bases)):
gap_penalty = opening_gap if is_opening else elongating_gap
if i == len(target_bases) - 1 or j == len(mirna_bases) - 1:
gap_penalty = 0.0
grid[i, j] = self._score(grid, i, j, gap_penalty, score_matrix[i, j])
is_opening = grid[i, j][0] == self.DIAGONAL
return grid
def _backward_pass(self, mirna_seq, target_seq, grid, score_matrix):
target_bases = [""] + list(target_seq)
mirna_bases = [""] + list(mirna_seq)
aligned_target = []
aligned_mirna = []
aligned_scores = []
i = grid.shape[0] - 1
j = grid.shape[1] - 1
while i != 0 or j != 0:
direction = grid[i, j][0]
if direction == self.RIGHT:
aligned_target.append("-")
aligned_mirna.append(mirna_bases[j])
aligned_scores.append(0.0)
j -= 1
elif direction == self.DOWN:
aligned_target.append(target_bases[i])
aligned_mirna.append("-")
aligned_scores.append(0.0)
i -= 1
else:
aligned_target.append(target_bases[i])
aligned_mirna.append(mirna_bases[j])
aligned_scores.append(score_matrix[i, j])
i -= 1
j -= 1
return aligned_target, aligned_scores, aligned_mirna
def _preprocess_score(self, score_matrix, opening_percentile, elongating_percentile):
score_matrix = np.asarray(score_matrix, dtype=float)
if score_matrix.size == 0:
return np.zeros((1, 1), dtype=float), 0.0, 0.0
max_absolute_value = float(np.max(np.abs(score_matrix)))
if max_absolute_value > 0:
score_matrix = score_matrix / max_absolute_value
else:
score_matrix = np.zeros_like(score_matrix)
abs_values = np.abs(score_matrix.sum(axis=-1)).ravel()
opening_gap = float(np.nanpercentile(abs_values, opening_percentile)) if abs_values.size else 0.0
elongating_gap = float(np.nanpercentile(abs_values, elongating_percentile)) if abs_values.size else 0.0
score_matrix = np.vstack([np.zeros(score_matrix.shape[1]), score_matrix])
score_matrix = np.hstack([np.zeros((score_matrix.shape[0], 1)), score_matrix])
return score_matrix, opening_gap, elongating_gap
def _scale_scores(alignment_scores):
values = np.asarray(alignment_scores, dtype=float)
if values.size == 0:
return np.array([], dtype=int)
max_absolute_value = float(np.max(np.abs(values)))
if max_absolute_value == 0.0:
return np.full(values.shape, 10, dtype=int)
scaled = values / max_absolute_value
scaled = scaled * 10 + 10
return np.clip(np.rint(scaled).astype(int), 0, len(COLOR_PALETTE) - 1)
def compute_shap_alignment(mirna_seq, target_seq, shap_2d,
opening_percentile=99, elongating_percentile=90):
"""Align miRNA and target using SHAP attributions as pairwise scores."""
shap_array = np.asarray(shap_2d, dtype=float)
expected_shape = (len(mirna_seq), len(target_seq))
if shap_array.shape != expected_shape:
raise ValueError(
f"Expected shap_2d shape {expected_shape}, got {tuple(shap_array.shape)}"
)
aligner = _ShapAlignment()
aligned_target, aligned_scores, aligned_mirna = aligner.align(
mirna_seq[::-1],
target_seq,
shap_array[::-1].T,
opening_percentile=opening_percentile,
elongating_percentile=elongating_percentile,
)
return aligned_target[::-1], aligned_scores[::-1], aligned_mirna[::-1]
def plot_alignment_image(mirna_seq, target_seq, shap_2d, arrows=True):
"""Render the SHAP-guided alignment as a PIL image."""
aligned_target, aligned_scores, aligned_mirna = compute_shap_alignment(
mirna_seq, target_seq, shap_2d
)
color_indices = _scale_scores(aligned_scores)
fig_width = max(12, len(aligned_target) * 0.42)
fig, ax = plt.subplots(figsize=(fig_width, 3.2))
step = 1.2
left_margin = 2.8
alignment_width = len(aligned_target) * step
ax.set_xlim(-left_margin, alignment_width + 0.6)
ax.set_ylim(0, 4.8)
ax.text(-left_margin + 0.2, 3.5, "Target", ha='left', va='center',
fontsize=12, fontweight='bold')
ax.text(-left_margin + 0.2, 1.1, "miRNA", ha='left', va='center',
fontsize=12, fontweight='bold')
if arrows:
ax.add_patch(patches.Arrow(x=0, y=4.4, dx=2, dy=0, width=0.6, color=COLOR_PALETTE[10]))
ax.add_patch(patches.Arrow(
x=max(alignment_width - 0.2, 0), y=0.2, dx=-2, dy=0, width=0.6, color=COLOR_PALETTE[10]
))
ax.text(-0.1, 4.45, "5'", ha='right', va='center', fontsize=10, fontweight='bold')
ax.text(alignment_width + 0.15, 4.45, "3'", ha='left', va='center', fontsize=10, fontweight='bold')
ax.text(-0.1, 0.15, "3'", ha='right', va='center', fontsize=10, fontweight='bold')
ax.text(alignment_width + 0.15, 0.15, "5'", ha='left', va='center', fontsize=10, fontweight='bold')
for i, (target_base, mirna_base, color_index) in enumerate(
zip(aligned_target, aligned_mirna, color_indices)
):
color = COLOR_PALETTE[int(color_index)]
left_edge = step * i
BASE_PLOTTERS.get(mirna_base, _plot_n)(ax, 0.6, left_edge, 1.0, color)
if mirna_base == COMPLEMENT.get(target_base):
_plot_line(ax, 1.7, left_edge, 1.0, color)
elif mirna_base in "ACTG" and target_base in "ACTG":
_plot_dot(ax, 1.7, left_edge, 1.0, color)
BASE_PLOTTERS.get(target_base, _plot_n)(ax, 3.0, left_edge, 1.0, color)
ax.axis('off')
fig.tight_layout()
buffer = io.BytesIO()
fig.savefig(buffer, format='png', dpi=120, bbox_inches='tight')
buffer.seek(0)
image = Image.open(buffer)
plt.close(fig)
return image