AndyRaoTHU commited on
Commit
864e595
·
1 Parent(s): 9888fa4
Files changed (1) hide show
  1. app.py +49 -36
app.py CHANGED
@@ -225,9 +225,9 @@ class Handler:
225
  self.vqgan.to(self.device)
226
  self.vqgan.eval()
227
 
228
- self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4")
229
- self.optvq.to(self.device)
230
- self.optvq.eval()
231
 
232
  # self.vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
233
  # self.vae.to(self.device)
@@ -236,9 +236,13 @@ class Handler:
236
  # self.revq = ReVQ.from_pretrained("AndyRaoTHU/revq-512T")
237
  # self.revq.to(self.device)
238
  # self.revq.eval()
239
- self.sftok = SFTok.from_pretrained("AndyRaoTHU/SFTok-B")
240
- self.sftok.to(self.device)
241
- self.sftok.eval()
 
 
 
 
242
  print("Models loaded successfully!")
243
 
244
  def tensor_v2_to_image(self, tensor):
@@ -276,17 +280,20 @@ class Handler:
276
  # revq_rec = self.vae.decode(revq_rec).sample
277
  # sftok_rec = revq_rec
278
 
279
- encoded_tokens = self.sftok.encode(img)[1]["min_encoding_indices"]
280
- # encoded_tokens, _ = self.sftok.encode(img)
281
- sftok_rec = self.sftok.decode_tokens(encoded_tokens)
 
 
282
 
283
  # tensor to PIL image
284
  img = self.tensor_to_image(img)
285
  basevq_rec = self.tensor_v2_to_image(basevq_rec)
286
  vqgan_rec = self.tensor_v2_to_image(vqgan_rec)
287
- sftok_rec = self.tensor_to_image(sftok_rec)
 
288
 
289
- return basevq_rec, vqgan_rec, sftok_rec
290
 
291
  if __name__ == "__main__":
292
  # create the model handler
@@ -301,48 +308,54 @@ if __name__ == "__main__":
301
  outputs=[
302
  gr.Image(label="BaseVQ Reconstruction", type="numpy"),
303
  gr.Image(label="VQGAN Reconstruction", type="numpy"),
304
- gr.Image(label="SFTok Reconstruction", type="numpy"),
 
305
  ],
306
  title="Demo 1: Image Reconstruction",
307
  description="Upload an image to see how different VQ models (BaseVQ, VQGAN, SFTok) reconstruct it from latent codes."
308
  )
309
 
310
- with gr.Blocks() as demo2:
311
- gr.Markdown("## Demo 2: Codebook Reset Strategy Visualization")
312
- gr.Markdown("Visualizes codebook and data movement at different training steps with or without codebook reset strategy.")
 
 
 
 
313
 
314
- with gr.Row():
315
- num_data = gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1)
316
- num_code = gr.Slider(label="num_code", value=12, minimum=8, maximum=16, step=1)
317
 
318
- submit_btn = gr.Button("Run Visualization")
 
 
319
 
320
- with gr.Column(): # 垂直输出
321
- out_without_reset = gr.Image(label="Without Reset")
322
- out_with_reset = gr.Image(label="With Reset")
323
 
324
- submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[out_without_reset, out_with_reset])
325
 
 
 
 
326
 
327
- with gr.Blocks() as demo3:
328
- gr.Markdown("## Demo 3: Channel Multi-Group Strategy Visualization")
329
- gr.Markdown("Visualizes codebook and data movement at different training steps with or without multi-group strategy.")
330
 
331
- with gr.Row():
332
- num_data = gr.Slider(label="num_data", value=32, minimum=28, maximum=40, step=1)
333
- num_code = gr.Slider(label="num_code", value=8, minimum=6, maximum=10, step=1)
334
 
335
- submit_btn = gr.Button("Run Visualization")
 
 
336
 
337
- with gr.Column(): # 垂直输出
338
- out_s = gr.Image(label="Single Group")
339
- out_m = gr.Image(label="Multi Group")
340
 
341
- submit_btn.click(fn=draw_multi_group_result, inputs=[num_data, num_code], outputs=[out_s, out_m])
 
 
 
342
 
343
  demo = gr.TabbedInterface(
344
- interface_list=[demo1, demo2, demo3],
345
- tab_names=["Image Reconstruction", "Reset Strategy", "Channel Multi-Group Strategy"]
346
  )
347
 
348
  demo.launch(share=True)
 
225
  self.vqgan.to(self.device)
226
  self.vqgan.eval()
227
 
228
+ # self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4")
229
+ # self.optvq.to(self.device)
230
+ # self.optvq.eval()
231
 
232
  # self.vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
