File size: 10,782 Bytes
06c11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# -*- coding: utf-8 -*-
# Common utility for saving videos with captions during reset phase (demonstration phase).

import os
from typing import Dict, List, Any, Optional, Tuple

import numpy as np
import cv2
import imageio
import torch

from ...logging_utils import logger

TEXT_AREA_HEIGHT = 60

# Tasks without subgoal demonstration: do not add red border when saving video (no demonstration phase to highlight)
NO_HIGHLIGHT_BORDER_ENV_IDS = frozenset({
    "SwingXtimes",
    "PickXtimes",
    "ButtonUnmaskSwap",
    "StopCube",
    "PickHighlight",
    "ButtonUnmask",
    "BinFill",
})


def _frame_to_numpy(frame: Any) -> np.ndarray:
    """Convert frame-like input to CPU numpy array for OpenCV/imageio writing."""
    if isinstance(frame, torch.Tensor):
        frame = frame.detach()
        if frame.is_cuda:
            frame = frame.cpu()
        frame = frame.numpy()
    else:
        frame = np.asarray(frame)
    return frame


def _flatten_column(batch_dict: Dict[str, List[Any]], key: str) -> List[Any]:
    out = []
    for item in (batch_dict or {}).get(key, []) or []:
        if item is None:
            continue
        if isinstance(item, (list, tuple)):
            out.extend([x for x in item if x is not None])
        else:
            out.append(item)
    return out


def _ensure_rgb(frame: np.ndarray) -> np.ndarray:
    if frame.ndim == 2:
        return cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
    if frame.ndim == 3 and frame.shape[2] == 1:
        return cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
    return frame


def _resize_to_height(frame: np.ndarray, target_h: int) -> np.ndarray:
    if frame.shape[0] == target_h:
        return frame
    scale = target_h / max(1, frame.shape[0])
    target_w = max(1, int(round(frame.shape[1] * scale)))
    return cv2.resize(frame, (target_w, target_h), interpolation=cv2.INTER_LINEAR)


def _concat_horizontal(frames: List[Any]) -> np.ndarray:
    rgb_frames = [_ensure_rgb(_frame_to_numpy(frame)) for frame in frames if frame is not None]
    if not rgb_frames:
        return np.zeros((1, 1, 3), dtype=np.uint8)
    target_h = rgb_frames[0].shape[0]
    aligned = [_resize_to_height(frame, target_h) for frame in rgb_frames]
    return np.hstack(aligned)


def add_text_to_frame(
    frame: np.ndarray,
    text: Any,
    text_area_height: int = TEXT_AREA_HEIGHT,
) -> np.ndarray:
    """Overlay caption at the top of the frame (white text on black background), style consistent with DemonstrationWrapper.save_video."""
    frame = _frame_to_numpy(frame).copy()
    if frame.ndim == 2:
        frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
    if text is None:
        text = ""
    if isinstance(text, (list, tuple)):
        text = " | ".join(str(t).strip() for t in text if t)
    text = str(text).strip()
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.3
    thickness = 1
    max_width = max(1, frame.shape[1] - 20)
    lines = []
    if text:
        words = text.replace(",", " ").split()
        if words:
            current_line = words[0]
            for word in words[1:]:
                test_line = f"{current_line} {word}"
                (text_width, _), _ = cv2.getTextSize(test_line, font, font_scale, thickness)
                if text_width <= max_width:
                    current_line = test_line
                else:
                    lines.append(current_line)
                    current_line = word
            lines.append(current_line)
    if not lines:
        text_area = np.zeros((text_area_height, frame.shape[1], 3), dtype=np.uint8)
        return np.vstack((text_area, frame))
    line_height = 20
    text_area = np.zeros((text_area_height, frame.shape[1], 3), dtype=np.uint8)
    text_area[:] = (0, 0, 0)
    max_visible_lines = (text_area_height - 15) // line_height
    for i, line in enumerate(lines[:max_visible_lines]):
        y_position = 15 + i * line_height
        cv2.putText(text_area, line, (10, y_position), font, font_scale, (255, 255, 255), thickness)
    return np.vstack((text_area, frame))


