physctrl / src /utils /visualization.py
chenwang's picture
update
4724018
import numpy as np
import matplotlib.pyplot as plt
import os
import json
from PIL import Image
from io import BytesIO
import sys, pathlib, html
def camera_view_dir_y(elev, azim):
"""Unit vector for camera direction with Y as 'up'."""
elev_rad = np.radians(elev)
azim_rad = np.radians(azim)
dx = np.sin(azim_rad) * np.cos(elev_rad)
dy = np.sin(elev_rad)
dz = np.cos(azim_rad) * np.cos(elev_rad)
return np.array([dx, dy, dz])
def compute_depth(points, elev, azim):
"""Project points onto the camera's view direction (Y as 'up')."""
view_dir = camera_view_dir_y(elev, azim)
# depth = p · view_dir
depth = points @ view_dir
return depth
def save_pointcloud_video(points_pred, points_gt, save_path, drag_mask=None, fps=48, point_color='blue', vis_flag=''):
# Configure the figure
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111, projection='3d')
ax.set_box_aspect([1, 1, 1])
if 'objaverse' in vis_flag:
x_max, y_max, z_max = 1.5, 1.5, 1.5
x_min, y_min, z_min = -1.5, -1.5, -1.5
else:
x_max, y_max, z_max = 1, 1, 1
x_min, y_min, z_min = -1, -1, -1
if 'shapenet' or 'objaverse' in vis_flag:
elev, azim = 45, 225
ax.view_init(elev=elev, azim=azim, vertical_axis='y')
# Plot and save each frame
cmap_1 = plt.colormaps.get_cmap('cool')
cmap_2 = plt.colormaps.get_cmap('autumn')
frames_pred = []
frames_gt = []
if drag_mask is not None and drag_mask.sum() == 0:
drag_mask = None
for label, points in [('pred', points_pred), ('gt', points_gt)]:
for i in range(points.shape[0]):
frame_points = points[i]
if drag_mask is not None and not (drag_mask == True).all():
drag_mask = (drag_mask == 1.0)
drag_points = frame_points[drag_mask]
frame_points = frame_points[~drag_mask]
depth_frame_points = compute_depth(frame_points, elev=elev, azim=azim)
depth_frame_points_normalized = (depth_frame_points - depth_frame_points.min()) / \
(depth_frame_points.max() - depth_frame_points.min())
color_frame_points = cmap_1(depth_frame_points_normalized)
if drag_mask is not None and not (drag_mask == True).all():
frame_points_drag = drag_points
depth_frame_points_drag = compute_depth(frame_points_drag, elev=elev, azim=azim)
depth_frame_points_drag_normalized = (depth_frame_points_drag - depth_frame_points_drag.min()) / \
(depth_frame_points_drag.max() - depth_frame_points_drag.min())
color_frame_points_drag = cmap_2(np.ones_like(depth_frame_points_drag_normalized) * -10)
all_points = np.concatenate([frame_points, frame_points_drag], axis=0)
all_color = np.concatenate([color_frame_points, color_frame_points_drag], axis=0)
else:
all_points, all_color = frame_points, color_frame_points
ax.clear()
ax.scatter(all_points[:, 0], all_points[:, 1], all_points[:, 2], c=all_color, s=1, depthshade=False)
ax.axis('off') # Turn off the axes
ax.grid(False) # Hide the grid
# Set equal aspect ratio
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_zlim(z_min, z_max)
# Adjust margins for tight layout
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
# Save frame
buf = BytesIO()
plt.savefig(buf, bbox_inches='tight', pad_inches=0.0, dpi=300)
buf.seek(0)
if label == 'pred':
frames_pred.append(Image.open(buf))
else:
frames_gt.append(Image.open(buf))
plt.close()
frames = []
for i in range(len(frames_pred)):
frame = np.concatenate([np.array(frames_pred[i]), np.array(frames_gt[i])], axis=1)
frames.append(Image.fromarray(frame))
frames[0].save(save_path, save_all=True, append_images=frames[1:], fps=fps, loop=0)
def save_pointcloud_json(points, output_json):
"""
Generate and save a point cloud sequence to a JSON file.
Parameters:
num_frames (int): Number of frames in the sequence.
num_points (int): Number of points per frame.
output_json (str): Output JSON file path.
"""
sequence = []
for frame in range(points.shape[0]):
# points = np.random.uniform(-1.5, 1.5, size=(num_points, 3)).tolist()
sequence.append({"frame": frame, "points": points[frame].tolist()})
# Save the sequence to a JSON file
with open(output_json, "w") as json_file:
json.dump({"sequence": sequence}, json_file)
def save_threejs_html(path1, path2, output_html):
html_template = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Three.js Point Cloud Animation</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/controls/OrbitControls.js"></script>
</head>
<body>
<script>
const scene = new THREE.Scene();
const camera = new THREE.PerspectiveCamera(75, window.innerWidth / window.innerHeight, 0.1, 1000);
const renderer = new THREE.WebGLRenderer();
renderer.setSize(window.innerWidth, window.innerHeight);
document.body.appendChild(renderer.domElement);
const controls = new THREE.OrbitControls(camera, renderer.domElement);
controls.enableDamping = true;
controls.dampingFactor = 0.05;
const pointCloudGrids = [];
// Load the point cloud sequence from a JSON file
function loadSequence(filePath, offset, callback) {{
const loader = new THREE.FileLoader();
loader.load(filePath, function (data) {{
const json = JSON.parse(data);
const sequence = json.sequence;
const pointClouds = [];
sequence.forEach(frameData => {{
const points = frameData.points.map(p => new THREE.Vector3(p[0], p[1], p[2]));
const geometry = new THREE.BufferGeometry().setFromPoints(points);
const material = new THREE.PointsMaterial({{ size: 0.01, color: 0xff0000 }});
const pointCloud = new THREE.Points(geometry, material);
pointCloud.visible = false;
pointCloud.position.x = offset;
scene.add(pointCloud);
pointClouds.push(pointCloud);
}});
pointCloudGrids.push(pointClouds);
console.log("Loaded point clouds:", filePath);
callback();
}}, undefined, function (err) {{
console.error("Error loading JSON file:", err);
}});
}}
camera.position.z = 5;
let currentFrame = 0;
function renderFrame(pointClouds, frameIndex) {{
pointClouds.forEach(pc => pc.visible = false);
if (pointClouds[frameIndex]) {{
pointClouds[frameIndex].visible = true;
}} else {{
console.error("No point cloud for frame:", frameIndex);
}}
}}
function animate() {{
requestAnimationFrame(animate);
controls.update();
renderer.render(scene, camera);
}}
animate();
function playSequence() {{
setInterval(() => {{
renderFrame(pointCloudGrids[0], currentFrame);
renderFrame(pointCloudGrids[1], currentFrame);
currentFrame = (currentFrame + 1) % pointCloudGrids[0].length;
}}, 50);
}}
loadSequence("{path1}", -2, () => {{
loadSequence("{path2}", 2, playSequence);
}});
</script>
</body>
</html>
"""
with open(output_html, 'w') as file:
file.write(html_template)
def vis_pcl_grid(point_clouds, save_path, grid_shape=(2, 2), bounds=[[0, 0, 0], [3, 3, 3]]):
"""
Visualizes multiple 3D point clouds in a grid.
Args:
point_clouds (list of np.ndarray): A list of point clouds, each as a (N, 3) numpy array.
grid_shape (tuple): Shape of the grid (rows, cols).
"""
rows, cols = grid_shape
fig = plt.figure(figsize=(5 * cols, 5 * rows))
for i, point_cloud in enumerate(point_clouds):
if i >= rows * cols:
break # Prevent overpopulation of the grid
ax = fig.add_subplot(rows, cols, i + 1, projection='3d')
ax.view_init(elev=0, azim=90)
ax.scatter(point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2], s=1)
if bounds == None:
ax.set_xlim()
else:
ax.set_xlim(bounds[0][0], bounds[1][0])
ax.set_ylim(bounds[0][1], bounds[1][1])
ax.set_zlim(bounds[0][2], bounds[1][2])
ax.set_title(f"Point Cloud {i + 1}")
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.grid(False)
plt.tight_layout()
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=300)
plt.close()
def generate_html_from_exts(vis_dir, output_html, exts):
gifs = sorted(pathlib.Path(vis_dir).glob(f'*.{exts}'))
rows = [
"<!DOCTYPE html>",
"<html lang='en'>",
"<head>",
" <meta charset='utf-8'>",
" <title>GIF gallery</title>",
" <style>",
" body{margin:0;font-family:sans-serif;background:#fafafa;color:#333}",
" .row{padding:16px;text-align:center;border-bottom:1px solid #eee;}",
" img{max-width:50%;height:auto;display:block;margin:0 auto;}",
" .caption{margin-top:8px;font-size:0.9rem;word-break:break-all;}",
" </style>",
"</head>",
"<body>",
]
# 4) one <div> per gif with caption
for gif in gifs:
name = gif.name # full file name (incl. .gif)
alt = html.escape(gif.stem) # alt text sans extension
rows.append(
f" <div class='row'>"
f"<img src='{name}' alt='{alt}'>"
f"<p class='caption'>{html.escape(name)}</p>"
f"</div>"
)
rows += ["</body>", "</html>"]
with open(output_html, 'w') as f:
f.write('\n'.join(rows))
# Example usage
if __name__ == "__main__":
# Generate random point clouds for demonstration
import h5py
# model_metas_1 = h5py.File(os.path.join('/mnt/lingjie_cache/sphere_traj/sphere/sphere_00001.h5'), 'r')
# model_metas_2 = h5py.File(os.path.join('/mnt/lingjie_cache/sphere_traj/sphere/sphere_00002.h5'), 'r')
model_metas_1 = h5py.File(os.path.join('/mnt/kostas-graid/datasets/chenwang/traj/sphere_traj_force0.1/sphere/sphere_13445.h5'), 'r')
model_metas_2 = h5py.File(os.path.join('/mnt/kostas-graid/datasets/chenwang/traj/sphere_traj_force0.1/sphere/sphere_00002.h5'), 'r')
model_pcls_1 = np.array(model_metas_1['x'])
model_pcls_2 = np.array(model_metas_2['x'])
num_drag_points = int(np.array(model_metas_1['drag_mask']).astype(np.float32).sum(axis=-1))
import torch
mask = torch.cat([torch.zeros(10000 - num_drag_points, dtype=torch.bool), torch.ones(num_drag_points, dtype=torch.bool)]).cpu().numpy()
# Visualize in a 2x3 grid
# vis_pcl_grid([model_pcls[0], model_pcls[30]], 'test.png', grid_shape=(1, 2))
save_pointcloud_video(model_pcls_1, model_pcls_2, 'debug/visualize_pcls.gif', drag_mask=mask)