Spaces:
Sleeping
Sleeping
Commit ·
38e4fb0
1
Parent(s): 698f59b
update reset
Browse files
app.py
CHANGED
|
@@ -151,36 +151,55 @@ def draw_arrow(ax, start, end):
|
|
| 151 |
ls="-", lw=1)
|
| 152 |
|
| 153 |
def draw_reset_result(num_data=16, num_code=12):
|
| 154 |
-
|
|
|
|
| 155 |
x = torch.randn(num_data, 1) * 2 + 5
|
| 156 |
y = torch.randn(num_data, 1) * 2 - 5
|
| 157 |
data = torch.cat([x, y], dim=1)
|
| 158 |
quantizer = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1)
|
| 159 |
optimizer = torch.optim.SGD(quantizer.parameters(), lr=0.1)
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
-
i_list = [1,
|
| 164 |
|
| 165 |
count = 0
|
| 166 |
for i in range(1000):
|
| 167 |
optimizer.zero_grad()
|
|
|
|
| 168 |
output_dict = quantizer(data.unsqueeze(1))
|
|
|
|
| 169 |
quant_data = output_dict["x_quant"].squeeze()
|
|
|
|
| 170 |
indices = output_dict["indices"].squeeze()
|
|
|
|
| 171 |
loss = torch.mean((quant_data - data) ** 2)
|
|
|
|
| 172 |
loss.backward()
|
|
|
|
| 173 |
optimizer.step()
|
|
|
|
| 174 |
|
| 175 |
if (i+1) in i_list:
|
| 176 |
count += 1
|
| 177 |
-
draw_fig(
|
| 178 |
-
draw_arrow(
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
quantizer.reset()
|
| 181 |
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
| 183 |
|
|
|
|
| 184 |
|
| 185 |
|
| 186 |
class Handler:
|
|
@@ -290,7 +309,10 @@ if __name__ == "__main__":
|
|
| 290 |
gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1),
|
| 291 |
gr.Slider(label="num_code", value=12, minimum=8, maximum=16, step=1),
|
| 292 |
],
|
| 293 |
-
outputs=
|
|
|
|
|
|
|
|
|
|
| 294 |
title="Demo 2: Codebook Reset Strategy Visualization",
|
| 295 |
description="Visualizes codebook and data movement at different training steps."
|
| 296 |
)
|
|
|
|
| 151 |
ls="-", lw=1)
|
| 152 |
|
| 153 |
def draw_reset_result(num_data=16, num_code=12):
|
| 154 |
+
fig_reset, ax_reset = plt.subplots(1, 6, figsize=(36, 6), dpi=400)
|
| 155 |
+
fig_nreset, ax_nreset = plt.subplots(1, 6, figsize=(36, 6), dpi=400)
|
| 156 |
x = torch.randn(num_data, 1) * 2 + 5
|
| 157 |
y = torch.randn(num_data, 1) * 2 - 5
|
| 158 |
data = torch.cat([x, y], dim=1)
|
| 159 |
quantizer = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1)
|
| 160 |
optimizer = torch.optim.SGD(quantizer.parameters(), lr=0.1)
|
| 161 |
+
quantizer_nreset = Quantizer(TYPE='vq', code_dim=2, num_code=num_code, num_group=1, tokens_per_data=1, auto_reset=False)
|
| 162 |
+
optimizer_nreset = torch.optim.SGD(quantizer_nreset.parameters(), lr=0.1)
|
| 163 |
+
draw_fig(ax_reset[0], quantizer, data, title=f"Initialization")
|
| 164 |
+
draw_fig(ax_nreset[0], quantizer_nreset, data, title=f"Initialization")
|
| 165 |
+
ax_reset[0].legend(["Data", "Code"], loc="upper right", fontsize=18)
|
| 166 |
+
ax_nreset[0].legend(["Data", "Code"], loc="upper right", fontsize=18)
|
| 167 |
|
| 168 |
+
i_list = [1, 3, 10, 50, 200]
|
| 169 |
|
| 170 |
count = 0
|
| 171 |
for i in range(1000):
|
| 172 |
optimizer.zero_grad()
|
| 173 |
+
optimizer_nreset.zero_grad()
|
| 174 |
output_dict = quantizer(data.unsqueeze(1))
|
| 175 |
+
output_dict_nreset = quantizer_nreset(data.unsqueeze(1))
|
| 176 |
quant_data = output_dict["x_quant"].squeeze()
|
| 177 |
+
quant_data_nreset = output_dict_nreset["x_quant"].squeeze()
|
| 178 |
indices = output_dict["indices"].squeeze()
|
| 179 |
+
indices = output_dict_nreset["indices"].squeeze()
|
| 180 |
loss = torch.mean((quant_data - data) ** 2)
|
| 181 |
+
loss_nreset = torch.mean((quant_data_nreset - data) ** 2)
|
| 182 |
loss.backward()
|
| 183 |
+
loss_nreset.backward()
|
| 184 |
optimizer.step()
|
| 185 |
+
optimizer_nreset.step()
|
| 186 |
|
| 187 |
if (i+1) in i_list:
|
| 188 |
count += 1
|
| 189 |
+
draw_fig(ax_reset[count], quantizer, data, title=f"Iters: {i+1}, MSE: {loss.item():.1f}")
|
| 190 |
+
draw_arrow(ax_reset[count], quant_data.detach().numpy(), data.numpy())
|
| 191 |
+
|
| 192 |
+
draw_fig(ax_nreset[count], quantizer_nreset, data, title=f"Iters: {i+1}, MSE: {loss_nreset.item():.1f}")
|
| 193 |
+
draw_arrow(ax_nreset[count], quant_data_nreset.detach().numpy(), data.numpy())
|
| 194 |
|
| 195 |
quantizer.reset()
|
| 196 |
|
| 197 |
+
fig_reset.suptitle("VQ Codebook Training with Reset", fontsize=24, y=1.05)
|
| 198 |
+
fig_nreset.suptitle("VQ Codebook Training without Reset", fontsize=24, y=1.05)
|
| 199 |
+
|
| 200 |
+
return fig_reset, fig_nreset
|
| 201 |
|
| 202 |
+
# end
|
| 203 |
|
| 204 |
|
| 205 |
class Handler:
|
|
|
|
| 309 |
gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1),
|
| 310 |
gr.Slider(label="num_code", value=12, minimum=8, maximum=16, step=1),
|
| 311 |
],
|
| 312 |
+
outputs=[
|
| 313 |
+
gr.Plot(label="With Reset"),
|
| 314 |
+
gr.Plot(label="Without Reset")
|
| 315 |
+
],
|
| 316 |
title="Demo 2: Codebook Reset Strategy Visualization",
|
| 317 |
description="Visualizes codebook and data movement at different training steps."
|
| 318 |
)
|