rsax commited on
Commit
f7bdc17
·
verified ·
1 Parent(s): 19c1f3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -45
app.py CHANGED
@@ -17,27 +17,14 @@ warnings.filterwarnings('ignore')
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
- # args = option_trans.get_args_parser()
21
-
22
  args = option_trans.get_args_parser()
23
 
24
  args.dataname = 't2m'
25
- # args.resume_pth = './output/VQVAE_imp_resnet_100k_hml3d/net_last_VQVAE_res1.pth'
26
- # args.resume_trans = './output/VQVAE_imp_resnet_100k_hml3d/net_best_fid.pth'
27
  args.down_t = 2
28
  args.depth = 3
29
  args.block_size = 51
30
 
31
  vqvae_model = vqvae.HumanVQVAE(args).to(device)
32
- # ,
33
- # nb_code=512,
34
- # code_dim=512,
35
- # output_emb_width=512,
36
- # down_t=args.down_t,
37
- # stride_t=2,
38
- # width=512,
39
- # depth=args.depth,
40
- # dilation_growth_rate=3).to(device)
41
 
42
  transformer_model = trans.Text2Motion_Transformer(num_vq=args.nb_code,
43
  embed_dim=1024,
@@ -95,7 +82,7 @@ def generate_motion(text, vqvae_model, transformer_model):
95
  if motion_data.size == 0:
96
  raise ValueError("Generated motion data is empty")
97
  return motion_data
98
-
99
  def create_animation(joints, title="3D Motion", save_path="static/animation.gif"):
100
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
101
 
@@ -107,37 +94,27 @@ def create_animation(joints, title="3D Motion", save_path="static/animation.gif"
107
  ax = fig.add_subplot(111, projection='3d')
108
  ax.set_title(title)
109
 
110
- data = np.array(joints).reshape(-1, 22, 3) # Ensure data shape is (frames, 22, 3)
111
- print(f"data shape: {data.shape}")
112
-
113
- # Re-orient the data if necessary
114
- data = data[:, :, [0, 2, 1]] # This swaps Y and Z for example if necessary
115
- data[:, :, 2] = -data[:, :, 2] # Inverting Z if necessary
116
-
117
- # Setting axes limits
118
- x_limits = [np.min(data[:, :, 0]), np.max(data[:, :, 0])]
119
- y_limits = [np.min(data[:, :, 1]), np.max(data[:, :, 1])]
120
- z_limits = [np.min(data[:, :, 2]), np.max(data[:, :, 2])]
121
- ax.set_xlim(x_limits)
122
- ax.set_ylim(y_limits)
123
- ax.set_zlim(z_limits)
124
-
125
- # Preparing lines to plot based on connections
126
- lines = []
127
- for conn in connections:
128
- x_coords = data[0, [conn[0], conn[1]], 0]
129
- y_coords = data[0, [conn[0], conn[1]], 1]
130
- z_coords = data[0, [conn[0], conn[1]], 2]
131
- line, = ax.plot(x_coords, y_coords, z_coords, marker='o', markersize=5)
132
- lines.append(line)
133
 
134
  def update(num):
135
  for line, conn in zip(lines, connections):
136
- x = data[num, [conn[0], conn[1]], 0]
137
- y = data[num, [conn[0], conn[1]], 1]
138
- z = data[num, [conn[0], conn[1]], 2]
139
- line.set_data(x, y)
140
- line.set_3d_properties(z)
141
  return lines
142
 
143
  ani = FuncAnimation(fig, update, frames=len(data), interval=50, blit=True)
@@ -145,7 +122,6 @@ def create_animation(joints, title="3D Motion", save_path="static/animation.gif"
145
  plt.close(fig)
146
  return save_path
147
 
148
-
149
  def infer(text):
150
  motion_data = generate_motion(text, vqvae_model, transformer_model)
151
  if motion_data.size == 0:
@@ -155,8 +131,8 @@ def infer(text):
155
 
156
  with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
157
  with gr.Column():
158
- gr.Markdown("## 3D Human Motion Generation on " + ("GPU" if device == "cuda" else "CPU"))
159
- text_input = gr.Textbox(label="Describe the action", placeholder="Enter text description for the action here...")
160
  output_image = gr.Image(label="Generated Human Motion")
161
  submit_button = gr.Button("Generate Motion")
162
 
 
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
 
 
20
  args = option_trans.get_args_parser()
21
 
22
  args.dataname = 't2m'
 
 
23
  args.down_t = 2
24
  args.depth = 3
25
  args.block_size = 51
26
 
27
  vqvae_model = vqvae.HumanVQVAE(args).to(device)
 
 
 
 
 
 
 
 
 
28
 
29
  transformer_model = trans.Text2Motion_Transformer(num_vq=args.nb_code,
30
  embed_dim=1024,
 
82
  if motion_data.size == 0:
83
  raise ValueError("Generated motion data is empty")
84
  return motion_data
85
+
86
  def create_animation(joints, title="3D Motion", save_path="static/animation.gif"):
87
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
88
 
 
94
  ax = fig.add_subplot(111, projection='3d')
95
  ax.set_title(title)
96
 
97
+ data = np.array(joints).reshape(-1, 22, 3)
98
+ data[:, :, [1, 2]] = data[:, :, [2, 1]]
99
+ data[:, :, 2] = -data[:, :, 2]
100
+
101
+ x_min, x_max = data[:, :, 0].min(), data[:, :, 0].max()
102
+ y_min, y_max = data[:, :, 1].min(), data[:, :, 1].max()
103
+ z_min, z_max = data[:, :, 2].min(), data[:, :, 2].max()
104
+ ax.set_xlim([x_min, x_max])
105
+ ax.set_ylim([y_min, y_max])
106
+ ax.set_zlim([z_min, z_max])
107
+
108
+ lines = [ax.plot([data[0, conn[0], 0], data[0, conn[1], 0]],
109
+ [data[0, conn[0], 1], data[0, conn[1], 1]],
110
+ [data[0, conn[0], 2], data[0, conn[1], 2]],
111
+ marker='o', markersize=5)[0] for conn in connections]
 
 
 
 
 
 
 
 
112
 
113
  def update(num):
114
  for line, conn in zip(lines, connections):
115
+ line.set_data([data[num, conn[0], 0], data[num, conn[1], 0]],
116
+ [data[num, conn[0], 1], data[num, conn[1], 1]])
117
+ line.set_3d_properties([data[num, conn[0], 2], data[num, conn[1], 2]])
 
 
118
  return lines
119
 
120
  ani = FuncAnimation(fig, update, frames=len(data), interval=50, blit=True)
 
122
  plt.close(fig)
123
  return save_path
124
 
 
125
  def infer(text):
126
  motion_data = generate_motion(text, vqvae_model, transformer_model)
127
  if motion_data.size == 0:
 
131
 
132
  with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
133
  with gr.Column():
134
+ gr.Markdown("## 3D Human Motion" + ("Generation" if device == "cuda" else "CPU"))
135
+ text_input = gr.Textbox(label="Enter the human action to generate", placeholder="Enter text description for the action here...")
136
  output_image = gr.Image(label="Generated Human Motion")
137
  submit_button = gr.Button("Generate Motion")
138