AndyRaoTHU commited on
Commit
38e4fb0
·
1 Parent(s): 698f59b

update reset

Browse files
Files changed (1) hide show
  1. app.py +30 -8
app.py CHANGED
@@ -151,36 +151,55 @@ def draw_arrow(ax, start, end):
151
  ls="-", lw=1)
152
 
153
  def draw_reset_result(num_data=16, num_code=12):
154
- fig, ax = plt.subplots(1, 6, figsize=(22, 4))
 
155
  x = torch.randn(num_data, 1) * 2 + 5
156
  y = torch.randn(num_data, 1) * 2 - 5
157
  data = torch.cat([x, y], dim=1)
158
  quantizer = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1)
159
  optimizer = torch.optim.SGD(quantizer.parameters(), lr=0.1)
160
- draw_fig(ax[0], quantizer, data, title=f"Initialization")
161
- ax[0].legend(["Data", "Code"], loc="upper right", fontsize=18)
 
 
 
 
162
 
163
- i_list = [1, 2, 3, 10, 100]
164
 
165
  count = 0
166
  for i in range(1000):
167
  optimizer.zero_grad()
 
168
  output_dict = quantizer(data.unsqueeze(1))
 
169
  quant_data = output_dict["x_quant"].squeeze()
 
170
  indices = output_dict["indices"].squeeze()
 
171
  loss = torch.mean((quant_data - data) ** 2)
 
172
  loss.backward()
 
173
  optimizer.step()
 
174
 
175
  if (i+1) in i_list:
176
  count += 1
177
- draw_fig(ax[count], quantizer, data, title=f"Iters: {i+1}, MSE: {loss.item():.1f}")
178
- draw_arrow(ax[count], quant_data.detach().numpy(), data.numpy())
 
 
 
179
 
180
  quantizer.reset()
181
 
182
- return fig
 
 
 
183
 
 
184
 
185
 
186
  class Handler:
@@ -290,7 +309,10 @@ if __name__ == "__main__":
290
  gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1),
291
  gr.Slider(label="num_code", value=12, minimum=8, maximum=16, step=1),
292
  ],
293
- outputs=gr.Plot(label="Training Visualization"),
 
 
 
294
  title="Demo 2: Codebook Reset Strategy Visualization",
295
  description="Visualizes codebook and data movement at different training steps."
296
  )
 
151
  ls="-", lw=1)
152
 
153
  def draw_reset_result(num_data=16, num_code=12):
154
+ fig_reset, ax_reset = plt.subplots(1, 6, figsize=(36, 6), dpi=400)
155
+ fig_nreset, ax_nreset = plt.subplots(1, 6, figsize=(36, 6), dpi=400)
156
  x = torch.randn(num_data, 1) * 2 + 5
157
  y = torch.randn(num_data, 1) * 2 - 5
158
  data = torch.cat([x, y], dim=1)
159
  quantizer = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1)
160
  optimizer = torch.optim.SGD(quantizer.parameters(), lr=0.1)
161
+ quantizer_nreset = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1, auto_reset=False)
162
+ optimizer_nreset = torch.optim.SGD(quantizer_nreset.parameters(), lr=0.1)
163
+ draw_fig(ax_reset[0], quantizer, data, title=f"Initialization")
164
+ draw_fig(ax_nreset[0], quantizer_nreset, data, title=f"Initialization")
165
+ ax_reset[0].legend(["Data", "Code"], loc="upper right", fontsize=18)
166
+ ax_nreset[0].legend(["Data", "Code"], loc="upper right", fontsize=18)
167
 
168
+ i_list = [1, 3, 10, 50, 200]
169
 
170
  count = 0
171
  for i in range(1000):
172
  optimizer.zero_grad()
173
+ optimizer_nreset.zero_grad()
174
  output_dict = quantizer(data.unsqueeze(1))
175
+ output_dict_nreset = quantizer_nreset(data.unsqueeze(1))
176
  quant_data = output_dict["x_quant"].squeeze()
177
+ quant_data_nreset = output_dict_nreset["x_quant"].squeeze()
178
  indices = output_dict["indices"].squeeze()
179
+ indices = output_dict_nreset["indices"].squeeze()
180
  loss = torch.mean((quant_data - data) ** 2)
181
+ loss_nreset = torch.mean((quant_data_nreset - data) ** 2)
182
  loss.backward()
183
+ loss_nreset.backward()
184
  optimizer.step()
185
+ optimizer_nreset.step()
186
 
187
  if (i+1) in i_list:
188
  count += 1
189
+ draw_fig(ax_reset[count], quantizer, data, title=f"Iters: {i+1}, MSE: {loss.item():.1f}")
190
+ draw_arrow(ax_reset[count], quant_data.detach().numpy(), data.numpy())
191
+
192
+ draw_fig(ax_nreset[count], quantizer_nreset, data, title=f"Iters: {i+1}, MSE: {loss_nreset.item():.1f}")
193
+ draw_arrow(ax_nreset[count], quant_data_nreset.detach().numpy(), data.numpy())
194
 
195
  quantizer.reset()
196
 
197
+ fig_reset.suptitle("VQ Codebook Training with Reset", fontsize=24, y=1.05)
198
+ fig_nreset.suptitle("VQ Codebook Training without Reset", fontsize=24, y=1.05)
199
+
200
+ return fig_reset, fig_nreset
201
 
202
+ # end
203
 
204
 
205
  class Handler:
 
309
  gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1),
310
  gr.Slider(label="num_code", value=12, minimum=8, maximum=16, step=1),
311
  ],
312
+ outputs=[
313
+ gr.Plot(label="With Reset"),
314
+ gr.Plot(label="Without Reset")
315
+ ],
316
  title="Demo 2: Codebook Reset Strategy Visualization",
317
  description="Visualizes codebook and data movement at different training steps."
318
  )