Update app.py
Browse files
app.py
CHANGED
|
@@ -17,27 +17,14 @@ warnings.filterwarnings('ignore')
|
|
| 17 |
|
| 18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
|
| 20 |
-
# args = option_trans.get_args_parser()
|
| 21 |
-
|
| 22 |
args = option_trans.get_args_parser()
|
| 23 |
|
| 24 |
args.dataname = 't2m'
|
| 25 |
-
# args.resume_pth = './output/VQVAE_imp_resnet_100k_hml3d/net_last_VQVAE_res1.pth'
|
| 26 |
-
# args.resume_trans = './output/VQVAE_imp_resnet_100k_hml3d/net_best_fid.pth'
|
| 27 |
args.down_t = 2
|
| 28 |
args.depth = 3
|
| 29 |
args.block_size = 51
|
| 30 |
|
| 31 |
vqvae_model = vqvae.HumanVQVAE(args).to(device)
|
| 32 |
-
# ,
|
| 33 |
-
# nb_code=512,
|
| 34 |
-
# code_dim=512,
|
| 35 |
-
# output_emb_width=512,
|
| 36 |
-
# down_t=args.down_t,
|
| 37 |
-
# stride_t=2,
|
| 38 |
-
# width=512,
|
| 39 |
-
# depth=args.depth,
|
| 40 |
-
# dilation_growth_rate=3).to(device)
|
| 41 |
|
| 42 |
transformer_model = trans.Text2Motion_Transformer(num_vq=args.nb_code,
|
| 43 |
embed_dim=1024,
|
|
@@ -95,7 +82,7 @@ def generate_motion(text, vqvae_model, transformer_model):
|
|
| 95 |
if motion_data.size == 0:
|
| 96 |
raise ValueError("Generated motion data is empty")
|
| 97 |
return motion_data
|
| 98 |
-
|
| 99 |
def create_animation(joints, title="3D Motion", save_path="static/animation.gif"):
|
| 100 |
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 101 |
|
|
@@ -107,37 +94,27 @@ def create_animation(joints, title="3D Motion", save_path="static/animation.gif"
|
|
| 107 |
ax = fig.add_subplot(111, projection='3d')
|
| 108 |
ax.set_title(title)
|
| 109 |
|
| 110 |
-
data = np.array(joints).reshape(-1, 22, 3)
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
data[:, :,
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
ax.
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
# Preparing lines to plot based on connections
|
| 126 |
-
lines = []
|
| 127 |
-
for conn in connections:
|
| 128 |
-
x_coords = data[0, [conn[0], conn[1]], 0]
|
| 129 |
-
y_coords = data[0, [conn[0], conn[1]], 1]
|
| 130 |
-
z_coords = data[0, [conn[0], conn[1]], 2]
|
| 131 |
-
line, = ax.plot(x_coords, y_coords, z_coords, marker='o', markersize=5)
|
| 132 |
-
lines.append(line)
|
| 133 |
|
| 134 |
def update(num):
|
| 135 |
for line, conn in zip(lines, connections):
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
line.set_data(x, y)
|
| 140 |
-
line.set_3d_properties(z)
|
| 141 |
return lines
|
| 142 |
|
| 143 |
ani = FuncAnimation(fig, update, frames=len(data), interval=50, blit=True)
|
|
@@ -145,7 +122,6 @@ def create_animation(joints, title="3D Motion", save_path="static/animation.gif"
|
|
| 145 |
plt.close(fig)
|
| 146 |
return save_path
|
| 147 |
|
| 148 |
-
|
| 149 |
def infer(text):
|
| 150 |
motion_data = generate_motion(text, vqvae_model, transformer_model)
|
| 151 |
if motion_data.size == 0:
|
|
@@ -155,8 +131,8 @@ def infer(text):
|
|
| 155 |
|
| 156 |
with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
|
| 157 |
with gr.Column():
|
| 158 |
-
gr.Markdown("## 3D Human Motion
|
| 159 |
-
text_input = gr.Textbox(label="
|
| 160 |
output_image = gr.Image(label="Generated Human Motion")
|
| 161 |
submit_button = gr.Button("Generate Motion")
|
| 162 |
|
|
|
|
| 17 |
|
| 18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
|
|
|
|
|
|
|
| 20 |
args = option_trans.get_args_parser()
|
| 21 |
|
| 22 |
args.dataname = 't2m'
|
|
|
|
|
|
|
| 23 |
args.down_t = 2
|
| 24 |
args.depth = 3
|
| 25 |
args.block_size = 51
|
| 26 |
|
| 27 |
vqvae_model = vqvae.HumanVQVAE(args).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
transformer_model = trans.Text2Motion_Transformer(num_vq=args.nb_code,
|
| 30 |
embed_dim=1024,
|
|
|
|
| 82 |
if motion_data.size == 0:
|
| 83 |
raise ValueError("Generated motion data is empty")
|
| 84 |
return motion_data
|
| 85 |
+
|
| 86 |
def create_animation(joints, title="3D Motion", save_path="static/animation.gif"):
|
| 87 |
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 88 |
|
|
|
|
| 94 |
ax = fig.add_subplot(111, projection='3d')
|
| 95 |
ax.set_title(title)
|
| 96 |
|
| 97 |
+
data = np.array(joints).reshape(-1, 22, 3)
|
| 98 |
+
data[:, :, [1, 2]] = data[:, :, [2, 1]]
|
| 99 |
+
data[:, :, 2] = -data[:, :, 2]
|
| 100 |
+
|
| 101 |
+
x_min, x_max = data[:, :, 0].min(), data[:, :, 0].max()
|
| 102 |
+
y_min, y_max = data[:, :, 1].min(), data[:, :, 1].max()
|
| 103 |
+
z_min, z_max = data[:, :, 2].min(), data[:, :, 2].max()
|
| 104 |
+
ax.set_xlim([x_min, x_max])
|
| 105 |
+
ax.set_ylim([y_min, y_max])
|
| 106 |
+
ax.set_zlim([z_min, z_max])
|
| 107 |
+
|
| 108 |
+
lines = [ax.plot([data[0, conn[0], 0], data[0, conn[1], 0]],
|
| 109 |
+
[data[0, conn[0], 1], data[0, conn[1], 1]],
|
| 110 |
+
[data[0, conn[0], 2], data[0, conn[1], 2]],
|
| 111 |
+
marker='o', markersize=5)[0] for conn in connections]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
def update(num):
|
| 114 |
for line, conn in zip(lines, connections):
|
| 115 |
+
line.set_data([data[num, conn[0], 0], data[num, conn[1], 0]],
|
| 116 |
+
[data[num, conn[0], 1], data[num, conn[1], 1]])
|
| 117 |
+
line.set_3d_properties([data[num, conn[0], 2], data[num, conn[1], 2]])
|
|
|
|
|
|
|
| 118 |
return lines
|
| 119 |
|
| 120 |
ani = FuncAnimation(fig, update, frames=len(data), interval=50, blit=True)
|
|
|
|
| 122 |
plt.close(fig)
|
| 123 |
return save_path
|
| 124 |
|
|
|
|
| 125 |
def infer(text):
|
| 126 |
motion_data = generate_motion(text, vqvae_model, transformer_model)
|
| 127 |
if motion_data.size == 0:
|
|
|
|
| 131 |
|
| 132 |
with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
|
| 133 |
with gr.Column():
|
| 134 |
+
gr.Markdown("## 3D Human Motion" + ("Generation" if device == "cuda" else "CPU"))
|
| 135 |
+
text_input = gr.Textbox(label="Enter the human action to generate", placeholder="Enter text description for the action here...")
|
| 136 |
output_image = gr.Image(label="Generated Human Motion")
|
| 137 |
submit_button = gr.Button("Generate Motion")
|
| 138 |
|