| import gradio as gr | |
| import clip | |
| import torch | |
| import numpy as np | |
| import tempfile | |
| import models.vqvae as vqvae | |
| import options.option_transformer as option_trans | |
| from utils.motion_process import recover_from_ric | |
| import models.t2m_trans as trans | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import base64 | |
| from PIL import Image | |
| import io | |
| import mpl_toolkits.mplot3d as p3 | |
| from matplotlib.animation import FuncAnimation, PillowWriter | |
| from mpl_toolkits.mplot3d import Axes3D | |
| from mpl_toolkits.mplot3d.art3d import Poly3DCollection | |
| from visualization.plot_3d_global import draw_to_batch | |
| import imageio | |
| import sys | |
| import os | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| args = option_trans.get_args_parser() | |
| args.dataname = 't2m' | |
| args.down_t = 2 | |
| args.depth = 3 | |
| args.block_size = 51 | |
| vqvae_model = vqvae.HumanVQVAE(args).to(device) | |
| transformer_model = trans.Text2Motion_Transformer(num_vq=args.nb_code, | |
| embed_dim=1024, | |
| clip_dim=args.clip_dim, | |
| block_size=args.block_size, | |
| num_layers=9, | |
| n_head=16, | |
| drop_out_rate=args.drop_out_rate, | |
| fc_rate=args.ff_rate).to(device) | |
| vqvae_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device) | |
| transformed_vqvae_state_dict = {k.replace("vqvae.", ""): v for k, v in vqvae_checkpoint['net'].items()} | |
| vqvae_model.load_state_dict(transformed_vqvae_state_dict, strict=False) | |
| transformer_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_best_fid.pth", map_location=device) | |
| transformed_transformer_state_dict = {k.replace("trans.", ""): v for k, v in transformer_checkpoint['trans'].items()} | |
| transformer_model.load_state_dict(transformed_transformer_state_dict, strict=False) | |
| vqvae_model.eval() | |
| transformer_model.eval() | |
| mean = torch.from_numpy(np.load('output/Mean.npy', allow_pickle=True)).to(device) | |
| std = torch.from_numpy(np.load('output/Std.npy', allow_pickle=True)).to(device) | |
| def generate_motion(text, vqvae_model, transformer_model): | |
| clip_text = [text] | |
| text_encoded = clip.tokenize(clip_text, truncate=True).to(device) | |
| with torch.no_grad(): | |
| clip_model, _ = clip.load("ViT-B/32", device=device) | |
| clip_model.eval() | |
| clip_features = clip_model.encode_text(text_encoded).float() | |
| motion_indices = transformer_model.sample(clip_features, False) | |
| pred_pose = vqvae_model.forward_decoder(motion_indices) | |
| pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22) | |
| return pred_xyz.cpu().numpy().reshape(-1, 22, 3) | |
| def infer(text): | |
| print("Received text:", text) | |
| try: | |
| motion_data = generate_motion(text, vqvae_model, transformer_model) | |
| if motion_data.size == 0: | |
| raise ValueError("Generated motion data is empty") | |
| except Exception as e: | |
| print(f"Failed during motion generation: {str(e)}") | |
| return "Error in motion generation." | |
| try: | |
| gif_data = draw_to_batch([motion_data], [text], None) | |
| if gif_data: | |
| gif_filename = "output.gif" | |
| gif_path = os.path.join(tempfile.gettempdir(), gif_filename) | |
| with open(gif_path, "wb") as gif_file: | |
| gif_file.write(gif_data) | |
| print("GIF successfully saved to:", gif_path) | |
| return gif_path | |
| else: | |
| print("Failed to generate GIF data.") | |
| return "Error generating GIF. Please try again." | |
| except Exception as e: | |
| print(f"Error generating GIF: {str(e)}") | |
| return "Error generating GIF. Please try again." | |
| css = ".container { max-width: 800px; margin: auto; }" | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("## 3D Human Motion Generation") | |
| with gr.Row(): | |
| text_input = gr.Textbox(label="Enter the human action to generate", placeholder="Enter text description for the action here...", show_label=True) | |
| submit_button = gr.Button("Generate Motion") | |
| output_image = gr.Image(label="Generated Human Motion", type="filepath", show_label=False) | |
| submit_button.click( | |
| fn=infer, | |
| inputs=[text_input], | |
| outputs=[output_image] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |