tsi-org commited on
Commit
a30355f
Β·
verified Β·
1 Parent(s): 13cd217

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -67
app.py CHANGED
@@ -148,60 +148,57 @@ APP_STATE = {
148
  "current_vae_decoder": None,
149
  }
150
 
151
- # ONLY ADDITION: Store frames for download
152
  DOWNLOAD_FRAMES = []
153
 
154
- def frames_to_ts_file(frames, filepath, fps = 15):
155
  """
156
- Convert frames directly to .ts file using PyAV.
157
-
158
- Args:
159
- frames: List of numpy arrays (HWC, RGB, uint8)
160
- filepath: Output file path
161
- fps: Frames per second
162
-
163
- Returns:
164
- The filepath of the created file
165
  """
166
  if not frames:
167
  return filepath
168
 
169
- height, width = frames[0].shape[:2]
170
-
171
- # Create container for MPEG-TS format
172
- container = av.open(filepath, mode='w', format='mpegts')
173
-
174
- # Add video stream with optimized settings for streaming
175
- stream = container.add_stream('h264', rate=fps)
176
- stream.width = width
177
- stream.height = height
178
- stream.pix_fmt = 'yuv420p'
179
-
180
- # Optimize for low latency streaming
181
- stream.options = {
182
- 'preset': 'ultrafast',
183
- 'tune': 'zerolatency',
184
- 'crf': '23',
185
- 'profile': 'baseline',
186
- 'level': '3.0'
187
- }
188
-
189
  try:
190
- for frame_np in frames:
191
- frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
192
- frame = frame.reformat(format=stream.pix_fmt)
193
- for packet in stream.encode(frame):
194
- container.mux(packet)
195
 
196
- for packet in stream.encode():
197
- container.mux(packet)
 
 
 
 
 
 
198
 
199
- finally:
200
- container.close()
201
-
202
- return filepath
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
- # ONLY ADDITION: Download function
205
  def create_download_mp4():
206
  global DOWNLOAD_FRAMES
207
  if not DOWNLOAD_FRAMES:
@@ -280,16 +277,15 @@ pipeline.to(dtype=torch.float16).to(gpu)
280
  @spaces.GPU
281
  def video_generation_handler_streaming(prompt, seed=42, fps=15):
282
  """
283
- Generator function that yields .ts video chunks using PyAV for streaming.
284
- Now optimized for block-based processing.
285
  """
286
  global DOWNLOAD_FRAMES
287
- DOWNLOAD_FRAMES = [] # ONLY ADDITION: Reset frames
288
 
289
  if seed == -1:
290
  seed = random.randint(0, 2**32 - 1)
291
 
292
- print(f"🎬 Starting PyAV streaming: '{prompt}', seed: {seed}")
293
 
294
  # Setup
295
  conditional_dict = text_encoder(text_prompts=[prompt])
@@ -376,15 +372,13 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
376
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
377
 
378
  all_frames_from_block.append(frame_np)
379
- DOWNLOAD_FRAMES.append(frame_np) # ONLY ADDITION: Store for download
380
  total_frames_yielded += 1
381
 
382
- # Yield status update for each frame (cute tracking!)
383
  blocks_completed = idx
384
  current_block_progress = (frame_idx + 1) / pixels.shape[1]
385
  total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
386
-
387
- # Cap at 100% to avoid going over
388
  total_progress = min(total_progress, 100.0)
389
 
390
  frame_status_html = (
@@ -399,25 +393,21 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
399
  f"</div>"
400
  )
401
 
402
- # Yield None for video but update status (frame-by-frame tracking)
403
  yield None, frame_status_html
404
 
405
- # Encode entire block as one chunk immediately
406
  if all_frames_from_block:
407
  print(f"πŸ“Ή Encoding block {idx} with {len(all_frames_from_block)} frames")
408
 
409
  try:
410
  chunk_uuid = str(uuid.uuid4())[:8]
411
- ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
412
- ts_path = os.path.join("gradio_tmp", ts_filename)
413
-
414
- frames_to_ts_file(all_frames_from_block, ts_path, fps)
415
 
416
- # Calculate final progress for this block
417
- total_progress = (idx + 1) / num_blocks * 100
418
 
419
- # Yield the actual video chunk
420
- yield ts_path, gr.update()
421
 
422
  except Exception as e:
423
  print(f"⚠️ Error encoding block {idx}: {e}")
@@ -438,13 +428,13 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
438
  f" πŸ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks"
439
  f" </p>"
440
  f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
441
- f" 🎬 Playback: {fps} FPS β€’ πŸ“ Format: MPEG-TS/H.264 β€’ πŸ“₯ Download ready!"
442
  f" </p>"
443
  f" </div>"
444
  f"</div>"
445
  )
446
  yield None, final_status_html
447
- print(f"βœ… PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
448
 
449
  # --- Gradio UI Layout ---
450
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
@@ -514,7 +504,7 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
514
  label="Generation Status"
515
  )
516
 
