import logging import numpy as np import torch import torch.distributed from sam2.modeling.sam2_base import SAM2Base from sam2.modeling.sam2_utils import ( get_1d_sine_pe, get_next_point, sample_box_points, select_closest_cond_frames, ) from sam2.utils.misc import concat_points from training.utils.data_utils import BatchedVideoDatapoint import random import sys sys.path.append('/home/yujunwei/sam2/GraCo') from isegm.inference.clicker import Clicker # from training.utils.GraCo.isegm.inference.evaluation import get_sam_input import cv2 def get_points_nd(clicks_lists): total_clicks = [] num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] num_max_points = max(num_pos_clicks + num_neg_clicks) num_max_points = max(1, num_max_points) for clicks_list in clicks_lists: pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] total_clicks.append(pos_clicks + neg_clicks) return total_clicks def get_sam_input(clicker, reverse=True): clicks_list = clicker.get_clicks() points_nd = get_points_nd([clicks_list]) point_length = len(points_nd[0]) // 2 point_coords = [] point_labels = [] for i, point in enumerate(points_nd[0]): if point[0] == -1: continue if i < point_length: point_labels.append(1) else: point_labels.append(0) if reverse: point_coords.append([point[1], point[0]]) # for SAM return np.array(point_coords), np.array(point_labels) def _iter_correct_pt_sampling_graco( self, is_init_cond_frame, point_inputs, gt_masks, high_res_features, pix_feat_with_mem, low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, object_score_logits, current_out, ): assert gt_masks is not None all_pred_masks = [low_res_masks] all_pred_high_res_masks = [high_res_masks] all_pred_multimasks = [low_res_multimasks] all_pred_high_res_multimasks = [high_res_multimasks] all_pred_ious = [ious] all_point_inputs = [point_inputs] all_object_score_logits = [object_score_logits] clicker_list = [Clicker(gt_mask=gt_mask) for gt_mask in gt_masks] pred_masks_list = [np.zeros_like(gt_mask) for gt_mask in gt_masks] point_coords = [] points_labels = [] for click_indx in range(self.num_correction_pt_per_frame): # sample a new point from the error between prediction and ground-truth # (with a small probability, directly sample from GT masks instead of errors) if self.training and self.prob_to_sample_from_gt_for_train > 0: sample_from_gt = ( self.rng.random() < self.prob_to_sample_from_gt_for_train ) else: sample_from_gt = False # if `pred_for_new_pt` is None, only GT masks will be used for point sampling pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0) new_points, new_labels = get_next_point( gt_masks=gt_masks, pred_masks=pred_for_new_pt, method="uniform" if self.training else self.pt_sampling_for_eval, ) point_inputs = concat_points(point_inputs, new_points, new_labels) # Feed the mask logits of the previous SAM outputs in the next SAM decoder step. # For tracking, this means that when the user adds a correction click, we also feed # the tracking output mask logits along with the click as input to the SAM decoder. mask_inputs = low_res_masks multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) if self.use_act_ckpt_iterative_pt_sampling and not multimask_output: sam_outputs = torch.utils.checkpoint.checkpoint( self._forward_sam_heads, backbone_features=pix_feat_with_mem, point_inputs=point_inputs, mask_inputs=mask_inputs, high_res_features=high_res_features, multimask_output=multimask_output, use_reentrant=False, ) else: sam_outputs = self._forward_sam_heads( backbone_features=pix_feat_with_mem, point_inputs=point_inputs, mask_inputs=mask_inputs, high_res_features=high_res_features, multimask_output=multimask_output, ) ( low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, _, object_score_logits, ) = sam_outputs all_pred_masks.append(low_res_masks) all_pred_high_res_masks.append(high_res_masks) all_pred_multimasks.append(low_res_multimasks) all_pred_high_res_multimasks.append(high_res_multimasks) all_pred_ious.append(ious) all_point_inputs.append(point_inputs) all_object_score_logits.append(object_score_logits) # Concatenate the masks along channel (to compute losses on all of them, # using `MultiStepIteractiveMasks`) current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1) current_out["multistep_pred_masks_high_res"] = torch.cat( all_pred_high_res_masks, dim=1 ) current_out["multistep_pred_multimasks"] = all_pred_multimasks current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks current_out["multistep_pred_ious"] = all_pred_ious current_out["multistep_point_inputs"] = all_point_inputs current_out["multistep_object_score_logits"] = all_object_score_logits return point_inputs, sam_outputs def process_points(points): positive = points[:, :1, :] negative = points[:, 1:, :] filtered_points = [] filtered_labels = [] for batch in range(points.shape[0]): batch_points = [] batch_labels = [] for point in positive[batch]: if point[0] != -1: point_y, point_x = point[:2] batch_points.append([point_x, point_y]) batch_labels.append(1) for point in negative[batch]: if point[0] != -1: point_y, point_x = point[:2] batch_points.append([point_x, point_y]) batch_labels.append(0) filtered_points.append(np.array(batch_points)) filtered_labels.append(np.array(batch_labels)) return filtered_points, filtered_labels def get_next_points_graco(pred, gt, points, click_indx, pred_thresh=0.49): assert click_indx > 0 pred = pred.cpu().numpy()[:, 0, :, :] gt = gt.cpu().numpy()[:, 0, :, :] # pred = pred[:, 0, :, :] # gt = gt[:, 0, :, :] # fn_mask = np.logical_and(gt, pred < pred_thresh) # fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh) fn_mask = np.logical_and(gt, np.logical_not(pred)) fp_mask = np.logical_and(np.logical_not(gt), pred) fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), 'constant').astype(np.uint8) num_points = points.size(1) // 2 points = points.clone() for bindx in range(fn_mask.shape[0]): fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] fn_max_dist = np.max(fn_mask_dt) fp_max_dist = np.max(fp_mask_dt) is_positive = fn_max_dist > fp_max_dist dt = fn_mask_dt if is_positive else fp_mask_dt inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0 indices = np.argwhere(inner_mask) if len(indices) > 0: coords = indices[np.random.randint(0, len(indices))] if is_positive: points[bindx, num_points - click_indx, 0] = float(coords[0]) points[bindx, num_points - click_indx, 1] = float(coords[1]) # points[bindx, num_points - click_indx, 2] = float(click_indx) else: points[bindx, 2 * num_points - click_indx, 0] = float(coords[0]) points[bindx, 2 * num_points - click_indx, 1] = float(coords[1]) # points[bindx, 2 * num_points - click_indx, 2] = float(click_indx) return points def graco_sample_optimized(gt_masks, pred_masks, mode, click_indx, points): """优化版本的graco_sample函数,减少torch-numpy转换""" device = points.device if mode == "train": # 完全在GPU上实现get_next_points_graco功能 points = get_next_points_graco_torch(pred_masks, gt_masks, points, click_indx + 1) # 处理点数据,无需CPU转换 filtered_points_list, filtered_labels_list = process_points_torch(points) # 将列表转换为批次的张量 # 注:这里假设所有批次的点数量相同,如果不同需要padding new_points_graco = torch.stack(filtered_points_list, dim=0) new_labels_graco = torch.stack(filtered_labels_list, dim=0) return new_points_graco, new_labels_graco else: # 评估模式处理保持不变 gt_masks_np = [gt_mask.cpu().numpy() for gt_mask in gt_masks] clicker_list = [Clicker(gt_mask=gt_mask_np) for gt_mask_np in gt_masks_np] point_coords = [] points_labels = [] for idx, gt_mask_np in enumerate(gt_masks_np): clicker = clicker_list[idx] pred_mask = pred_masks[idx].cpu().numpy() clicker.make_next_click(pred_mask) curr_point_coords, curr_point_labels = get_sam_input(clicker) point_coords.append(curr_point_coords) points_labels.append(curr_point_labels) # 一次性转换最终结果 new_points_graco = torch.tensor(np.stack(point_coords, axis=0), device=device) new_labels_graco = torch.tensor(np.stack(points_labels, axis=0), device=device) return new_points_graco, new_labels_graco def get_next_points_graco_torch(pred, gt, points, click_indx, pred_thresh=0.49): """PyTorch版本的get_next_points_graco,避免CPU-GPU传输""" assert click_indx > 0 # 在GPU上进行所有操作 pred = pred[:, 0, :, :] # 不需要移动到CPU gt = gt[:, 0, :, :] # 使用PyTorch操作替代NumPy操作 fn_mask = torch.logical_and(gt, torch.logical_not(pred)) fp_mask = torch.logical_and(torch.logical_not(gt), pred) # 克隆点,保持在GPU上 points = points.clone() num_points = points.size(1) // 2 # 处理每个批次项 for bindx in range(fn_mask.shape[0]): # 这部分需要转到CPU计算距离变换,然后马上返回GPU # 这是唯一需要CPU的部分 fn_mask_cpu = fn_mask[bindx].cpu().numpy().astype(np.uint8) fp_mask_cpu = fp_mask[bindx].cpu().numpy().astype(np.uint8) # 填充和距离变换 fn_mask_cpu = np.pad(fn_mask_cpu, ((1, 1), (1, 1)), 'constant') fp_mask_cpu = np.pad(fp_mask_cpu, ((1, 1), (1, 1)), 'constant') fn_mask_dt = cv2.distanceTransform(fn_mask_cpu, cv2.DIST_L2, 5)[1:-1, 1:-1] fp_mask_dt = cv2.distanceTransform(fp_mask_cpu, cv2.DIST_L2, 5)[1:-1, 1:-1] fn_max_dist = np.max(fn_mask_dt) fp_max_dist = np.max(fp_mask_dt) is_positive = fn_max_dist > fp_max_dist dt = fn_mask_dt if is_positive else fp_mask_dt inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0 indices = np.argwhere(inner_mask) if len(indices) > 0: # 随机选择一个点 coords = indices[np.random.randint(0, len(indices))] # 立即将结果应用到GPU张量 if is_positive: points[bindx, num_points - click_indx, 0] = float(coords[0]) points[bindx, num_points - click_indx, 1] = float(coords[1]) else: points[bindx, 2 * num_points - click_indx, 0] = float(coords[0]) points[bindx, 2 * num_points - click_indx, 1] = float(coords[1]) return points def process_points_torch(points): """PyTorch版本的process_points,避免CPU-GPU传输""" positive = points[:, :1, :] negative = points[:, 1:, :] # 预分配结果列表 batch_size = points.shape[0] filtered_points = [] filtered_labels = [] for batch in range(batch_size): # 处理正点 pos_mask = positive[batch, :, 0] != -1 pos_points = positive[batch, pos_mask, :2] # 处理负点 neg_mask = negative[batch, :, 0] != -1 neg_points = negative[batch, neg_mask, :2] # 交换x和y坐标 if pos_points.size(0) > 0: pos_points = torch.stack([pos_points[:, 1], pos_points[:, 0]], dim=1) if neg_points.size(0) > 0: neg_points = torch.stack([neg_points[:, 1], neg_points[:, 0]], dim=1) # 创建标签张量 pos_labels = torch.ones(pos_points.size(0), device=points.device, dtype=torch.int32) neg_labels = torch.zeros(neg_points.size(0), device=points.device, dtype=torch.int32) # 合并点和标签 batch_points = torch.cat([pos_points, neg_points], dim=0) if pos_points.size(0) > 0 and neg_points.size(0) > 0 else \ pos_points if pos_points.size(0) > 0 else neg_points batch_labels = torch.cat([pos_labels, neg_labels], dim=0) if pos_labels.size(0) > 0 and neg_labels.size(0) > 0 else \ pos_labels if pos_labels.size(0) > 0 else neg_labels filtered_points.append(batch_points) filtered_labels.append(batch_labels) return filtered_points, filtered_labels def graco_sample(gt_masks, pred_masks, mode, click_indx, points): # 为Clicker创建NumPy版本,但不修改原始gt_masks device = points.device gt_masks_np = [gt_mask.cpu().numpy() if torch.is_tensor(gt_mask) else gt_mask for gt_mask in gt_masks] # gt_masks_np = np.stack(gt_masks_np, axis=0) # 使用NumPy版本创建Clicker clicker_list = [Clicker(gt_mask=gt_mask_np) for gt_mask_np in gt_masks_np] # 对应的预测掩码也使用NumPy # pred_masks_np = [pred_mask.cpu().numpy() if torch.is_tensor(pred_mask) else pred_mask for pred_mask in pred_masks] # pred_masks_np = np.stack(pred_masks_np, axis=0) point_coords = [] points_labels = [] ## GraCo's sampling method if mode == "train": # 这里使用原始张量, TODO: need to change to raw logits as inputs # prev_output = torch.sigmoid(pred_masks) # 确保points变量已定义 # we should import points from sam2's initial sampling # points = torch.zeros(gt_masks.shape[0], 2 * 20, 3, device=gt_masks.device) # 假设最多20个点 points = get_next_points_graco(pred_masks, gt_masks, points, click_indx + 1) input_point, input_label = process_points(points.cpu().numpy()) new_points_graco = np.stack(input_point, axis=0) new_labels_graco = np.stack(input_label, axis=0) # 转回张量 return torch.from_numpy(new_points_graco).to(device), torch.from_numpy(new_labels_graco).to(device) else: for idx, gt_mask_np in enumerate(gt_masks_np): clicker = clicker_list[idx] pred_mask = pred_masks[idx,:,:].cpu().numpy() clicker.make_next_click(pred_mask) curr_point_coords, curr_point_labels = get_sam_input(clicker) point_coords.append(curr_point_coords) points_labels.append(curr_point_labels) new_points_graco = np.stack(point_coords, axis=0) new_labels_graco = np.stack(points_labels, axis=0) # 转回张量 return torch.from_numpy(new_points_graco).to(device), torch.from_numpy(new_labels_graco).to(device) # def graco_sample(gt_masks, pred_masks, mode, click_indx): # clicker_list = [Clicker(gt_mask=gt_mask.cpu()) for gt_mask in gt_masks] # pred_masks_list = [np.zeros_like(gt_mask) for gt_mask in gt_masks] # point_coords = [] # points_labels = [] # ## GraCo's sampling method # if mode == "train": # prev_output = torch.sigmoid(pred_masks) # points = get_next_points_graco(prev_output, gt_masks, points, click_indx + 1) # input_point, input_label = process_points(points.cpu().numpy()) # new_points_graco = np.stack(input_point, axis=0) # new_labels_graco = np.stack(input_label, axis=0) # return new_points_graco.unsqueeze(1), new_labels_graco.unsqueeze(1) # else: # for idx, gt_mask in enumerate(gt_masks): # clicker = clicker_list[idx] # pred_mask = pred_masks_list[idx] # clicker.make_next_click(pred_mask) # curr_point_coords, curr_point_labels = get_sam_input(clicker) # point_coords.append(curr_point_coords) # points_labels.append(curr_point_labels) # new_points_graco = np.stack(point_coords, axis=0) # new_labels_graco = np.stack(points_labels, axis=0) # return new_points_graco.unsqueeze(1), new_labels_graco.unsqueeze(1)