TTI / Dev /visual_prompting /generate_video.py
JosephBai's picture
Upload folder using huggingface_hub
857c2e9 verified
import os
import sys
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from xml.etree import ElementTree as ET
from libero.libero import benchmark, get_libero_path
import h5py
import imageio
from PIL import Image
from arrowline import *
def _figure_to_rgb(fig):
fig.canvas.draw()
canvas_width, canvas_height = fig.canvas.get_width_height()
if hasattr(fig.canvas, "tostring_rgb"):
buffer = fig.canvas.tostring_rgb()
rgb = np.frombuffer(buffer, dtype=np.uint8).reshape(canvas_height, canvas_width, 3)
else:
buffer = fig.canvas.buffer_rgba()
rgba = np.frombuffer(buffer, dtype=np.uint8).reshape(canvas_height, canvas_width, 4)
rgb = rgba[..., :3]
return rgb
def _load_demo_data(demo_file, demo_file_for_xml, include_trajectory):
with h5py.File(demo_file_for_xml, "r") as f:
model_xml = f["data"]["demo_0"].attrs.get("model_file")
if isinstance(model_xml, bytes):
model_xml = model_xml.decode("utf-8")
with h5py.File(demo_file, "r") as f:
images = f["data"]["demo_0"]["obs"]["agentview_rgb"][()]
eef_positions = None
if include_trajectory:
obs_group = f["data"]["demo_0"]["obs"]
if "ee_pos" in obs_group:
eef_positions = obs_group["ee_pos"][()]
else:
print("Dataset missing 'ee_pos' trajectory; skipping overlay.")
return images, eef_positions, model_xml
def _quat_to_rotation_matrix(quat):
q = np.asarray(quat, dtype=np.float64)
if q.shape != (4,):
raise ValueError("Quaternion must have four components.")
norm = np.linalg.norm(q)
if norm == 0:
raise ValueError("Quaternion norm must be positive.")
q = q / norm
w, x, y, z = q
return np.array(
[
[1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y)],
[2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x)],
[2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y)],
],
dtype=np.float64,
)
def _get_camera_parameters(model_xml, camera_name, image_shape):
if model_xml is None:
raise ValueError("Model XML is not available in the demonstration file.")
root = ET.fromstring(model_xml)
camera_elem = root.find(f".//camera[@name='{camera_name}']")
if camera_elem is None:
raise KeyError(f"Camera '{camera_name}' not found in the MuJoCo model.")
pos = np.fromstring(camera_elem.attrib.get("pos", "0 0 0"), sep=" ", dtype=np.float64)
if pos.size != 3:
raise ValueError(f"Camera '{camera_name}' position must have three components.")
quat_attr = camera_elem.attrib.get("quat", "1 0 0 0")
quat = np.fromstring(quat_attr, sep=" ", dtype=np.float64)
if quat.size != 4:
raise ValueError(f"Camera '{camera_name}' quaternion must have four components.")
rotation = _quat_to_rotation_matrix(quat)
fovy_deg = float(camera_elem.attrib.get("fovy", 45.0))
height, width = image_shape
fovy_rad = np.deg2rad(fovy_deg)
fy = 0.5 * height / np.tan(0.5 * fovy_rad)
fx = fy
cx = 0.5 * (width - 1)
cy = 0.5 * (height - 1)
return {
"position": pos,
"rotation": rotation,
"fx": fx,
"fy": fy,
"cx": cx,
"cy": cy,
"width": width,
"height": height,
}
def _project_points_to_image(points, camera_params):
if points is None or len(points) == 0:
return np.empty((0, 2), dtype=np.float64)
points = np.asarray(points, dtype=np.float64)
rel = points - camera_params["position"]
cam_coords = rel @ camera_params["rotation"]
depth = -cam_coords[:, 2]
with np.errstate(divide="ignore", invalid="ignore"):
u = camera_params["fx"] * (cam_coords[:, 0] / depth) + camera_params["cx"]
v = camera_params["fy"] * (-cam_coords[:, 1] / depth) + camera_params["cy"]
v = camera_params["height"] - 1 - v
pixels = np.stack([u, v], axis=1)
valid = depth > 1e-6
valid &= np.isfinite(u) & np.isfinite(v)
valid &= (u >= 0) & (u <= camera_params["width"] - 1)
valid &= (v >= 0) & (v <= camera_params["height"] - 1)
pixels[~valid] = np.nan
return pixels
def _create_frame_with_trajectory(
image,
trajectory_pixels,
frame_idx,
total_traj_frames,
figsize,
):
fig, ax = plt.subplots(figsize=figsize)
rgb_image = image[..., ::-1]
ax.imshow(rgb_image)
height, width = rgb_image.shape[:2]
ax.set_xlim([0, width - 1])
ax.set_ylim([height - 1, 0])
ax.set_axis_off()
# ax.set_title(f"{benchmark_name} | Task {task_index}", fontsize=10)
if trajectory_pixels is not None and trajectory_pixels.size > 0:
executed_range = min(frame_idx + 1, total_traj_frames)
ax.plot(
trajectory_pixels[:, 0],
trajectory_pixels[:, 1],
color="#b0bec5",
linestyle="--",
linewidth=1.5,
label="Full path",
)
if executed_range > 0:
executed_points = trajectory_pixels[:executed_range].copy()
executed_valid = np.isfinite(executed_points[:, 0]) & np.isfinite(executed_points[:, 1])
executed_points[~executed_valid] = np.nan
ax.plot(
executed_points[:, 0],
executed_points[:, 1],
color="#1976d2",
linewidth=2,
label="Executed",
)
valid = np.isfinite(trajectory_pixels[:, 0]) & np.isfinite(trajectory_pixels[:, 1])
valid_indices = np.flatnonzero(valid)
if valid_indices.size > 0:
start_point = trajectory_pixels[valid_indices[0]]
goal_point = trajectory_pixels[valid_indices[-1]]
ax.scatter(start_point[0], start_point[1], color="#2e7d32", s=40, label="Start")
ax.scatter(goal_point[0], goal_point[1], color="#c62828", s=40, label="Goal")
if total_traj_frames > 0:
current_idx = min(frame_idx, total_traj_frames - 1)
current_point = trajectory_pixels[current_idx]
if np.all(np.isfinite(current_point)):
ax.scatter(
current_point[0],
current_point[1],
color="#ffb300",
s=60,
label="Current",
edgecolors="black",
linewidths=0.5,
)
fig.tight_layout(pad=0.2)
frame = _figure_to_rgb(fig)
plt.close(fig)
return frame
def _create_frame_with_arrow_traj(
image,
trajectory_pixels,
figsize
):
fig, ax = plt.subplots(figsize=figsize)
rgb_image = image[..., ::-1]
ax.imshow(rgb_image)
height, width = rgb_image.shape[:2]
ax.set_xlim([0, width - 1])
ax.set_ylim([height - 1, 0])
ax.set_axis_off()
# ax.plot(
# trajectory_pixels[:, 0],
# trajectory_pixels[:, 1],
# color="#1976d2",
# linewidth=2,
# label="Executed",
# )
arrowline(ax, trajectory_pixels[:, 0], trajectory_pixels[:, 1],
style='equal_d', interval=16, arrow_size=1.5, color='b')
fig.tight_layout(pad=0.02)
frame = _figure_to_rgb(fig)
plt.close(fig)
return frame
def generate_task_video(
task_index=0,
benchmark_name="libero_10",
output_video="task_demo.mp4",
include_trajectory=True,
camera_name="agentview",
fps=60,
figsize=(4, 4),
only_image=False,
):
"""
Generate a demo video for the specified task, overlaying the end-effector trajectory in image space.
Args:
task_index: Index of the task within the benchmark.
benchmark_name: Name of the benchmark whose dataset should be used.
output_video: Output video filename (MP4).
include_trajectory: If True, render the end-effector trajectory onto the video.
camera_name: Name of the MuJoCo camera that matches the rendered RGB frames.
fps: Frames per second for the output video.
figsize: Matplotlib figure size used when drawing frames with the trajectory overlay.
Returns:
HTML widget embedding the generated video for display in notebooks.
"""
datasets_path = "/home/zechen/Data/Robo/LIBERO_Regen"
benchmark_dict = benchmark.get_benchmark_dict()
if benchmark_name not in benchmark_dict:
raise KeyError(
f"Unknown benchmark '{benchmark_name}'. Available keys: {list(benchmark_dict.keys())}"
)
benchmark_instance = benchmark_dict[benchmark_name]()
num_tasks = benchmark_instance.get_num_tasks()
if task_index >= num_tasks:
raise ValueError(
f"Task index {task_index} out of range. Benchmark has {num_tasks} tasks."
)
demo_file = os.path.join(
datasets_path,
benchmark_instance.get_task_demonstration(task_index),
)
print("Task name: ", benchmark_instance.get_task_demonstration(task_index))
if not os.path.exists(demo_file):
raise FileNotFoundError(f"Demo file not found: {demo_file}")
demo_file_for_xml = os.path.join(
"/home/zechen/Data/Robo/OriginalLIBERO",
benchmark_instance.get_task_demonstration(task_index),
)
if not os.path.exists(demo_file_for_xml):
raise FileNotFoundError(f"Demo file for xml not found: {demo_file_for_xml}")
print(f"Using demo file: {demo_file}")
images, eef_positions, model_xml = _load_demo_data(demo_file, demo_file_for_xml, include_trajectory)
has_traj_data = include_trajectory and eef_positions is not None
trajectory_pixels = None
if has_traj_data:
if model_xml is None:
print("Model XML not found in dataset; skipping trajectory overlay.")
else:
try:
image_height, image_width = images.shape[1:3]
camera_params = _get_camera_parameters(model_xml, camera_name, (image_height, image_width))
trajectory_pixels = _project_points_to_image(eef_positions, camera_params)
if trajectory_pixels.size == 0:
trajectory_pixels = None
except Exception as exc:
print(f"Failed to project trajectory for camera '{camera_name}': {exc}")
trajectory_pixels = None
has_traj_overlay = trajectory_pixels is not None
traj_frame_count = trajectory_pixels.shape[0] if has_traj_overlay else 0
print("Length of trajectory_pixels: ", traj_frame_count)
if only_image:
frame = _create_frame_with_arrow_traj(
images[0],
trajectory_pixels[:8],
figsize,
)
frame = frame[..., ::-1]
Image.fromarray(frame).transpose(Image.FLIP_TOP_BOTTOM).save(output_video.replace(".mp4", ".png"))
else:
os.makedirs(os.path.dirname(output_video) or ".", exist_ok=True)
with imageio.get_writer(output_video, fps=fps) as video_writer:
for frame_idx, image in enumerate(images):
if has_traj_overlay:
frame = _create_frame_with_trajectory(
image,
trajectory_pixels,
frame_idx,
traj_frame_count,
figsize,
)
else:
frame = image[..., ::-1] # BGR to RGB
frame = np.asarray(Image.fromarray(frame).transpose(Image.FLIP_TOP_BOTTOM))
video_writer.append_data(frame)
print(f"Video saved as: {output_video}")
if __name__ == "__main__":
# Generate video for the first task with trajectory overlay
generate_task_video(
task_index=1,
benchmark_name="libero_object",
output_video="task_1_demo_with_traj.mp4",
include_trajectory=True,
camera_name="agentview",
only_image=True
)