dagloop5 commited on
Commit
db91dcb
·
verified ·
1 Parent(s): 00fea93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -60
app.py CHANGED
@@ -101,7 +101,7 @@ RESOLUTIONS = {
101
  }
102
 
103
 
104
- class LTX23DistilledA2VPipeline(DistilledPipeline):
105
  """DistilledPipeline with optional audio conditioning."""
106
 
107
  def __call__(
@@ -113,24 +113,9 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
113
  num_frames: int,
114
  frame_rate: float,
115
  images: list[ImageConditioningInput],
116
- audio_path: str | None = None,
117
  tiling_config: TilingConfig | None = None,
118
  enhance_prompt: bool = False,
119
  ):
120
- # Standard path when no audio input is provided.
121
- print(prompt)
122
- if audio_path is None:
123
- return super().__call__(
124
- prompt=prompt,
125
- seed=seed,
126
- height=height,
127
- width=width,
128
- num_frames=num_frames,
129
- frame_rate=frame_rate,
130
- images=images,
131
- tiling_config=tiling_config,
132
- enhance_prompt=enhance_prompt,
133
- )
134
 
135
  generator = torch.Generator(device=self.device).manual_seed(seed)
136
  noiser = GaussianNoiser(generator=generator)
@@ -141,38 +126,18 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
141
  [prompt],
142
  self.model_ledger,
143
  enhance_first_prompt=enhance_prompt,
144
- enhance_prompt_image=images[0].path if len(images) > 0 else None,
145
  )
146
  video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
147
 
148
- video_duration = num_frames / frame_rate
149
- decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
150
- if decoded_audio is None:
151
- raise ValueError(f"Could not extract audio stream from {audio_path}")
152
-
153
- encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
154
- audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
155
- expected_frames = audio_shape.frames
156
- actual_frames = encoded_audio_latent.shape[2]
157
-
158
- if actual_frames > expected_frames:
159
- encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
160
- elif actual_frames < expected_frames:
161
- pad = torch.zeros(
162
- encoded_audio_latent.shape[0],
163
- encoded_audio_latent.shape[1],
164
- expected_frames - actual_frames,
165
- encoded_audio_latent.shape[3],
166
- device=encoded_audio_latent.device,
167
- dtype=encoded_audio_latent.dtype,
168
- )
169
- encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
170
-
171
  video_encoder = self.model_ledger.video_encoder()
172
  transformer = self.model_ledger.transformer()
173
- stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
174
 
175
- def denoising_loop(sigmas, video_state, audio_state, stepper):
 
 
176
  return euler_denoising_loop(
177
  sigmas=sigmas,
178
  video_state=video_state,
@@ -181,7 +146,7 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
181
  denoise_fn=simple_denoising_func(
182
  video_context=video_context,
183
  audio_context=audio_context,
184
- transformer=transformer,
185
  ),
186
  )
187
 
@@ -200,7 +165,8 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
200
  dtype=dtype,
201
  device=self.device,
202
  )
203
- video_state = denoise_video_only(
 
204
  output_shape=stage_1_output_shape,
205
  conditionings=stage_1_conditionings,
206
  noiser=noiser,
@@ -210,7 +176,6 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
210
  components=self.pipeline_components,
211
  dtype=dtype,
212
  device=self.device,
213
- initial_audio_latent=encoded_audio_latent,
214
  )
215
 
216
  torch.cuda.synchronize()
@@ -219,16 +184,12 @@ class LTX23DistilledA2VPipeline(DistilledPipeline):
219
  cleanup_memory()
220
 
221
  decoded_video = vae_decode_video(
222
- video_state.latent,
223
- self.model_ledger.video_decoder(),
224
- tiling_config,
225
- generator,
226
  )
227
- original_audio = Audio(
228
- waveform=decoded_audio.waveform.squeeze(0),
229
- sampling_rate=decoded_audio.sampling_rate,
230
  )
231
- return decoded_video, original_audio
232
 
233
 
234
  # Model repos
