| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| |
|
| | import imageio |
| | import numpy as np |
| | import torch |
| |
|
| | from ar_model import AutoRegressiveModel |
| | from text2world_prompt_upsampler_inference import ( |
| | create_prompt_upsampler, |
| | run_chat_completion, |
| | ) |
| | from presets import ( |
| | create_text_guardrail_runner, |
| | create_video_guardrail_runner, |
| | run_text_guardrail, |
| | run_video_guardrail, |
| | ) |
| | from .log import log |
| |
|
| |
|
| | def get_upsampled_prompt( |
| | prompt_upsampler_model: AutoRegressiveModel, input_prompt: str, temperature: float = 0.01 |
| | ) -> str: |
| | """ |
| | Get upsampled prompt from the prompt upsampler model instance. |
| | |
| | Args: |
| | prompt_upsampler_model: The prompt upsampler model instance. |
| | input_prompt (str): Original prompt to upsample. |
| | temperature (float): Temperature for generation (default: 0.01). |
| | |
| | Returns: |
| | str: The upsampled prompt. |
| | """ |
| | dialogs = [ |
| | [ |
| | { |
| | "role": "user", |
| | "content": f"Upsample the short caption to a long caption: {input_prompt}", |
| | } |
| | ] |
| | ] |
| |
|
| | upsampled_prompt = run_chat_completion(prompt_upsampler_model, dialogs, temperature=temperature) |
| | return upsampled_prompt |
| |
|
| |
|
| | def print_rank_0(string: str): |
| | rank = torch.distributed.get_rank() |
| | if rank == 0: |
| | log.info(string) |
| |
|
| |
|
| | def process_prompt( |
| | prompt: str, |
| | checkpoint_dir: str, |
| | prompt_upsampler_dir: str, |
| | guardrails_dir: str, |
| | image_path: str = None, |
| | enable_prompt_upsampler: bool = True, |
| | ) -> str: |
| | """ |
| | Handle prompt upsampling if enabled, then run guardrails to ensure safety. |
| | |
| | Args: |
| | prompt (str): The original text prompt. |
| | checkpoint_dir (str): Base checkpoint directory. |
| | prompt_upsampler_dir (str): Directory containing prompt upsampler weights. |
| | guardrails_dir (str): Directory containing guardrails weights. |
| | image_path (str, optional): Path to an image, if any (not implemented for upsampling). |
| | enable_prompt_upsampler (bool): Whether to enable prompt upsampling. |
| | |
| | Returns: |
| | str: The upsampled prompt or original prompt if upsampling is disabled or fails. |
| | """ |
| |
|
| | text_guardrail = create_text_guardrail_runner(os.path.join(checkpoint_dir, guardrails_dir)) |
| |
|
| | |
| | is_safe = run_text_guardrail(str(prompt), text_guardrail) |
| | if not is_safe: |
| | raise ValueError("Guardrail blocked world generation.") |
| |
|
| | if enable_prompt_upsampler: |
| | if image_path: |
| | raise NotImplementedError("Prompt upsampling is not supported for image generation") |
| | else: |
| | prompt_upsampler = create_prompt_upsampler( |
| | checkpoint_dir=os.path.join(checkpoint_dir, prompt_upsampler_dir) |
| | ) |
| | upsampled_prompt = get_upsampled_prompt(prompt_upsampler, prompt) |
| | print_rank_0(f"Original prompt: {prompt}\nUpsampled prompt: {upsampled_prompt}\n") |
| | del prompt_upsampler |
| |
|
| | |
| | is_safe = run_text_guardrail(str(upsampled_prompt), text_guardrail) |
| | if not is_safe: |
| | raise ValueError("Guardrail blocked world generation.") |
| |
|
| | return upsampled_prompt |
| | else: |
| | return prompt |
| |
|
| |
|
| | def save_video( |
| | grid: np.ndarray, |
| | fps: int, |
| | H: int, |
| | W: int, |
| | video_save_quality: int, |
| | video_save_path: str, |
| | checkpoint_dir: str, |
| | guardrails_dir: str, |
| | ): |
| | """ |
| | Save video frames to file, applying a safety check before writing. |
| | |
| | Args: |
| | grid (np.ndarray): Video frames array [T, H, W, C]. |
| | fps (int): Frames per second. |
| | H (int): Frame height. |
| | W (int): Frame width. |
| | video_save_quality (int): Video encoding quality (0-10). |
| | video_save_path (str): Output video file path. |
| | checkpoint_dir (str): Directory containing model checkpoints. |
| | guardrails_dir (str): Directory containing guardrails weights. |
| | """ |
| | video_classifier_guardrail = create_video_guardrail_runner(os.path.join(checkpoint_dir, guardrails_dir)) |
| |
|
| | |
| | grid = run_video_guardrail(grid, video_classifier_guardrail) |
| |
|
| | kwargs = { |
| | "fps": fps, |
| | "quality": video_save_quality, |
| | "macro_block_size": 1, |
| | "ffmpeg_params": ["-s", f"{W}x{H}"], |
| | "output_params": ["-f", "mp4"], |
| | } |
| |
|
| | imageio.mimsave(video_save_path, grid, "mp4", **kwargs) |
| |
|