Spaces:
Running
Running
| """SHAP-guided target site alignment visualization helpers.""" | |
| from __future__ import annotations | |
| import io | |
| import matplotlib.patches as patches | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from PIL import Image | |
| COLOR_PALETTE = [ | |
| '#034da1', '#3359a7', '#4c66ad', '#6174b3', '#7482b9', | |
| '#8590be', '#979ec4', '#a8adca', '#b9bbcf', '#c9cbd5', | |
| '#dadada', '#dac4c4', '#daaeae', '#da9999', '#da8383', | |
| '#da6d6d', '#d95757', '#d94141', '#d92c2c', '#d91616', '#d90000', | |
| ] | |
| COMPLEMENT = { | |
| 'A': 'T', | |
| 'T': 'A', | |
| 'C': 'G', | |
| 'G': 'C', | |
| '-': '', | |
| 'N': '', | |
| } | |
| def _plot_t(ax, base, left_edge, height, color): | |
| ax.add_patch(patches.Rectangle( | |
| xy=[left_edge + 0.4, base], width=0.2, height=height, | |
| facecolor=color, edgecolor=color, fill=True | |
| )) | |
| ax.add_patch(patches.Rectangle( | |
| xy=[left_edge, base + 0.85 * height], width=1.0, height=0.15 * height, | |
| facecolor=color, edgecolor=color, fill=True | |
| )) | |
| def _plot_g(ax, base, left_edge, height, color): | |
| ax.add_patch(patches.Ellipse( | |
| xy=[left_edge + 0.65, base + 0.5 * height], width=1.3, height=height, | |
| facecolor=color, edgecolor=color | |
| )) | |
| ax.add_patch(patches.Ellipse( | |
| xy=[left_edge + 0.65, base + 0.5 * height], width=0.91, height=0.7 * height, | |
| facecolor='white', edgecolor='white' | |
| )) | |
| ax.add_patch(patches.Rectangle( | |
| xy=[left_edge + 1, base], width=1.0, height=height, | |
| facecolor='white', edgecolor='white', fill=True | |
| )) | |
| ax.add_patch(patches.Rectangle( | |
| xy=[left_edge + 0.825, base + 0.085 * height], width=0.174, height=0.415 * height, | |
| facecolor=color, edgecolor=color, fill=True | |
| )) | |
| ax.add_patch(patches.Rectangle( | |
| xy=[left_edge + 0.625, base + 0.35 * height], width=0.374, height=0.15 * height, | |
| facecolor=color, edgecolor=color, fill=True | |
| )) | |
| def _plot_a(ax, base, left_edge, height, color): | |
| polygons = [ | |
| np.array([[0.0, 0.0], [0.5, 1.0], [0.5, 0.8], [0.2, 0.0]]), | |
| np.array([[1.0, 0.0], [0.5, 1.0], [0.5, 0.8], [0.8, 0.0]]), | |
| np.array([[0.225, 0.45], [0.775, 0.45], [0.85, 0.3], [0.15, 0.3]]), | |
| ] | |
| scale = np.array([1, height])[None, :] | |
| offset = np.array([left_edge, base])[None, :] | |
| for polygon in polygons: | |
| ax.add_patch(patches.Polygon(polygon * scale + offset, facecolor=color, edgecolor=color)) | |
| def _plot_c(ax, base, left_edge, height, color): | |
| ax.add_patch(patches.Ellipse( | |
| xy=[left_edge + 0.65, base + 0.5 * height], width=1.3, height=height, | |
| facecolor=color, edgecolor=color | |
| )) | |
| ax.add_patch(patches.Ellipse( | |
| xy=[left_edge + 0.65, base + 0.5 * height], width=0.91, height=0.7 * height, | |
| facecolor='white', edgecolor='white' | |
| )) | |
| ax.add_patch(patches.Rectangle( | |
| xy=[left_edge + 1, base], width=1.0, height=height, | |
| facecolor='white', edgecolor='white', fill=True | |
| )) | |
| def _plot_n(ax, base, left_edge, height, color): | |
| ax.add_patch(patches.Rectangle( | |
| xy=[left_edge, base], width=0.2, height=height, | |
| facecolor=color, edgecolor=color, fill=True | |
| )) | |
| ax.add_patch(patches.Rectangle( | |
| xy=[left_edge + 0.8, base], width=0.15, height=height, | |
| facecolor=color, edgecolor=color, fill=True, angle=45 | |
| )) | |
| ax.add_patch(patches.Rectangle( | |
| xy=[left_edge + 0.8, base], width=0.2, height=height, | |
| facecolor=color, edgecolor=color, fill=True | |
| )) | |
| def _plot_dash(ax, base, left_edge, height, color): | |
| ax.add_patch(patches.Rectangle( | |
| xy=[left_edge + 0.2, base + 0.3 * height], width=0.6, height=0.2 * height, | |
| facecolor=color, edgecolor=color, fill=True | |
| )) | |
| def _plot_line(ax, base, left_edge, height, color): | |
| ax.add_patch(patches.Rectangle( | |
| xy=[left_edge + 0.45, base + 0.2 * height], width=0.1, height=0.6 * height, | |
| facecolor=color, edgecolor=color, fill=True | |
| )) | |
| def _plot_dot(ax, base, left_edge, height, color): | |
| ax.add_patch(patches.Ellipse( | |
| xy=[left_edge + 0.5, base + 0.5 * height], width=0.3, height=0.2 * height, | |
| facecolor=color, edgecolor=color | |
| )) | |
| BASE_PLOTTERS = { | |
| 'A': _plot_a, | |
| 'C': _plot_c, | |
| 'T': _plot_t, | |
| 'G': _plot_g, | |
| 'N': _plot_n, | |
| '-': _plot_dash, | |
| } | |
| class _ShapAlignment: | |
| """Needleman-Wunsch-style traceback with SHAP-informed scores.""" | |
| RIGHT = "right" | |
| DOWN = "down" | |
| DIAGONAL = "diag" | |
| def align(self, mirna_seq, target_seq, score_matrix, | |
| opening_percentile=99, elongating_percentile=90): | |
| processed_scores, opening_gap, elongating_gap = self._preprocess_score( | |
| score_matrix, opening_percentile, elongating_percentile | |
| ) | |
| grid = self._forward_pass(mirna_seq, target_seq, processed_scores, opening_gap, elongating_gap) | |
| return self._backward_pass(mirna_seq, target_seq, grid, processed_scores) | |
| def _score(self, grid, i, j, gap_penalty, score): | |
| directions = [self.RIGHT, self.DIAGONAL, self.DOWN] | |
| values = [ | |
| grid[i, j - 1][1] - gap_penalty, | |
| grid[i - 1, j - 1][1] + score, | |
| grid[i - 1, j][1] - gap_penalty, | |
| ] | |
| best_index = int(np.argmax(values)) | |
| return directions[best_index], values[best_index] | |
| def _forward_pass(self, mirna_seq, target_seq, score_matrix, opening_gap, elongating_gap): | |
| target_bases = [""] + list(target_seq) | |
| mirna_bases = [""] + list(mirna_seq) | |
| grid = np.empty((len(target_bases), len(mirna_bases)), dtype=object) | |
| for i in range(len(target_bases)): | |
| grid[i, 0] = (self.DOWN, 0.0) | |
| for j in range(len(mirna_bases)): | |
| grid[0, j] = (self.RIGHT, 0.0) | |
| is_opening = False | |
| for i in range(1, len(target_bases)): | |
| for j in range(1, len(mirna_bases)): | |
| gap_penalty = opening_gap if is_opening else elongating_gap | |
| if i == len(target_bases) - 1 or j == len(mirna_bases) - 1: | |
| gap_penalty = 0.0 | |
| grid[i, j] = self._score(grid, i, j, gap_penalty, score_matrix[i, j]) | |
| is_opening = grid[i, j][0] == self.DIAGONAL | |
| return grid | |
| def _backward_pass(self, mirna_seq, target_seq, grid, score_matrix): | |
| target_bases = [""] + list(target_seq) | |
| mirna_bases = [""] + list(mirna_seq) | |
| aligned_target = [] | |
| aligned_mirna = [] | |
| aligned_scores = [] | |
| i = grid.shape[0] - 1 | |
| j = grid.shape[1] - 1 | |
| while i != 0 or j != 0: | |
| direction = grid[i, j][0] | |
| if direction == self.RIGHT: | |
| aligned_target.append("-") | |
| aligned_mirna.append(mirna_bases[j]) | |
| aligned_scores.append(0.0) | |
| j -= 1 | |
| elif direction == self.DOWN: | |
| aligned_target.append(target_bases[i]) | |
| aligned_mirna.append("-") | |
| aligned_scores.append(0.0) | |
| i -= 1 | |
| else: | |
| aligned_target.append(target_bases[i]) | |
| aligned_mirna.append(mirna_bases[j]) | |
| aligned_scores.append(score_matrix[i, j]) | |
| i -= 1 | |
| j -= 1 | |
| return aligned_target, aligned_scores, aligned_mirna | |
| def _preprocess_score(self, score_matrix, opening_percentile, elongating_percentile): | |
| score_matrix = np.asarray(score_matrix, dtype=float) | |
| if score_matrix.size == 0: | |
| return np.zeros((1, 1), dtype=float), 0.0, 0.0 | |
| max_absolute_value = float(np.max(np.abs(score_matrix))) | |
| if max_absolute_value > 0: | |
| score_matrix = score_matrix / max_absolute_value | |
| else: | |
| score_matrix = np.zeros_like(score_matrix) | |
| abs_values = np.abs(score_matrix.sum(axis=-1)).ravel() | |
| opening_gap = float(np.nanpercentile(abs_values, opening_percentile)) if abs_values.size else 0.0 | |
| elongating_gap = float(np.nanpercentile(abs_values, elongating_percentile)) if abs_values.size else 0.0 | |
| score_matrix = np.vstack([np.zeros(score_matrix.shape[1]), score_matrix]) | |
| score_matrix = np.hstack([np.zeros((score_matrix.shape[0], 1)), score_matrix]) | |
| return score_matrix, opening_gap, elongating_gap | |
| def _scale_scores(alignment_scores): | |
| values = np.asarray(alignment_scores, dtype=float) | |
| if values.size == 0: | |
| return np.array([], dtype=int) | |
| max_absolute_value = float(np.max(np.abs(values))) | |
| if max_absolute_value == 0.0: | |
| return np.full(values.shape, 10, dtype=int) | |
| scaled = values / max_absolute_value | |
| scaled = scaled * 10 + 10 | |
| return np.clip(np.rint(scaled).astype(int), 0, len(COLOR_PALETTE) - 1) | |
| def compute_shap_alignment(mirna_seq, target_seq, shap_2d, | |
| opening_percentile=99, elongating_percentile=90): | |
| """Align miRNA and target using SHAP attributions as pairwise scores.""" | |
| shap_array = np.asarray(shap_2d, dtype=float) | |
| expected_shape = (len(mirna_seq), len(target_seq)) | |
| if shap_array.shape != expected_shape: | |
| raise ValueError( | |
| f"Expected shap_2d shape {expected_shape}, got {tuple(shap_array.shape)}" | |
| ) | |
| aligner = _ShapAlignment() | |
| aligned_target, aligned_scores, aligned_mirna = aligner.align( | |
| mirna_seq[::-1], | |
| target_seq, | |
| shap_array[::-1].T, | |
| opening_percentile=opening_percentile, | |
| elongating_percentile=elongating_percentile, | |
| ) | |
| return aligned_target[::-1], aligned_scores[::-1], aligned_mirna[::-1] | |
| def plot_alignment_image(mirna_seq, target_seq, shap_2d, arrows=True): | |
| """Render the SHAP-guided alignment as a PIL image.""" | |
| aligned_target, aligned_scores, aligned_mirna = compute_shap_alignment( | |
| mirna_seq, target_seq, shap_2d | |
| ) | |
| color_indices = _scale_scores(aligned_scores) | |
| fig_width = max(12, len(aligned_target) * 0.42) | |
| fig, ax = plt.subplots(figsize=(fig_width, 3.2)) | |
| step = 1.2 | |
| left_margin = 2.8 | |
| alignment_width = len(aligned_target) * step | |
| ax.set_xlim(-left_margin, alignment_width + 0.6) | |
| ax.set_ylim(0, 4.8) | |
| ax.text(-left_margin + 0.2, 3.5, "Target", ha='left', va='center', | |
| fontsize=12, fontweight='bold') | |
| ax.text(-left_margin + 0.2, 1.1, "miRNA", ha='left', va='center', | |
| fontsize=12, fontweight='bold') | |
| if arrows: | |
| ax.add_patch(patches.Arrow(x=0, y=4.4, dx=2, dy=0, width=0.6, color=COLOR_PALETTE[10])) | |
| ax.add_patch(patches.Arrow( | |
| x=max(alignment_width - 0.2, 0), y=0.2, dx=-2, dy=0, width=0.6, color=COLOR_PALETTE[10] | |
| )) | |
| ax.text(-0.1, 4.45, "5'", ha='right', va='center', fontsize=10, fontweight='bold') | |
| ax.text(alignment_width + 0.15, 4.45, "3'", ha='left', va='center', fontsize=10, fontweight='bold') | |
| ax.text(-0.1, 0.15, "3'", ha='right', va='center', fontsize=10, fontweight='bold') | |
| ax.text(alignment_width + 0.15, 0.15, "5'", ha='left', va='center', fontsize=10, fontweight='bold') | |
| for i, (target_base, mirna_base, color_index) in enumerate( | |
| zip(aligned_target, aligned_mirna, color_indices) | |
| ): | |
| color = COLOR_PALETTE[int(color_index)] | |
| left_edge = step * i | |
| BASE_PLOTTERS.get(mirna_base, _plot_n)(ax, 0.6, left_edge, 1.0, color) | |
| if mirna_base == COMPLEMENT.get(target_base): | |
| _plot_line(ax, 1.7, left_edge, 1.0, color) | |
| elif mirna_base in "ACTG" and target_base in "ACTG": | |
| _plot_dot(ax, 1.7, left_edge, 1.0, color) | |
| BASE_PLOTTERS.get(target_base, _plot_n)(ax, 3.0, left_edge, 1.0, color) | |
| ax.axis('off') | |
| fig.tight_layout() | |
| buffer = io.BytesIO() | |
| fig.savefig(buffer, format='png', dpi=120, bbox_inches='tight') | |
| buffer.seek(0) | |
| image = Image.open(buffer) | |
| plt.close(fig) | |
| return image | |