@@ -566,7 +527,6 @@ def on_highres_toggle(first_image, last_image, high_res):
566
  def get_gpu_duration(
567
  first_image,
568
  last_image,
569
- input_audio,
570
  prompt: str,
571
  duration: float,
572
  gpu_duration: float,
@@ -598,7 +558,6 @@ def get_gpu_duration(
598
  def generate_video(
599
  first_image,
600
  last_image,
601
- input_audio,
602
  prompt: str,
603
  duration: float,
604
  gpu_duration: float,
@@ -670,7 +629,6 @@ def generate_video(
670
  num_frames=num_frames,
671
  frame_rate=frame_rate,
672
  images=images,
673
- audio_path=input_audio,
674
  tiling_config=tiling_config,
675
  enhance_prompt=enhance_prompt,
676
  )
@@ -705,7 +663,6 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
705
  with gr.Row():
706
  first_image = gr.Image(label="First Frame (Optional)", type="pil")
707
  last_image = gr.Image(label="Last Frame (Optional)", type="pil")
708
- input_audio = gr.Audio(label="Audio Input (Optional)", type="filepath")
709
  prompt = gr.Textbox(
710
  label="Prompt",
711
  info="for best results - make it as elaborate as possible",
@@ -807,7 +764,6 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
807
  [
808
  None,
809
  "pinkknit.jpg",
810
- None,
811
  "The camera falls downward through darkness as if dropped into a tunnel. "
812
  "As it slows, five friends wearing pink knitted hats and sunglasses lean "
813
  "over and look down toward the camera with curious expressions. The lens "
@@ -838,7 +794,7 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
838
  ],
839
  ],
840
  inputs=[
841
- first_image, last_image, input_audio, prompt, duration, gpu_duration,
842
  enhance_prompt, seed, randomize_seed, height, width,
843
  pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength,
844
  ],
@@ -871,7 +827,7 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
871
  generate_btn.click(
872
  fn=generate_video,
873
  inputs=[
874
- first_image, last_image, input_audio, prompt, duration, gpu_duration, enhance_prompt,
875
  seed, randomize_seed, height, width,
876
  pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength,
877
  ],
 
101
  }
102
 
103
 
104
+ class LTX23DistilledA2VPipeline:
105
  """DistilledPipeline with optional audio conditioning."""
106
 
107
  def __call__(
 
113
  num_frames: int,
114
  frame_rate: float,
115
  images: list[ImageConditioningInput],
 
116
  tiling_config: TilingConfig | None = None,
117
  enhance_prompt: bool = False,
118
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  generator = torch.Generator(device=self.device).manual_seed(seed)
121
  noiser = GaussianNoiser(generator=generator)
 
126
  [prompt],
127
  self.model_ledger,
128
  enhance_first_prompt=enhance_prompt,
129
+ enhance_prompt_image=images[0][0] if len(images) > 0 else None,
130
  )
131
  video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
132
 
133
+ # Stage 1: Initial low resolution video generation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  video_encoder = self.model_ledger.video_encoder()
135
  transformer = self.model_ledger.transformer()
136
+ stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device)
137
 
138
+ def denoising_loop(
139
+ sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
140
+ ) -> tuple[LatentState, LatentState]:
141
  return euler_denoising_loop(
142
  sigmas=sigmas,
143
  video_state=video_state,
 
146
  denoise_fn=simple_denoising_func(
147
  video_context=video_context,
148
  audio_context=audio_context,
149
+ transformer=transformer, # noqa: F821
150
  ),
151
  )
152
 
 
165
  dtype=dtype,
166
  device=self.device,
167
  )
168
+
169
+ video_state, audio_state = denoise_audio_video(
170
  output_shape=stage_1_output_shape,
171
  conditionings=stage_1_conditionings,
172
  noiser=noiser,
 
176
  components=self.pipeline_components,
177
  dtype=dtype,
178
  device=self.device,
 
179
  )
180
 
181
  torch.cuda.synchronize()
 
184
  cleanup_memory()
185
 
186
  decoded_video = vae_decode_video(
187
+ video_state.latent, self.model_ledger.video_decoder(), tiling_config, generator
 
 
 
188
  )
189
+ decoded_audio = vae_decode_audio(
190
+ audio_state.latent, self.model_ledger.audio_decoder(), self.model_ledger.vocoder()
 
191
  )
192
+ return decoded_video, decoded_audio
193
 
194
 
195
  # Model repos
 
527
  def get_gpu_duration(
528
  first_image,
529
  last_image,
 
530
  prompt: str,
531
  duration: float,
532
  gpu_duration: float,
 
558
  def generate_video(
559
  first_image,
560
  last_image,
 
561
  prompt: str,
562
  duration: float,
563
  gpu_duration: float,
 
629
  num_frames=num_frames,
630
  frame_rate=frame_rate,
631
  images=images,
 
632
  tiling_config=tiling_config,
633
  enhance_prompt=enhance_prompt,
634
  )
 
663
  with gr.Row():
664
  first_image = gr.Image(label="First Frame (Optional)", type="pil")
665
  last_image = gr.Image(label="Last Frame (Optional)", type="pil")
 
666
  prompt = gr.Textbox(
667
  label="Prompt",
668
  info="for best results - make it as elaborate as possible",
 
764
  [
765
  None,
766
  "pinkknit.jpg",
 
767
  "The camera falls downward through darkness as if dropped into a tunnel. "
768
  "As it slows, five friends wearing pink knitted hats and sunglasses lean "
769
  "over and look down toward the camera with curious expressions. The lens "
 
794
  ],
795
  ],
796
  inputs=[
797
+ first_image, last_image, prompt, duration, gpu_duration,
798
  enhance_prompt, seed, randomize_seed, height, width,
799
  pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength,
800
  ],
 
827
  generate_btn.click(
828
  fn=generate_video,
829
  inputs=[
830
+ first_image, last_image, prompt, duration, gpu_duration, enhance_prompt,
831
  seed, randomize_seed, height, width,
832
  pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength,
833
  ],