233
  # self.vae.to(self.device)
 
236
  # self.revq = ReVQ.from_pretrained("AndyRaoTHU/revq-512T")
237
  # self.revq.to(self.device)
238
  # self.revq.eval()
239
+ self.sftok_b = SFTok.from_pretrained("AndyRaoTHU/SFTok-B")
240
+ self.sftok_b.to(self.device)
241
+ self.sftok_b.eval()
242
+
243
+ self.sftok_l = SFTok.from_pretrained("AndyRaoTHU/SFTok-L")
244
+ self.sftok_l.to(self.device)
245
+ self.sftok_l.eval()
246
  print("Models loaded successfully!")
247
 
248
  def tensor_v2_to_image(self, tensor):
 
280
  # revq_rec = self.vae.decode(revq_rec).sample
281
  # sftok_rec = revq_rec
282
 
283
+ encoded_tokens_b = self.sftok_b.encode(img)[1]["min_encoding_indices"]
284
+ sftok_rec_b = self.sftok_b.decode_tokens(encoded_tokens_b)
285
+
286
+ encoded_tokens_l = self.sftok_l.encode(img)[1]["min_encoding_indices"]
287
+ sftok_rec_l = self.sftok_l.decode_tokens(encoded_tokens_l)
288
 
289
  # tensor to PIL image
290
  img = self.tensor_to_image(img)
291
  basevq_rec = self.tensor_v2_to_image(basevq_rec)
292
  vqgan_rec = self.tensor_v2_to_image(vqgan_rec)
293
+ sftok_rec_b = self.tensor_to_image(sftok_rec_b)
294
+ sftok_rec_l = self.tensor_to_image(sftok_rec_l)
295
 
296
+ return basevq_rec, vqgan_rec, sftok_rec_b, sftok_rec_l
297
 
298
  if __name__ == "__main__":
299
  # create the model handler
 
308
  outputs=[
309
  gr.Image(label="BaseVQ Reconstruction", type="numpy"),
310
  gr.Image(label="VQGAN Reconstruction", type="numpy"),
311
+ gr.Image(label="SFTok-B Reconstruction", type="numpy"),
312
+ gr.Image(label="SFTok-L Reconstruction", type="numpy"),
313
  ],
314
  title="Demo 1: Image Reconstruction",
315
  description="Upload an image to see how different VQ models (BaseVQ, VQGAN, SFTok) reconstruct it from latent codes."
316
  )
317
 
318
+ # with gr.Blocks() as demo2:
319
+ # gr.Markdown("## Demo 2: Codebook Reset Strategy Visualization")
320
+ # gr.Markdown("Visualizes codebook and data movement at different training steps with or without codebook reset strategy.")
321
+
322
+ # with gr.Row():
323
+ # num_data = gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1)
324
+ # num_code = gr.Slider(label="num_code", value=12, minimum=8, maximum=16, step=1)
325
 
326
+ # submit_btn = gr.Button("Run Visualization")
 
 
327
 
328
+ # with gr.Column(): # 垂直输出
329
+ # out_without_reset = gr.Image(label="Without Reset")
330
+ # out_with_reset = gr.Image(label="With Reset")
331
 
332
+ # submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[out_without_reset, out_with_reset])
 
 
333
 
 
334
 
335
+ # with gr.Blocks() as demo3:
336
+ # gr.Markdown("## Demo 3: Channel Multi-Group Strategy Visualization")
337
+ # gr.Markdown("Visualizes codebook and data movement at different training steps with or without multi-group strategy.")
338
 
339
+ # with gr.Row():
340
+ # num_data = gr.Slider(label="num_data", value=32, minimum=28, maximum=40, step=1)
341
+ # num_code = gr.Slider(label="num_code", value=8, minimum=6, maximum=10, step=1)
342
 
343
+ # submit_btn = gr.Button("Run Visualization")
 
 
344
 
345
+ # with gr.Column(): # 垂直输出
346
+ # out_s = gr.Image(label="Single Group")
347
+ # out_m = gr.Image(label="Multi Group")
348
 
349
+ # submit_btn.click(fn=draw_multi_group_result, inputs=[num_data, num_code], outputs=[out_s, out_m])
 
 
350
 
351
+ # demo = gr.TabbedInterface(
352
+ # interface_list=[demo1, demo2, demo3],
353
+ # tab_names=["Image Reconstruction", "Reset Strategy", "Channel Multi-Group Strategy"]
354
+ # )
355
 
356
  demo = gr.TabbedInterface(
357
+ interface_list=[demo1],
358
+ tab_names=["Image Reconstruction"]
359
  )
360
 
361
  demo.launch(share=True)