linoyts HF Staff commited on
Commit
aabdf9e
Β·
verified Β·
1 Parent(s): fc8f41d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -9
app.py CHANGED
@@ -9,6 +9,9 @@ os.environ["TORCHDYNAMO_DISABLE"] = "1"
9
  # Install xformers for memory-efficient attention
10
  subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
11
 
 
 
 
12
  # Clone LTX-2 repo and install packages
13
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
14
  LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
@@ -91,6 +94,152 @@ except Exception as e:
91
  logging.getLogger().setLevel(logging.INFO)
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # ─────────────────────────────────────────────────────────────────────────────
95
  # Helper: read reference downscale factor from IC-LoRA metadata
96
  # ─────────────────────────────────────────────────────────────────────────────
@@ -547,7 +696,7 @@ pipeline = LTX23UnifiedPipeline(
547
  distilled_checkpoint_path=checkpoint_path,
548
  spatial_upsampler_path=spatial_upsampler_path,
549
  gemma_root=gemma_root,
550
- # ic_loras=ic_loras,
551
  quantization=QuantizationPolicy.fp8_cast(),
552
  )
553
 
@@ -680,6 +829,31 @@ def on_highres_toggle(image, video, high_res):
680
  # ─────────────────────────────────────────────────────────────────────────────
681
  # Generation
682
  # ─────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  @spaces.GPU(duration=180)
684
  @torch.inference_mode()
