"""SHAP-guided target site alignment visualization helpers.""" from __future__ import annotations import io import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np from PIL import Image COLOR_PALETTE = [ '#034da1', '#3359a7', '#4c66ad', '#6174b3', '#7482b9', '#8590be', '#979ec4', '#a8adca', '#b9bbcf', '#c9cbd5', '#dadada', '#dac4c4', '#daaeae', '#da9999', '#da8383', '#da6d6d', '#d95757', '#d94141', '#d92c2c', '#d91616', '#d90000', ] COMPLEMENT = { 'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C', '-': '', 'N': '', } def _plot_t(ax, base, left_edge, height, color): ax.add_patch(patches.Rectangle( xy=[left_edge + 0.4, base], width=0.2, height=height, facecolor=color, edgecolor=color, fill=True )) ax.add_patch(patches.Rectangle( xy=[left_edge, base + 0.85 * height], width=1.0, height=0.15 * height, facecolor=color, edgecolor=color, fill=True )) def _plot_g(ax, base, left_edge, height, color): ax.add_patch(patches.Ellipse( xy=[left_edge + 0.65, base + 0.5 * height], width=1.3, height=height, facecolor=color, edgecolor=color )) ax.add_patch(patches.Ellipse( xy=[left_edge + 0.65, base + 0.5 * height], width=0.91, height=0.7 * height, facecolor='white', edgecolor='white' )) ax.add_patch(patches.Rectangle( xy=[left_edge + 1, base], width=1.0, height=height, facecolor='white', edgecolor='white', fill=True )) ax.add_patch(patches.Rectangle( xy=[left_edge + 0.825, base + 0.085 * height], width=0.174, height=0.415 * height, facecolor=color, edgecolor=color, fill=True )) ax.add_patch(patches.Rectangle( xy=[left_edge + 0.625, base + 0.35 * height], width=0.374, height=0.15 * height, facecolor=color, edgecolor=color, fill=True )) def _plot_a(ax, base, left_edge, height, color): polygons = [ np.array([[0.0, 0.0], [0.5, 1.0], [0.5, 0.8], [0.2, 0.0]]), np.array([[1.0, 0.0], [0.5, 1.0], [0.5, 0.8], [0.8, 0.0]]), np.array([[0.225, 0.45], [0.775, 0.45], [0.85, 0.3], [0.15, 0.3]]), ] scale = np.array([1, height])[None, :] offset = np.array([left_edge, base])[None, :] for polygon in polygons: ax.add_patch(patches.Polygon(polygon * scale + offset, facecolor=color, edgecolor=color)) def _plot_c(ax, base, left_edge, height, color): ax.add_patch(patches.Ellipse( xy=[left_edge + 0.65, base + 0.5 * height], width=1.3, height=height, facecolor=color, edgecolor=color )) ax.add_patch(patches.Ellipse( xy=[left_edge + 0.65, base + 0.5 * height], width=0.91, height=0.7 * height, facecolor='white', edgecolor='white' )) ax.add_patch(patches.Rectangle( xy=[left_edge + 1, base], width=1.0, height=height, facecolor='white', edgecolor='white', fill=True )) def _plot_n(ax, base, left_edge, height, color): ax.add_patch(patches.Rectangle( xy=[left_edge, base], width=0.2, height=height, facecolor=color, edgecolor=color, fill=True )) ax.add_patch(patches.Rectangle( xy=[left_edge + 0.8, base], width=0.15, height=height, facecolor=color, edgecolor=color, fill=True, angle=45 )) ax.add_patch(patches.Rectangle( xy=[left_edge + 0.8, base], width=0.2, height=height, facecolor=color, edgecolor=color, fill=True )) def _plot_dash(ax, base, left_edge, height, color): ax.add_patch(patches.Rectangle( xy=[left_edge + 0.2, base + 0.3 * height], width=0.6, height=0.2 * height, facecolor=color, edgecolor=color, fill=True )) def _plot_line(ax, base, left_edge, height, color): ax.add_patch(patches.Rectangle( xy=[left_edge + 0.45, base + 0.2 * height], width=0.1, height=0.6 * height, facecolor=color, edgecolor=color, fill=True )) def _plot_dot(ax, base, left_edge, height, color): ax.add_patch(patches.Ellipse( xy=[left_edge + 0.5, base + 0.5 * height], width=0.3, height=0.2 * height, facecolor=color, edgecolor=color )) BASE_PLOTTERS = { 'A': _plot_a, 'C': _plot_c, 'T': _plot_t, 'G': _plot_g, 'N': _plot_n, '-': _plot_dash, } class _ShapAlignment: """Needleman-Wunsch-style traceback with SHAP-informed scores.""" RIGHT = "right" DOWN = "down" DIAGONAL = "diag" def align(self, mirna_seq, target_seq, score_matrix, opening_percentile=99, elongating_percentile=90): processed_scores, opening_gap, elongating_gap = self._preprocess_score( score_matrix, opening_percentile, elongating_percentile ) grid = self._forward_pass(mirna_seq, target_seq, processed_scores, opening_gap, elongating_gap) return self._backward_pass(mirna_seq, target_seq, grid, processed_scores) def _score(self, grid, i, j, gap_penalty, score): directions = [self.RIGHT, self.DIAGONAL, self.DOWN] values = [ grid[i, j - 1][1] - gap_penalty, grid[i - 1, j - 1][1] + score, grid[i - 1, j][1] - gap_penalty, ] best_index = int(np.argmax(values)) return directions[best_index], values[best_index] def _forward_pass(self, mirna_seq, target_seq, score_matrix, opening_gap, elongating_gap): target_bases = [""] + list(target_seq) mirna_bases = [""] + list(mirna_seq) grid = np.empty((len(target_bases), len(mirna_bases)), dtype=object) for i in range(len(target_bases)): grid[i, 0] = (self.DOWN, 0.0) for j in range(len(mirna_bases)): grid[0, j] = (self.RIGHT, 0.0) is_opening = False for i in range(1, len(target_bases)): for j in range(1, len(mirna_bases)): gap_penalty = opening_gap if is_opening else elongating_gap if i == len(target_bases) - 1 or j == len(mirna_bases) - 1: gap_penalty = 0.0 grid[i, j] = self._score(grid, i, j, gap_penalty, score_matrix[i, j]) is_opening = grid[i, j][0] == self.DIAGONAL return grid def _backward_pass(self, mirna_seq, target_seq, grid, score_matrix): target_bases = [""] + list(target_seq) mirna_bases = [""] + list(mirna_seq) aligned_target = [] aligned_mirna = [] aligned_scores = [] i = grid.shape[0] - 1 j = grid.shape[1] - 1 while i != 0 or j != 0: direction = grid[i, j][0] if direction == self.RIGHT: aligned_target.append("-") aligned_mirna.append(mirna_bases[j]) aligned_scores.append(0.0) j -= 1 elif direction == self.DOWN: aligned_target.append(target_bases[i]) aligned_mirna.append("-") aligned_scores.append(0.0) i -= 1 else: aligned_target.append(target_bases[i]) aligned_mirna.append(mirna_bases[j]) aligned_scores.append(score_matrix[i, j]) i -= 1 j -= 1 return aligned_target, aligned_scores, aligned_mirna def _preprocess_score(self, score_matrix, opening_percentile, elongating_percentile): score_matrix = np.asarray(score_matrix, dtype=float) if score_matrix.size == 0: return np.zeros((1, 1), dtype=float), 0.0, 0.0 max_absolute_value = float(np.max(np.abs(score_matrix))) if max_absolute_value > 0: score_matrix = score_matrix / max_absolute_value else: score_matrix = np.zeros_like(score_matrix) abs_values = np.abs(score_matrix.sum(axis=-1)).ravel() opening_gap = float(np.nanpercentile(abs_values, opening_percentile)) if abs_values.size else 0.0 elongating_gap = float(np.nanpercentile(abs_values, elongating_percentile)) if abs_values.size else 0.0 score_matrix = np.vstack([np.zeros(score_matrix.shape[1]), score_matrix]) score_matrix = np.hstack([np.zeros((score_matrix.shape[0], 1)), score_matrix]) return score_matrix, opening_gap, elongating_gap def _scale_scores(alignment_scores): values = np.asarray(alignment_scores, dtype=float) if values.size == 0: return np.array([], dtype=int) max_absolute_value = float(np.max(np.abs(values))) if max_absolute_value == 0.0: return np.full(values.shape, 10, dtype=int) scaled = values / max_absolute_value scaled = scaled * 10 + 10 return np.clip(np.rint(scaled).astype(int), 0, len(COLOR_PALETTE) - 1) def compute_shap_alignment(mirna_seq, target_seq, shap_2d, opening_percentile=99, elongating_percentile=90): """Align miRNA and target using SHAP attributions as pairwise scores.""" shap_array = np.asarray(shap_2d, dtype=float) expected_shape = (len(mirna_seq), len(target_seq)) if shap_array.shape != expected_shape: raise ValueError( f"Expected shap_2d shape {expected_shape}, got {tuple(shap_array.shape)}" ) aligner = _ShapAlignment() aligned_target, aligned_scores, aligned_mirna = aligner.align( mirna_seq[::-1], target_seq, shap_array[::-1].T, opening_percentile=opening_percentile, elongating_percentile=elongating_percentile, ) return aligned_target[::-1], aligned_scores[::-1], aligned_mirna[::-1] def plot_alignment_image(mirna_seq, target_seq, shap_2d, arrows=True): """Render the SHAP-guided alignment as a PIL image.""" aligned_target, aligned_scores, aligned_mirna = compute_shap_alignment( mirna_seq, target_seq, shap_2d ) color_indices = _scale_scores(aligned_scores) fig_width = max(12, len(aligned_target) * 0.42) fig, ax = plt.subplots(figsize=(fig_width, 3.2)) step = 1.2 left_margin = 2.8 alignment_width = len(aligned_target) * step ax.set_xlim(-left_margin, alignment_width + 0.6) ax.set_ylim(0, 4.8) ax.text(-left_margin + 0.2, 3.5, "Target", ha='left', va='center', fontsize=12, fontweight='bold') ax.text(-left_margin + 0.2, 1.1, "miRNA", ha='left', va='center', fontsize=12, fontweight='bold') if arrows: ax.add_patch(patches.Arrow(x=0, y=4.4, dx=2, dy=0, width=0.6, color=COLOR_PALETTE[10])) ax.add_patch(patches.Arrow( x=max(alignment_width - 0.2, 0), y=0.2, dx=-2, dy=0, width=0.6, color=COLOR_PALETTE[10] )) ax.text(-0.1, 4.45, "5'", ha='right', va='center', fontsize=10, fontweight='bold') ax.text(alignment_width + 0.15, 4.45, "3'", ha='left', va='center', fontsize=10, fontweight='bold') ax.text(-0.1, 0.15, "3'", ha='right', va='center', fontsize=10, fontweight='bold') ax.text(alignment_width + 0.15, 0.15, "5'", ha='left', va='center', fontsize=10, fontweight='bold') for i, (target_base, mirna_base, color_index) in enumerate( zip(aligned_target, aligned_mirna, color_indices) ): color = COLOR_PALETTE[int(color_index)] left_edge = step * i BASE_PLOTTERS.get(mirna_base, _plot_n)(ax, 0.6, left_edge, 1.0, color) if mirna_base == COMPLEMENT.get(target_base): _plot_line(ax, 1.7, left_edge, 1.0, color) elif mirna_base in "ACTG" and target_base in "ACTG": _plot_dot(ax, 1.7, left_edge, 1.0, color) BASE_PLOTTERS.get(target_base, _plot_n)(ax, 3.0, left_edge, 1.0, color) ax.axis('off') fig.tight_layout() buffer = io.BytesIO() fig.savefig(buffer, format='png', dpi=120, bbox_inches='tight') buffer.seek(0) image = Image.open(buffer) plt.close(fig) return image