Spaces:
Sleeping
Sleeping
Commit ·
864e595
1
Parent(s): 9888fa4
update
Browse files
app.py
CHANGED
|
@@ -225,9 +225,9 @@ class Handler:
|
|
| 225 |
self.vqgan.to(self.device)
|
| 226 |
self.vqgan.eval()
|
| 227 |
|
| 228 |
-
self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4")
|
| 229 |
-
self.optvq.to(self.device)
|
| 230 |
-
self.optvq.eval()
|
| 231 |
|
| 232 |
# self.vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
|
| 233 |
# self.vae.to(self.device)
|
|
@@ -236,9 +236,13 @@ class Handler:
|
|
| 236 |
# self.revq = ReVQ.from_pretrained("AndyRaoTHU/revq-512T")
|
| 237 |
# self.revq.to(self.device)
|
| 238 |
# self.revq.eval()
|
| 239 |
-
self.
|
| 240 |
-
self.
|
| 241 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
print("Models loaded successfully!")
|
| 243 |
|
| 244 |
def tensor_v2_to_image(self, tensor):
|
|
@@ -276,17 +280,20 @@ class Handler:
|
|
| 276 |
# revq_rec = self.vae.decode(revq_rec).sample
|
| 277 |
# sftok_rec = revq_rec
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
|
|
|
|
|
|
| 282 |
|
| 283 |
# tensor to PIL image
|
| 284 |
img = self.tensor_to_image(img)
|
| 285 |
basevq_rec = self.tensor_v2_to_image(basevq_rec)
|
| 286 |
vqgan_rec = self.tensor_v2_to_image(vqgan_rec)
|
| 287 |
-
|
|
|
|
| 288 |
|
| 289 |
-
return basevq_rec, vqgan_rec,
|
| 290 |
|
| 291 |
if __name__ == "__main__":
|
| 292 |
# create the model handler
|
|
@@ -301,48 +308,54 @@ if __name__ == "__main__":
|
|
| 301 |
outputs=[
|
| 302 |
gr.Image(label="BaseVQ Reconstruction", type="numpy"),
|
| 303 |
gr.Image(label="VQGAN Reconstruction", type="numpy"),
|
| 304 |
-
gr.Image(label="SFTok Reconstruction", type="numpy"),
|
|
|
|
| 305 |
],
|
| 306 |
title="Demo 1: Image Reconstruction",
|
| 307 |
description="Upload an image to see how different VQ models (BaseVQ, VQGAN, SFTok) reconstruct it from latent codes."
|
| 308 |
)
|
| 309 |
|
| 310 |
-
with gr.Blocks() as demo2:
|
| 311 |
-
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
-
|
| 315 |
-
num_data = gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1)
|
| 316 |
-
num_code = gr.Slider(label="num_code", value=12, minimum=8, maximum=16, step=1)
|
| 317 |
|
| 318 |
-
|
|
|
|
|
|
|
| 319 |
|
| 320 |
-
|
| 321 |
-
out_without_reset = gr.Image(label="Without Reset")
|
| 322 |
-
out_with_reset = gr.Image(label="With Reset")
|
| 323 |
|
| 324 |
-
submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[out_without_reset, out_with_reset])
|
| 325 |
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
-
with gr.
|
| 328 |
-
|
| 329 |
-
|
| 330 |
|
| 331 |
-
|
| 332 |
-
num_data = gr.Slider(label="num_data", value=32, minimum=28, maximum=40, step=1)
|
| 333 |
-
num_code = gr.Slider(label="num_code", value=8, minimum=6, maximum=10, step=1)
|
| 334 |
|
| 335 |
-
|
|
|
|
|
|
|
| 336 |
|
| 337 |
-
|
| 338 |
-
out_s = gr.Image(label="Single Group")
|
| 339 |
-
out_m = gr.Image(label="Multi Group")
|
| 340 |
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
demo = gr.TabbedInterface(
|
| 344 |
-
interface_list=[demo1
|
| 345 |
-
tab_names=["Image Reconstruction"
|
| 346 |
)
|
| 347 |
|
| 348 |
demo.launch(share=True)
|
|
|
|
| 225 |
self.vqgan.to(self.device)
|
| 226 |
self.vqgan.eval()
|
| 227 |
|
| 228 |
+
# self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4")
|
| 229 |
+
# self.optvq.to(self.device)
|
| 230 |
+
# self.optvq.eval()
|
| 231 |
|
| 232 |
# self.vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
|
| 233 |
# self.vae.to(self.device)
|
|
|
|
| 236 |
# self.revq = ReVQ.from_pretrained("AndyRaoTHU/revq-512T")
|
| 237 |
# self.revq.to(self.device)
|
| 238 |
# self.revq.eval()
|
| 239 |
+
self.sftok_b = SFTok.from_pretrained("AndyRaoTHU/SFTok-B")
|
| 240 |
+
self.sftok_b.to(self.device)
|
| 241 |
+
self.sftok_b.eval()
|
| 242 |
+
|
| 243 |
+
self.sftok_l = SFTok.from_pretrained("AndyRaoTHU/SFTok-L")
|
| 244 |
+
self.sftok_l.to(self.device)
|
| 245 |
+
self.sftok_l.eval()
|
| 246 |
print("Models loaded successfully!")
|
| 247 |
|
| 248 |
def tensor_v2_to_image(self, tensor):
|
|
|
|
| 280 |
# revq_rec = self.vae.decode(revq_rec).sample
|
| 281 |
# sftok_rec = revq_rec
|
| 282 |
|
| 283 |
+
encoded_tokens_b = self.sftok_b.encode(img)[1]["min_encoding_indices"]
|
| 284 |
+
sftok_rec_b = self.sftok_b.decode_tokens(encoded_tokens_b)
|
| 285 |
+
|
| 286 |
+
encoded_tokens_l = self.sftok_l.encode(img)[1]["min_encoding_indices"]
|
| 287 |
+
sftok_rec_l = self.sftok_l.decode_tokens(encoded_tokens_l)
|
| 288 |
|
| 289 |
# tensor to PIL image
|
| 290 |
img = self.tensor_to_image(img)
|
| 291 |
basevq_rec = self.tensor_v2_to_image(basevq_rec)
|
| 292 |
vqgan_rec = self.tensor_v2_to_image(vqgan_rec)
|
| 293 |
+
sftok_rec_b = self.tensor_to_image(sftok_rec_b)
|
| 294 |
+
sftok_rec_l = self.tensor_to_image(sftok_rec_l)
|
| 295 |
|
| 296 |
+
return basevq_rec, vqgan_rec, sftok_rec_b, sftok_rec_l
|
| 297 |
|
| 298 |
if __name__ == "__main__":
|
| 299 |
# create the model handler
|
|
|
|
| 308 |
outputs=[
|
| 309 |
gr.Image(label="BaseVQ Reconstruction", type="numpy"),
|
| 310 |
gr.Image(label="VQGAN Reconstruction", type="numpy"),
|
| 311 |
+
gr.Image(label="SFTok-B Reconstruction", type="numpy"),
|
| 312 |
+
gr.Image(label="SFTok-L Reconstruction", type="numpy"),
|
| 313 |
],
|
| 314 |
title="Demo 1: Image Reconstruction",
|
| 315 |
description="Upload an image to see how different VQ models (BaseVQ, VQGAN, SFTok) reconstruct it from latent codes."
|
| 316 |
)
|
| 317 |
|
| 318 |
+
# with gr.Blocks() as demo2:
|
| 319 |
+
# gr.Markdown("## Demo 2: Codebook Reset Strategy Visualization")
|
| 320 |
+
# gr.Markdown("Visualizes codebook and data movement at different training steps with or without codebook reset strategy.")
|
| 321 |
+
|
| 322 |
+
# with gr.Row():
|
| 323 |
+
# num_data = gr.Slider(label="num_data", value=16, minimum=10, maximum=20, step=1)
|
| 324 |
+
# num_code = gr.Slider(label="num_code", value=12, minimum=8, maximum=16, step=1)
|
| 325 |
|
| 326 |
+
# submit_btn = gr.Button("Run Visualization")
|
|
|
|
|
|
|
| 327 |
|
| 328 |
+
# with gr.Column(): # 垂直输出
|
| 329 |
+
# out_without_reset = gr.Image(label="Without Reset")
|
| 330 |
+
# out_with_reset = gr.Image(label="With Reset")
|
| 331 |
|
| 332 |
+
# submit_btn.click(fn=draw_reset_result, inputs=[num_data, num_code], outputs=[out_without_reset, out_with_reset])
|
|
|
|
|
|
|
| 333 |
|
|
|
|
| 334 |
|
| 335 |
+
# with gr.Blocks() as demo3:
|
| 336 |
+
# gr.Markdown("## Demo 3: Channel Multi-Group Strategy Visualization")
|
| 337 |
+
# gr.Markdown("Visualizes codebook and data movement at different training steps with or without multi-group strategy.")
|
| 338 |
|
| 339 |
+
# with gr.Row():
|
| 340 |
+
# num_data = gr.Slider(label="num_data", value=32, minimum=28, maximum=40, step=1)
|
| 341 |
+
# num_code = gr.Slider(label="num_code", value=8, minimum=6, maximum=10, step=1)
|
| 342 |
|
| 343 |
+
# submit_btn = gr.Button("Run Visualization")
|
|
|
|
|
|
|
| 344 |
|
| 345 |
+
# with gr.Column(): # 垂直输出
|
| 346 |
+
# out_s = gr.Image(label="Single Group")
|
| 347 |
+
# out_m = gr.Image(label="Multi Group")
|
| 348 |
|
| 349 |
+
# submit_btn.click(fn=draw_multi_group_result, inputs=[num_data, num_code], outputs=[out_s, out_m])
|
|
|
|
|
|
|
| 350 |
|
| 351 |
+
# demo = gr.TabbedInterface(
|
| 352 |
+
# interface_list=[demo1, demo2, demo3],
|
| 353 |
+
# tab_names=["Image Reconstruction", "Reset Strategy", "Channel Multi-Group Strategy"]
|
| 354 |
+
# )
|
| 355 |
|
| 356 |
demo = gr.TabbedInterface(
|
| 357 |
+
interface_list=[demo1],
|
| 358 |
+
tab_names=["Image Reconstruction"]
|
| 359 |
)
|
| 360 |
|
| 361 |
demo.launch(share=True)
|