"""Post-processing and fusion helpers for TF + SAM outputs.""" from typing import List, Tuple import numpy as np def top_k_lumen_exemplar_indices(masks: np.ndarray, lumen_conf: np.ndarray, k: int = 100) -> List[int]: """Select top-K valid lumen frames for SAM exemplar prompts.""" conf = np.asarray(lumen_conf, dtype=np.float32) valid = np.isfinite(conf) & np.any(masks == 1, axis=(1, 2)) exemplars: List[int] = [] if np.any(valid): ordered = np.where(valid)[0][np.argsort(conf[valid])[::-1]] exemplars.extend([int(i) for i in ordered[:k]]) if len(exemplars) < k: lumen_area = np.sum(masks == 1, axis=(1, 2)) area_order = np.argsort(lumen_area)[::-1] for idx in area_order: idx = int(idx) if idx in exemplars or not np.any(masks[idx] == 1): continue exemplars.append(idx) if len(exemplars) >= k: break return sorted(exemplars) def build_sam_label_masks(sam_lumen_masks: np.ndarray) -> np.ndarray: """Build contour labels with lumen=1 and non-lumen=2.""" return np.where(sam_lumen_masks > 0, 1, 2).astype(np.uint8) def empty_contours(num_frames: int): """Return an empty contour structure for all frames.""" return [[[] for _ in range(num_frames)], [[] for _ in range(num_frames)]] def merge_sam_with_tensorflow( sam_lumen_masks: np.ndarray, sam_scores: np.ndarray, tf_masks: np.ndarray, tf_lumen_conf: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]: """Use TF masks on frames where TF confidence beats SAM confidence.""" merged_masks = sam_lumen_masks.copy().astype(np.uint8) merged_scores = np.asarray(sam_scores, dtype=np.float32).copy() tf_scores = np.asarray(tf_lumen_conf, dtype=np.float32) for idx in range(merged_masks.shape[0]): tf_score = float(tf_scores[idx]) if idx < tf_scores.shape[0] else float("nan") sam_score = float(merged_scores[idx]) if idx < merged_scores.shape[0] else float("nan") if np.isfinite(tf_score): tf_score = float(np.clip(tf_score, 0.0, 1.0)) if np.isfinite(sam_score): sam_score = float(np.clip(sam_score, 0.0, 1.0)) if np.isfinite(tf_score) and (not np.isfinite(sam_score) or sam_score < tf_score): merged_masks[idx] = (tf_masks[idx] == 1).astype(np.uint8) merged_scores[idx] = np.float32(tf_score) elif not np.isfinite(sam_score): merged_scores[idx] = np.float32(0.0) return merged_masks, merged_scores def propagate_previous_mask_on_zero_score(lumen_masks: np.ndarray, scores: np.ndarray) -> np.ndarray: """Replace zero-score frames with previous mask.""" propagated = lumen_masks.copy().astype(np.uint8) scores_arr = np.asarray(scores, dtype=np.float32) for idx in range(1, propagated.shape[0]): if float(scores_arr[idx]) <= 0.0: propagated[idx] = propagated[idx - 1] return propagated def fill_empty_masks_temporally(lumen_masks: np.ndarray) -> np.ndarray: """Fill empty masks by temporal carry-over.""" filled = lumen_masks.copy().astype(np.uint8) num_frames = filled.shape[0] non_empty = np.any(filled > 0, axis=(1, 2)) if not np.any(non_empty): return filled first_non_empty = int(np.argmax(non_empty)) for idx in range(0, first_non_empty): filled[idx] = filled[first_non_empty] for idx in range(first_non_empty + 1, num_frames): if not np.any(filled[idx] > 0): filled[idx] = filled[idx - 1] return filled def fill_empty_lumen_contours_temporally(lumen): """Fill empty contour frames by nearest previous valid contour.""" filled_x = [list(x) for x in lumen[0]] filled_y = [list(y) for y in lumen[1]] num_frames = len(filled_x) has_contour = [len(filled_x[i]) > 0 and len(filled_y[i]) > 0 for i in range(num_frames)] if not any(has_contour): return lumen first_idx = has_contour.index(True) for idx in range(first_idx): filled_x[idx] = list(filled_x[first_idx]) filled_y[idx] = list(filled_y[first_idx]) for idx in range(first_idx + 1, num_frames): if len(filled_x[idx]) == 0 or len(filled_y[idx]) == 0: filled_x[idx] = list(filled_x[idx - 1]) filled_y[idx] = list(filled_y[idx - 1]) return [filled_x, filled_y]