File size: 4,402 Bytes
1d197a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Post-processing and fusion helpers for TF + SAM outputs."""

from typing import List, Tuple

import numpy as np


def top_k_lumen_exemplar_indices(masks: np.ndarray, lumen_conf: np.ndarray, k: int = 100) -> List[int]:
    """Select top-K valid lumen frames for SAM exemplar prompts."""
    conf = np.asarray(lumen_conf, dtype=np.float32)
    valid = np.isfinite(conf) & np.any(masks == 1, axis=(1, 2))
    exemplars: List[int] = []

    if np.any(valid):
        ordered = np.where(valid)[0][np.argsort(conf[valid])[::-1]]
        exemplars.extend([int(i) for i in ordered[:k]])

    if len(exemplars) < k:
        lumen_area = np.sum(masks == 1, axis=(1, 2))
        area_order = np.argsort(lumen_area)[::-1]
        for idx in area_order:
            idx = int(idx)
            if idx in exemplars or not np.any(masks[idx] == 1):
                continue
            exemplars.append(idx)
            if len(exemplars) >= k:
                break

    return sorted(exemplars)


def build_sam_label_masks(sam_lumen_masks: np.ndarray) -> np.ndarray:
    """Build contour labels with lumen=1 and non-lumen=2."""
    return np.where(sam_lumen_masks > 0, 1, 2).astype(np.uint8)


def empty_contours(num_frames: int):
    """Return an empty contour structure for all frames."""
    return [[[] for _ in range(num_frames)], [[] for _ in range(num_frames)]]


def merge_sam_with_tensorflow(
    sam_lumen_masks: np.ndarray,
    sam_scores: np.ndarray,
    tf_masks: np.ndarray,
    tf_lumen_conf: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
    """Use TF masks on frames where TF confidence beats SAM confidence."""
    merged_masks = sam_lumen_masks.copy().astype(np.uint8)
    merged_scores = np.asarray(sam_scores, dtype=np.float32).copy()
    tf_scores = np.asarray(tf_lumen_conf, dtype=np.float32)

    for idx in range(merged_masks.shape[0]):
        tf_score = float(tf_scores[idx]) if idx < tf_scores.shape[0] else float("nan")
        sam_score = float(merged_scores[idx]) if idx < merged_scores.shape[0] else float("nan")

        if np.isfinite(tf_score):
            tf_score = float(np.clip(tf_score, 0.0, 1.0))
        if np.isfinite(sam_score):
            sam_score = float(np.clip(sam_score, 0.0, 1.0))

        if np.isfinite(tf_score) and (not np.isfinite(sam_score) or sam_score < tf_score):
            merged_masks[idx] = (tf_masks[idx] == 1).astype(np.uint8)
            merged_scores[idx] = np.float32(tf_score)
        elif not np.isfinite(sam_score):
            merged_scores[idx] = np.float32(0.0)

    return merged_masks, merged_scores


def propagate_previous_mask_on_zero_score(lumen_masks: np.ndarray, scores: np.ndarray) -> np.ndarray:
    """Replace zero-score frames with previous mask."""
    propagated = lumen_masks.copy().astype(np.uint8)
    scores_arr = np.asarray(scores, dtype=np.float32)
    for idx in range(1, propagated.shape[0]):
        if float(scores_arr[idx]) <= 0.0:
            propagated[idx] = propagated[idx - 1]
    return propagated


def fill_empty_masks_temporally(lumen_masks: np.ndarray) -> np.ndarray:
    """Fill empty masks by temporal carry-over."""
    filled = lumen_masks.copy().astype(np.uint8)
    num_frames = filled.shape[0]
    non_empty = np.any(filled > 0, axis=(1, 2))
    if not np.any(non_empty):
        return filled

    first_non_empty = int(np.argmax(non_empty))
    for idx in range(0, first_non_empty):
        filled[idx] = filled[first_non_empty]

    for idx in range(first_non_empty + 1, num_frames):
        if not np.any(filled[idx] > 0):
            filled[idx] = filled[idx - 1]

    return filled


def fill_empty_lumen_contours_temporally(lumen):
    """Fill empty contour frames by nearest previous valid contour."""
    filled_x = [list(x) for x in lumen[0]]
    filled_y = [list(y) for y in lumen[1]]
    num_frames = len(filled_x)
    has_contour = [len(filled_x[i]) > 0 and len(filled_y[i]) > 0 for i in range(num_frames)]
    if not any(has_contour):
        return lumen

    first_idx = has_contour.index(True)
    for idx in range(first_idx):
        filled_x[idx] = list(filled_x[first_idx])
        filled_y[idx] = list(filled_y[first_idx])

    for idx in range(first_idx + 1, num_frames):
        if len(filled_x[idx]) == 0 or len(filled_y[idx]) == 0:
            filled_x[idx] = list(filled_x[idx - 1])
            filled_y[idx] = list(filled_y[idx - 1])

    return [filled_x, filled_y]