File size: 5,037 Bytes
08ff31f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
from __future__ import annotations

import argparse
from collections import defaultdict
from pathlib import Path

import imageio.v3 as iio
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
from PIL import ImageDraw
from tqdm import tqdm

ACTION_LABELS = ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]


def _episode_paths(dataset_root: Path, limit: int | None) -> list[Path]:
    paths = sorted((dataset_root / "data").glob("chunk-*/episode_*.parquet"))
    if limit is not None:
        paths = paths[:limit]
    if not paths:
        raise FileNotFoundError(f"No parquet episodes found under {dataset_root / 'data'}")
    return paths


def _stack_column(df: pd.DataFrame, name: str) -> np.ndarray:
    return np.stack(df[name].to_numpy()).astype(np.float32)


def _plot_actions(df: pd.DataFrame, out_path: Path) -> None:
    actions = _stack_column(df, "action")
    fig, axes = plt.subplots(7, 1, figsize=(12, 10), sharex=True)
    for i, ax in enumerate(axes):
        ax.plot(actions[:, i], linewidth=1.2)
        ax.set_ylabel(ACTION_LABELS[i])
        ax.grid(visible=True, alpha=0.2)
    title = f"episode={int(df['episode_index'].iloc[0])}"
    if "speed_label" in df:
        title += f"  speed={df['speed_label'].iloc[0]}"
    if "source_episode_index" in df:
        title += f"  source={int(df['source_episode_index'].iloc[0])}"
    axes[0].set_title(title)
    axes[-1].set_xlabel("controller step")
    fig.tight_layout()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=140)
    plt.close(fig)


def _find_video(dataset_root: Path, episode_index: int, key: str) -> Path | None:
    for chunk in sorted((dataset_root / "videos").glob("chunk-*")):
        candidate = chunk / key / f"episode_{episode_index:06d}.mp4"
        if candidate.exists():
            return candidate
    return None


def _contact_sheet(df: pd.DataFrame, dataset_root: Path, out_path: Path, video_key: str, samples: int) -> None:
    episode_index = int(df["episode_index"].iloc[0])
    video_path = _find_video(dataset_root, episode_index, video_key)
    if video_path is None:
        return

    frames = [np.asarray(frame) for frame in iio.imiter(video_path)]
    if not frames:
        return
    indices = np.linspace(0, len(frames) - 1, num=min(samples, len(frames)), dtype=int)
    thumbs = []
    for idx in indices:
        img = Image.fromarray(frames[idx]).resize((160, 160))
        draw = ImageDraw.Draw(img)
        mask = int(df["observation_mask"].iloc[idx]) if "observation_mask" in df else 1
        label = f"t={idx} m={mask}"
        draw.rectangle((0, 0, 78, 18), fill=(0, 0, 0))
        draw.text((4, 3), label, fill=(255, 255, 255))
        thumbs.append(img)

    cols = min(5, len(thumbs))
    rows = int(np.ceil(len(thumbs) / cols))
    sheet = Image.new("RGB", (cols * 160, rows * 160), color=(255, 255, 255))
    for i, thumb in enumerate(thumbs):
        sheet.paste(thumb, ((i % cols) * 160, (i // cols) * 160))
    out_path.parent.mkdir(parents=True, exist_ok=True)
    sheet.save(out_path)


def visualize(args: argparse.Namespace) -> None:
    dataset_root = Path(args.dataset).resolve()
    out_dir = Path(args.out).resolve()
    by_source: dict[int, list[Path]] = defaultdict(list)
    for path in _episode_paths(dataset_root, None):
        df_head = pd.read_parquet(path, columns=["episode_index", "source_episode_index"])
        source = (
            int(df_head["source_episode_index"].iloc[0])
            if "source_episode_index" in df_head
            else int(df_head["episode_index"].iloc[0])
        )
        by_source[source].append(path)

    selected = []
    for _source, paths in sorted(by_source.items()):
        selected.extend(paths)
        if len(selected) >= args.num_demos:
            break

    for path in tqdm(selected[: args.num_demos], desc="visualize"):
        episode_df = pd.read_parquet(path)
        episode_index = int(episode_df["episode_index"].iloc[0])
        speed_label = str(episode_df["speed_label"].iloc[0]) if "speed_label" in episode_df else "speed"
        stem = f"episode_{episode_index:06d}_{speed_label}"
        _plot_actions(episode_df, out_dir / f"{stem}_actions.png")
        _contact_sheet(
            episode_df,
            dataset_root,
            out_dir / f"{stem}_{args.video_key.replace('/', '_')}.jpg",
            args.video_key,
            args.frames,
        )
    print(f"Wrote visualizations to {out_dir}")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Visualize variable-speed LIBERO episodes.")
    parser.add_argument("--dataset", required=True)
    parser.add_argument("--out", required=True)
    parser.add_argument("--num-demos", type=int, default=20)
    parser.add_argument("--frames", type=int, default=10)
    parser.add_argument("--video-key", default="observation.images.image")
    return parser.parse_args()


if __name__ == "__main__":
    visualize(parse_args())