File size: 4,582 Bytes
ab3bbe5 02c7eaf ab3bbe5 af907cd e8ed1c5 02c7eaf 7a1fa76 02c7eaf af907cd 146fd37 a2eaaae 155727b a2eaaae 146fd37 0a9f283 af907cd 0a9f283 847857a 5673fea 02c7eaf ab3bbe5 4c70558 ec8b3ca c213aac ec8b3ca 4c70558 0a9f283 bd8219d 0a9f283 847857a ad7134f 0a9f283 2d2b00b 847857a f6c4439 0a9f283 af907cd 02c7eaf 847857a 02c7eaf 6e02822 8784728 847857a b57fa67 02c7eaf f7bdc17 847857a d6f83c2 02c7eaf 2472db7 97e1707 8e270d0 d00e6c7 a0d9d81 509ef71 2472db7 509ef71 a0d9d81 d00e6c7 7bc767a 02c7eaf 509ef71 5673fea 0d780bd 5673fea 509ef71 5673fea 509ef71 af907cd e3e6e0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
import clip
import torch
import numpy as np
import tempfile
import models.vqvae as vqvae
import options.option_transformer as option_trans
from utils.motion_process import recover_from_ric
import models.t2m_trans as trans
import matplotlib.pyplot as plt
import matplotlib
import base64
from PIL import Image
import io
import mpl_toolkits.mplot3d as p3
from matplotlib.animation import FuncAnimation, PillowWriter
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from visualization.plot_3d_global import draw_to_batch
import imageio
import sys
import os
import warnings
warnings.filterwarnings('ignore')
device = "cuda" if torch.cuda.is_available() else "cpu"
args = option_trans.get_args_parser()
args.dataname = 't2m'
args.down_t = 2
args.depth = 3
args.block_size = 51
vqvae_model = vqvae.HumanVQVAE(args).to(device)
transformer_model = trans.Text2Motion_Transformer(num_vq=args.nb_code,
embed_dim=1024,
clip_dim=args.clip_dim,
block_size=args.block_size,
num_layers=9,
n_head=16,
drop_out_rate=args.drop_out_rate,
fc_rate=args.ff_rate).to(device)
vqvae_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device)
transformed_vqvae_state_dict = {k.replace("vqvae.", ""): v for k, v in vqvae_checkpoint['net'].items()}
vqvae_model.load_state_dict(transformed_vqvae_state_dict, strict=False)
transformer_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_best_fid.pth", map_location=device)
transformed_transformer_state_dict = {k.replace("trans.", ""): v for k, v in transformer_checkpoint['trans'].items()}
transformer_model.load_state_dict(transformed_transformer_state_dict, strict=False)
vqvae_model.eval()
transformer_model.eval()
mean = torch.from_numpy(np.load('output/Mean.npy', allow_pickle=True)).to(device)
std = torch.from_numpy(np.load('output/Std.npy', allow_pickle=True)).to(device)
def generate_motion(text, vqvae_model, transformer_model):
clip_text = [text]
text_encoded = clip.tokenize(clip_text, truncate=True).to(device)
with torch.no_grad():
clip_model, _ = clip.load("ViT-B/32", device=device)
clip_model.eval()
clip_features = clip_model.encode_text(text_encoded).float()
motion_indices = transformer_model.sample(clip_features, False)
pred_pose = vqvae_model.forward_decoder(motion_indices)
pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22)
return pred_xyz.cpu().numpy().reshape(-1, 22, 3)
def infer(text):
print("Received text:", text)
try:
motion_data = generate_motion(text, vqvae_model, transformer_model)
if motion_data.size == 0:
raise ValueError("Generated motion data is empty")
except Exception as e:
print(f"Failed during motion generation: {str(e)}")
return "Error in motion generation."
try:
gif_data = draw_to_batch([motion_data], [text], None)
if gif_data:
gif_filename = "output.gif"
gif_path = os.path.join(tempfile.gettempdir(), gif_filename)
with open(gif_path, "wb") as gif_file:
gif_file.write(gif_data)
print("GIF successfully saved to:", gif_path)
return gif_path
else:
print("Failed to generate GIF data.")
return "Error generating GIF. Please try again."
except Exception as e:
print(f"Error generating GIF: {str(e)}")
return "Error generating GIF. Please try again."
css = ".container { max-width: 800px; margin: auto; }"
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("## 3D Human Motion Generation")
with gr.Row():
text_input = gr.Textbox(label="Enter the human action to generate", placeholder="Enter text description for the action here...", show_label=True)
submit_button = gr.Button("Generate Motion")
output_image = gr.Image(label="Generated Human Motion", type="filepath", show_label=False)
submit_button.click(
fn=infer,
inputs=[text_input],
outputs=[output_image]
)
if __name__ == "__main__":
demo.launch() |