Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -141,26 +141,12 @@ transformer.eval().to(dtype=torch.float16).requires_grad_(False)
|
|
| 141 |
text_encoder.to(gpu)
|
| 142 |
transformer.to(gpu)
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
"fp8_applied": False,
|
| 147 |
-
"current_use_taehv": False,
|
| 148 |
-
"current_vae_decoder": None,
|
| 149 |
-
"last_generated_frames": [], # Store frames for download
|
| 150 |
-
"last_generation_info": {} # Store metadata
|
| 151 |
-
}
|
| 152 |
|
| 153 |
def frames_to_ts_file(frames, filepath, fps = 15):
|
| 154 |
"""
|
| 155 |
Convert frames directly to .ts file using PyAV.
|
| 156 |
-
|
| 157 |
-
Args:
|
| 158 |
-
frames: List of numpy arrays (HWC, RGB, uint8)
|
| 159 |
-
filepath: Output file path
|
| 160 |
-
fps: Frames per second
|
| 161 |
-
|
| 162 |
-
Returns:
|
| 163 |
-
The filepath of the created file
|
| 164 |
"""
|
| 165 |
if not frames:
|
| 166 |
return filepath
|
|
@@ -200,83 +186,52 @@ def frames_to_ts_file(frames, filepath, fps = 15):
|
|
| 200 |
|
| 201 |
return filepath
|
| 202 |
|
| 203 |
-
def
|
| 204 |
"""
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
Args:
|
| 208 |
-
frames: List of numpy arrays (HWC, RGB, uint8)
|
| 209 |
-
filepath: Output file path
|
| 210 |
-
fps: Frames per second
|
| 211 |
-
|
| 212 |
-
Returns:
|
| 213 |
-
The filepath of the created file
|
| 214 |
"""
|
| 215 |
-
|
| 216 |
-
return filepath
|
| 217 |
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
-
#
|
| 224 |
-
stream = container.add_stream('h264', rate=fps)
|
| 225 |
-
stream.width = width
|
| 226 |
-
stream.height = height
|
| 227 |
-
stream.pix_fmt = 'yuv420p'
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
'preset': 'medium',
|
| 232 |
-
'crf': '18', # Higher quality
|
| 233 |
-
'profile': 'high',
|
| 234 |
-
'level': '4.0'
|
| 235 |
-
}
|
| 236 |
|
| 237 |
-
|
| 238 |
-
for frame_np in frames:
|
| 239 |
-
frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
|
| 240 |
-
frame = frame.reformat(format=stream.pix_fmt)
|
| 241 |
-
for packet in stream.encode(frame):
|
| 242 |
-
container.mux(packet)
|
| 243 |
-
|
| 244 |
-
for packet in stream.encode():
|
| 245 |
-
container.mux(packet)
|
| 246 |
-
|
| 247 |
-
finally:
|
| 248 |
-
container.close()
|
| 249 |
-
|
| 250 |
-
return filepath
|
| 251 |
|
| 252 |
-
def
|
| 253 |
"""
|
| 254 |
-
|
| 255 |
"""
|
| 256 |
-
if not
|
| 257 |
return None
|
| 258 |
|
| 259 |
try:
|
| 260 |
-
#
|
| 261 |
-
|
|
|
|
|
|
|
| 262 |
|
| 263 |
-
|
| 264 |
-
timestamp = int(time.time())
|
| 265 |
-
prompt_hash = hashlib.md5(APP_STATE["last_generation_info"].get("prompt", "").encode()).hexdigest()[:8]
|
| 266 |
-
filename = f"pixio_video_{timestamp}_{prompt_hash}.mp4"
|
| 267 |
-
filepath = os.path.join("downloads", filename)
|
| 268 |
-
|
| 269 |
-
# Create MP4 file
|
| 270 |
-
fps = APP_STATE["last_generation_info"].get("fps", 15)
|
| 271 |
-
frames_to_mp4_file(APP_STATE["last_generated_frames"], filepath, fps)
|
| 272 |
-
|
| 273 |
-
print(f"β
Download video created: {filepath}")
|
| 274 |
return filepath
|
| 275 |
|
| 276 |
except Exception as e:
|
| 277 |
-
print(f"β Error creating
|
| 278 |
-
import traceback
|
| 279 |
-
traceback.print_exc()
|
| 280 |
return None
|
| 281 |
|
| 282 |
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
|
@@ -326,6 +281,13 @@ def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
|
| 326 |
APP_STATE["current_vae_decoder"] = vae_decoder
|
| 327 |
print(f"β
VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}")
|
| 328 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
# Initialize with default VAE
|
| 330 |
initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
|
| 331 |
|
|
@@ -340,17 +302,14 @@ pipeline.to(dtype=torch.float16).to(gpu)
|
|
| 340 |
@spaces.GPU
|
| 341 |
def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
| 342 |
"""
|
| 343 |
-
Generator function that
|
| 344 |
-
Now optimized for block-based processing and stores frames for download.
|
| 345 |
"""
|
|
|
|
|
|
|
| 346 |
if seed == -1:
|
| 347 |
seed = random.randint(0, 2**32 - 1)
|
| 348 |
|
| 349 |
-
print(f"π¬ Starting
|
| 350 |
-
|
| 351 |
-
# Clear previous generation data
|
| 352 |
-
APP_STATE["last_generated_frames"] = []
|
| 353 |
-
APP_STATE["last_generation_info"] = {"prompt": prompt, "seed": seed, "fps": fps}
|
| 354 |
|
| 355 |
# Setup
|
| 356 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
|
@@ -371,9 +330,14 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 371 |
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
| 372 |
|
| 373 |
total_frames_yielded = 0
|
|
|
|
|
|
|
| 374 |
|
| 375 |
-
#
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
# Generation loop
|
| 379 |
for idx, current_num_frames in enumerate(all_num_frames):
|
|
@@ -424,10 +388,8 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 424 |
elif APP_STATE["current_use_taehv"] and idx > 0:
|
| 425 |
pixels = pixels[:, 12:]
|
| 426 |
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
# Process all frames from this block at once
|
| 430 |
-
all_frames_from_block = []
|
| 431 |
for frame_idx in range(pixels.shape[1]):
|
| 432 |
frame_tensor = pixels[0, frame_idx]
|
| 433 |
|
|
@@ -436,17 +398,14 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 436 |
frame_np = frame_np.to(torch.uint8).cpu().numpy()
|
| 437 |
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
| 438 |
|
| 439 |
-
|
| 440 |
-
# Store
|
| 441 |
-
APP_STATE["last_generated_frames"].append(frame_np)
|
| 442 |
total_frames_yielded += 1
|
| 443 |
|
| 444 |
-
#
|
| 445 |
blocks_completed = idx
|
| 446 |
current_block_progress = (frame_idx + 1) / pixels.shape[1]
|
| 447 |
total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
|
| 448 |
-
|
| 449 |
-
# Cap at 100% to avoid going over
|
| 450 |
total_progress = min(total_progress, 100.0)
|
| 451 |
|
| 452 |
frame_status_html = (
|
|
@@ -461,34 +420,51 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 461 |
f"</div>"
|
| 462 |
)
|
| 463 |
|
| 464 |
-
|
| 465 |
-
yield None, frame_status_html, gr.update(visible=False)
|
| 466 |
|
| 467 |
-
#
|
| 468 |
-
if
|
| 469 |
-
print(f"πΉ Encoding block {idx} with {len(all_frames_from_block)} frames")
|
| 470 |
-
|
| 471 |
try:
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
ts_path = os.path.join("gradio_tmp", ts_filename)
|
| 475 |
|
| 476 |
-
frames_to_ts_file(
|
|
|
|
| 477 |
|
| 478 |
-
#
|
| 479 |
-
|
|
|
|
| 480 |
|
| 481 |
-
# Yield the
|
| 482 |
-
yield
|
| 483 |
|
| 484 |
except Exception as e:
|
| 485 |
-
print(f"β οΈ Error
|
| 486 |
-
import traceback
|
| 487 |
-
traceback.print_exc()
|
| 488 |
|
| 489 |
current_start_frame += current_num_frames
|
| 490 |
|
| 491 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
final_status_html = (
|
| 493 |
f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
|
| 494 |
f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
|
|
@@ -500,27 +476,33 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 500 |
f" π Generated {total_frames_yielded} frames across {num_blocks} blocks"
|
| 501 |
f" </p>"
|
| 502 |
f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
|
| 503 |
-
f" π¬ Playback: {fps} FPS β’ π Format:
|
| 504 |
f" </p>"
|
| 505 |
f" </div>"
|
| 506 |
f"</div>"
|
| 507 |
)
|
| 508 |
-
yield None, final_status_html
|
| 509 |
-
print(f"β
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
|
| 511 |
# --- Gradio UI Layout ---
|
| 512 |
with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
| 513 |
gr.Markdown("# π Pixio Streaming Video Generation")
|
| 514 |
-
gr.Markdown("Real-time video generation with
|
| 515 |
|
| 516 |
with gr.Row():
|
| 517 |
with gr.Column(scale=2):
|
| 518 |
with gr.Group():
|
| 519 |
prompt = gr.Textbox(
|
| 520 |
label="Prompt",
|
| 521 |
-
placeholder="A
|
| 522 |
lines=4,
|
| 523 |
-
value=""
|
| 524 |
)
|
| 525 |
enhance_button = gr.Button("β¨ Enhance Prompt", variant="secondary")
|
| 526 |
|
|
@@ -576,20 +558,20 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
| 576 |
label="Generation Status"
|
| 577 |
)
|
| 578 |
|
| 579 |
-
# Download button
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
|
| 588 |
-
# Connect the
|
| 589 |
start_btn.click(
|
| 590 |
fn=video_generation_handler_streaming,
|
| 591 |
inputs=[prompt, seed, fps],
|
| 592 |
-
outputs=[streaming_video, status_display
|
| 593 |
)
|
| 594 |
|
| 595 |
enhance_button.click(
|
|
@@ -607,18 +589,17 @@ if __name__ == "__main__":
|
|
| 607 |
os.makedirs("downloads", exist_ok=True)
|
| 608 |
|
| 609 |
print("π Starting Self-Forcing Streaming Demo")
|
| 610 |
-
print(f"π Temporary files
|
| 611 |
-
print(f"π₯ Download files
|
| 612 |
-
print(f"π―
|
| 613 |
-
print(f"
|
| 614 |
|
| 615 |
demo.queue().launch(
|
| 616 |
server_name=args.host,
|
| 617 |
server_port=args.port,
|
| 618 |
share=args.share,
|
| 619 |
show_error=True,
|
| 620 |
-
max_threads=40
|
| 621 |
-
mcp_server=True
|
| 622 |
)
|
| 623 |
|
| 624 |
# import subprocess
|
|
|
|
| 141 |
text_encoder.to(gpu)
|
| 142 |
transformer.to(gpu)
|
| 143 |
|
| 144 |
+
# Global state for download
|
| 145 |
+
CURRENT_DOWNLOAD_PATH = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
def frames_to_ts_file(frames, filepath, fps = 15):
|
| 148 |
"""
|
| 149 |
Convert frames directly to .ts file using PyAV.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
"""
|
| 151 |
if not frames:
|
| 152 |
return filepath
|
|
|
|
| 186 |
|
| 187 |
return filepath
|
| 188 |
|
| 189 |
+
def create_hls_playlist(ts_files, playlist_path, fps=15):
|
| 190 |
"""
|
| 191 |
+
Create HLS playlist (.m3u8) file for streaming.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
"""
|
| 193 |
+
segment_duration = 1.0 # Each segment duration in seconds
|
|
|
|
| 194 |
|
| 195 |
+
playlist_content = [
|
| 196 |
+
"#EXTM3U",
|
| 197 |
+
"#EXT-X-VERSION:3",
|
| 198 |
+
f"#EXT-X-TARGETDURATION:{int(segment_duration) + 1}",
|
| 199 |
+
"#EXT-X-MEDIA-SEQUENCE:0",
|
| 200 |
+
"#EXT-X-PLAYLIST-TYPE:VOD"
|
| 201 |
+
]
|
| 202 |
|
| 203 |
+
for ts_file in ts_files:
|
| 204 |
+
ts_filename = os.path.basename(ts_file)
|
| 205 |
+
playlist_content.extend([
|
| 206 |
+
f"#EXTINF:{segment_duration:.1f},",
|
| 207 |
+
ts_filename
|
| 208 |
+
])
|
| 209 |
|
| 210 |
+
playlist_content.append("#EXT-X-ENDLIST")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
+
with open(playlist_path, 'w') as f:
|
| 213 |
+
f.write('\n'.join(playlist_content))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
return playlist_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
+
def frames_to_mp4_file(frames, filepath, fps=15):
|
| 218 |
"""
|
| 219 |
+
Convert frames to MP4 file using imageio.
|
| 220 |
"""
|
| 221 |
+
if not frames:
|
| 222 |
return None
|
| 223 |
|
| 224 |
try:
|
| 225 |
+
# Use imageio for reliable MP4 creation
|
| 226 |
+
with imageio.get_writer(filepath, fps=fps, codec='libx264', quality=8) as writer:
|
| 227 |
+
for frame in frames:
|
| 228 |
+
writer.append_data(frame)
|
| 229 |
|
| 230 |
+
print(f"β
MP4 created successfully: {filepath}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
return filepath
|
| 232 |
|
| 233 |
except Exception as e:
|
| 234 |
+
print(f"β Error creating MP4: {e}")
|
|
|
|
|
|
|
| 235 |
return None
|
| 236 |
|
| 237 |
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
|
|
|
| 281 |
APP_STATE["current_vae_decoder"] = vae_decoder
|
| 282 |
print(f"β
VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}")
|
| 283 |
|
| 284 |
+
APP_STATE = {
|
| 285 |
+
"torch_compile_applied": False,
|
| 286 |
+
"fp8_applied": False,
|
| 287 |
+
"current_use_taehv": False,
|
| 288 |
+
"current_vae_decoder": None,
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
# Initialize with default VAE
|
| 292 |
initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
|
| 293 |
|
|
|
|
| 302 |
@spaces.GPU
|
| 303 |
def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
| 304 |
"""
|
| 305 |
+
Generator function that creates HLS stream and final MP4.
|
|
|
|
| 306 |
"""
|
| 307 |
+
global CURRENT_DOWNLOAD_PATH
|
| 308 |
+
|
| 309 |
if seed == -1:
|
| 310 |
seed = random.randint(0, 2**32 - 1)
|
| 311 |
|
| 312 |
+
print(f"π¬ Starting HLS streaming: '{prompt}', seed: {seed}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
# Setup
|
| 315 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
|
|
|
| 330 |
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
| 331 |
|
| 332 |
total_frames_yielded = 0
|
| 333 |
+
all_frames_for_download = [] # Store frames for final MP4
|
| 334 |
+
ts_files = [] # Store TS files for HLS playlist
|
| 335 |
|
| 336 |
+
# Create unique session directory
|
| 337 |
+
session_id = str(uuid.uuid4())[:8]
|
| 338 |
+
session_dir = os.path.join("gradio_tmp", f"session_{session_id}")
|
| 339 |
+
os.makedirs(session_dir, exist_ok=True)
|
| 340 |
+
os.makedirs("downloads", exist_ok=True)
|
| 341 |
|
| 342 |
# Generation loop
|
| 343 |
for idx, current_num_frames in enumerate(all_num_frames):
|
|
|
|
| 388 |
elif APP_STATE["current_use_taehv"] and idx > 0:
|
| 389 |
pixels = pixels[:, 12:]
|
| 390 |
|
| 391 |
+
# Process frames from this block
|
| 392 |
+
block_frames = []
|
|
|
|
|
|
|
| 393 |
for frame_idx in range(pixels.shape[1]):
|
| 394 |
frame_tensor = pixels[0, frame_idx]
|
| 395 |
|
|
|
|
| 398 |
frame_np = frame_np.to(torch.uint8).cpu().numpy()
|
| 399 |
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
| 400 |
|
| 401 |
+
block_frames.append(frame_np)
|
| 402 |
+
all_frames_for_download.append(frame_np) # Store for final MP4
|
|
|
|
| 403 |
total_frames_yielded += 1
|
| 404 |
|
| 405 |
+
# Progress tracking
|
| 406 |
blocks_completed = idx
|
| 407 |
current_block_progress = (frame_idx + 1) / pixels.shape[1]
|
| 408 |
total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
|
|
|
|
|
|
|
| 409 |
total_progress = min(total_progress, 100.0)
|
| 410 |
|
| 411 |
frame_status_html = (
|
|
|
|
| 420 |
f"</div>"
|
| 421 |
)
|
| 422 |
|
| 423 |
+
yield None, frame_status_html
|
|
|
|
| 424 |
|
| 425 |
+
# Create TS segment for this block
|
| 426 |
+
if block_frames:
|
|
|
|
|
|
|
| 427 |
try:
|
| 428 |
+
ts_filename = f"segment_{idx:04d}.ts"
|
| 429 |
+
ts_path = os.path.join(session_dir, ts_filename)
|
|
|
|
| 430 |
|
| 431 |
+
frames_to_ts_file(block_frames, ts_path, fps)
|
| 432 |
+
ts_files.append(ts_path)
|
| 433 |
|
| 434 |
+
# Create/update HLS playlist
|
| 435 |
+
playlist_path = os.path.join(session_dir, "playlist.m3u8")
|
| 436 |
+
create_hls_playlist(ts_files, playlist_path, fps)
|
| 437 |
|
| 438 |
+
# Yield the HLS playlist for streaming
|
| 439 |
+
yield playlist_path, gr.update()
|
| 440 |
|
| 441 |
except Exception as e:
|
| 442 |
+
print(f"β οΈ Error creating HLS segment {idx}: {e}")
|
|
|
|
|
|
|
| 443 |
|
| 444 |
current_start_frame += current_num_frames
|
| 445 |
|
| 446 |
+
# Create final MP4 for download
|
| 447 |
+
print("π¬ Creating final MP4 for download...")
|
| 448 |
+
try:
|
| 449 |
+
timestamp = int(time.time())
|
| 450 |
+
prompt_hash = hashlib.md5(prompt.encode()).hexdigest()[:8]
|
| 451 |
+
mp4_filename = f"pixio_video_{timestamp}_{prompt_hash}.mp4"
|
| 452 |
+
mp4_path = os.path.join("downloads", mp4_filename)
|
| 453 |
+
|
| 454 |
+
final_mp4 = frames_to_mp4_file(all_frames_for_download, mp4_path, fps)
|
| 455 |
+
if final_mp4:
|
| 456 |
+
CURRENT_DOWNLOAD_PATH = final_mp4
|
| 457 |
+
print(f"β
Final MP4 created: {final_mp4}")
|
| 458 |
+
else:
|
| 459 |
+
print("β Failed to create final MP4")
|
| 460 |
+
CURRENT_DOWNLOAD_PATH = None
|
| 461 |
+
|
| 462 |
+
except Exception as e:
|
| 463 |
+
print(f"β Error creating final MP4: {e}")
|
| 464 |
+
CURRENT_DOWNLOAD_PATH = None
|
| 465 |
+
|
| 466 |
+
# Final completion status with download info
|
| 467 |
+
download_info = "π₯ Download ready!" if CURRENT_DOWNLOAD_PATH else "β Download failed"
|
| 468 |
final_status_html = (
|
| 469 |
f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
|
| 470 |
f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
|
|
|
|
| 476 |
f" π Generated {total_frames_yielded} frames across {num_blocks} blocks"
|
| 477 |
f" </p>"
|
| 478 |
f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
|
| 479 |
+
f" π¬ Playback: {fps} FPS β’ π Format: HLS/H.264 β’ {download_info}"
|
| 480 |
f" </p>"
|
| 481 |
f" </div>"
|
| 482 |
f"</div>"
|
| 483 |
)
|
| 484 |
+
yield None, final_status_html
|
| 485 |
+
print(f"β
HLS streaming complete! {total_frames_yielded} frames")
|
| 486 |
+
|
| 487 |
+
def download_video():
|
| 488 |
+
"""Return the current download file path."""
|
| 489 |
+
if CURRENT_DOWNLOAD_PATH and os.path.exists(CURRENT_DOWNLOAD_PATH):
|
| 490 |
+
return CURRENT_DOWNLOAD_PATH
|
| 491 |
+
return None
|
| 492 |
|
| 493 |
# --- Gradio UI Layout ---
|
| 494 |
with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
| 495 |
gr.Markdown("# π Pixio Streaming Video Generation")
|
| 496 |
+
gr.Markdown("Real-time video generation with distilled Wan2-1.3B [[Model]](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B), [[Project page]](https://pixio.myapps.ai), [[Paper]](https://arxiv.org/abs/2412.09738)")
|
| 497 |
|
| 498 |
with gr.Row():
|
| 499 |
with gr.Column(scale=2):
|
| 500 |
with gr.Group():
|
| 501 |
prompt = gr.Textbox(
|
| 502 |
label="Prompt",
|
| 503 |
+
placeholder="A close-up shot of a ceramic teacup slowly pouring water into a glass mug.",
|
| 504 |
lines=4,
|
| 505 |
+
value="A close-up shot of a ceramic teacup slowly pouring water into a glass mug."
|
| 506 |
)
|
| 507 |
enhance_button = gr.Button("β¨ Enhance Prompt", variant="secondary")
|
| 508 |
|
|
|
|
| 558 |
label="Generation Status"
|
| 559 |
)
|
| 560 |
|
| 561 |
+
# Download button that appears after completion
|
| 562 |
+
with gr.Row():
|
| 563 |
+
download_btn = gr.DownloadButton(
|
| 564 |
+
label="π₯ Download MP4 Video",
|
| 565 |
+
value=download_video,
|
| 566 |
+
variant="secondary",
|
| 567 |
+
size="lg"
|
| 568 |
+
)
|
| 569 |
|
| 570 |
+
# Connect the streaming function
|
| 571 |
start_btn.click(
|
| 572 |
fn=video_generation_handler_streaming,
|
| 573 |
inputs=[prompt, seed, fps],
|
| 574 |
+
outputs=[streaming_video, status_display]
|
| 575 |
)
|
| 576 |
|
| 577 |
enhance_button.click(
|
|
|
|
| 589 |
os.makedirs("downloads", exist_ok=True)
|
| 590 |
|
| 591 |
print("π Starting Self-Forcing Streaming Demo")
|
| 592 |
+
print(f"π Temporary files: gradio_tmp/")
|
| 593 |
+
print(f"π₯ Download files: downloads/")
|
| 594 |
+
print(f"π― Streaming: HLS (.m3u8 + .ts segments)")
|
| 595 |
+
print(f"π± Download: MP4 (imageio)")
|
| 596 |
|
| 597 |
demo.queue().launch(
|
| 598 |
server_name=args.host,
|
| 599 |
server_port=args.port,
|
| 600 |
share=args.share,
|
| 601 |
show_error=True,
|
| 602 |
+
max_threads=40
|
|
|
|
| 603 |
)
|
| 604 |
|
| 605 |
# import subprocess
|