Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -148,7 +148,7 @@ APP_STATE = {
|
|
| 148 |
"current_vae_decoder": None,
|
| 149 |
}
|
| 150 |
|
| 151 |
-
#
|
| 152 |
DOWNLOAD_FRAMES = []
|
| 153 |
|
| 154 |
def frames_to_ts_file(frames, filepath, fps = 15):
|
|
@@ -201,59 +201,22 @@ def frames_to_ts_file(frames, filepath, fps = 15):
|
|
| 201 |
|
| 202 |
return filepath
|
| 203 |
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
Create HLS playlist (.m3u8) file for streaming.
|
| 207 |
-
"""
|
| 208 |
-
playlist_path = os.path.join(playlist_dir, "playlist.m3u8")
|
| 209 |
-
segment_duration = 2.0 # Each segment duration in seconds
|
| 210 |
-
|
| 211 |
-
playlist_content = [
|
| 212 |
-
"#EXTM3U",
|
| 213 |
-
"#EXT-X-VERSION:3",
|
| 214 |
-
f"#EXT-X-TARGETDURATION:{int(segment_duration) + 1}",
|
| 215 |
-
"#EXT-X-MEDIA-SEQUENCE:0",
|
| 216 |
-
"#EXT-X-PLAYLIST-TYPE:VOD"
|
| 217 |
-
]
|
| 218 |
-
|
| 219 |
-
for ts_file in ts_files:
|
| 220 |
-
ts_filename = os.path.basename(ts_file)
|
| 221 |
-
playlist_content.extend([
|
| 222 |
-
f"#EXTINF:{segment_duration:.1f},",
|
| 223 |
-
ts_filename
|
| 224 |
-
])
|
| 225 |
-
|
| 226 |
-
playlist_content.append("#EXT-X-ENDLIST")
|
| 227 |
-
|
| 228 |
-
with open(playlist_path, 'w') as f:
|
| 229 |
-
f.write('\n'.join(playlist_content))
|
| 230 |
-
|
| 231 |
-
return playlist_path
|
| 232 |
-
|
| 233 |
-
def create_mp4_download():
|
| 234 |
-
"""Create MP4 file from stored frames for download."""
|
| 235 |
global DOWNLOAD_FRAMES
|
| 236 |
-
|
| 237 |
if not DOWNLOAD_FRAMES:
|
| 238 |
return None
|
| 239 |
-
|
| 240 |
try:
|
| 241 |
os.makedirs("downloads", exist_ok=True)
|
| 242 |
-
|
| 243 |
timestamp = int(time.time())
|
| 244 |
-
|
| 245 |
-
mp4_path = os.path.join("downloads", mp4_filename)
|
| 246 |
-
|
| 247 |
-
# Use imageio to create MP4
|
| 248 |
with imageio.get_writer(mp4_path, fps=args.fps, codec='libx264', quality=8) as writer:
|
| 249 |
for frame in DOWNLOAD_FRAMES:
|
| 250 |
writer.append_data(frame)
|
| 251 |
-
|
| 252 |
-
print(f"β
MP4 created for download: {mp4_path}")
|
| 253 |
return mp4_path
|
| 254 |
-
|
| 255 |
except Exception as e:
|
| 256 |
-
print(f"β
|
| 257 |
return None
|
| 258 |
|
| 259 |
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
|
@@ -317,20 +280,16 @@ pipeline.to(dtype=torch.float16).to(gpu)
|
|
| 317 |
@spaces.GPU
|
| 318 |
def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
| 319 |
"""
|
| 320 |
-
Generator function that
|
|
|
|
| 321 |
"""
|
| 322 |
global DOWNLOAD_FRAMES
|
| 323 |
-
DOWNLOAD_FRAMES = [] # Reset frames
|
| 324 |
|
| 325 |
if seed == -1:
|
| 326 |
seed = random.randint(0, 2**32 - 1)
|
| 327 |
|
| 328 |
-
print(f"π¬ Starting
|
| 329 |
-
|
| 330 |
-
# Create unique session directory for HLS files
|
| 331 |
-
session_id = str(uuid.uuid4())[:8]
|
| 332 |
-
session_dir = os.path.join("gradio_tmp", f"session_{session_id}")
|
| 333 |
-
os.makedirs(session_dir, exist_ok=True)
|
| 334 |
|
| 335 |
# Setup
|
| 336 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
|
@@ -351,7 +310,9 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 351 |
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
| 352 |
|
| 353 |
total_frames_yielded = 0
|
| 354 |
-
|
|
|
|
|
|
|
| 355 |
|
| 356 |
# Generation loop
|
| 357 |
for idx, current_num_frames in enumerate(all_num_frames):
|
|
@@ -415,7 +376,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 415 |
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
| 416 |
|
| 417 |
all_frames_from_block.append(frame_np)
|
| 418 |
-
DOWNLOAD_FRAMES.append(frame_np) # Store for download
|
| 419 |
total_frames_yielded += 1
|
| 420 |
|
| 421 |
# Yield status update for each frame (cute tracking!)
|
|
@@ -441,25 +402,25 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 441 |
# Yield None for video but update status (frame-by-frame tracking)
|
| 442 |
yield None, frame_status_html
|
| 443 |
|
| 444 |
-
#
|
| 445 |
if all_frames_from_block:
|
| 446 |
print(f"πΉ Encoding block {idx} with {len(all_frames_from_block)} frames")
|
| 447 |
|
| 448 |
try:
|
| 449 |
-
|
| 450 |
-
|
|
|
|
| 451 |
|
| 452 |
frames_to_ts_file(all_frames_from_block, ts_path, fps)
|
| 453 |
-
ts_files.append(ts_path)
|
| 454 |
|
| 455 |
-
#
|
| 456 |
-
|
| 457 |
|
| 458 |
-
# Yield the
|
| 459 |
-
yield
|
| 460 |
|
| 461 |
except Exception as e:
|
| 462 |
-
print(f"β οΈ Error
|
| 463 |
import traceback
|
| 464 |
traceback.print_exc()
|
| 465 |
|
|
@@ -477,13 +438,13 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 477 |
f" π Generated {total_frames_yielded} frames across {num_blocks} blocks"
|
| 478 |
f" </p>"
|
| 479 |
f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
|
| 480 |
-
f" π¬ Playback: {fps} FPS β’ π Format:
|
| 481 |
f" </p>"
|
| 482 |
f" </div>"
|
| 483 |
f"</div>"
|
| 484 |
)
|
| 485 |
yield None, final_status_html
|
| 486 |
-
print(f"β
|
| 487 |
|
| 488 |
# --- Gradio UI Layout ---
|
| 489 |
with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
@@ -553,10 +514,10 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
| 553 |
label="Generation Status"
|
| 554 |
)
|
| 555 |
|
| 556 |
-
# Download button
|
| 557 |
download_btn = gr.DownloadButton(
|
| 558 |
label="π₯ Download MP4",
|
| 559 |
-
value=
|
| 560 |
variant="secondary"
|
| 561 |
)
|
| 562 |
|
|
@@ -579,13 +540,13 @@ if __name__ == "__main__":
|
|
| 579 |
import shutil
|
| 580 |
shutil.rmtree("gradio_tmp")
|
| 581 |
os.makedirs("gradio_tmp", exist_ok=True)
|
| 582 |
-
os.makedirs("downloads", exist_ok=True)
|
| 583 |
|
| 584 |
print("π Starting Self-Forcing Streaming Demo")
|
| 585 |
print(f"π Temporary files will be stored in: gradio_tmp/")
|
| 586 |
-
print(f"π₯ Download files will be stored in: downloads/")
|
| 587 |
-
print(f"π―
|
| 588 |
-
print(f"
|
| 589 |
|
| 590 |
demo.queue().launch(
|
| 591 |
server_name=args.host,
|
|
|
|
| 148 |
"current_vae_decoder": None,
|
| 149 |
}
|
| 150 |
|
| 151 |
+
# ONLY ADDITION: Store frames for download
|
| 152 |
DOWNLOAD_FRAMES = []
|
| 153 |
|
| 154 |
def frames_to_ts_file(frames, filepath, fps = 15):
|
|
|
|
| 201 |
|
| 202 |
return filepath
|
| 203 |
|
| 204 |
+
# ONLY ADDITION: Download function
|
| 205 |
+
def create_download_mp4():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
global DOWNLOAD_FRAMES
|
|
|
|
| 207 |
if not DOWNLOAD_FRAMES:
|
| 208 |
return None
|
|
|
|
| 209 |
try:
|
| 210 |
os.makedirs("downloads", exist_ok=True)
|
|
|
|
| 211 |
timestamp = int(time.time())
|
| 212 |
+
mp4_path = f"downloads/video_{timestamp}.mp4"
|
|
|
|
|
|
|
|
|
|
| 213 |
with imageio.get_writer(mp4_path, fps=args.fps, codec='libx264', quality=8) as writer:
|
| 214 |
for frame in DOWNLOAD_FRAMES:
|
| 215 |
writer.append_data(frame)
|
| 216 |
+
print(f"β
Download MP4 created: {mp4_path}")
|
|
|
|
| 217 |
return mp4_path
|
|
|
|
| 218 |
except Exception as e:
|
| 219 |
+
print(f"β Download error: {e}")
|
| 220 |
return None
|
| 221 |
|
| 222 |
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
|
|
|
| 280 |
@spaces.GPU
|
| 281 |
def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
| 282 |
"""
|
| 283 |
+
Generator function that yields .ts video chunks using PyAV for streaming.
|
| 284 |
+
Now optimized for block-based processing.
|
| 285 |
"""
|
| 286 |
global DOWNLOAD_FRAMES
|
| 287 |
+
DOWNLOAD_FRAMES = [] # ONLY ADDITION: Reset frames
|
| 288 |
|
| 289 |
if seed == -1:
|
| 290 |
seed = random.randint(0, 2**32 - 1)
|
| 291 |
|
| 292 |
+
print(f"π¬ Starting PyAV streaming: '{prompt}', seed: {seed}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
# Setup
|
| 295 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
|
|
|
| 310 |
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
| 311 |
|
| 312 |
total_frames_yielded = 0
|
| 313 |
+
|
| 314 |
+
# Ensure temp directory exists
|
| 315 |
+
os.makedirs("gradio_tmp", exist_ok=True)
|
| 316 |
|
| 317 |
# Generation loop
|
| 318 |
for idx, current_num_frames in enumerate(all_num_frames):
|
|
|
|
| 376 |
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
| 377 |
|
| 378 |
all_frames_from_block.append(frame_np)
|
| 379 |
+
DOWNLOAD_FRAMES.append(frame_np) # ONLY ADDITION: Store for download
|
| 380 |
total_frames_yielded += 1
|
| 381 |
|
| 382 |
# Yield status update for each frame (cute tracking!)
|
|
|
|
| 402 |
# Yield None for video but update status (frame-by-frame tracking)
|
| 403 |
yield None, frame_status_html
|
| 404 |
|
| 405 |
+
# Encode entire block as one chunk immediately
|
| 406 |
if all_frames_from_block:
|
| 407 |
print(f"πΉ Encoding block {idx} with {len(all_frames_from_block)} frames")
|
| 408 |
|
| 409 |
try:
|
| 410 |
+
chunk_uuid = str(uuid.uuid4())[:8]
|
| 411 |
+
ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
|
| 412 |
+
ts_path = os.path.join("gradio_tmp", ts_filename)
|
| 413 |
|
| 414 |
frames_to_ts_file(all_frames_from_block, ts_path, fps)
|
|
|
|
| 415 |
|
| 416 |
+
# Calculate final progress for this block
|
| 417 |
+
total_progress = (idx + 1) / num_blocks * 100
|
| 418 |
|
| 419 |
+
# Yield the actual video chunk
|
| 420 |
+
yield ts_path, gr.update()
|
| 421 |
|
| 422 |
except Exception as e:
|
| 423 |
+
print(f"β οΈ Error encoding block {idx}: {e}")
|
| 424 |
import traceback
|
| 425 |
traceback.print_exc()
|
| 426 |
|
|
|
|
| 438 |
f" π Generated {total_frames_yielded} frames across {num_blocks} blocks"
|
| 439 |
f" </p>"
|
| 440 |
f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
|
| 441 |
+
f" π¬ Playback: {fps} FPS β’ π Format: MPEG-TS/H.264 β’ π₯ Download ready!"
|
| 442 |
f" </p>"
|
| 443 |
f" </div>"
|
| 444 |
f"</div>"
|
| 445 |
)
|
| 446 |
yield None, final_status_html
|
| 447 |
+
print(f"β
PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
|
| 448 |
|
| 449 |
# --- Gradio UI Layout ---
|
| 450 |
with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
|
|
| 514 |
label="Generation Status"
|
| 515 |
)
|
| 516 |
|
| 517 |
+
# ONLY ADDITION: Download button
|
| 518 |
download_btn = gr.DownloadButton(
|
| 519 |
label="π₯ Download MP4",
|
| 520 |
+
value=create_download_mp4,
|
| 521 |
variant="secondary"
|
| 522 |
)
|
| 523 |
|
|
|
|
| 540 |
import shutil
|
| 541 |
shutil.rmtree("gradio_tmp")
|
| 542 |
os.makedirs("gradio_tmp", exist_ok=True)
|
| 543 |
+
os.makedirs("downloads", exist_ok=True) # ONLY ADDITION
|
| 544 |
|
| 545 |
print("π Starting Self-Forcing Streaming Demo")
|
| 546 |
print(f"π Temporary files will be stored in: gradio_tmp/")
|
| 547 |
+
print(f"π₯ Download files will be stored in: downloads/") # ONLY ADDITION
|
| 548 |
+
print(f"π― Chunk encoding: PyAV (MPEG-TS/H.264)")
|
| 549 |
+
print(f"β‘ GPU acceleration: {gpu}")
|
| 550 |
|
| 551 |
demo.queue().launch(
|
| 552 |
server_name=args.host,
|