Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,6 +34,7 @@ from tqdm import tqdm
|
|
| 34 |
import imageio
|
| 35 |
import av
|
| 36 |
import uuid
|
|
|
|
| 37 |
|
| 38 |
from pipeline import CausalInferencePipeline
|
| 39 |
from demo_utils.constant import ZERO_VAE_CACHE
|
|
@@ -68,7 +69,7 @@ T2V_CINEMATIC_PROMPT = \
|
|
| 68 |
'''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \
|
| 69 |
'''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \
|
| 70 |
'''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \
|
| 71 |
-
'''4. Prompts should match the user
|
| 72 |
'''5. Emphasize motion information and different camera movements present in the input description;\n''' \
|
| 73 |
'''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
|
| 74 |
'''7. The revised prompt should be around 80-100 words long.\n''' \
|
|
@@ -146,75 +147,58 @@ APP_STATE = {
|
|
| 146 |
"fp8_applied": False,
|
| 147 |
"current_use_taehv": False,
|
| 148 |
"current_vae_decoder": None,
|
|
|
|
| 149 |
}
|
| 150 |
|
| 151 |
-
|
| 152 |
-
DOWNLOAD_FRAMES = []
|
| 153 |
-
|
| 154 |
-
def frames_to_mp4_chunk(frames, filepath, fps=15):
|
| 155 |
"""
|
| 156 |
-
Convert frames to
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
"""
|
| 158 |
if not frames:
|
| 159 |
return filepath
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
try:
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
return filepath
|
| 168 |
-
|
| 169 |
-
except Exception as e:
|
| 170 |
-
print(f"❌ Error creating MP4 chunk: {e}")
|
| 171 |
-
# Fallback to PyAV if imageio fails
|
| 172 |
-
try:
|
| 173 |
-
height, width = frames[0].shape[:2]
|
| 174 |
-
container = av.open(filepath, mode='w', format='mp4')
|
| 175 |
-
|
| 176 |
-
stream = container.add_stream('h264', rate=fps)
|
| 177 |
-
stream.width = width
|
| 178 |
-
stream.height = height
|
| 179 |
-
stream.pix_fmt = 'yuv420p'
|
| 180 |
-
stream.options = {
|
| 181 |
-
'preset': 'ultrafast',
|
| 182 |
-
'tune': 'zerolatency',
|
| 183 |
-
'crf': '28'
|
| 184 |
-
}
|
| 185 |
-
|
| 186 |
-
for frame_np in frames:
|
| 187 |
-
frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
|
| 188 |
-
frame = frame.reformat(format=stream.pix_fmt)
|
| 189 |
-
for packet in stream.encode(frame):
|
| 190 |
-
container.mux(packet)
|
| 191 |
-
|
| 192 |
-
for packet in stream.encode():
|
| 193 |
container.mux(packet)
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
def create_download_mp4():
|
| 203 |
-
global DOWNLOAD_FRAMES
|
| 204 |
-
if not DOWNLOAD_FRAMES:
|
| 205 |
-
return None
|
| 206 |
-
try:
|
| 207 |
-
os.makedirs("downloads", exist_ok=True)
|
| 208 |
-
timestamp = int(time.time())
|
| 209 |
-
mp4_path = f"downloads/video_{timestamp}.mp4"
|
| 210 |
-
with imageio.get_writer(mp4_path, fps=args.fps, codec='libx264', quality=8) as writer:
|
| 211 |
-
for frame in DOWNLOAD_FRAMES:
|
| 212 |
-
writer.append_data(frame)
|
| 213 |
-
print(f"✅ Download MP4 created: {mp4_path}")
|
| 214 |
-
return mp4_path
|
| 215 |
-
except Exception as e:
|
| 216 |
-
print(f"❌ Download error: {e}")
|
| 217 |
-
return None
|
| 218 |
|
| 219 |
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
| 220 |
if use_trt:
|
|
@@ -275,17 +259,15 @@ pipeline.to(dtype=torch.float16).to(gpu)
|
|
| 275 |
|
| 276 |
@torch.no_grad()
|
| 277 |
@spaces.GPU
|
| 278 |
-
def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
| 279 |
"""
|
| 280 |
-
Generator function that yields
|
|
|
|
| 281 |
"""
|
| 282 |
-
global DOWNLOAD_FRAMES
|
| 283 |
-
DOWNLOAD_FRAMES = [] # Reset frames
|
| 284 |
-
|
| 285 |
if seed == -1:
|
| 286 |
seed = random.randint(0, 2**32 - 1)
|
| 287 |
|
| 288 |
-
print(f"🎬 Starting
|
| 289 |
|
| 290 |
# Setup
|
| 291 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
|
@@ -372,13 +354,14 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 372 |
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
| 373 |
|
| 374 |
all_frames_from_block.append(frame_np)
|
| 375 |
-
DOWNLOAD_FRAMES.append(frame_np) # Store for download
|
| 376 |
total_frames_yielded += 1
|
| 377 |
|
| 378 |
-
# Yield status update for each frame
|
| 379 |
blocks_completed = idx
|
| 380 |
current_block_progress = (frame_idx + 1) / pixels.shape[1]
|
| 381 |
total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
|
|
|
|
|
|
|
| 382 |
total_progress = min(total_progress, 100.0)
|
| 383 |
|
| 384 |
frame_status_html = (
|
|
@@ -393,21 +376,25 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 393 |
f"</div>"
|
| 394 |
)
|
| 395 |
|
|
|
|
| 396 |
yield None, frame_status_html
|
| 397 |
|
| 398 |
-
#
|
| 399 |
if all_frames_from_block:
|
| 400 |
print(f"📹 Encoding block {idx} with {len(all_frames_from_block)} frames")
|
| 401 |
|
| 402 |
try:
|
| 403 |
chunk_uuid = str(uuid.uuid4())[:8]
|
| 404 |
-
|
| 405 |
-
|
| 406 |
|
| 407 |
-
|
| 408 |
|
| 409 |
-
#
|
| 410 |
-
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
except Exception as e:
|
| 413 |
print(f"⚠️ Error encoding block {idx}: {e}")
|
|
@@ -428,13 +415,41 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
|
| 428 |
f" 📊 Generated {total_frames_yielded} frames across {num_blocks} blocks"
|
| 429 |
f" </p>"
|
| 430 |
f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
|
| 431 |
-
f" 🎬 Playback: {fps} FPS • 📁 Format:
|
| 432 |
f" </p>"
|
| 433 |
f" </div>"
|
| 434 |
f"</div>"
|
| 435 |
)
|
| 436 |
yield None, final_status_html
|
| 437 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
# --- Gradio UI Layout ---
|
| 440 |
with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
@@ -504,20 +519,31 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
| 504 |
label="Generation Status"
|
| 505 |
)
|
| 506 |
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
label="📥 Download MP4",
|
| 510 |
-
value=create_download_mp4,
|
| 511 |
-
variant="secondary"
|
| 512 |
-
)
|
| 513 |
|
| 514 |
# Connect the generator to the streaming video
|
| 515 |
start_btn.click(
|
| 516 |
-
fn=video_generation_handler_streaming,
|
| 517 |
inputs=[prompt, seed, fps],
|
| 518 |
outputs=[streaming_video, status_display]
|
| 519 |
)
|
| 520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
enhance_button.click(
|
| 522 |
fn=enhance_prompt,
|
| 523 |
inputs=[prompt],
|
|
@@ -530,12 +556,10 @@ if __name__ == "__main__":
|
|
| 530 |
import shutil
|
| 531 |
shutil.rmtree("gradio_tmp")
|
| 532 |
os.makedirs("gradio_tmp", exist_ok=True)
|
| 533 |
-
os.makedirs("downloads", exist_ok=True)
|
| 534 |
|
| 535 |
print("🚀 Starting Self-Forcing Streaming Demo")
|
| 536 |
print(f"📁 Temporary files will be stored in: gradio_tmp/")
|
| 537 |
-
print(f"
|
| 538 |
-
print(f"🎯 Chunk encoding: MP4/H.264 (more compatible)")
|
| 539 |
print(f"⚡ GPU acceleration: {gpu}")
|
| 540 |
|
| 541 |
demo.queue().launch(
|
|
@@ -546,8 +570,6 @@ if __name__ == "__main__":
|
|
| 546 |
max_threads=40,
|
| 547 |
mcp_server=True
|
| 548 |
)
|
| 549 |
-
|
| 550 |
-
|
| 551 |
# import subprocess
|
| 552 |
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 553 |
|
|
|
|
| 34 |
import imageio
|
| 35 |
import av
|
| 36 |
import uuid
|
| 37 |
+
import tempfile
|
| 38 |
|
| 39 |
from pipeline import CausalInferencePipeline
|
| 40 |
from demo_utils.constant import ZERO_VAE_CACHE
|
|
|
|
| 69 |
'''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \
|
| 70 |
'''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \
|
| 71 |
'''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \
|
| 72 |
+
'''4. Prompts should match the user’s intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \
|
| 73 |
'''5. Emphasize motion information and different camera movements present in the input description;\n''' \
|
| 74 |
'''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
|
| 75 |
'''7. The revised prompt should be around 80-100 words long.\n''' \
|
|
|
|
| 147 |
"fp8_applied": False,
|
| 148 |
"current_use_taehv": False,
|
| 149 |
"current_vae_decoder": None,
|
| 150 |
+
"current_frames": [],
|
| 151 |
}
|
| 152 |
|
| 153 |
+
def frames_to_ts_file(frames, filepath, fps = 15):
|
|
|
|
|
|
|
|
|
|
| 154 |
"""
|
| 155 |
+
Convert frames directly to .ts file using PyAV.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
frames: List of numpy arrays (HWC, RGB, uint8)
|
| 159 |
+
filepath: Output file path
|
| 160 |
+
fps: Frames per second
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
The filepath of the created file
|
| 164 |
"""
|
| 165 |
if not frames:
|
| 166 |
return filepath
|
| 167 |
|
| 168 |
+
height, width = frames[0].shape[:2]
|
| 169 |
+
|
| 170 |
+
# Create container for MPEG-TS format
|
| 171 |
+
container = av.open(filepath, mode='w', format='mpegts')
|
| 172 |
+
|
| 173 |
+
# Add video stream with optimized settings for streaming
|
| 174 |
+
stream = container.add_stream('h264', rate=fps)
|
| 175 |
+
stream.width = width
|
| 176 |
+
stream.height = height
|
| 177 |
+
stream.pix_fmt = 'yuv420p'
|
| 178 |
+
|
| 179 |
+
# Optimize for low latency streaming
|
| 180 |
+
stream.options = {
|
| 181 |
+
'preset': 'ultrafast',
|
| 182 |
+
'tune': 'zerolatency',
|
| 183 |
+
'crf': '23',
|
| 184 |
+
'profile': 'baseline',
|
| 185 |
+
'level': '3.0'
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
try:
|
| 189 |
+
for frame_np in frames:
|
| 190 |
+
frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
|
| 191 |
+
frame = frame.reformat(format=stream.pix_fmt)
|
| 192 |
+
for packet in stream.encode(frame):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
container.mux(packet)
|
| 194 |
+
|
| 195 |
+
for packet in stream.encode():
|
| 196 |
+
container.mux(packet)
|
| 197 |
|
| 198 |
+
finally:
|
| 199 |
+
container.close()
|
| 200 |
+
|
| 201 |
+
return filepath
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
| 204 |
if use_trt:
|
|
|
|
| 259 |
|
| 260 |
@torch.no_grad()
|
| 261 |
@spaces.GPU
|
| 262 |
+
def video_generation_handler_streaming(prompt, seed=42, fps=15, save_frames=True):
|
| 263 |
"""
|
| 264 |
+
Generator function that yields .ts video chunks using PyAV for streaming.
|
| 265 |
+
Now optimized for block-based processing.
|
| 266 |
"""
|
|
|
|
|
|
|
|
|
|
| 267 |
if seed == -1:
|
| 268 |
seed = random.randint(0, 2**32 - 1)
|
| 269 |
|
| 270 |
+
print(f"🎬 Starting PyAV streaming: '{prompt}', seed: {seed}")
|
| 271 |
|
| 272 |
# Setup
|
| 273 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
|
|
|
| 354 |
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
| 355 |
|
| 356 |
all_frames_from_block.append(frame_np)
|
|
|
|
| 357 |
total_frames_yielded += 1
|
| 358 |
|
| 359 |
+
# Yield status update for each frame (cute tracking!)
|
| 360 |
blocks_completed = idx
|
| 361 |
current_block_progress = (frame_idx + 1) / pixels.shape[1]
|
| 362 |
total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
|
| 363 |
+
|
| 364 |
+
# Cap at 100% to avoid going over
|
| 365 |
total_progress = min(total_progress, 100.0)
|
| 366 |
|
| 367 |
frame_status_html = (
|
|
|
|
| 376 |
f"</div>"
|
| 377 |
)
|
| 378 |
|
| 379 |
+
# Yield None for video but update status (frame-by-frame tracking)
|
| 380 |
yield None, frame_status_html
|
| 381 |
|
| 382 |
+
# Encode entire block as one chunk immediately
|
| 383 |
if all_frames_from_block:
|
| 384 |
print(f"📹 Encoding block {idx} with {len(all_frames_from_block)} frames")
|
| 385 |
|
| 386 |
try:
|
| 387 |
chunk_uuid = str(uuid.uuid4())[:8]
|
| 388 |
+
ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
|
| 389 |
+
ts_path = os.path.join("gradio_tmp", ts_filename)
|
| 390 |
|
| 391 |
+
frames_to_ts_file(all_frames_from_block, ts_path, fps)
|
| 392 |
|
| 393 |
+
# Calculate final progress for this block
|
| 394 |
+
total_progress = (idx + 1) / num_blocks * 100
|
| 395 |
+
|
| 396 |
+
# Yield the actual video chunk
|
| 397 |
+
yield ts_path, gr.update()
|
| 398 |
|
| 399 |
except Exception as e:
|
| 400 |
print(f"⚠️ Error encoding block {idx}: {e}")
|
|
|
|
| 415 |
f" 📊 Generated {total_frames_yielded} frames across {num_blocks} blocks"
|
| 416 |
f" </p>"
|
| 417 |
f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
|
| 418 |
+
f" 🎬 Playback: {fps} FPS • 📁 Format: MPEG-TS/H.264"
|
| 419 |
f" </p>"
|
| 420 |
f" </div>"
|
| 421 |
f"</div>"
|
| 422 |
)
|
| 423 |
yield None, final_status_html
|
| 424 |
+
print(f" PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
|
| 425 |
+
|
| 426 |
+
def save_frames_as_video(frames, fps=15):
|
| 427 |
+
"""
|
| 428 |
+
Convert frames to a downloadable MP4 video file.
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
frames: List of numpy arrays (HWC, RGB, uint8)
|
| 432 |
+
fps: Frames per second
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
Path to the saved video file
|
| 436 |
+
"""
|
| 437 |
+
if not frames:
|
| 438 |
+
return None
|
| 439 |
+
|
| 440 |
+
# Create a temporary file with a unique name
|
| 441 |
+
temp_file = os.path.join("gradio_tmp", f"download_{uuid.uuid4()}.mp4")
|
| 442 |
+
|
| 443 |
+
# Use imageio to write the video file
|
| 444 |
+
try:
|
| 445 |
+
writer = imageio.get_writer(temp_file, fps=fps, codec='h264', quality=9)
|
| 446 |
+
for frame in frames:
|
| 447 |
+
writer.append_data(frame)
|
| 448 |
+
writer.close()
|
| 449 |
+
return temp_file
|
| 450 |
+
except Exception as e:
|
| 451 |
+
print(f"Error saving video: {e}")
|
| 452 |
+
return None
|
| 453 |
|
| 454 |
# --- Gradio UI Layout ---
|
| 455 |
with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
|
|
|
| 519 |
label="Generation Status"
|
| 520 |
)
|
| 521 |
|
| 522 |
+
download_btn = gr.Button("💾 Download Video", variant="secondary")
|
| 523 |
+
download_output = gr.File(label="Download")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
|
| 525 |
# Connect the generator to the streaming video
|
| 526 |
start_btn.click(
|
| 527 |
+
fn=lambda p, s, f: (APP_STATE.update({"current_frames": []}) or video_generation_handler_streaming(p, s, f)),
|
| 528 |
inputs=[prompt, seed, fps],
|
| 529 |
outputs=[streaming_video, status_display]
|
| 530 |
)
|
| 531 |
|
| 532 |
+
# Function to handle download button click
|
| 533 |
+
def download_video(fps):
|
| 534 |
+
if not APP_STATE.get("current_frames"):
|
| 535 |
+
return None
|
| 536 |
+
video_path = save_frames_as_video(APP_STATE["current_frames"], fps)
|
| 537 |
+
return video_path
|
| 538 |
+
|
| 539 |
+
# Connect download button
|
| 540 |
+
download_btn.click(
|
| 541 |
+
fn=download_video,
|
| 542 |
+
inputs=[fps],
|
| 543 |
+
outputs=[download_output],
|
| 544 |
+
show_progress=True
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
enhance_button.click(
|
| 548 |
fn=enhance_prompt,
|
| 549 |
inputs=[prompt],
|
|
|
|
| 556 |
import shutil
|
| 557 |
shutil.rmtree("gradio_tmp")
|
| 558 |
os.makedirs("gradio_tmp", exist_ok=True)
|
|
|
|
| 559 |
|
| 560 |
print("🚀 Starting Self-Forcing Streaming Demo")
|
| 561 |
print(f"📁 Temporary files will be stored in: gradio_tmp/")
|
| 562 |
+
print(f"🎯 Chunk encoding: PyAV (MPEG-TS/H.264)")
|
|
|
|
| 563 |
print(f"⚡ GPU acceleration: {gpu}")
|
| 564 |
|
| 565 |
demo.queue().launch(
|
|
|
|
| 570 |
max_threads=40,
|
| 571 |
mcp_server=True
|
| 572 |
)
|
|
|
|
|
|
|
| 573 |
# import subprocess
|
| 574 |
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 575 |
|