import matplotlib matplotlib.use('Agg') # Non-interactive backend import matplotlib.pyplot as plt import matplotlib.animation as animation from mpl_toolkits.mplot3d import Axes3D import numpy as np from sklearn.decomposition import PCA from scipy.spatial.transform import Rotation as R def render_smpl(pose_data, output_path, fps=30): """ Render SMPL 3D pose data to a video file. Args: pose_data (np.ndarray): Shape (Frames, 24, 3) output_path (str): Path to save the MP4 video. fps (int): Frames per second. """ # SMPL kinematic tree (approximate for visualization) # 0: Pelvis # 1: L_Hip, 2: R_Hip, 3: Spine1 # 4: L_Knee, 5: R_Knee, 6: Spine2 # 7: L_Ankle, 8: R_Ankle, 9: Spine3 # 10: L_Foot, 11: R_Foot, 12: Neck # 13: L_Collar, 14: R_Collar, 15: Head # 16: L_Shoulder, 17: R_Shoulder # 18: L_Elbow, 19: R_Elbow # 20: L_Wrist, 21: R_Wrist # 22: L_Hand, 23: R_Hand # Connectivity for drawing bones connections = [ (0, 1), (0, 2), (0, 3), (1, 4), (2, 5), (3, 6), (4, 7), (5, 8), (6, 9), (7, 10), (8, 11), (9, 12), (9, 13), (9, 14), (12, 15), (13, 16), (14, 17), (16, 18), (17, 19), (18, 20), (19, 21), (20, 22), (21, 23) ] fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111, projection='3d') # --- Alignment & Centering --- # 1. Fit plane to feet to find ground orientation feet_indices = [10, 11] # L_Foot, R_Foot feet_points = pose_data[:, feet_indices, :].reshape(-1, 3) pca = PCA(n_components=3) pca.fit(feet_points) normal = pca.components_[2] # Component with least variance is the normal # Calculate Body Up vector (Pelvis to Head) to determine correct up direction # Pelvis is 0, Head is 15 pelvis_head_vector = pose_data[:, 15, :] - pose_data[:, 0, :] avg_body_up = np.mean(pelvis_head_vector, axis=0) # Ensure normal points in same direction as body up if np.dot(normal, avg_body_up) < 0: normal = -normal # 2. Compute rotation to align normal to Z-axis [0, 0, 1] target_normal = np.array([0, 0, 1]) # Use scipy to find rotation # We want R such that R * normal = target_normal # align_vectors finds rotation that maps vectors_b to vectors_a. # So we map normal (b) to target (a). rot, rssd = R.align_vectors([target_normal], [normal]) rot_matrix = rot.as_matrix() # Apply rotation to all points # Points are (Frames, Joints, 3). Flatten for transform original_shape = pose_data.shape flat_data = pose_data.reshape(-1, 3) # Apply rotation: (R @ v.T).T = v @ R.T # Scipy apply: rot.apply(vectors) handles the broadcasting pose_data_rotated = rot.apply(flat_data) pose_data = pose_data_rotated.reshape(original_shape) # 3. Center trajectory # Center X/Y at 0 all_x = pose_data[:, :, 0] all_y = pose_data[:, :, 1] all_z = pose_data[:, :, 2] # Mean of all points as center (or could use root joint mean) center_x = np.mean(all_x) center_y = np.mean(all_y) pose_data[:, :, 0] -= center_x pose_data[:, :, 1] -= center_y # Shift Z so min is 0 (Ground level) min_z = np.min(all_z) pose_data[:, :, 2] -= min_z # Update bounds variables for plotting all_x = pose_data[:, :, 0] all_y = pose_data[:, :, 1] all_z = pose_data[:, :, 2] mid_x = (np.min(all_x) + np.max(all_x)) / 2 mid_y = (np.min(all_y) + np.max(all_y)) / 2 mid_z = (np.min(all_z) + np.max(all_z)) / 2 max_range = np.array([np.ptp(all_x), np.ptp(all_y), np.ptp(all_z)]).max() / 2.0 # Recalculate bounds after shift all_x = pose_data[:, :, 0] all_y = pose_data[:, :, 1] all_z = pose_data[:, :, 2] # Use (min+max)/2 for center to ensure bounding box is centered mid_x = (np.min(all_x) + np.max(all_x)) / 2 mid_y = (np.min(all_y) + np.max(all_y)) / 2 mid_z = (np.min(all_z) + np.max(all_z)) / 2 # Dynamic ground plane bounds covering all trajectory padding = 1.0 # Increase padding gp_min_x = np.min(all_x) - padding gp_max_x = np.max(all_x) + padding gp_min_y = np.min(all_y) - padding gp_max_y = np.max(all_y) + padding def update(frame): ax.clear() ax.set_axis_off() # Transparent gray ground plane at z=0 x = np.linspace(gp_min_x, gp_max_x, 2) y = np.linspace(gp_min_y, gp_max_y, 2) X, Y = np.meshgrid(x, y) Z = np.zeros_like(X) # Ground at z=0 ax.plot_surface(X, Y, Z, color='gray', alpha=0.2, shade=False) current_pose = pose_data[frame] # Scatter points for joints ax.scatter(current_pose[:, 0], current_pose[:, 1], current_pose[:, 2], c='blue', s=20) # Draw bones for start, end in connections: xs = [current_pose[start, 0], current_pose[end, 0]] ys = [current_pose[start, 1], current_pose[end, 1]] zs = [current_pose[start, 2], current_pose[end, 2]] ax.plot(xs, ys, zs, c='red') # Set limits ax.set_xlim(mid_x - max_range, mid_x + max_range) ax.set_ylim(mid_y - max_range, mid_y + max_range) ax.set_zlim(mid_z - max_range, mid_z + max_range) # ax.set_xlabel('X') # ax.set_ylabel('Y') # ax.set_zlabel('Z') ax.set_title(f"Frame {frame}") ani = animation.FuncAnimation(fig, update, frames=len(pose_data), interval=1000/fps) # Save using ffmpeg writer print(f"Saving video to {output_path}...") try: if animation.writers.is_available('ffmpeg'): writer = animation.FFMpegWriter(fps=fps, bitrate=5000) ani.save(output_path, writer=writer) else: raise RuntimeError("ffmpeg not available") except Exception as e: print(f"ffmpeg failed or not found ({e}). Using OpenCV fallback...") try: import cv2 plt.close(fig) # Close the animation fig # Re-setup figure for opencv loop fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111, projection='3d') # Figure size in pixels approx (10*100 = 1000x1000 usually dpi=100) fig.canvas.draw() width, height = fig.canvas.get_width_height() # Setup video writer - Try H.264 (avc1) first fourcc = cv2.VideoWriter_fourcc(*'avc1') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) if not out.isOpened(): print("avc1 failed. Trying h264...") fourcc = cv2.VideoWriter_fourcc(*'h264') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) if not out.isOpened(): print("h264 failed. Trying vp80...") fourcc = cv2.VideoWriter_fourcc(*'vp80') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) if not out.isOpened(): print("vp80 failed. Trying mp4v (less compatible)...") fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) if not out.isOpened(): raise RuntimeError("Failed to open VideoWriter with any compatible codec.") print("Rendering frames directly to OpenCV VideoWriter...") for frame in range(len(pose_data)): update(frame) fig.canvas.draw() # Convert canvas to image # Check for buffer_rgba support (matplotlib 3.x) try: img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) img = img.reshape(height, width, 4)[:, :, :3] # RGBA -> RGB except AttributeError: # Fallback for older matplotlib or different backend img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) img = img.reshape(height, width, 3) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) out.write(img) out.release() plt.close(fig) print("OpenCV fallback rendering complete.") except Exception as cv_e: print(f"OpenCV fallback also failed: {cv_e}") raise cv_e return output_path