TRME / app.py
rsax's picture
Update app.py
2472db7 verified
import gradio as gr
import clip
import torch
import numpy as np
import tempfile
import models.vqvae as vqvae
import options.option_transformer as option_trans
from utils.motion_process import recover_from_ric
import models.t2m_trans as trans
import matplotlib.pyplot as plt
import matplotlib
import base64
from PIL import Image
import io
import mpl_toolkits.mplot3d as p3
from matplotlib.animation import FuncAnimation, PillowWriter
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from visualization.plot_3d_global import draw_to_batch
import imageio
import sys
import os
import warnings
warnings.filterwarnings('ignore')
device = "cuda" if torch.cuda.is_available() else "cpu"
args = option_trans.get_args_parser()
args.dataname = 't2m'
args.down_t = 2
args.depth = 3
args.block_size = 51
vqvae_model = vqvae.HumanVQVAE(args).to(device)
transformer_model = trans.Text2Motion_Transformer(num_vq=args.nb_code,
embed_dim=1024,
clip_dim=args.clip_dim,
block_size=args.block_size,
num_layers=9,
n_head=16,
drop_out_rate=args.drop_out_rate,
fc_rate=args.ff_rate).to(device)
vqvae_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device)
transformed_vqvae_state_dict = {k.replace("vqvae.", ""): v for k, v in vqvae_checkpoint['net'].items()}
vqvae_model.load_state_dict(transformed_vqvae_state_dict, strict=False)
transformer_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_best_fid.pth", map_location=device)
transformed_transformer_state_dict = {k.replace("trans.", ""): v for k, v in transformer_checkpoint['trans'].items()}
transformer_model.load_state_dict(transformed_transformer_state_dict, strict=False)
vqvae_model.eval()
transformer_model.eval()
mean = torch.from_numpy(np.load('output/Mean.npy', allow_pickle=True)).to(device)
std = torch.from_numpy(np.load('output/Std.npy', allow_pickle=True)).to(device)
def generate_motion(text, vqvae_model, transformer_model):
clip_text = [text]
text_encoded = clip.tokenize(clip_text, truncate=True).to(device)
with torch.no_grad():
clip_model, _ = clip.load("ViT-B/32", device=device)
clip_model.eval()
clip_features = clip_model.encode_text(text_encoded).float()
motion_indices = transformer_model.sample(clip_features, False)
pred_pose = vqvae_model.forward_decoder(motion_indices)
pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22)
return pred_xyz.cpu().numpy().reshape(-1, 22, 3)
def infer(text):
print("Received text:", text)
try:
motion_data = generate_motion(text, vqvae_model, transformer_model)
if motion_data.size == 0:
raise ValueError("Generated motion data is empty")
except Exception as e:
print(f"Failed during motion generation: {str(e)}")
return "Error in motion generation."
try:
gif_data = draw_to_batch([motion_data], [text], None)
if gif_data:
gif_filename = "output.gif"
gif_path = os.path.join(tempfile.gettempdir(), gif_filename)
with open(gif_path, "wb") as gif_file:
gif_file.write(gif_data)
print("GIF successfully saved to:", gif_path)
return gif_path
else:
print("Failed to generate GIF data.")
return "Error generating GIF. Please try again."
except Exception as e:
print(f"Error generating GIF: {str(e)}")
return "Error generating GIF. Please try again."
css = ".container { max-width: 800px; margin: auto; }"
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("## 3D Human Motion Generation")
with gr.Row():
text_input = gr.Textbox(label="Enter the human action to generate", placeholder="Enter text description for the action here...", show_label=True)
submit_button = gr.Button("Generate Motion")
output_image = gr.Image(label="Generated Human Motion", type="filepath", show_label=False)
submit_button.click(
fn=infer,
inputs=[text_input],
outputs=[output_image]
)
if __name__ == "__main__":
demo.launch()