File size: 6,559 Bytes
0453c63 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | import os
import cv2
import torch
import pickle
import argparse
import numpy as np
import warnings
from tqdm import tqdm
from pathlib import Path
from PIL import Image
from detectron2.data.detection_utils import read_image
from supervision import Detections, BoxAnnotator, MaskAnnotator, LabelAnnotator, mask_to_xyxy
from sam2.build_sam import build_sam2_video_predictor
from VLPart.build_vlpart import build_vlpart_model
warnings.filterwarnings('ignore')
# Constants
SAM2_CONFIG = "sam2_hiera_l.yaml"
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
OUTPUT_ROOT = "/data/robot-merlin/mask_vlpart+sam2_tracking"
OUTPUT_ROOT_IMG = "/data/robot-merlin/mask_vlpart+sam2_tracking_with_image"
# Set up torch environment
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"
def load_affordance_data(pkl_path):
"""
Load affordance data from a pickle file and organize it by video directory.
Args:
pkl_path (str): Path to the pickle file containing affordance data.
Returns:
dict: A dictionary where keys are video directory paths and values are lists of data entries.
"""
with open(pkl_path, 'rb') as f:
datas = pickle.load(f)
data_dict = {}
for data in datas:
vid_path = os.path.dirname(data['frame_path'])
data_dict.setdefault(vid_path, []).append(data)
return data_dict
def init_vlpart_once(text, prev_text, vlpart_model):
"""
Initialize VLPart model if the text has changed.
"""
if text != prev_text:
if vlpart_model is not None:
del vlpart_model
vlpart_model = build_vlpart_model(text)
return vlpart_model, text
def run_vlpart_on_first_frame(vlpart_model, image_path):
"""
Run VLPart model on the first frame to get bounding boxes.
"""
img = read_image(image_path, format="BGR")
predictions, _ = vlpart_model.run_on_image(img)
if len(predictions["instances"]) != 1:
return None
return predictions["instances"].pred_boxes.tensor.cpu().numpy()
def run_sam2_tracking(video_dir, frame_names, sam2_predictor, boxes):
"""
Run SAM2 tracking on the video frames using the provided bounding boxes.
"""
inference_state = sam2_predictor.init_state(video_path=video_dir)
sam2_predictor.reset_state(inference_state)
_, obj_ids, mask_logits = sam2_predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=0,
obj_id=1,
box=boxes,
)
results = {}
for frame_idx, out_ids, out_logits in sam2_predictor.propagate_in_video(inference_state):
results[frame_idx] = {
oid: (out_logits[i] > 0).cpu().numpy()
for i, oid in enumerate(out_ids)
}
return results
def save_tracking_results(video_dir, frame_names, video_segments, object_name, output_base, vid):
"""
Save the tracking results to the specified output directory.
"""
objects = [object_name]
id_to_objects = {i: obj for i, obj in enumerate(objects, start=1)}
output_dir = Path(f"{output_base}/{vid:06d}")
output_dir.mkdir(parents=True, exist_ok=True)
output_dir_img = Path(f"{OUTPUT_ROOT_IMG}/{vid:06d}")
output_dir_img.mkdir(parents=True, exist_ok=True)
box_annotator = BoxAnnotator()
label_annotator = LabelAnnotator()
mask_annotator = MaskAnnotator()
for idx, masks in video_segments.items():
frame_path = os.path.join(video_dir, frame_names[idx])
frame = cv2.imread(frame_path)
obj_ids = list(masks.keys())
mask_arr = np.concatenate(list(masks.values()), axis=0)
detections = Detections(
xyxy=mask_to_xyxy(mask_arr),
mask=mask_arr,
class_id=np.array(obj_ids, dtype=np.int32),
)
annotated = box_annotator.annotate(frame.copy(), detections)
annotated = label_annotator.annotate(annotated, detections, [id_to_objects[i] for i in obj_ids])
annotated = mask_annotator.annotate(annotated, detections)
cv2.imwrite(str(output_dir_img / frame_names[idx]), annotated)
cv2.imwrite(str(output_dir / frame_names[idx]), mask_arr[0] * 255)
def get_sorted_frame_names(video_dir):
return sorted([
f for f in os.listdir(video_dir)
if f.lower().endswith(('.jpg', '.jpeg'))
], key=lambda name: int(os.path.splitext(name)[0]))
def main(openx_data, text_override=None):
# You can reorganize the data loading logic as needed
data_dict = load_affordance_data(f'./data/{openx_data}_for_affordance.pkl')
# Initialize SAM2 predictor
sam2_predictor = build_sam2_video_predictor(SAM2_CONFIG, SAM2_CHECKPOINT, device=device)
prev_text = ''
vlpart_model = None
for video_dir, data_list in tqdm(data_dict.items()):
first_sample = data_list[0]
frame_path = first_sample['frame_path']
task_class = first_sample['task_object_class']
# Only process specific classes
if not any(k in task_class for k in ['door', 'drawer', 'knife']):
continue
# Initialize VLPart model with the task class
input_text = f"{task_class} handle" if not text_override else text_override
vlpart_model, prev_text = init_vlpart_once(input_text, prev_text, vlpart_model)
# Process the first frame to get bounding boxes
boxes = run_vlpart_on_first_frame(vlpart_model, frame_path)
if boxes is None:
continue
# Run SAM2 tracking on the video frames
frame_names = get_sorted_frame_names(video_dir)
segments = run_sam2_tracking(video_dir, frame_names, sam2_predictor, boxes)
save_tracking_results(video_dir, frame_names, segments, input_text,
f"{OUTPUT_ROOT}/", first_sample['vid'])
print(f"[Done] {frame_path} | {task_class}")
if __name__ == "__main__":
parser = argparse.ArgumentParser("VLPart + SAM2 Tracking Demo")
parser.add_argument("--pipeline", type=str, default="referring_expression_segmentation", help="Pipeline task")
parser.add_argument("--text_input", type=str, default=None, help="Optional override for input text")
parser.add_argument("--dataset", type=str, default="bridge", help="Dataset name (e.g., bridge)")
args = parser.parse_args()
main(args.dataset, args.pipeline, args.text_input)
|