alexnasa commited on
Commit
8764c13
·
verified ·
1 Parent(s): cdbef26

session based aspect ratio as oppose to the global one

Browse files
Files changed (1) hide show
  1. app.py +23 -27
app.py CHANGED
@@ -321,8 +321,6 @@ class WanInferencePipeline(nn.Module):
321
  image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
322
 
323
  _, _, h, w = image.shape
324
- select_size = match_size(getattr( self.args, f'image_sizes_{ self.args.max_hw}'), h, w)
325
- image = resize_pad(image, (h, w), select_size)
326
  image = image * 2.0 - 1.0
327
  image = image[:, :, None]
328
 
@@ -330,7 +328,7 @@ class WanInferencePipeline(nn.Module):
330
  image = None
331
  select_size = [height, width]
332
  num = self.args.max_tokens * 16 * 16 * 4
333
- den = select_size[0] * select_size[1]
334
  L0 = num // den
335
  diff = (L0 - 1) % 4
336
  L = L0 - diff
@@ -394,29 +392,22 @@ class WanInferencePipeline(nn.Module):
394
  image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
395
 
396
  _, _, h, w = image.shape
397
- select_size = match_size(getattr(self.args, f'image_sizes_{self.args.max_hw}'), h, w)
398
- image = resize_pad(image, (h, w), select_size)
399
  image = image * 2.0 - 1.0
400
  image = image[:, :, None]
401
 
402
  else:
403
  image = None
404
- select_size = [height, width]
405
- # L = int(self.args.max_tokens * 16 * 16 * 4 / select_size[0] / select_size[1])
406
- # L = L // 4 * 4 + 1 if L % 4 != 0 else L - 3 # video frames
407
- # T = (L + 3) // 4 # latent frames
408
 
409
  # step 1: numerator and denominator as ints
410
  num = args.max_tokens * 16 * 16 * 4
411
- den = select_size[0] * select_size[1]
412
 
413
  # step 2: integer division
414
  L0 = num // den # exact floor division, no float in sight
415
 
416
  # step 3: make it ≡ 1 mod 4
417
- # if L0 % 4 == 1, keep L0;
418
- # otherwise subtract the difference so that (L0 - diff) % 4 == 1,
419
- # but ensure the result stays positive.
420
  diff = (L0 - 1) % 4
421
  L = L0 - diff
422
  if L < 1:
@@ -615,7 +606,7 @@ def get_duration(image_path, audio_path, text, num_steps, session_id, progress):
615
 
616
  return int(duration_s)
617
 
618
- def preprocess_img(input_image_path, raw_image_path, session_id = None):
619
 
620
  if session_id is None:
621
  session_id = uuid.uuid4().hex
@@ -631,7 +622,7 @@ def preprocess_img(input_image_path, raw_image_path, session_id = None):
631
  image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
632
 
633
  _, _, h, w = image.shape
634
- select_size = match_size(getattr( args, f'image_sizes_{ args.max_hw}'), h, w)
635
  image = resize_pad(image, (h, w), select_size)
636
  image = image * 2.0 - 1.0
637
  image = image[:, :, None]
@@ -649,13 +640,12 @@ def preprocess_img(input_image_path, raw_image_path, session_id = None):
649
 
650
  def infer_example(image_path, audio_path, text, num_steps, raw_image_path, session_id = None, progress=gr.Progress(track_tqdm=True),):
651
 
652
- current_image_size = args.image_sizes_720
653
- args.image_sizes_720 = [[720, 400]]
654
 
 
655
  result = infer(image_path, audio_path, text, num_steps, session_id, progress)
656
 
657
- args.image_sizes_720 = current_image_size
658
-
659
  return result
660
 
661
  @spaces.GPU(duration=get_duration)
