File size: 4,582 Bytes
ab3bbe5
02c7eaf
ab3bbe5
af907cd
e8ed1c5
02c7eaf
7a1fa76
 
02c7eaf
af907cd
146fd37
a2eaaae
155727b
a2eaaae
146fd37
0a9f283
af907cd
0a9f283
847857a
5673fea
02c7eaf
 
 
 
 
ab3bbe5
 
4c70558
ec8b3ca
c213aac
 
 
 
ec8b3ca
4c70558
0a9f283
bd8219d
0a9f283
 
 
 
 
 
 
 
 
847857a
ad7134f
0a9f283
2d2b00b
847857a
f6c4439
0a9f283
af907cd
 
 
02c7eaf
 
 
 
 
 
847857a
02c7eaf
6e02822
 
8784728
847857a
b57fa67
02c7eaf
 
f7bdc17
847857a
d6f83c2
02c7eaf
2472db7
97e1707
 
 
 
 
 
 
8e270d0
d00e6c7
a0d9d81
 
509ef71
 
 
 
2472db7
509ef71
a0d9d81
 
 
d00e6c7
 
7bc767a
02c7eaf
509ef71
 
5673fea
0d780bd
5673fea
509ef71
5673fea
 
 
 
 
 
 
 
509ef71
af907cd
e3e6e0c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import clip
import torch
import numpy as np
import tempfile
import models.vqvae as vqvae
import options.option_transformer as option_trans
from utils.motion_process import recover_from_ric
import models.t2m_trans as trans
import matplotlib.pyplot as plt
import matplotlib
import base64
from PIL import Image
import io
import mpl_toolkits.mplot3d as p3
from matplotlib.animation import FuncAnimation, PillowWriter
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from visualization.plot_3d_global import draw_to_batch
import imageio
import sys
import os
import warnings
warnings.filterwarnings('ignore')

device = "cuda" if torch.cuda.is_available() else "cpu"

args = option_trans.get_args_parser()

args.dataname = 't2m'
args.down_t = 2
args.depth = 3
args.block_size = 51

vqvae_model = vqvae.HumanVQVAE(args).to(device)

transformer_model = trans.Text2Motion_Transformer(num_vq=args.nb_code,
                                                  embed_dim=1024,
                                                  clip_dim=args.clip_dim,
                                                  block_size=args.block_size,
                                                  num_layers=9,
                                                  n_head=16,
                                                  drop_out_rate=args.drop_out_rate,
                                                  fc_rate=args.ff_rate).to(device)

vqvae_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device)
transformed_vqvae_state_dict = {k.replace("vqvae.", ""): v for k, v in vqvae_checkpoint['net'].items()}
vqvae_model.load_state_dict(transformed_vqvae_state_dict, strict=False)

transformer_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_best_fid.pth", map_location=device)
transformed_transformer_state_dict = {k.replace("trans.", ""): v for k, v in transformer_checkpoint['trans'].items()}
transformer_model.load_state_dict(transformed_transformer_state_dict, strict=False)

vqvae_model.eval()
transformer_model.eval()

mean = torch.from_numpy(np.load('output/Mean.npy', allow_pickle=True)).to(device)
std = torch.from_numpy(np.load('output/Std.npy', allow_pickle=True)).to(device)

def generate_motion(text, vqvae_model, transformer_model):
    clip_text = [text]
    text_encoded = clip.tokenize(clip_text, truncate=True).to(device)
    
    with torch.no_grad():
        clip_model, _ = clip.load("ViT-B/32", device=device)
        clip_model.eval()
        clip_features = clip_model.encode_text(text_encoded).float()
        
        motion_indices = transformer_model.sample(clip_features, False)
        pred_pose = vqvae_model.forward_decoder(motion_indices)
        pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22)
    
    return pred_xyz.cpu().numpy().reshape(-1, 22, 3)

def infer(text):
    print("Received text:", text)
    try:
        motion_data = generate_motion(text, vqvae_model, transformer_model)
        if motion_data.size == 0:
            raise ValueError("Generated motion data is empty")
    except Exception as e:
        print(f"Failed during motion generation: {str(e)}")
        return "Error in motion generation."

    try:
        gif_data = draw_to_batch([motion_data], [text], None)
        if gif_data:
            gif_filename = "output.gif"
            gif_path = os.path.join(tempfile.gettempdir(), gif_filename)
            with open(gif_path, "wb") as gif_file:
                gif_file.write(gif_data)
            print("GIF successfully saved to:", gif_path)
            return gif_path
        else:
            print("Failed to generate GIF data.")
            return "Error generating GIF. Please try again."
    except Exception as e:
        print(f"Error generating GIF: {str(e)}")
        return "Error generating GIF. Please try again."

css = ".container { max-width: 800px; margin: auto; }"
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("## 3D Human Motion Generation")
        with gr.Row():
            text_input = gr.Textbox(label="Enter the human action to generate", placeholder="Enter text description for the action here...", show_label=True)
            submit_button = gr.Button("Generate Motion")
        output_image = gr.Image(label="Generated Human Motion", type="filepath", show_label=False)

    submit_button.click(
        fn=infer,
        inputs=[text_input],
        outputs=[output_image]
    )
    
if __name__ == "__main__":
    demo.launch()