Spaces:
Runtime error
Runtime error
| import pathlib | |
| import gradio as gr | |
| import imageio | |
| from yt_dlp import YoutubeDL | |
| import cv2 | |
| import torch | |
| import torchvision | |
| import tempfile | |
| import numpy as np | |
| import smplx | |
| import pyrender | |
| import trimesh | |
| import trimesh.transformations as tra | |
| from dataclasses import dataclass | |
| from typing import List, Dict, Any | |
| import SkeletonDiffusion_Demo.plot_several_meshes as plot_several_meshes | |
| import SkeletonDiffusion_Demo.combine_video as combine | |
| import os | |
| # fix no display problem | |
| os.environ['PYOPENGL_PLATFORM'] = 'egl' | |
| os.system('export IMAGEMAGICK_BINARY=/home/stud/yaji/storage/user/yaji/NonisotropicSkeletonDiffusion/magick') | |
| # Load your torchscript model (ensure the path is correct) | |
| nlf_model = torch.jit.load("./models/nlf_l_multi.torchscript").cuda().eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load SMPL model (ensure the pkl file is correct) | |
| # This code assumes the SMPL_NEUTRAL.pkl model has its canonical orientation with the face along +X and up along +Y. | |
| smpl_model = smplx.create("./models/SMPL_NEUTRAL.pkl", model_type="smpl", gender="neutral").to( | |
| device | |
| ) | |
| smpl_params = [] | |
| DESCRIPTION = "# SMPL Visualization Demo" | |
| FRAME_LIMIT = 100 | |
| class SMPLParams: | |
| """ | |
| Data structure to hold SMPL parameters. | |
| """ | |
| global_orient: torch.Tensor | |
| body_pose: torch.Tensor | |
| betas: torch.Tensor | |
| transl: torch.Tensor | |
| def handle_video_input(video_file, youtube_url): | |
| """Handles the video input: either a local file or a YouTube URL.""" | |
| if youtube_url: | |
| ydl_opts = { | |
| "format": "best", | |
| "outtmpl": "downloads/%(title)s.%(ext)s", | |
| "cookies": "cookies/cookies.txt", | |
| } | |
| with YoutubeDL(ydl_opts) as ydl: | |
| info = ydl.extract_info(youtube_url, download=True) | |
| video_path = ydl.prepare_filename(info) | |
| return video_path | |
| elif video_file: | |
| return video_file | |
| return None | |
| def correct_vertices(vertices): | |
| """ | |
| Corrects the SMPL vertices to convert from the SMPL coordinate system to the renderer's coordinate system. | |
| This version applies a rotation about the Y-axis by -90 degrees so that the original +X axis (assumed to be the | |
| face direction) is transformed to the -Z axis (i.e., the model will face the camera if the camera is placed at | |
| [0, 0, distance] looking along -Z). The up direction (Y axis) remains unchanged. | |
| """ | |
| angle = np.radians(180) | |
| # Build a 4x4 rotation matrix around Y | |
| R = tra.rotation_matrix(angle, [1, 0, 0]) | |
| # Convert vertices to homogeneous coordinates (assumes vertices shape is (1, N, 3)) | |
| vertices_homo = np.hstack( | |
| [vertices[0], np.ones((vertices[0].shape[0], 1))] | |
| ) # shape: (N, 4) | |
| vertices_corrected = (R @ vertices_homo.T).T # Apply rotation | |
| # Reshape back to (1, N, 3) | |
| return vertices_corrected[:, :3].reshape(1, -1, 3) | |
| def render_smpl(vertices, width, height): | |
| """ | |
| Renders the SMPL 3D model using PyRender. | |
| - Applies a coordinate correction to the SMPL vertices. | |
| - Builds a trimesh object and adds it to a pyrender scene. | |
| - Sets up an orthographic camera with a given pose. | |
| - Renders the scene offscreen. | |
| - Converts the output image from RGB to BGR (to match OpenCV color format). | |
| """ | |
| # Correct the vertices using the new rotation | |
| vertices_corrected = correct_vertices(vertices) | |
| # Create a trimesh mesh object using the corrected vertices and the SMPL faces | |
| mesh = trimesh.Trimesh(vertices_corrected[0], smpl_model.faces) | |
| scene = pyrender.Scene( | |
| bg_color=[1.0, 1.0, 1.0, 0.9] | |
| ) # Background color: white (RGB) | |
| # remove background | |
| mesh_node = pyrender.Mesh.from_trimesh(mesh) | |
| scene.add(mesh_node) | |
| # Set up an orthographic camera. Here we place the camera at [0, 0, distance] looking toward the origin. | |
| camera = pyrender.OrthographicCamera(xmag=1.0, ymag=1.0) | |
| camera_pose = np.eye(4) | |
| distance = 5.0 | |
| camera_pose[:3, 3] = [0, 0, distance] | |
| scene.add(camera, pose=camera_pose) | |
| # Render the scene using an offscreen renderer | |
| renderer = pyrender.OffscreenRenderer(width, height) | |
| color, _ = renderer.render(scene) | |
| # Convert from RGB to BGR for OpenCV compatibility | |
| color_bgr = cv2.cvtColor(color, cv2.COLOR_RGB2BGR) | |
| return color_bgr | |
| def process_video(video_file, youtube_url): | |
| """ | |
| Processes the input video and outputs a GIF: | |
| - Obtains the video path. | |
| - Reads frames from the video. | |
| - Runs SMPL detection on each frame. | |
| - Generates the SMPL mesh using the SMPL model. | |
| - Renders the SMPL mesh. | |
| - Blends the rendered SMPL visualization with the original frame. | |
| - Saves the processed frames as a GIF. | |
| """ | |
| input_path = handle_video_input(video_file, youtube_url) | |
| if not input_path: | |
| return None | |
| output_path = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name | |
| cap = cv2.VideoCapture(input_path) | |
| frame_count = 0 | |
| smpl_params_list = [] | |
| rendered_smpl = None | |
| frames = [] | |
| while cap.isOpened() and frame_count < FRAME_LIMIT: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| image_tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).int().to(device) | |
| with torch.inference_mode(): | |
| pred = nlf_model.detect_smpl_batched(image_tensor.unsqueeze(0)) | |
| pose_params = pred["pose"][0].cpu().numpy() | |
| betas = pred["betas"][0].cpu().numpy() | |
| transl = pred["trans"][0].cpu().numpy() | |
| if pose_params.shape[0] == 0 and rendered_smpl is None: | |
| print(f"No SMPL detected in frame {frame_count}") | |
| frames.append(frame_rgb) | |
| continue | |
| if pose_params.shape[0] > 0: | |
| smpl_param = SMPLParams( | |
| global_orient=torch.tensor(pose_params[:, :3]).to(device), | |
| body_pose=torch.tensor(pose_params[:, 3:]).to(device), | |
| betas=torch.tensor(betas).to(device), | |
| transl=torch.tensor(transl).to(device), | |
| ) | |
| output_smpl = smpl_model( | |
| global_orient=torch.tensor(pose_params[:, :3]).to(device), | |
| body_pose=torch.tensor(pose_params[:, 3:]).to(device), | |
| betas=torch.tensor(betas).to(device), | |
| transl=torch.tensor(transl).to(device), | |
| ) | |
| vertices = output_smpl.vertices.detach().cpu().numpy() | |
| rendered_smpl = render_smpl(vertices, frame.shape[1], frame.shape[0]) | |
| smpl_params_list.append(smpl_param) | |
| alpha = 0.6 | |
| blended = cv2.addWeighted(frame_rgb, 1 - alpha, rendered_smpl, alpha, 0) | |
| frames.append(blended) | |
| frame_count += 1 | |
| cap.release() | |
| # Serialize SMPL parameters into a JSON-compatible format | |
| smpl_params_serialized = [ | |
| { | |
| "global_orient": p.global_orient.tolist(), | |
| "body_pose": p.body_pose.tolist(), | |
| "betas": p.betas.tolist(), | |
| "transl": p.transl.tolist(), | |
| } | |
| for p in smpl_params_list | |
| ] | |
| # Save as GIF | |
| imageio.mimsave(output_path, frames, fps=30, loop=0) | |
| print(f"Output GIF saved to {output_path}") | |
| return output_path, smpl_params_serialized | |
| def generate_motion_video(smpl_params_json: List[Dict[str, Any]]): | |
| """ | |
| Generate a motion video from the given SMPL parameters. | |
| """ | |
| # Deserialize JSON back into SMPLParams objects | |
| smpl_params_list = [ | |
| SMPLParams( | |
| global_orient=torch.tensor(p["global_orient"]), | |
| body_pose=torch.tensor(p["body_pose"]), | |
| betas=torch.tensor(p["betas"]), | |
| transl=torch.tensor(p["transl"]), | |
| ) | |
| for p in smpl_params_json | |
| ] | |
| # TODO: Using the SMPL parameters obtained from video, generate motion and save as .obj format, rank | |
| # and find the closest to the ground truth and the farthest from the ground truth, just like the samples. | |
| sample_obj_path = "./9622_GRAB/" | |
| plot_several_meshes.main(sample_obj_path) | |
| return combine.combine_video(sample_obj_path) | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Tabs(): | |
| with gr.Tab("Video Processing"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_video = gr.Video(label="Input Video") | |
| youtube_url = gr.Textbox(label="YouTube URL") | |
| process_btn = gr.Button("Process Video") | |
| with gr.Column(): | |
| # output_video = gr.Video(label="SMPL Visualization") | |
| video_to_smpl = gr.Image(label="SMPL Visualization") | |
| # save smpl params in gradio | |
| obs_smpl_params = gr.JSON(label="SMPL Parameters") | |
| obs_smpl_params.visible = False | |
| generate_btn = gr.Button("Generate Motion") | |
| output_video = gr.Image(label="Generated Motion") | |
| gr.Examples( | |
| examples=sorted(pathlib.Path("downloads").glob("*.mp4")), | |
| inputs=input_video, | |
| outputs=video_to_smpl, | |
| cache_examples=False, | |
| ) | |
| process_btn.click( | |
| fn=process_video, | |
| inputs=[input_video, youtube_url], | |
| outputs=[video_to_smpl, obs_smpl_params], | |
| ) | |
| generate_btn.click( | |
| fn=generate_motion_video, inputs=[obs_smpl_params], outputs=[output_video] | |
| ) | |
| demo.launch(server_name="0.0.0.0", share=True) | |