UnSAMv2 / sam2 /training /utils /graco /graco_sample.py
yjwnb6
Initial HF Space upload
7b25808
raw
history blame
17.8 kB
import logging
import numpy as np
import torch
import torch.distributed
from sam2.modeling.sam2_base import SAM2Base
from sam2.modeling.sam2_utils import (
get_1d_sine_pe,
get_next_point,
sample_box_points,
select_closest_cond_frames,
)
from sam2.utils.misc import concat_points
from training.utils.data_utils import BatchedVideoDatapoint
import random
import sys
sys.path.append('/home/yujunwei/sam2/GraCo')
from isegm.inference.clicker import Clicker
# from training.utils.GraCo.isegm.inference.evaluation import get_sam_input
import cv2
def get_points_nd(clicks_lists):
total_clicks = []
num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists]
num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)]
num_max_points = max(num_pos_clicks + num_neg_clicks)
num_max_points = max(1, num_max_points)
for clicks_list in clicks_lists:
pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive]
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)]
neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive]
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)]
total_clicks.append(pos_clicks + neg_clicks)
return total_clicks
def get_sam_input(clicker, reverse=True):
clicks_list = clicker.get_clicks()
points_nd = get_points_nd([clicks_list])
point_length = len(points_nd[0]) // 2
point_coords = []
point_labels = []
for i, point in enumerate(points_nd[0]):
if point[0] == -1:
continue
if i < point_length:
point_labels.append(1)
else:
point_labels.append(0)
if reverse:
point_coords.append([point[1], point[0]]) # for SAM
return np.array(point_coords), np.array(point_labels)
def _iter_correct_pt_sampling_graco(
self,
is_init_cond_frame,
point_inputs,
gt_masks,
high_res_features,
pix_feat_with_mem,
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
object_score_logits,
current_out,
):
assert gt_masks is not None
all_pred_masks = [low_res_masks]
all_pred_high_res_masks = [high_res_masks]
all_pred_multimasks = [low_res_multimasks]
all_pred_high_res_multimasks = [high_res_multimasks]
all_pred_ious = [ious]
all_point_inputs = [point_inputs]
all_object_score_logits = [object_score_logits]
clicker_list = [Clicker(gt_mask=gt_mask) for gt_mask in gt_masks]
pred_masks_list = [np.zeros_like(gt_mask) for gt_mask in gt_masks]
point_coords = []
points_labels = []
for click_indx in range(self.num_correction_pt_per_frame):
# sample a new point from the error between prediction and ground-truth
# (with a small probability, directly sample from GT masks instead of errors)
if self.training and self.prob_to_sample_from_gt_for_train > 0:
sample_from_gt = (
self.rng.random() < self.prob_to_sample_from_gt_for_train
)
else:
sample_from_gt = False
# if `pred_for_new_pt` is None, only GT masks will be used for point sampling
pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0)
new_points, new_labels = get_next_point(
gt_masks=gt_masks,
pred_masks=pred_for_new_pt,
method="uniform" if self.training else self.pt_sampling_for_eval,
)
point_inputs = concat_points(point_inputs, new_points, new_labels)
# Feed the mask logits of the previous SAM outputs in the next SAM decoder step.
# For tracking, this means that when the user adds a correction click, we also feed
# the tracking output mask logits along with the click as input to the SAM decoder.
mask_inputs = low_res_masks
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
if self.use_act_ckpt_iterative_pt_sampling and not multimask_output:
sam_outputs = torch.utils.checkpoint.checkpoint(
self._forward_sam_heads,
backbone_features=pix_feat_with_mem,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
use_reentrant=False,
)
else:
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat_with_mem,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
(
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
_,
object_score_logits,
) = sam_outputs
all_pred_masks.append(low_res_masks)
all_pred_high_res_masks.append(high_res_masks)
all_pred_multimasks.append(low_res_multimasks)
all_pred_high_res_multimasks.append(high_res_multimasks)
all_pred_ious.append(ious)
all_point_inputs.append(point_inputs)
all_object_score_logits.append(object_score_logits)
# Concatenate the masks along channel (to compute losses on all of them,
# using `MultiStepIteractiveMasks`)
current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1)
current_out["multistep_pred_masks_high_res"] = torch.cat(
all_pred_high_res_masks, dim=1
)
current_out["multistep_pred_multimasks"] = all_pred_multimasks
current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks
current_out["multistep_pred_ious"] = all_pred_ious
current_out["multistep_point_inputs"] = all_point_inputs
current_out["multistep_object_score_logits"] = all_object_score_logits
return point_inputs, sam_outputs
def process_points(points):
positive = points[:, :1, :]
negative = points[:, 1:, :]
filtered_points = []
filtered_labels = []
for batch in range(points.shape[0]):
batch_points = []
batch_labels = []
for point in positive[batch]:
if point[0] != -1:
point_y, point_x = point[:2]
batch_points.append([point_x, point_y])
batch_labels.append(1)
for point in negative[batch]:
if point[0] != -1:
point_y, point_x = point[:2]
batch_points.append([point_x, point_y])
batch_labels.append(0)
filtered_points.append(np.array(batch_points))
filtered_labels.append(np.array(batch_labels))
return filtered_points, filtered_labels
def get_next_points_graco(pred, gt, points, click_indx, pred_thresh=0.49):
assert click_indx > 0
pred = pred.cpu().numpy()[:, 0, :, :]
gt = gt.cpu().numpy()[:, 0, :, :]
# pred = pred[:, 0, :, :]
# gt = gt[:, 0, :, :]
# fn_mask = np.logical_and(gt, pred < pred_thresh)
# fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh)
fn_mask = np.logical_and(gt, np.logical_not(pred))
fp_mask = np.logical_and(np.logical_not(gt), pred)
fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8)
fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8)
num_points = points.size(1) // 2
points = points.clone()
for bindx in range(fn_mask.shape[0]):
fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1]
fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1]
fn_max_dist = np.max(fn_mask_dt)
fp_max_dist = np.max(fp_mask_dt)
is_positive = fn_max_dist > fp_max_dist
dt = fn_mask_dt if is_positive else fp_mask_dt
inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0
indices = np.argwhere(inner_mask)
if len(indices) > 0:
coords = indices[np.random.randint(0, len(indices))]
if is_positive:
points[bindx, num_points - click_indx, 0] = float(coords[0])
points[bindx, num_points - click_indx, 1] = float(coords[1])
# points[bindx, num_points - click_indx, 2] = float(click_indx)
else:
points[bindx, 2 * num_points - click_indx, 0] = float(coords[0])
points[bindx, 2 * num_points - click_indx, 1] = float(coords[1])
# points[bindx, 2 * num_points - click_indx, 2] = float(click_indx)
return points
def graco_sample_optimized(gt_masks, pred_masks, mode, click_indx, points):
"""优化版本的graco_sample函数,减少torch-numpy转换"""
device = points.device
if mode == "train":
# 完全在GPU上实现get_next_points_graco功能
points = get_next_points_graco_torch(pred_masks, gt_masks, points, click_indx + 1)
# 处理点数据,无需CPU转换
filtered_points_list, filtered_labels_list = process_points_torch(points)
# 将列表转换为批次的张量
# 注:这里假设所有批次的点数量相同,如果不同需要padding
new_points_graco = torch.stack(filtered_points_list, dim=0)
new_labels_graco = torch.stack(filtered_labels_list, dim=0)
return new_points_graco, new_labels_graco
else:
# 评估模式处理保持不变
gt_masks_np = [gt_mask.cpu().numpy() for gt_mask in gt_masks]
clicker_list = [Clicker(gt_mask=gt_mask_np) for gt_mask_np in gt_masks_np]
point_coords = []
points_labels = []
for idx, gt_mask_np in enumerate(gt_masks_np):
clicker = clicker_list[idx]
pred_mask = pred_masks[idx].cpu().numpy()
clicker.make_next_click(pred_mask)
curr_point_coords, curr_point_labels = get_sam_input(clicker)
point_coords.append(curr_point_coords)
points_labels.append(curr_point_labels)
# 一次性转换最终结果
new_points_graco = torch.tensor(np.stack(point_coords, axis=0), device=device)
new_labels_graco = torch.tensor(np.stack(points_labels, axis=0), device=device)
return new_points_graco, new_labels_graco
def get_next_points_graco_torch(pred, gt, points, click_indx, pred_thresh=0.49):
"""PyTorch版本的get_next_points_graco,避免CPU-GPU传输"""
assert click_indx > 0
# 在GPU上进行所有操作
pred = pred[:, 0, :, :] # 不需要移动到CPU
gt = gt[:, 0, :, :]
# 使用PyTorch操作替代NumPy操作
fn_mask = torch.logical_and(gt, torch.logical_not(pred))
fp_mask = torch.logical_and(torch.logical_not(gt), pred)
# 克隆点,保持在GPU上
points = points.clone()
num_points = points.size(1) // 2
# 处理每个批次项
for bindx in range(fn_mask.shape[0]):
# 这部分需要转到CPU计算距离变换,然后马上返回GPU
# 这是唯一需要CPU的部分
fn_mask_cpu = fn_mask[bindx].cpu().numpy().astype(np.uint8)
fp_mask_cpu = fp_mask[bindx].cpu().numpy().astype(np.uint8)
# 填充和距离变换
fn_mask_cpu = np.pad(fn_mask_cpu, ((1, 1), (1, 1)), 'constant')
fp_mask_cpu = np.pad(fp_mask_cpu, ((1, 1), (1, 1)), 'constant')
fn_mask_dt = cv2.distanceTransform(fn_mask_cpu, cv2.DIST_L2, 5)[1:-1, 1:-1]
fp_mask_dt = cv2.distanceTransform(fp_mask_cpu, cv2.DIST_L2, 5)[1:-1, 1:-1]
fn_max_dist = np.max(fn_mask_dt)
fp_max_dist = np.max(fp_mask_dt)
is_positive = fn_max_dist > fp_max_dist
dt = fn_mask_dt if is_positive else fp_mask_dt
inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0
indices = np.argwhere(inner_mask)
if len(indices) > 0:
# 随机选择一个点
coords = indices[np.random.randint(0, len(indices))]
# 立即将结果应用到GPU张量
if is_positive:
points[bindx, num_points - click_indx, 0] = float(coords[0])
points[bindx, num_points - click_indx, 1] = float(coords[1])
else:
points[bindx, 2 * num_points - click_indx, 0] = float(coords[0])
points[bindx, 2 * num_points - click_indx, 1] = float(coords[1])
return points
def process_points_torch(points):
"""PyTorch版本的process_points,避免CPU-GPU传输"""
positive = points[:, :1, :]
negative = points[:, 1:, :]
# 预分配结果列表
batch_size = points.shape[0]
filtered_points = []
filtered_labels = []
for batch in range(batch_size):
# 处理正点
pos_mask = positive[batch, :, 0] != -1
pos_points = positive[batch, pos_mask, :2]
# 处理负点
neg_mask = negative[batch, :, 0] != -1
neg_points = negative[batch, neg_mask, :2]
# 交换x和y坐标
if pos_points.size(0) > 0:
pos_points = torch.stack([pos_points[:, 1], pos_points[:, 0]], dim=1)
if neg_points.size(0) > 0:
neg_points = torch.stack([neg_points[:, 1], neg_points[:, 0]], dim=1)
# 创建标签张量
pos_labels = torch.ones(pos_points.size(0), device=points.device, dtype=torch.int32)
neg_labels = torch.zeros(neg_points.size(0), device=points.device, dtype=torch.int32)
# 合并点和标签
batch_points = torch.cat([pos_points, neg_points], dim=0) if pos_points.size(0) > 0 and neg_points.size(0) > 0 else \
pos_points if pos_points.size(0) > 0 else neg_points
batch_labels = torch.cat([pos_labels, neg_labels], dim=0) if pos_labels.size(0) > 0 and neg_labels.size(0) > 0 else \
pos_labels if pos_labels.size(0) > 0 else neg_labels
filtered_points.append(batch_points)
filtered_labels.append(batch_labels)
return filtered_points, filtered_labels
def graco_sample(gt_masks, pred_masks, mode, click_indx, points):
# 为Clicker创建NumPy版本,但不修改原始gt_masks
device = points.device
gt_masks_np = [gt_mask.cpu().numpy() if torch.is_tensor(gt_mask) else gt_mask for gt_mask in gt_masks]
# gt_masks_np = np.stack(gt_masks_np, axis=0)
# 使用NumPy版本创建Clicker
clicker_list = [Clicker(gt_mask=gt_mask_np) for gt_mask_np in gt_masks_np]
# 对应的预测掩码也使用NumPy
# pred_masks_np = [pred_mask.cpu().numpy() if torch.is_tensor(pred_mask) else pred_mask for pred_mask in pred_masks]
# pred_masks_np = np.stack(pred_masks_np, axis=0)
point_coords = []
points_labels = []
## GraCo's sampling method
if mode == "train":
# 这里使用原始张量, TODO: need to change to raw logits as inputs
# prev_output = torch.sigmoid(pred_masks)
# 确保points变量已定义
# we should import points from sam2's initial sampling
# points = torch.zeros(gt_masks.shape[0], 2 * 20, 3, device=gt_masks.device) # 假设最多20个点
points = get_next_points_graco(pred_masks, gt_masks, points, click_indx + 1)
input_point, input_label = process_points(points.cpu().numpy())
new_points_graco = np.stack(input_point, axis=0)
new_labels_graco = np.stack(input_label, axis=0)
# 转回张量
return torch.from_numpy(new_points_graco).to(device), torch.from_numpy(new_labels_graco).to(device)
else:
for idx, gt_mask_np in enumerate(gt_masks_np):
clicker = clicker_list[idx]
pred_mask = pred_masks[idx,:,:].cpu().numpy()
clicker.make_next_click(pred_mask)
curr_point_coords, curr_point_labels = get_sam_input(clicker)
point_coords.append(curr_point_coords)
points_labels.append(curr_point_labels)
new_points_graco = np.stack(point_coords, axis=0)
new_labels_graco = np.stack(points_labels, axis=0)
# 转回张量
return torch.from_numpy(new_points_graco).to(device), torch.from_numpy(new_labels_graco).to(device)
# def graco_sample(gt_masks, pred_masks, mode, click_indx):
# clicker_list = [Clicker(gt_mask=gt_mask.cpu()) for gt_mask in gt_masks]
# pred_masks_list = [np.zeros_like(gt_mask) for gt_mask in gt_masks]
# point_coords = []
# points_labels = []
# ## GraCo's sampling method
# if mode == "train":
# prev_output = torch.sigmoid(pred_masks)
# points = get_next_points_graco(prev_output, gt_masks, points, click_indx + 1)
# input_point, input_label = process_points(points.cpu().numpy())
# new_points_graco = np.stack(input_point, axis=0)
# new_labels_graco = np.stack(input_label, axis=0)
# return new_points_graco.unsqueeze(1), new_labels_graco.unsqueeze(1)
# else:
# for idx, gt_mask in enumerate(gt_masks):
# clicker = clicker_list[idx]
# pred_mask = pred_masks_list[idx]
# clicker.make_next_click(pred_mask)
# curr_point_coords, curr_point_labels = get_sam_input(clicker)
# point_coords.append(curr_point_coords)
# points_labels.append(curr_point_labels)
# new_points_graco = np.stack(point_coords, axis=0)
# new_labels_graco = np.stack(points_labels, axis=0)
# return new_points_graco.unsqueeze(1), new_labels_graco.unsqueeze(1)