685
  def generate_video(
@@ -689,6 +863,7 @@ def generate_video(
689
  prompt: str,
690
  duration: float,
691
  conditioning_strength: float,
 
692
  enhance_prompt: bool,
693
  seed: int,
694
  randomize_seed: bool,
@@ -708,7 +883,7 @@ def generate_video(
708
  if input_image is not None:
709
  mode_parts.append("Image")
710
  if input_video is not None:
711
- mode_parts.append("Video(IC-LoRA)")
712
  if input_audio is not None:
713
  mode_parts.append("Audio")
714
  if not mode_parts:
@@ -723,10 +898,40 @@ def generate_video(
723
  if input_image is not None:
724
  images = [ImageConditioningInput(path=str(input_image), frame_idx=0, strength=1.0)]
725
 
726
- # Build video conditionings for IC-LoRA / V2V
727
  video_conditioning = None
728
  if input_video is not None:
729
- video_conditioning = [(str(input_video), 1.0)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
 
731
  tiling_config = TilingConfig.default()
732
  video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
@@ -783,14 +988,31 @@ with gr.Blocks(title="LTX-2.3 Unified: V2V + I2V + A2V") as demo:
783
  label="πŸ–ΌοΈ Input Image (I2V β€” first frame)",
784
  type="filepath",
785
  )
786
- input_video = gr.Video(
787
- label="🎬 Reference Video (V2V β€” IC-LoRA)",
788
- sources=["upload"],
789
- )
 
 
 
 
 
 
 
 
 
 
 
 
790
  input_audio = gr.Audio(
791
  label="πŸ”Š Input Audio (A2V β€” lipsync / BGM)",
792
  type="filepath",
793
  )
 
 
 
 
 
794
 
795
  prompt = gr.Textbox(
796
  label="Prompt",
@@ -849,7 +1071,7 @@ with gr.Blocks(title="LTX-2.3 Unified: V2V + I2V + A2V") as demo:
849
  fn=generate_video,
850
  inputs=[
851
  input_image, input_video, input_audio, prompt, duration,
852
- conditioning_strength, enhance_prompt,
853
  seed, randomize_seed, height, width,
854
  ],
855
  outputs=[output_video, seed],
 
9
  # Install xformers for memory-efficient attention
10
  subprocess.run([sys.executable, "-m", "pip", "install", "xformers==0.0.32.post2", "--no-build-isolation"], check=False)
11
 
12
+ # Install video preprocessing dependencies (pose/canny/depth extraction)
13
+ subprocess.run([sys.executable, "-m", "pip", "install", "controlnet_aux", "imageio[ffmpeg]"], check=False)
14
+
15
  # Clone LTX-2 repo and install packages
16
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
17
  LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
 
94
  logging.getLogger().setLevel(logging.INFO)
95
 
96
 
97
+ # ─────────────────────────────────────────────────────────────────────────────
98
+ # Video Preprocessing: Strip appearance, keep structure
99
+ # ─────────────────────────────────────────────────────────────────────────────
100
+ import imageio
101
+ import cv2
102
+ from PIL import Image
103
+
104
+ # Lazy-loaded processors (heavy models, only init when needed)
105
+ _pose_processor = None
106
+ _depth_processor = None
107
+
108
+
109
+ def _get_pose_processor():
110
+ global _pose_processor
111
+ if _pose_processor is None:
112
+ from controlnet_aux import DWposeDetector
113
+ _pose_processor = DWposeDetector.from_pretrained_default()
114
+ print("[Preprocess] DWPose processor loaded")
115
+ return _pose_processor
116
+
117
+
118
+ def _get_depth_processor():
119
+ global _depth_processor
120
+ if _depth_processor is None:
121
+ from controlnet_aux import MidasDetector
122
+ _depth_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
123
+ print("[Preprocess] MiDaS depth processor loaded")
124
+ return _depth_processor
125
+
126
+
127
+ def load_video_frames(video_path: str) -> list[np.ndarray]:
128
+ """Load video frames as list of HWC uint8 numpy arrays."""
129
+ frames = []
130
+ with imageio.get_reader(video_path) as reader:
131
+ for frame in reader:
132
+ frames.append(frame)
133
+ return frames
134
+
135
+
136
+ def write_video_mp4(frames_float_01: list[np.ndarray], fps: float, out_path: str) -> str:
137
+ """Write float [0,1] frames to mp4."""
138
+ frames_uint8 = [(f * 255).astype(np.uint8) for f in frames_float_01]
139
+ with imageio.get_writer(out_path, fps=fps, macro_block_size=1) as writer:
140
+ for fr in frames_uint8:
141
+ writer.append_data(fr)
142
+ return out_path
143
+
144
+
145
+ def extract_first_frame(video_path: str) -> str:
146
+ """Extract first frame as a temp PNG file, return path."""
147
+ frames = load_video_frames(video_path)
148
+ if not frames:
149
+ raise ValueError("No frames in video")
150
+ out_path = tempfile.mktemp(suffix=".png")
151
+ Image.fromarray(frames[0]).save(out_path)
152
+ return out_path
153
+
154
+
155
+ def preprocess_video_pose(frames: list[np.ndarray], width: int, height: int) -> list[np.ndarray]:
156
+ """Extract DWPose skeletons from each frame. Returns float [0,1] frames."""
157
+ processor = _get_pose_processor()
158
+ result = []
159
+ for frame in frames:
160
+ pil = Image.fromarray(frame.astype(np.uint8)).convert("RGB")
161
+ pose_img = processor(pil, include_body=True, include_hand=True, include_face=True)
162
+ if not isinstance(pose_img, Image.Image):
163
+ pose_img = Image.fromarray(pose_img.astype(np.uint8))
164
+ pose_img = pose_img.convert("RGB").resize((width, height), Image.BILINEAR)
165
+ result.append(np.array(pose_img).astype(np.float32) / 255.0)
166
+ return result
167
+
168
+
169
+ def preprocess_video_canny(frames: list[np.ndarray], width: int, height: int,
170
+ low_threshold: int = 50, high_threshold: int = 100) -> list[np.ndarray]:
171
+ """Extract Canny edges from each frame. Returns float [0,1] frames."""
172
+ result = []
173
+ for frame in frames:
174
+ # Resize first
175
+ resized = cv2.resize(frame, (width, height), interpolation=cv2.INTER_AREA)
176
+ gray = cv2.cvtColor(resized, cv2.COLOR_RGB2GRAY)
177
+ edges = cv2.Canny(gray, low_threshold, high_threshold)
178
+ # Convert single-channel to 3-channel
179
+ edges_3ch = np.stack([edges, edges, edges], axis=-1)
180
+ result.append(edges_3ch.astype(np.float32) / 255.0)
181
+ return result
182
+
183
+
184
+ def preprocess_video_depth(frames: list[np.ndarray], width: int, height: int) -> list[np.ndarray]:
185
+ """Extract MiDaS depth maps from each frame. Returns float [0,1] frames."""
186
+ processor = _get_depth_processor()
187
+ detect_res = max(frames[0].shape[0], frames[0].shape[1])
188
+ image_res = max(width, height)
189
+ result = []
190
+ for frame in frames:
191
+ depth = processor(frame, detect_resolution=detect_res,
192
+ image_resolution=image_res, output_type="np")
193
+ if depth.ndim == 2:
194
+ depth = np.stack([depth, depth, depth], axis=-1)
195
+ elif depth.shape[-1] == 1:
196
+ depth = np.repeat(depth, 3, axis=-1)
197
+ result.append(depth)
198
+ return result
199
+
200
+
201
+ def preprocess_conditioning_video(
202
+ video_path: str,
203
+ mode: str,
204
+ width: int,
205
+ height: int,
206
+ num_frames: int,
207
+ fps: float,
208
+ ) -> tuple[str, str]:
209
+ """
210
+ Preprocess a video for conditioning. Strips appearance, keeps structure.
211
+
212
+ Returns:
213
+ (conditioning_mp4_path, first_frame_png_path)
214
+ """
215
+ frames = load_video_frames(video_path)
216
+ if not frames:
217
+ raise ValueError("No frames decoded from video")
218
+
219
+ # Trim to num_frames
220
+ frames = frames[:num_frames]
221
+
222
+ # Save first frame (original appearance) for image conditioning
223
+ first_png = tempfile.mktemp(suffix=".png")
224
+ Image.fromarray(frames[0]).save(first_png)
225
+
226
+ # Process based on mode
227
+ if mode == "Pose (DWPose)":
228
+ processed = preprocess_video_pose(frames, width, height)
229
+ elif mode == "Canny Edge":
230
+ processed = preprocess_video_canny(frames, width, height)
231
+ elif mode == "Depth (MiDaS)":
232
+ processed = preprocess_video_depth(frames, width, height)
233
+ else:
234
+ # "Raw" mode β€” no preprocessing
235
+ processed = [f.astype(np.float32) / 255.0 for f in frames]
236
+
237
+ cond_mp4 = tempfile.mktemp(suffix=".mp4")
238
+ write_video_mp4(processed, fps=fps, out_path=cond_mp4)
239
+
240
+ return cond_mp4, first_png
241
+
242
+
243
  # ─────────────────────────────────────────────────────────────────────────────
244
  # Helper: read reference downscale factor from IC-LoRA metadata
245
  # ─────────────────────────────────────────────────────────────────────────────
 
696
  distilled_checkpoint_path=checkpoint_path,
697
  spatial_upsampler_path=spatial_upsampler_path,
698
  gemma_root=gemma_root,
699
+ ic_loras=ic_loras,
700
  quantization=QuantizationPolicy.fp8_cast(),
701
  )
702
 
 
829
  # ─────────────────────────────────────────────────────────────────────────────
830
  # Generation
831
  # ─────────────────────────────────────────────────────────────────────────────
832
+ def _extract_audio_from_video(video_path: str) -> str | None:
833
+ """Extract audio from video as a temp WAV file. Returns None if no audio."""
834
+ out_path = tempfile.mktemp(suffix=".wav")
835
+ try:
836
+ # Check if video has an audio stream
837
+ probe = subprocess.run(
838
+ ["ffprobe", "-v", "error", "-select_streams", "a:0",
839
+ "-show_entries", "stream=codec_type", "-of", "default=nw=1:nk=1",
840
+ video_path],
841
+ capture_output=True, text=True,
842
+ )
843
+ if not probe.stdout.strip():
844
+ return None
845
+
846
+ # Extract audio
847
+ subprocess.run(
848
+ ["ffmpeg", "-y", "-v", "error", "-i", video_path,
849
+ "-vn", "-ac", "1", "-ar", "48000", "-c:a", "pcm_s16le", out_path],
850
+ check=True,
851
+ )
852
+ return out_path
853
+ except (subprocess.CalledProcessError, FileNotFoundError):
854
+ return None
855
+
856
+
857
  @spaces.GPU(duration=180)
858
  @torch.inference_mode()
859
  def generate_video(
 
863
  prompt: str,
864
  duration: float,
865
  conditioning_strength: float,
866
+ video_preprocess: str,
867
  enhance_prompt: bool,
868
  seed: int,
869
  randomize_seed: bool,
 
883
  if input_image is not None:
884
  mode_parts.append("Image")
885
  if input_video is not None:
886
+ mode_parts.append(f"Video({video_preprocess})")
887
  if input_audio is not None:
888
  mode_parts.append("Audio")
889
  if not mode_parts:
 
898
  if input_image is not None:
899
  images = [ImageConditioningInput(path=str(input_image), frame_idx=0, strength=1.0)]
900
 
901
+ # Build video conditionings β€” preprocess to strip appearance
902
  video_conditioning = None
903
  if input_video is not None:
904
+ video_path = str(input_video)
905
+
906
+ if video_preprocess != "Raw (no preprocessing)":
907
+ print(f"[Preprocess] Running {video_preprocess} on input video...")
908
+ cond_mp4, first_frame_png = preprocess_conditioning_video(
909
+ video_path=video_path,
910
+ mode=video_preprocess,
911
+ width=int(width) // 2, # Stage 1 operates at half res
912
+ height=int(height) // 2,
913
+ num_frames=num_frames,
914
+ fps=frame_rate,
915
+ )
916
+ video_conditioning = [(cond_mp4, 1.0)]
917
+
918
+ # If no image was provided, use the video's first frame
919
+ # (original appearance) as the image conditioning
920
+ if input_image is None:
921
+ images = [ImageConditioningInput(
922
+ path=first_frame_png, frame_idx=0, strength=1.0,
923
+ )]
924
+ print(f"[Preprocess] Using video first frame as image conditioning")
925
+ else:
926
+ # Raw mode β€” pass video as-is
927
+ video_conditioning = [(video_path, 1.0)]
928
+
929
+ # If no audio was provided, try to extract audio from the video
930
+ if input_audio is None:
931
+ extracted_audio = _extract_audio_from_video(video_path)
932
+ if extracted_audio is not None:
933
+ input_audio = extracted_audio
934
+ print(f"[Preprocess] Extracted audio from input video")
935
 
936
  tiling_config = TilingConfig.default()
937
  video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
 
988
  label="πŸ–ΌοΈ Input Image (I2V β€” first frame)",
989
  type="filepath",
990
  )
991
+ with gr.Column():
992
+ input_video = gr.Video(
993
+ label="🎬 Reference Video (V2V)",
994
+ sources=["upload"],
995
+ )
996
+ video_preprocess = gr.Dropdown(
997
+ label="Video Preprocessing",
998
+ choices=[
999
+ "Pose (DWPose)",
1000
+ "Canny Edge",
1001
+ "Depth (MiDaS)",
1002
+ "Raw (no preprocessing)",
1003
+ ],
1004
+ value="Pose (DWPose)",
1005
+ info="Strips appearance from video β†’ style comes from image/prompt instead",
1006
+ )
1007
  input_audio = gr.Audio(
1008
  label="πŸ”Š Input Audio (A2V β€” lipsync / BGM)",
1009
  type="filepath",
1010
  )
1011
+ gr.Markdown(
1012
+ "*When a video is uploaded: its first frame auto-becomes the image input "
1013
+ "(if none provided), and its audio track auto-becomes the audio input "
1014
+ "(if none provided).*"
1015
+ )
1016
 
1017
  prompt = gr.Textbox(
1018
  label="Prompt",
 
1071
  fn=generate_video,
1072
  inputs=[
1073
  input_image, input_video, input_audio, prompt, duration,
1074
+ conditioning_strength, video_preprocess, enhance_prompt,
1075
  seed, randomize_seed, height, width,
1076
  ],
1077
  outputs=[output_video, seed],