def add_border_to_frame(
    frame: np.ndarray,
    color: Tuple[int, int, int] = (255, 0, 0),
    thickness: int = 4,
) -> np.ndarray:
    """Overlay border around the frame (RGB color)."""
    frame = _frame_to_numpy(frame).copy()
    if frame.ndim == 2:
        frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
    if frame.ndim != 3 or frame.shape[2] != 3:
        return frame

    h, w = frame.shape[:2]
    t = max(1, int(thickness))
    t = min(t, h // 2 if h > 1 else 1, w // 2 if w > 1 else 1)

    frame[:t, :, :] = color
    frame[-t:, :, :] = color
    frame[:, :t, :] = color
    frame[:, -t:, :] = color
    return frame


def save_robomme_video(
    reset_base_frames: List[Any],
    reset_wrist_frames: List[Any],
    rollout_base_frames: List[Any],
    rollout_wrist_frames: List[Any],
    reset_subgoal_grounded: List[Any],
    rollout_subgoal_grounded: List[Any],
    out_video_dir: str,
    action_space: str,
    env_id: str,
    episode: int,
    episode_success: bool,
    reset_right_frames: Optional[List[Any]] = None,
    rollout_right_frames: Optional[List[Any]] = None,
    reset_far_right_frames: Optional[List[Any]] = None,
    rollout_far_right_frames: Optional[List[Any]] = None,
    fps: int = 20,
    highlight_color: Tuple[int, int, int] = (255, 0, 0),
    highlight_thickness: int = 4,
) -> bool:
    """
    Unified method to save replay videos (including reset prefix highlighting, naming, and output path concatenation).
    Whether to draw the red border on reset-phase frames is determined by env_id: tasks in NO_HIGHLIGHT_BORDER_ENV_IDS (no subgoal demonstration) do not get the border.

    Args:
        reset_base_frames/reset_wrist_frames/reset_right_frames/reset_far_right_frames: List of camera frames from reset phase.
        rollout_base_frames/rollout_wrist_frames/rollout_right_frames/rollout_far_right_frames: List of camera frames from rollout phase.
        reset_subgoal_grounded/rollout_subgoal_grounded: List of captions for corresponding phases.
        out_video_dir: Output directory.
        action_space: Current action space, used for filename prefix.
        env_id: Environment ID (used to decide if highlight border is drawn; see NO_HIGHLIGHT_BORDER_ENV_IDS).
        episode: Current episode index.
        episode_success: Whether the current episode was successful.
        fps: Output frame rate.
        highlight_color: Border color (RGB).
        highlight_thickness: Border thickness (pixels).

    Returns:
        True if at least one frame was written, False otherwise.
    """
    success_prefix = "success" if episode_success else "fail"
    mode_prefix = action_space
    out_video_path = os.path.join(
        out_video_dir,
        f"{success_prefix}_replay_{mode_prefix}_{env_id}_ep{episode}.mp4",
    )

    reset_base_frames = list(reset_base_frames or [])
    reset_wrist_frames = list(reset_wrist_frames or [])
    reset_right_frames = list(reset_right_frames or [])
    reset_far_right_frames = list(reset_far_right_frames or [])
    rollout_base_frames = list(rollout_base_frames or [])
    rollout_wrist_frames = list(rollout_wrist_frames or [])
    rollout_right_frames = list(rollout_right_frames or [])
    rollout_far_right_frames = list(rollout_far_right_frames or [])
    reset_subgoal_grounded = list(reset_subgoal_grounded or [])
    rollout_subgoal_grounded = list(rollout_subgoal_grounded or [])

    merged_base_frames = reset_base_frames + rollout_base_frames
    merged_wrist_frames = reset_wrist_frames + rollout_wrist_frames
    merged_right_frames = reset_right_frames + rollout_right_frames
    merged_far_right_frames = reset_far_right_frames + rollout_far_right_frames
    merged_subgoal_grounded = reset_subgoal_grounded + rollout_subgoal_grounded

    if not (
        merged_base_frames
        or merged_wrist_frames
        or merged_right_frames
        or merged_far_right_frames
    ):
        logger.debug(f"Skipped video (no frames): {out_video_path}")
        return False

    base_camera = [_frame_to_numpy(f) for f in merged_base_frames if f is not None]
    wrist_camera = [_frame_to_numpy(f) for f in merged_wrist_frames if f is not None]
    right_camera = [_frame_to_numpy(f) for f in merged_right_frames if f is not None]
    far_right_camera = [_frame_to_numpy(f) for f in merged_far_right_frames if f is not None]
    reset_base_camera = [_frame_to_numpy(f) for f in reset_base_frames if f is not None]
    reset_wrist_camera = [_frame_to_numpy(f) for f in reset_wrist_frames if f is not None]
    reset_right_camera = [_frame_to_numpy(f) for f in reset_right_frames if f is not None]
    reset_far_right_camera = [
        _frame_to_numpy(f) for f in reset_far_right_frames if f is not None
    ]

    merged_feeds: List[Tuple[List[np.ndarray], List[np.ndarray]]] = []
    if base_camera:
        merged_feeds.append((base_camera, reset_base_camera))
    if wrist_camera:
        merged_feeds.append((wrist_camera, reset_wrist_camera))
    if right_camera:
        merged_feeds.append((right_camera, reset_right_camera))
    if far_right_camera:
        merged_feeds.append((far_right_camera, reset_far_right_camera))
    if not merged_feeds:
        logger.debug(f"Skipped video (no valid frames): {out_video_path}")
        return False

    n_frames = min(len(feed) for feed, _ in merged_feeds)
    if n_frames <= 0:
        logger.debug(f"Skipped video (no aligned frames): {out_video_path}")
        return False
    image = [
        _concat_horizontal([feed[i] for feed, _ in merged_feeds])
        for i in range(n_frames)
    ]
    reset_frame_count = min(len(reset_feed) for _, reset_feed in merged_feeds)
    reset_frame_count = max(0, min(reset_frame_count, n_frames))

    draw_highlight_border = env_id not in NO_HIGHLIGHT_BORDER_ENV_IDS
    highlight_prefix_count = reset_frame_count if draw_highlight_border else 0

    subgoal_grounded = [text for text in merged_subgoal_grounded if text is not None]

    out_dir = os.path.dirname(os.path.abspath(out_video_path))
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)

    with imageio.get_writer(out_video_path, fps=fps, codec="libx264", quality=8) as writer:
        for i in range(n_frames):
            frame = _frame_to_numpy(image[i])
            caption = subgoal_grounded[i] if i < len(subgoal_grounded) else ""
            combined = add_text_to_frame(frame, caption)
            if i < max(0, int(highlight_prefix_count)):
                combined = add_border_to_frame(
                    combined,
                    color=highlight_color,
                    thickness=highlight_thickness,
                )
            writer.append_data(combined)

    logger.debug(f"Saved video: {out_video_path}")
    return True