517
- # ONLY ADDITION: Download button
518
  download_btn = gr.DownloadButton(
519
  label="πŸ“₯ Download MP4",
520
  value=create_download_mp4,
@@ -540,12 +530,12 @@ if __name__ == "__main__":
540
  import shutil
541
  shutil.rmtree("gradio_tmp")
542
  os.makedirs("gradio_tmp", exist_ok=True)
543
- os.makedirs("downloads", exist_ok=True) # ONLY ADDITION
544
 
545
  print("πŸš€ Starting Self-Forcing Streaming Demo")
546
  print(f"πŸ“ Temporary files will be stored in: gradio_tmp/")
547
- print(f"πŸ“₯ Download files will be stored in: downloads/") # ONLY ADDITION
548
- print(f"🎯 Chunk encoding: PyAV (MPEG-TS/H.264)")
549
  print(f"⚑ GPU acceleration: {gpu}")
550
 
551
  demo.queue().launch(
 
148
  "current_vae_decoder": None,
149
  }
150
 
151
+ # Store frames for download
152
  DOWNLOAD_FRAMES = []
153
 
154
+ def frames_to_mp4_chunk(frames, filepath, fps=15):
155
  """
156
+ Convert frames to MP4 chunk using imageio (more compatible than .ts for Gradio streaming)
 
 
 
 
 
 
 
 
157
  """
158
  if not frames:
159
  return filepath
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  try:
162
+ # Use imageio to create MP4 chunk
163
+ with imageio.get_writer(filepath, fps=fps, codec='libx264', quality=6) as writer:
164
+ for frame_np in frames:
165
+ writer.append_data(frame_np)
 
166
 
167
+ return filepath
168
+
169
+ except Exception as e:
170
+ print(f"❌ Error creating MP4 chunk: {e}")
171
+ # Fallback to PyAV if imageio fails
172
+ try:
173
+ height, width = frames[0].shape[:2]
174
+ container = av.open(filepath, mode='w', format='mp4')
175
 
176
+ stream = container.add_stream('h264', rate=fps)
177
+ stream.width = width
178
+ stream.height = height
179
+ stream.pix_fmt = 'yuv420p'
180
+ stream.options = {
181
+ 'preset': 'ultrafast',
182
+ 'tune': 'zerolatency',
183
+ 'crf': '28'
184
+ }
185
+
186
+ for frame_np in frames:
187
+ frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
188
+ frame = frame.reformat(format=stream.pix_fmt)
189
+ for packet in stream.encode(frame):
190
+ container.mux(packet)
191
+
192
+ for packet in stream.encode():
193
+ container.mux(packet)
194
+
195
+ container.close()
196
+ return filepath
197
+
198
+ except Exception as e2:
199
+ print(f"❌ Both imageio and PyAV failed: {e2}")
200
+ return filepath
201
 
 
202
  def create_download_mp4():
203
  global DOWNLOAD_FRAMES
204
  if not DOWNLOAD_FRAMES:
 
277
  @spaces.GPU
278
  def video_generation_handler_streaming(prompt, seed=42, fps=15):
279
  """
280
+ Generator function that yields MP4 video chunks for streaming.
 
281
  """
282
  global DOWNLOAD_FRAMES
283
+ DOWNLOAD_FRAMES = [] # Reset frames
284
 
285
  if seed == -1:
286
  seed = random.randint(0, 2**32 - 1)
287
 
288
+ print(f"🎬 Starting MP4 streaming: '{prompt}', seed: {seed}")
289
 
290
  # Setup
291
  conditional_dict = text_encoder(text_prompts=[prompt])
 
372
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
373
 
374
  all_frames_from_block.append(frame_np)
375
+ DOWNLOAD_FRAMES.append(frame_np) # Store for download
376
  total_frames_yielded += 1
377
 
378
+ # Yield status update for each frame
379
  blocks_completed = idx
380
  current_block_progress = (frame_idx + 1) / pixels.shape[1]
381
  total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
 
 
382
  total_progress = min(total_progress, 100.0)
383
 
384
  frame_status_html = (
 
393
  f"</div>"
394
  )
395
 
 
396
  yield None, frame_status_html
397
 
398
+ # Create MP4 chunk for this block
399
  if all_frames_from_block:
400
  print(f"πŸ“Ή Encoding block {idx} with {len(all_frames_from_block)} frames")
401
 
402
  try:
403
  chunk_uuid = str(uuid.uuid4())[:8]
404
+ mp4_filename = f"block_{idx:04d}_{chunk_uuid}.mp4"
405
+ mp4_path = os.path.join("gradio_tmp", mp4_filename)
 
 
406
 
407
+ frames_to_mp4_chunk(all_frames_from_block, mp4_path, fps)
 
408
 
409
+ # Yield the MP4 chunk
410
+ yield mp4_path, gr.update()
411
 
412
  except Exception as e:
413
  print(f"⚠️ Error encoding block {idx}: {e}")
 
428
  f" πŸ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks"
429
  f" </p>"
430
  f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
431
+ f" 🎬 Playback: {fps} FPS β€’ πŸ“ Format: MP4/H.264 β€’ πŸ“₯ Download ready!"
432
  f" </p>"
433
  f" </div>"
434
  f"</div>"
435
  )
436
  yield None, final_status_html
437
+ print(f"βœ… MP4 streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
438
 
439
  # --- Gradio UI Layout ---
440
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
 
504
  label="Generation Status"
505
  )
506
 
507
+ # Download button
508
  download_btn = gr.DownloadButton(
509
  label="πŸ“₯ Download MP4",
510
  value=create_download_mp4,
 
530
  import shutil
531
  shutil.rmtree("gradio_tmp")
532
  os.makedirs("gradio_tmp", exist_ok=True)
533
+ os.makedirs("downloads", exist_ok=True)
534
 
535
  print("πŸš€ Starting Self-Forcing Streaming Demo")
536
  print(f"πŸ“ Temporary files will be stored in: gradio_tmp/")
537
+ print(f"πŸ“₯ Download files will be stored in: downloads/")
538
+ print(f"🎯 Chunk encoding: MP4/H.264 (more compatible)")
539
  print(f"⚑ GPU acceleration: {gpu}")
540
 
541
  demo.queue().launch(