File size: 5,319 Bytes
0e1a816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9bdee3
 
 
 
 
 
 
 
 
0e1a816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9bdee3
 
0e1a816
 
 
 
 
 
 
 
 
 
 
 
 
 
6727da5
0e1a816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482c7f2
0e1a816
 
 
 
 
482c7f2
 
 
 
 
 
 
 
 
6727da5
482c7f2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""Heatmap visualization utilities (colorbar, overlays, Plotly)."""
import base64

import cv2
import numpy as np
import streamlit as st
import plotly.graph_objects as go

from config.constants import COLORMAPS


def _colormap_gradient_base64(colormap_name, width=512):
    """Generate a horizontal gradient bar as base64 PNG for the given colormap."""
    cv2_cmap = COLORMAPS.get(colormap_name, cv2.COLORMAP_JET)
    gradient = np.linspace(0, 255, width, dtype=np.uint8).reshape(1, -1)
    rgb = cv2.cvtColor(cv2.applyColorMap(gradient, cv2_cmap), cv2.COLOR_BGR2RGB)
    bar = np.repeat(rgb, 6, axis=0)
    _, buf = cv2.imencode(".png", cv2.cvtColor(bar, cv2.COLOR_RGB2BGR))
    return base64.b64encode(buf.tobytes()).decode("utf-8")

# Distinct colors for each region (RGB - heatmap_rgb is RGB)
_REGION_COLORS = [
    (0, 188, 212),   # cyan (matches drawing tool)
    (0, 230, 118),   # green
    (255, 235, 59),  # yellow
    (171, 71, 188),  # purple
    (0, 150, 255),   # blue
    (255, 167, 38),  # amber
    (124, 179, 66),  # light green
    (233, 30, 99),   # pink
]


def _draw_region_overlay(annotated, mask, color, fill_alpha=0.3, stroke_width=2):
    """Draw single region overlay on annotated heatmap (fill + alpha blend + contour). Modifies annotated in place."""
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    overlay = annotated.copy()
    cv2.fillPoly(overlay, contours, color)
    mask_3d = np.stack([mask] * 3, axis=-1).astype(bool)
    annotated[mask_3d] = (
        (1 - fill_alpha) * annotated[mask_3d].astype(np.float32)
        + fill_alpha * overlay[mask_3d].astype(np.float32)
    ).astype(np.uint8)
    cv2.drawContours(annotated, contours, -1, color, stroke_width)


def render_horizontal_colorbar(colormap_name, clip_min=0, clip_max=1, is_rescale=False, caption=None):
    """
    Render a compact horizontal colorbar for batch mode, anchored above the table.

    When ``is_rescale`` is True (Force scale **Range** with a strict sub-interval), tick labels show
    model force values in ``[clip_min, clip_max]``. The gradient still spans the full colormap because
    the heatmap has already been **rescaled** so the lowest (highest) value in your range maps to
    the colormap minimum (maximum)—same convention as the main Plotly view.
    """
    ticks = [0, 0.25, 0.5, 0.75, 1]
    if is_rescale:
        rng = clip_max - clip_min
        labels = [f"{clip_min + t * rng:.2f}" for t in ticks]
    else:
        labels = [f"{t:.2f}" for t in ticks]

    data_url = _colormap_gradient_base64(colormap_name)
    labels_html = "".join(f'<span class="cb-tick">{l}</span>' for l in labels)
    html = f"""
    <div class="colorbar-table-header">
        <div class="colorbar-ticks">{labels_html}</div>
        <div class="colorbar-bar" style="background-image: url(data:image/png;base64,{data_url});"></div>
    </div>
    """
    st.markdown(html, unsafe_allow_html=True)
    if caption:
        st.caption(caption)


def make_annotated_heatmap(heatmap_rgb, mask, fill_alpha=0.3, stroke_color=(0, 188, 212), stroke_width=2):
    """Composite heatmap with drawn region overlay."""
    annotated = heatmap_rgb.copy()
    _draw_region_overlay(annotated, mask, stroke_color, fill_alpha, stroke_width)
    return annotated


def make_annotated_heatmap_multi_regions(heatmap_rgb, masks, labels, cell_mask=None, fill_alpha=0.3):
    """Draw each region separately with distinct color and label (R1, R2, ...). No merging."""
    annotated = heatmap_rgb.copy()
    if cell_mask is not None and np.any(cell_mask > 0):
        contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(annotated, contours, -1, (255, 0, 0), 3)
    for i, mask in enumerate(masks):
        color = _REGION_COLORS[i % len(_REGION_COLORS)]
        _draw_region_overlay(annotated, mask, color, fill_alpha, stroke_width=2)
        # Label at centroid
        M = cv2.moments(mask)
        if M["m00"] > 0:
            cx = int(M["m10"] / M["m00"])
            cy = int(M["m01"] / M["m00"])
            label = labels[i] if i < len(labels) else f"R{i + 1}"
            cv2.putText(
                annotated, label, (cx - 12, cy + 5),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA
            )
            cv2.putText(
                annotated, label, (cx - 12, cy + 5),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 1, cv2.LINE_AA
            )
    return annotated


def add_cell_contour_to_fig(fig_pl, cell_mask, row=1, col=2):
    """Add red contour overlay to Plotly heatmap subplot. Draws all contours (handles multiple disconnected regions)."""
    if cell_mask is None or not np.any(cell_mask > 0):
        return
    contours, _ = cv2.findContours(cell_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return
    for cnt in contours:
        pts = cnt.squeeze()
        if pts.ndim == 1:
            pts = pts.reshape(1, 2)
        x, y = pts[:, 0].tolist(), pts[:, 1].tolist()
        if x[0] != x[-1] or y[0] != y[-1]:
            x.append(x[0])
            y.append(y[0])
        fig_pl.add_trace(
            go.Scatter(x=x, y=y, mode="lines", line=dict(color="red", width=3), showlegend=False),
            row=row, col=col
        )