Aditya2162's picture
Upload folder using huggingface_hub
1d197a4 verified
"""Post-processing and fusion helpers for TF + SAM outputs."""
from typing import List, Tuple
import numpy as np
def top_k_lumen_exemplar_indices(masks: np.ndarray, lumen_conf: np.ndarray, k: int = 100) -> List[int]:
"""Select top-K valid lumen frames for SAM exemplar prompts."""
conf = np.asarray(lumen_conf, dtype=np.float32)
valid = np.isfinite(conf) & np.any(masks == 1, axis=(1, 2))
exemplars: List[int] = []
if np.any(valid):
ordered = np.where(valid)[0][np.argsort(conf[valid])[::-1]]
exemplars.extend([int(i) for i in ordered[:k]])
if len(exemplars) < k:
lumen_area = np.sum(masks == 1, axis=(1, 2))
area_order = np.argsort(lumen_area)[::-1]
for idx in area_order:
idx = int(idx)
if idx in exemplars or not np.any(masks[idx] == 1):
continue
exemplars.append(idx)
if len(exemplars) >= k:
break
return sorted(exemplars)
def build_sam_label_masks(sam_lumen_masks: np.ndarray) -> np.ndarray:
"""Build contour labels with lumen=1 and non-lumen=2."""
return np.where(sam_lumen_masks > 0, 1, 2).astype(np.uint8)
def empty_contours(num_frames: int):
"""Return an empty contour structure for all frames."""
return [[[] for _ in range(num_frames)], [[] for _ in range(num_frames)]]
def merge_sam_with_tensorflow(
sam_lumen_masks: np.ndarray,
sam_scores: np.ndarray,
tf_masks: np.ndarray,
tf_lumen_conf: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""Use TF masks on frames where TF confidence beats SAM confidence."""
merged_masks = sam_lumen_masks.copy().astype(np.uint8)
merged_scores = np.asarray(sam_scores, dtype=np.float32).copy()
tf_scores = np.asarray(tf_lumen_conf, dtype=np.float32)
for idx in range(merged_masks.shape[0]):
tf_score = float(tf_scores[idx]) if idx < tf_scores.shape[0] else float("nan")
sam_score = float(merged_scores[idx]) if idx < merged_scores.shape[0] else float("nan")
if np.isfinite(tf_score):
tf_score = float(np.clip(tf_score, 0.0, 1.0))
if np.isfinite(sam_score):
sam_score = float(np.clip(sam_score, 0.0, 1.0))
if np.isfinite(tf_score) and (not np.isfinite(sam_score) or sam_score < tf_score):
merged_masks[idx] = (tf_masks[idx] == 1).astype(np.uint8)
merged_scores[idx] = np.float32(tf_score)
elif not np.isfinite(sam_score):
merged_scores[idx] = np.float32(0.0)
return merged_masks, merged_scores
def propagate_previous_mask_on_zero_score(lumen_masks: np.ndarray, scores: np.ndarray) -> np.ndarray:
"""Replace zero-score frames with previous mask."""
propagated = lumen_masks.copy().astype(np.uint8)
scores_arr = np.asarray(scores, dtype=np.float32)
for idx in range(1, propagated.shape[0]):
if float(scores_arr[idx]) <= 0.0:
propagated[idx] = propagated[idx - 1]
return propagated
def fill_empty_masks_temporally(lumen_masks: np.ndarray) -> np.ndarray:
"""Fill empty masks by temporal carry-over."""
filled = lumen_masks.copy().astype(np.uint8)
num_frames = filled.shape[0]
non_empty = np.any(filled > 0, axis=(1, 2))
if not np.any(non_empty):
return filled
first_non_empty = int(np.argmax(non_empty))
for idx in range(0, first_non_empty):
filled[idx] = filled[first_non_empty]
for idx in range(first_non_empty + 1, num_frames):
if not np.any(filled[idx] > 0):
filled[idx] = filled[idx - 1]
return filled
def fill_empty_lumen_contours_temporally(lumen):
"""Fill empty contour frames by nearest previous valid contour."""
filled_x = [list(x) for x in lumen[0]]
filled_y = [list(y) for y in lumen[1]]
num_frames = len(filled_x)
has_contour = [len(filled_x[i]) > 0 and len(filled_y[i]) > 0 for i in range(num_frames)]
if not any(has_contour):
return lumen
first_idx = has_contour.index(True)
for idx in range(first_idx):
filled_x[idx] = list(filled_x[first_idx])
filled_y[idx] = list(filled_y[first_idx])
for idx in range(first_idx + 1, num_frames):
if len(filled_x[idx]) == 0 or len(filled_y[idx]) == 0:
filled_x[idx] = list(filled_x[idx - 1])
filled_y[idx] = list(filled_y[idx - 1])
return [filled_x, filled_y]