File size: 6,559 Bytes
0453c63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import cv2
import torch
import pickle
import argparse
import numpy as np
import warnings
from tqdm import tqdm
from pathlib import Path
from PIL import Image

from detectron2.data.detection_utils import read_image
from supervision import Detections, BoxAnnotator, MaskAnnotator, LabelAnnotator, mask_to_xyxy

from sam2.build_sam import build_sam2_video_predictor
from VLPart.build_vlpart import build_vlpart_model


warnings.filterwarnings('ignore')

# Constants
SAM2_CONFIG = "sam2_hiera_l.yaml"
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
OUTPUT_ROOT = "/data/robot-merlin/mask_vlpart+sam2_tracking"
OUTPUT_ROOT_IMG = "/data/robot-merlin/mask_vlpart+sam2_tracking_with_image"

# Set up torch environment
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

device = "cuda:0" if torch.cuda.is_available() else "cpu"


def load_affordance_data(pkl_path):
    """
    Load affordance data from a pickle file and organize it by video directory.
    Args:
        pkl_path (str): Path to the pickle file containing affordance data.
    Returns:
        dict: A dictionary where keys are video directory paths and values are lists of data entries.
    """
    with open(pkl_path, 'rb') as f:
        datas = pickle.load(f)

    data_dict = {}
    for data in datas:
        vid_path = os.path.dirname(data['frame_path'])
        data_dict.setdefault(vid_path, []).append(data)
    return data_dict


def init_vlpart_once(text, prev_text, vlpart_model):
    """
    Initialize VLPart model if the text has changed.
    """
    if text != prev_text:
        if vlpart_model is not None:
            del vlpart_model
        vlpart_model = build_vlpart_model(text)
    return vlpart_model, text


def run_vlpart_on_first_frame(vlpart_model, image_path):
    """
    Run VLPart model on the first frame to get bounding boxes.
    """
    img = read_image(image_path, format="BGR")
    predictions, _ = vlpart_model.run_on_image(img)
    if len(predictions["instances"]) != 1:
        return None
    return predictions["instances"].pred_boxes.tensor.cpu().numpy()


def run_sam2_tracking(video_dir, frame_names, sam2_predictor, boxes):
    """
    Run SAM2 tracking on the video frames using the provided bounding boxes.
    """
    inference_state = sam2_predictor.init_state(video_path=video_dir)
    sam2_predictor.reset_state(inference_state)

    _, obj_ids, mask_logits = sam2_predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=0,
        obj_id=1,
        box=boxes,
    )

    results = {}
    for frame_idx, out_ids, out_logits in sam2_predictor.propagate_in_video(inference_state):
        results[frame_idx] = {
            oid: (out_logits[i] > 0).cpu().numpy()
            for i, oid in enumerate(out_ids)
        }
    return results


def save_tracking_results(video_dir, frame_names, video_segments, object_name, output_base, vid):
    """
    Save the tracking results to the specified output directory.
    """
    objects = [object_name]
    id_to_objects = {i: obj for i, obj in enumerate(objects, start=1)}

    output_dir = Path(f"{output_base}/{vid:06d}")
    output_dir.mkdir(parents=True, exist_ok=True)

    output_dir_img = Path(f"{OUTPUT_ROOT_IMG}/{vid:06d}")
    output_dir_img.mkdir(parents=True, exist_ok=True)

    box_annotator = BoxAnnotator()
    label_annotator = LabelAnnotator()
    mask_annotator = MaskAnnotator()

    for idx, masks in video_segments.items():
        frame_path = os.path.join(video_dir, frame_names[idx])
        frame = cv2.imread(frame_path)

        obj_ids = list(masks.keys())
        mask_arr = np.concatenate(list(masks.values()), axis=0)

        detections = Detections(
            xyxy=mask_to_xyxy(mask_arr),
            mask=mask_arr,
            class_id=np.array(obj_ids, dtype=np.int32),
        )

        annotated = box_annotator.annotate(frame.copy(), detections)
        annotated = label_annotator.annotate(annotated, detections, [id_to_objects[i] for i in obj_ids])
        annotated = mask_annotator.annotate(annotated, detections)

        cv2.imwrite(str(output_dir_img / frame_names[idx]), annotated)
        cv2.imwrite(str(output_dir / frame_names[idx]), mask_arr[0] * 255)


def get_sorted_frame_names(video_dir):
    return sorted([
        f for f in os.listdir(video_dir)
        if f.lower().endswith(('.jpg', '.jpeg'))
    ], key=lambda name: int(os.path.splitext(name)[0]))


def main(openx_data, text_override=None):
    # You can reorganize the data loading logic as needed
    data_dict = load_affordance_data(f'./data/{openx_data}_for_affordance.pkl')

    # Initialize SAM2 predictor
    sam2_predictor = build_sam2_video_predictor(SAM2_CONFIG, SAM2_CHECKPOINT, device=device)

    prev_text = ''
    vlpart_model = None

    for video_dir, data_list in tqdm(data_dict.items()):
        first_sample = data_list[0]
        frame_path = first_sample['frame_path']
        task_class = first_sample['task_object_class']

        # Only process specific classes
        if not any(k in task_class for k in ['door', 'drawer', 'knife']):
            continue

        # Initialize VLPart model with the task class
        input_text = f"{task_class} handle" if not text_override else text_override
        vlpart_model, prev_text = init_vlpart_once(input_text, prev_text, vlpart_model)

        # Process the first frame to get bounding boxes
        boxes = run_vlpart_on_first_frame(vlpart_model, frame_path)
        if boxes is None:
            continue

        # Run SAM2 tracking on the video frames
        frame_names = get_sorted_frame_names(video_dir)
        segments = run_sam2_tracking(video_dir, frame_names, sam2_predictor, boxes)
        save_tracking_results(video_dir, frame_names, segments, input_text,
                              f"{OUTPUT_ROOT}/", first_sample['vid'])
        print(f"[Done] {frame_path} | {task_class}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser("VLPart + SAM2 Tracking Demo")
    parser.add_argument("--pipeline", type=str, default="referring_expression_segmentation", help="Pipeline task")
    parser.add_argument("--text_input", type=str, default=None, help="Optional override for input text")
    parser.add_argument("--dataset", type=str, default="bridge", help="Dataset name (e.g., bridge)")
    args = parser.parse_args()

    main(args.dataset, args.pipeline, args.text_input)