@@ -713,7 +703,8 @@ def infer(image_path, audio_path, text, num_steps, session_id = None, progress=g
713
 
714
  def apply_image(request):
715
  print('image applied')
716
- return request, None
 
717
 
718
  def apply_audio(request):
719
  print('audio applied')
@@ -739,13 +730,15 @@ def orientation_changed(session_id, evt: gr.EventData):
739
  detail = getattr(evt, "data", None) or getattr(evt, "_data", {}) or {}
740
 
741
  if detail['value'] == "9:16":
742
- args.image_sizes_720 = [[720, 400]]
743
  elif detail['value'] == "1:1":
744
- args.image_sizes_720 = [[720, 720]]
745
  elif detail['value'] == "16:9":
746
- args.image_sizes_720 = [[400, 720]]
 
 
747
 
748
- print(f'{session_id} has {args.image_sizes_720} orientation')
749
 
750
  def clear_raw_image():
751
  return ''
@@ -819,6 +812,7 @@ css = """
819
  with gr.Blocks(css=css) as demo:
820
 
821
  session_state = gr.State()
 
822
  demo.load(start_session, outputs=[session_state])
823
 
824
 
@@ -936,7 +930,9 @@ with gr.Blocks(css=css) as demo:
936
  ],
937
  label="Image Samples",
938
  inputs=[image_input],
939
- cache_examples=False
 
 
940
  )
941
 
942
  audio_examples = gr.Examples(
@@ -981,9 +977,9 @@ with gr.Blocks(css=css) as demo:
981
  inputs=[audio_input, limit_on, session_state],
982
  outputs=[audio_input],
983
  )
984
- image_input.orientation(fn=orientation_changed, inputs=[session_state]).then(fn=preprocess_img, inputs=[image_input, raw_img_text, session_state], outputs=[image_input, raw_img_text])
985
  image_input.clear(fn=clear_raw_image, outputs=[raw_img_text])
986
- image_input.upload(fn=preprocess_img, inputs=[image_input, raw_img_text, session_state], outputs=[image_input, raw_img_text])
987
  image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
988
  audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
989
  num_steps.change(fn=slider_value_change, inputs=[image_input, audio_input, text_input, num_steps, session_state, adaptive_text], outputs=[time_required, text_input])
 
321
  image = self.transform(image).unsqueeze(0).to(dtype=self.dtype)
322
 
323
  _, _, h, w = image.shape
 
 
324
  image = image * 2.0 - 1.0
325
  image = image[:, :, None]
326
 
 
328
  image = None
329
  select_size = [height, width]
330
  num = self.args.max_tokens * 16 * 16 * 4
331
+ den = h * w
332
  L0 = num // den
333
  diff = (L0 - 1) % 4
334
  L = L0 - diff
 
392
  image = self.transform(image).unsqueeze(0).to(self.device, dtype=self.dtype)
393
 
394
  _, _, h, w = image.shape
 
 
395
  image = image * 2.0 - 1.0
396
  image = image[:, :, None]
397
 
398
  else:
399
  image = None
400
+ h = height
401
+ w = width
 
 
402
 
403
  # step 1: numerator and denominator as ints
404
  num = args.max_tokens * 16 * 16 * 4
405
+ den = h * w
406
 
407
  # step 2: integer division
408
  L0 = num // den # exact floor division, no float in sight
409
 
410
  # step 3: make it ≡ 1 mod 4
 
 
 
411
  diff = (L0 - 1) % 4
412
  L = L0 - diff
413
  if L < 1:
 
606
 
607
  return int(duration_s)
608
 
609
+ def preprocess_img(input_image_path, raw_image_path, orientation_state, session_id = None):
610
 
611
  if session_id is None:
612
  session_id = uuid.uuid4().hex
 
622
  image = inferpipe.transform(image).unsqueeze(0).to(dtype=inferpipe.dtype)
623
 
624
  _, _, h, w = image.shape
625
+ select_size = match_size(orientation_state, h, w)
626
  image = resize_pad(image, (h, w), select_size)
627
  image = image * 2.0 - 1.0
628
  image = image[:, :, None]
 
640
 
641
  def infer_example(image_path, audio_path, text, num_steps, raw_image_path, session_id = None, progress=gr.Progress(track_tqdm=True),):
642
 
643
+ if session_id is None:
644
+ session_id = uuid.uuid4().hex
645
 
646
+ image_path, _ = preprocess_img(image_path, image_path, [[720, 400]], session_id)
647
  result = infer(image_path, audio_path, text, num_steps, session_id, progress)
648
 
 
 
649
  return result
650
 
651
  @spaces.GPU(duration=get_duration)
 
703
 
704
  def apply_image(request):
705
  print('image applied')
706
+
707
+ return request, request
708
 
709
  def apply_audio(request):
710
  print('audio applied')
 
730
  detail = getattr(evt, "data", None) or getattr(evt, "_data", {}) or {}
731
 
732
  if detail['value'] == "9:16":
733
+ orientation_state = [[720, 400]]
734
  elif detail['value'] == "1:1":
735
+ orientation_state = [[720, 720]]
736
  elif detail['value'] == "16:9":
737
+ orientation_state = [[400, 720]]
738
+
739
+ print(f'{session_id} has {orientation_state} orientation')
740
 
741
+ return orientation_state
742
 
743
  def clear_raw_image():
744
  return ''
 
812
  with gr.Blocks(css=css) as demo:
813
 
814
  session_state = gr.State()
815
+ orientation_state = gr.State([[720, 400]])
816
  demo.load(start_session, outputs=[session_state])
817
 
818
 
 
930
  ],
931
  label="Image Samples",
932
  inputs=[image_input],
933
+ outputs=[image_input, raw_img_text],
934
+ fn=apply_image,
935
+ cache_examples=True
936
  )
937
 
938
  audio_examples = gr.Examples(
 
977
  inputs=[audio_input, limit_on, session_state],
978
  outputs=[audio_input],
979
  )
980
+ image_input.orientation(fn=orientation_changed, inputs=[session_state], outputs=[orientation_state]).then(fn=preprocess_img, inputs=[image_input, raw_img_text, orientation_state, session_state], outputs=[image_input, raw_img_text])
981
  image_input.clear(fn=clear_raw_image, outputs=[raw_img_text])
982
+ image_input.upload(fn=preprocess_img, inputs=[image_input, raw_img_text, orientation_state, session_state], outputs=[image_input, raw_img_text])
983
  image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
984
  audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps, session_state], outputs=[time_required])
985
  num_steps.change(fn=slider_value_change, inputs=[image_input, audio_input, text_input, num_steps, session_state, adaptive_text], outputs=[time_required, text_input])