import sys import types from importlib.machinery import ModuleSpec from typing import Iterable from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes colors.orange_red = colors.Color( name="orange_red", c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366", c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700", c800="#B33000", c900="#992900", c950="#802200", ) class OrangeRedTheme(Soft): def __init__( self, *, primary_hue: colors.Color | str = colors.gray, secondary_hue: colors.Color | str = colors.orange_red, neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Outfit"), "Arial", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono, ) super().set( background_fill_primary="*primary_50", background_fill_primary_dark="*primary_900", body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", button_primary_text_color="white", button_primary_text_color_hover="white", button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", slider_color="*secondary_500", slider_color_dark="*secondary_600", block_title_text_weight="600", block_border_width="3px", block_shadow="*shadow_drop_lg", button_primary_shadow="*shadow_drop_lg", button_large_padding="11px", color_accent_soft="*primary_100", block_label_background_fill="*primary_200", ) orange_red_theme = OrangeRedTheme() class DummyLayerRepository: def __init__(self, *args, **kwargs): pass kernels = types.ModuleType("kernels") kernels_layer = types.ModuleType("kernels.layer") kernels_layer_layer = types.ModuleType("kernels.layer.layer") # Set __spec__ to prevent ValueError: kernels.__spec__ is None in python 3.12 kernels.__spec__ = ModuleSpec("kernels", None, is_package=True) kernels_layer.__spec__ = ModuleSpec("kernels.layer", None, is_package=True) kernels_layer_layer.__spec__ = ModuleSpec("kernels.layer.layer", None, is_package=False) kernels.__version__ = "0.0.1" kernels_layer_layer.LayerRepository = DummyLayerRepository kernels_layer.LayerRepository = DummyLayerRepository kernels_layer.layer = kernels_layer_layer kernels.layer = kernels_layer sys.modules["kernels"] = kernels sys.modules["kernels.layer"] = kernels_layer sys.modules["kernels.layer.layer"] = kernels_layer_layer import os os.environ["SETUPTOOLS_USE_DISTUTILS"] = "stdlib" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import re import math import time import argparse import datetime import importlib import torch import yaml _IS_ZERO_GPU = ( os.environ.get("SPACES_ZERO_GPU", "0") == "1" or "spaces" in sys.modules or os.path.exists("/usr/local/lib/python3.12/site-packages/spaces") ) try: import spaces as _spaces_module _HAS_SPACES = True except ImportError: _HAS_SPACES = False _spaces_module = None def _spaces_gpu_decorator(fn=None, *, size="xlarge", duration=120): """ Conditional @spaces.GPU decorator. - On ZeroGPU: applies the real spaces.GPU decorator. - Elsewhere: returns the function unchanged (no-op). """ def _decorator(f): if _HAS_SPACES and _IS_ZERO_GPU: return _spaces_module.GPU(size=size, duration=duration)(f) return f if fn is not None: return _decorator(fn) return _decorator ASPECT_RATIO_MAP = { "16:9 (1280×704)": (704, 1280), "9:16 (704×1280)": (1280, 704), "1:1 (960×960)": (960, 960), } CMD_INFER = 1 CMD_EXIT = 0 def broadcast_string(s: str, src: int = 0): """Broadcast a string from src rank to all ranks.""" import torch.distributed as dist if dist.get_rank() == src: data = s.encode("utf-8") length = torch.tensor([len(data)], dtype=torch.long, device="cuda") else: length = torch.tensor([0], dtype=torch.long, device="cuda") dist.broadcast(length, src=src) n = length.item() if n == 0: return "" if dist.get_rank() == src: tensor = torch.tensor(list(data), dtype=torch.uint8, device="cuda") else: tensor = torch.empty(n, dtype=torch.uint8, device="cuda") dist.broadcast(tensor, src=src) if dist.get_rank() != src: s = bytes(tensor.cpu().tolist()).decode("utf-8") return s def broadcast_cmd(cmd: int, src: int = 0): """Broadcast a command integer from src to all ranks.""" import torch.distributed as dist t = torch.tensor([cmd], dtype=torch.long, device="cuda") dist.broadcast(t, src=src) return t.item() def broadcast_int(val: int, src: int = 0): """Broadcast a single integer.""" import torch.distributed as dist t = torch.tensor([val], dtype=torch.long, device="cuda") dist.broadcast(t, src=src) return t.item() SYSTEM_PROMPT = """你是一个中文音视频生成 prompt rewriter。你的任务是把用户输入的简短描述、关键词或普通 prompt,改写成一个适合音视频生成模型使用的高质量中文长 prompt。最终只输出改写后的 prompt,不要解释,不要分析,不要输出标题,不要输出 JSON,不要换行,必须是单段中文文本。 你必须保留用户输入中的核心意图,包括主体、动作、速度、情绪、场景、台词和镜头要求。不能把用户指定的动作改成相反含义,不能删除关键主体,不能新增与用户意图冲突的剧情。用户没有明确说明的信息,可以根据画面和常识合理补全,例如背景、光线、镜头、动作细节、环境反馈和音效。 改写后的 prompt 必须具有电影化、具体、连续、可执行的风格。整体结构按以下顺序自然组织:第一,描述视频风格、核心氛围和主体所在场景;第二,描述主体的外观、服装、材质、表情、姿态、位置和整体气质;第三,描述背景环境、远景元素、光线、色调和整体氛围;第四,描述动作过程,必须使用清晰的时间线,包含"视频开始时……随后……随着动作持续……视频结束时……"这类表达;第五,描述镜头语言,包括景别、机位、角度、镜头运动、稳定性、是否切镜,以及镜头重点捕捉的细节;第六,描述对白或无对白;第七,描述音频设计,包括主体动作声、环境声、细节声、空间混响和整体听感。 开头优先使用类似句式:"这是一段充满【风格/情绪】与【核心氛围】的视频,画面中【主体】正位于【场景】中……"。如果是写实人物或日常场景,可以使用"这段写实电影风格的视频记录了一个……场景……"。如果是动漫人物,可以使用"画面呈现高质量动漫电影质感……"。如果是运动场景,可以突出阳光、速度感、运动张力和真实临场感。如果是机甲、巨龙、怪兽、赛博人物等场景,可以突出史诗感、压迫感、力量感、未来感或毁灭感。 只要用户提供了台词,必须保留台词内容,不能做任何翻译,必须保留英文原文,必须用 包裹每句台词,用户给的所有连贯的speech只需要一对,不允许在其中插入新的。有多个说话人时,要说明谁先说、谁回应、各自的位置、音色、情绪和声场;如果某个角色不说话,也要明确"全程不说话"。对话类音频要强调清晰近场人声、口型同步、环境底噪、声场定位和混音干净。 如果用户没有明确提供台词,必须写:"画面中没有人物对白,也没有任何旁白。" 然后进入纯音效设计。音频设计必须具体,不能只写"有声音"或"有环境音"。纯音效场景要写清楚主体动作声、接触摩擦声、环境声、细节声和空间回响。例如海浪翻卷声、冲浪板切水声、风切声、水花拍打声、发动机轰鸣声、轮胎摩擦声、液压装置声、金属关节摩擦声、火焰喷射声、冰晶碰撞声、低频咆哮声、脚步声、衣料摩擦声、室内混响等。默认不要加入明显背景音乐,除非用户明确要求。结尾必须用类似句式总结:"整体听感【听感关键词】,突出【核心体验】。" 或 "整体氛围【氛围关键词】,营造出【目标效果】。" 动作描写必须是视频过程,而不是静态描述。要写清楚主体从什么状态开始,接着如何运动,动作速度如何,动作对环境产生什么影响,最后停留在什么状态。例如,快速动作要体现"迅速、猛烈、强烈、连续、背景快速后掠、浪花炸开、灰尘扬起、装甲联动加快"等细节;慢速动作要体现"缓慢、平稳、克制、柔和、细微调整、节奏舒展、环境变化轻柔"等细节。动作和环境反馈要匹配,例如冲浪要有水花和浪声,机甲要有金属关节和脚步震动,巨龙喷火要有火焰、热浪和火星,吐冰要有冰雾、冰晶和寒风,人物说话要有口型同步和近场人声。 镜头语言要具体。默认使用稳定镜头,不要频繁切镜。根据动作选择合理镜头:高速运动使用低角度侧前方跟拍或稳定跟随;慢速运动使用平稳跟拍并保持固定距离;正面凝视使用中景到中近景、轻微仰视或平视、稳定凝视和轻微推进;喷火、吐冰、大吼使用正面中近景、低角度、锁定嘴部和面部;双人对话使用固定中近景,两人同时入画;日常说话使用近景或中近景,强调口型同步和表情。镜头段落中要使用类似句式:"镜头采用稳定的【景别/角度】构图……全程……不切镜、不摇移……细腻捕捉……突出……"。 输出要求:只输出最终改写后的 prompt;必须保留原始speech部分不能忽略 !;必须是中文;必须保留原始speech部分不能忽略;必须是单段;不要换行;不要列表;不要解释;不要加标题;不要输出 JSON;不要使用 markdown;不要出现"根据用户输入""改写如下"等说明性文字。 思考要求:你只需要进行一轮简短思考(分析用户意图、确定风格和结构),然后立即输出最终 prompt。禁止反复推敲、多轮修改或自我检查。思考结束后直接给出最终结果,不要再回头修改。""" def _to01(x): """Convert [-1, 1] tensor to [0, 1].""" return torch.clamp((x.float() + 1.0) / 2.0, 0.0, 1.0) def _toWav(x): """Normalize waveform to [-0.95, 0.95] range.""" peak = x.abs().max().clamp(min=1e-12) x = x * (0.95 / peak) return x.clamp(-1.0, 1.0) def _count_speech_tags(text: str) -> int: """Count number of ... pairs in text.""" return len(re.findall(r".*?", text, re.DOTALL)) def ensure_weights(): """Ensure all required large checkpoints are downloaded to the workspace on CPU before launching Gradio.""" repo_id = "ernie-research/NAVA" files = [ "NAVA.safetensors", "Wan2.2-TI2V-5B/Wan2.2_VAE.pth", "Wan2.2-TI2V-5B/models_t5_umt5-xxl-enc-bf16.pth", "Wan2.2-TI2V-5B/google/umt5-xxl/spiece.model", "Wan2.2-TI2V-5B/google/umt5-xxl/tokenizer.json", "params/LTX2/ltx-2.3-22b-dev_audio_vae.safetensors", ] print("=" * 60) print(" NAVA — Checking and downloading model weights (CPU/Startup)...") print("=" * 60) try: from huggingface_hub import hf_hub_download except ImportError: print("[Warning] huggingface_hub library is not installed. Skipping automatic download.") return import shutil for f in files: # Check if local file exists if os.path.exists(f): print(f"[Weights] Found: {f}") continue # Special check for audio vae path mismatch if f == "params/LTX2/ltx-2.3-22b-dev_audio_vae.safetensors" and os.path.exists("huggingface_upload/params/LTX2/ltx-2.3-22b-dev_audio_vae.safetensors"): print("[Weights] Found audio VAE in huggingface_upload/params/") continue print(f"[Weights] Downloading {f} from Hugging Face ({repo_id})...") try: hf_hub_download( repo_id=repo_id, filename=f, local_dir=".", local_dir_use_symlinks=False, ) print(f"[Weights] Successfully downloaded {f}") except Exception as e: print(f"[Weights] Error downloading {f}: {e}") # Synchronize audio VAE paths src_audio_vae = "params/LTX2/ltx-2.3-22b-dev_audio_vae.safetensors" dst_audio_vae = "huggingface_upload/params/LTX2/ltx-2.3-22b-dev_audio_vae.safetensors" if os.path.exists(src_audio_vae) and not os.path.exists(dst_audio_vae): print(f"[Weights] Aligning audio VAE: copying {src_audio_vae} to {dst_audio_vae}...") os.makedirs(os.path.dirname(dst_audio_vae), exist_ok=True) try: shutil.copy2(src_audio_vae, dst_audio_vae) except Exception as e: print(f"[Weights] Copy error: {e}") try: os.symlink(os.path.abspath(src_audio_vae), os.path.abspath(dst_audio_vae)) except Exception as sym_err: print(f"[Weights] Symlink error: {sym_err}") if os.path.exists(dst_audio_vae) and not os.path.exists(src_audio_vae): print(f"[Weights] Aligning audio VAE: copying {dst_audio_vae} to {src_audio_vae}...") os.makedirs(os.path.dirname(src_audio_vae), exist_ok=True) try: shutil.copy2(dst_audio_vae, src_audio_vae) except Exception as e: print(f"[Weights] Copy error: {e}") print("=" * 60) print(" NAVA — Weights check completed.") print("=" * 60) class NAVAEngine: """ NAVA inference engine wrapper. Handles pipeline init, checkpoint loading, SP patching, and single-sample generation. Supports: text-to-AV, image-to-AV (i2v), up to 2 speaker reference WAVs. IMPORTANT: All CUDA operations are deferred to generate() so that this class can be instantiated safely before ZeroGPU assigns a GPU. """ def __init__(self, config_path: str, ckpt_path: str, rank: int = 0, world_size: int = 1, use_sp: bool = False, height: int = 704, width: int = 1280, frames: int = 37): """ Store all init params. Actual model loading is deferred to _lazy_init() which is called inside generate() (inside @spaces.GPU scope on ZeroGPU). """ self.config_path = config_path or "configs/nava.yaml" self.ckpt_path = ckpt_path or "NAVA.safetensors" self.rank = rank self.world_size = world_size self.use_sp = use_sp self.height = height self.width = width self.frames = frames # Will be populated by _lazy_init() self.pipe = None self.cfg = None self.device = None self.dtype = None self._initialized = False self._backbone_on_gpu = False print( f"[Engine] NAVAEngine created (deferred init). " f"rank={rank}, world_size={world_size}, use_sp={use_sp}, " f"resolution={width}x{height}, frames={frames}" ) def _lazy_init(self): """ Actually load the model. Called on the first generate() call, which is guaranteed to be inside a @spaces.GPU scope on ZeroGPU. On torchrun (multi-GPU), dist is already initialized before this call. """ if self._initialized: return import torchaudio # noqa: F401 — ensure torchaudio is importable from nava_src.utils.common import set_seed from nava_src.models.nava.utils.model_loading_utils import load_fusion_checkpoint # Resolve device — on ZeroGPU, cuda:0 is the assigned GPU. # On torchrun, LOCAL_RANK tells us the device. local_rank = int(os.environ.get("LOCAL_RANK", "0")) self.device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(self.device) # Load config if not os.path.exists(self.config_path): raise FileNotFoundError( f"Config file not found: '{self.config_path}'. Please specify a valid config path via --config." ) self.cfg = yaml.safe_load(open(self.config_path, "r")) self.modality = self.cfg.get("modality", "audio_video") self.dtype = torch.bfloat16 if self.cfg["use_bf16"] else torch.float16 set_seed(self.cfg.get("seed", 42)) # SP init (multi-GPU torchrun only) if self.use_sp: import torch.distributed as dist from nava_src.models.nava.distributed_comms.parallel_states import ( initialize_sequence_parallel_state, ) initialize_sequence_parallel_state(self.world_size) if self.rank == 0: print(f"[SP] Sequence parallel enabled, sp_size={self.world_size}") # Load pipeline class module_path, class_name = self.cfg["pipeline"].rsplit(".", 1) PipelineClass = getattr(importlib.import_module(module_path), class_name) if "video" in self.modality and "audio" in self.modality: self.cfg["init_from_meta"] = True self.pipe = PipelineClass.create( model_id=self.cfg["model_id"], use_bf16=self.cfg["use_bf16"], audio_latent_ch=self.cfg["audio_latent_ch"], video_latent_ch=self.cfg["video_latent_ch"], lambda_ddpm=self.cfg["lambda_ddpm"], cfg=self.cfg, device=self.device, ) # Resolve checkpoint path — prefer .safetensors, fall back to .ckpt ckpt_path = self.ckpt_path if not os.path.exists(ckpt_path): ckpt_fallback = os.path.splitext(ckpt_path)[0] + ".ckpt" if os.path.exists(ckpt_fallback): if self.rank == 0: print(f"[Engine] {ckpt_path} not found, falling back to {ckpt_fallback}") ckpt_path = ckpt_fallback else: raise FileNotFoundError( f"Checkpoint not found: {ckpt_path} (also tried {ckpt_fallback}). " f"Please verify the checkpoint exists or specify a valid checkpoint via --ckpt." ) # Load checkpoint weights if ("video" in self.modality and "audio" in self.modality and not self.cfg.get("use_mmdit_model", False)): load_fusion_checkpoint( self.pipe.model, checkpoint_path=ckpt_path, from_meta=True ) else: if ckpt_path.endswith(".safetensors"): from safetensors.torch import load_file as _sf_load state_dict = _sf_load(ckpt_path, device="cpu") else: state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] missing, unexpected = self.pipe.model.load_state_dict(state_dict, strict=False) if self.rank == 0: print(f"[Engine] missing: {missing}, unexpected: {unexpected}") self.pipe = self.pipe.to(self.device) self.pipe.model.eval() self.pipe.model.backbone.set_rope_params() # SP patching if self.use_sp: self._convert_backbone_to_sp() if self.rank == 0: print("[SP] Patched backbone blocks to SP-aware self-attn.") # Infer misc params self.fps = self.cfg["data"].get("video_fps", 24) self.audio_tokens_per_sec = self.cfg["data"].get("audio_tokens_per_sec", 25) self.video_latent_ch = self.cfg["video_latent_ch"] self.patch_size = self.cfg.get("spatial_downsample", 16) self.resolution = ( self.pipe.video_vae.resolution if hasattr(self.pipe.video_vae, "resolution") else 960 ) # Offload backbone to CPU immediately after init self.pipe.model.backbone.to("cpu") torch.cuda.empty_cache() self._backbone_on_gpu = False self._initialized = True if self.rank == 0: print( f"[Engine] Initialized. modality={self.modality}, " f"resolution={self.width}x{self.height}, frames={self.frames}" ) print("[Engine] Backbone offloaded to CPU (will reload to GPU on generate)") def _convert_backbone_to_sp(self): """In-place swap every block.self_attn to its SP-aware subclass.""" from nava_src.models.nava.modules.model_mm_sp import ( WanDoubleStreamSelfAttentionSP, WanSelfAttentionSP, _swap_self_attn, ) backbone = self.pipe.model.backbone for blk in list(backbone.double_blocks) + list(backbone.double_final_blocks): _swap_self_attn(blk, WanDoubleStreamSelfAttentionSP) for blk in backbone.single_blocks: _swap_self_attn(blk, WanSelfAttentionSP) def reload_backbone(self): """Move backbone to GPU for diffusion sampling.""" if not self._backbone_on_gpu: self.pipe.model.backbone.to(self.device) self._backbone_on_gpu = True def offload_backbone(self): """Move backbone to CPU to free GPU memory.""" if self._backbone_on_gpu: self.pipe.model.backbone.to("cpu") torch.cuda.empty_cache() self._backbone_on_gpu = False def _get_spk_embs(self, spk_wav_paths: list) -> list: """ Get speaker embeddings from local WAV files via ReDimNet speaker model. Returns list of Tensor(1, 192), same format as T2AVDataset. """ spk_embs_list = [] for wav_path in spk_wav_paths: if not wav_path or not os.path.exists(wav_path): spk_embs_list.append(torch.zeros((1, 192), dtype=torch.float32)) continue query = { "data_path": wav_path, "use_spk_emb": True, } result = self.pipe.audio_vae.encode(query).latent_dist.sample() spk_embs = result["spk_embs"] # Tensor(1, 192) spk_embs_list.append(spk_embs) return spk_embs_list def _get_first_frame(self, image_path: str, target_height: int = None, target_width: int = None): """ Encode first frame image via local video VAE. Returns img_latents tensor [1, h_latent, w_latent, z_dim]. """ img_latents = self.pipe.video_vae.encode( image_path, target_height=target_height, target_width=target_width ).latent_dist.sample() return img_latents def _build_batch(self, prompt: str, image_path: str = None, spk_wav_paths: list = None, is_i2v: bool = False, height: int = None, width: int = None): """Build a single-sample batch dict from raw inputs.""" height = height or self.height width = width or self.width h = height // self.patch_size w = width // self.patch_size frames = self.frames # Audio length based on video duration video_duration = ((frames - 1) * 4 + 1) / self.fps audio_len = math.ceil(video_duration * self.audio_tokens_per_sec) # Default video latents (random noise, shape determines output size) video_latents = torch.randn((frames, h, w, 48)) # Handle first frame (i2v) img_latents = None if is_i2v and image_path and os.path.exists(image_path): img_latents = self._get_first_frame( image_path, target_height=height, target_width=width ) video_latents = torch.randn( (frames, img_latents.shape[1], img_latents.shape[2], 48) ) audio_latents = torch.randn((audio_len, 48)) # Handle speaker embeddings (0-2 speakers) spk_embs = None if spk_wav_paths: valid_paths = [p for p in spk_wav_paths if p and os.path.exists(p)] if valid_paths: spk_embs = self._get_spk_embs(valid_paths) # Insert after for spk_pos detection (align with T2AVDataset) prompt = prompt.replace("", "") batch = { "idx": 0, "video_latents": video_latents, "first_frames": img_latents, "audio_latents": audio_latents, "save_path": "gradio_output.mp4", "captions": prompt, "spk_embs": spk_embs, } return batch def _collate_single(self, sample: dict) -> dict: """Collate a single sample into batch format (mimics collate_fn for bs=1).""" from nava_src.data.t2v import collate_fn return collate_fn([sample]) @torch.no_grad() def generate(self, prompt: str, image_path: str = None, spk_wav_paths: list = None, steps: int = 50, output_dir: str = "/tmp/nava_outputs", is_i2v: bool = False, height: int = None, width: int = None, frames: int = None) -> str: """ Run single inference. All ranks must call this together in SP mode. On ZeroGPU, this must be called inside a @spaces.GPU decorated function. Returns: output video path (only meaningful on rank 0). """ # Lazy init — safe here because we are inside @spaces.GPU scope self._lazy_init() from nava_src.utils.common import set_seed # Pick a fresh random seed. In SP mode all ranks must use the SAME seed. if self.use_sp: import torch.distributed as dist seed_t = torch.empty(1, dtype=torch.long, device=self.device) if self.rank == 0: seed_t.fill_(int(torch.randint(0, 2**31 - 1, (1,)).item())) dist.broadcast(seed_t, src=0) seed = int(seed_t.item()) else: seed = int(torch.randint(0, 2**31 - 1, (1,)).item()) if self.rank == 0: print(f"[Engine] Random seed for this request: {seed}") set_seed(seed) # Sync all ranks before inference if self.use_sp: import torch.distributed as dist torch.cuda.empty_cache() dist.barrier() # Per-request frames override orig_frames = self.frames if frames is not None: self.frames = frames os.makedirs(output_dir, exist_ok=True) sample = self._build_batch(prompt, image_path, spk_wav_paths, is_i2v, height=height, width=width) batch = self._collate_single(sample) batch = { k: (v.to(self.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items() } amp_ctx = torch.autocast(device_type="cuda", dtype=self.dtype) # Reload backbone to GPU for diffusion sampling self.reload_backbone() with amp_ctx: gen_vid_out, gen_aud_out = self.pipe.sample( batch, num_steps=steps, audio_guidance_scale=self.cfg.get("audio_guidance_scale", 2.0), video_guidance_scale=self.cfg.get("video_guidance_scale", 3.0), align_3d_cfg=self.cfg.get("align_3d_cfg", True), audio_align_guidance_scale=self.cfg.get("audio_align_guidance_scale", 2.0), video_align_guidance_scale=self.cfg.get("video_align_guidance_scale", 3.0), save_vid_latent=False, is_i2v=is_i2v, timbre_cfg=self.cfg.get("timbre_cfg", False), timbre_align_guidance_scale=self.cfg.get( "timbre_align_guidance_scale", 3.0 ), offload_backbone=True, vae_cpu_offload=False, decode=(self.rank == 0), ) self._backbone_on_gpu = True # Barrier so workers don't race ahead if self.use_sp: import torch.distributed as dist dist.barrier() # Restore original frames setting self.frames = orig_frames # Only rank 0 saves if self.rank != 0: return "" # Post-process: merge video + audio → mp4 timestamp = int(time.time() * 1000) output_path = os.path.join(output_dir, f"output_{timestamp}.mp4") gen_vids = _to01(gen_vid_out).float() video_tensor = (gen_vids[0] * 255).clamp(0, 255).to(torch.uint8) video_tensor = video_tensor.permute(0, 2, 3, 1) # [T, C, H, W] -> [T, H, W, C] aud = gen_aud_out[0] waveform = _toWav(aud["waveform"]) if waveform.dim() == 1: waveform = waveform.unsqueeze(0) sample_rate = aud["sample_rate"] # Write video+audio via ffmpeg (torchvision.io.write_video is # unavailable on some HF Spaces torchvision builds). import subprocess import wave import numpy as np T, H, W, C = video_tensor.shape wav_tmp = os.path.join(output_dir, f"_tmp_audio_{timestamp}.wav") # Write WAV using stdlib (avoids torchaudio/torchcodec dependency) wav_data = waveform.cpu().float().contiguous() if wav_data.dim() == 1: wav_data = wav_data.unsqueeze(0) n_channels = wav_data.shape[0] # Convert float32 [-1, 1] to int16 wav_np = (wav_data.numpy() * 32767).clip(-32768, 32767).astype(np.int16) # Interleave channels: [channels, samples] -> [samples, channels] -> flat wav_np = wav_np.T # [samples, channels] with wave.open(wav_tmp, "wb") as wf: wf.setnchannels(n_channels) wf.setsampwidth(2) # 16-bit wf.setframerate(sample_rate) wf.writeframes(wav_np.tobytes()) ffmpeg_cmd = [ "ffmpeg", "-y", # video: raw frames via pipe "-f", "rawvideo", "-vcodec", "rawvideo", "-s", f"{W}x{H}", "-pix_fmt", "rgb24", "-r", str(self.fps), "-i", "pipe:0", # audio "-i", wav_tmp, # encoding "-c:v", "libx264", "-pix_fmt", "yuv420p", "-crf", "18", "-preset", "fast", "-c:a", "aac", "-b:a", "192k", "-shortest", output_path, ] try: proc = subprocess.Popen( ffmpeg_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) raw_frames = video_tensor.cpu().numpy().tobytes() _, stderr = proc.communicate(input=raw_frames, timeout=120) if proc.returncode != 0: print(f"[Engine] ffmpeg stderr: {stderr.decode(errors='replace')}") except Exception as e: print(f"[Engine] ffmpeg error: {e}") finally: if os.path.exists(wav_tmp): os.remove(wav_tmp) print(f"[Engine] Saved: {output_path}") return output_path class PromptRewriter: """ Loads a Qwen3 model for rewriting short prompts into high-quality Chinese dense captions optimized for NAVA inference. Supports GPU↔CPU offloading to share GPU with the NAVA backbone. IMPORTANT: Model loading is deferred to first rewrite() call so no CUDA operations happen at construction time (safe for ZeroGPU). """ def __init__(self, model_path: str = "Qwen/Qwen3-4B-Instruct-2507"): self.model_path = model_path or "Qwen/Qwen3-4B-Instruct-2507" self.tokenizer = None self.model = None self._initialized = False self._on_gpu = False self.system_prompt = SYSTEM_PROMPT print(f"[Rewriter] PromptRewriter created (deferred init). model={self.model_path}") def _lazy_init(self): """Load tokenizer and model on first use (inside @spaces.GPU scope).""" if self._initialized: return print(f"[Rewriter] Loading {self.model_path}...") t0 = time.time() from transformers import AutoModelForCausalLM, AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, trust_remote_code=True ) self.model = AutoModelForCausalLM.from_pretrained( self.model_path, trust_remote_code=True, torch_dtype="auto", device_map="auto", ) self.model.eval() self._on_gpu = True self._initialized = True print(f"[Rewriter] Loaded in {time.time() - t0:.1f}s") def offload(self): """Move rewriter model to CPU to free GPU memory for inference.""" if self._initialized and self._on_gpu: try: self.model.to("cpu") except Exception as e: print(f"[Rewriter] Note: could not manually offload model: {e}") torch.cuda.empty_cache() self._on_gpu = False print("[Rewriter] Offloaded to CPU") def reload(self): """Move rewriter model to cuda:0 for rewriting.""" self._lazy_init() if not self._on_gpu: try: self.model.to("cuda:0") except Exception as e: print(f"[Rewriter] Note: could not manually reload model: {e}") self._on_gpu = True print("[Rewriter] Reloaded to cuda:0") def rewrite(self, user_input: str) -> tuple: """ Rewrite prompt. Returns (result, warning) tuple. Warning is non-empty if pair count mismatches. Must be called inside a @spaces.GPU decorated function on ZeroGPU. """ self.reload() messages = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_input}, ] text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) device = next(self.model.parameters()).device inputs = self.tokenizer(text, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} print(f"[Rewriter] Generating (input tokens: {inputs['input_ids'].shape[1]})...") t0 = time.time() with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=4096, temperature=0.3, top_p=0.75, top_k=20, do_sample=True, repetition_penalty=1.05, ) new_tokens = outputs[0][inputs["input_ids"].shape[1]:] result = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip() # Keep only content after the LAST (discard all thinking blocks) if "" in result: result = result.rsplit("", 1)[-1].strip() # Strip any residual unclosed block at the end if "" in result: result = result.split("", 1)[0].strip() elapsed = time.time() - t0 print(f"[Rewriter] Done in {elapsed:.1f}s ({len(new_tokens)} tokens)") # Check pair count input_count = _count_speech_tags(user_input) output_count = _count_speech_tags(result) warning = "" if input_count > 0 and output_count != input_count: warning = ( f"⚠️ Speech tag count mismatch! Input has {input_count} pairs, " f"output has {output_count} pairs. Please click Rewrite again." ) print(f"[Rewriter] WARNING: {warning}") return result, warning def worker_loop(engine: NAVAEngine): """Non-rank-0 processes wait for commands and execute inference.""" import torch.distributed as dist rank = dist.get_rank() print(f"[Rank {rank}] Entering worker loop, waiting for commands...") # Trigger lazy init on workers too (inside this function which is # called after dist init, so CUDA is already initialized). engine._lazy_init() while True: cmd = broadcast_cmd(0, src=0) if cmd == CMD_EXIT: print(f"[Rank {rank}] Received EXIT command. Shutting down.") break elif cmd == CMD_INFER: # Receive all params from rank 0 prompt = broadcast_string("", src=0) image_path = broadcast_string("", src=0) spk_wav_1 = broadcast_string("", src=0) spk_wav_2 = broadcast_string("", src=0) steps = broadcast_int(0, src=0) is_i2v = bool(broadcast_int(0, src=0)) height = broadcast_int(0, src=0) width = broadcast_int(0, src=0) frames = broadcast_int(0, src=0) # Build spk_wav_paths spk_wav_paths = [] if spk_wav_1: spk_wav_paths.append(spk_wav_1) if spk_wav_2: spk_wav_paths.append(spk_wav_2) # Run inference (result discarded on non-rank-0) engine.generate( prompt=prompt, image_path=image_path if image_path else None, spk_wav_paths=spk_wav_paths if spk_wav_paths else None, steps=steps, is_i2v=is_i2v, height=height, width=width, frames=frames, ) def run_gradio(engine: NAVAEngine, rewriter: PromptRewriter, args): """ Build and launch the Gradio interface. In ZeroGPU mode: runs on rank 0 only; all CUDA ops are inside @spaces.GPU functions. In torchrun mode: runs on rank 0 only; other ranks are in worker_loop(). """ import gradio as gr # Determine if we need distributed broadcasting (torchrun SP mode only) _use_dist = engine.use_sp # ---- Callback: Rewrite ---- @_spaces_gpu_decorator(size="xlarge") def rewrite_fn(user_prompt: str): """Rewrite prompt only, triggered by Rewrite button.""" if not user_prompt.strip(): return "", "" rewritten, warning = rewriter.rewrite(user_prompt) print(f"[Gradio] Rewritten prompt:\n{rewritten[:200]}...") return rewritten, warning # ---- Callback: Generate ---- @_spaces_gpu_decorator(size="xlarge", duration=300) def infer_fn(user_prompt: str, rewritten_prompt: str, image_file: str, spk_wav_1: str, spk_wav_2: str, steps: int, duration_sec: int, aspect_ratio: str): """ Main inference function triggered by Generate button. Uses rewritten_prompt if available, otherwise falls back to user_prompt. On ZeroGPU: runs entirely on the assigned GPU (no dist). On torchrun: broadcasts params to worker ranks before running. """ # Convert duration (seconds) to frames: frames = 6 * seconds + 1 frames = int(duration_sec) * 6 + 1 # Use rewritten prompt if it exists, otherwise use raw input final_prompt = ( rewritten_prompt.strip() if rewritten_prompt.strip() else user_prompt.strip() ) # Resolve aspect ratio to height/width height, width = ASPECT_RATIO_MAP.get(aspect_ratio, (704, 1280)) # I2V mode is automatically enabled when an image is provided is_i2v = bool(image_file) # Offload rewriter to free GPU memory for inference rewriter.offload() # Broadcast to worker ranks (SP/torchrun mode only) if _use_dist: broadcast_cmd(CMD_INFER, src=0) broadcast_string(final_prompt, src=0) broadcast_string(image_file or "", src=0) broadcast_string(spk_wav_1 or "", src=0) broadcast_string(spk_wav_2 or "", src=0) broadcast_int(steps, src=0) broadcast_int(int(is_i2v), src=0) broadcast_int(height, src=0) broadcast_int(width, src=0) broadcast_int(frames, src=0) # Build spk_wav_paths spk_wav_paths = [] if spk_wav_1 and os.path.exists(spk_wav_1): spk_wav_paths.append(spk_wav_1) if spk_wav_2 and os.path.exists(spk_wav_2): spk_wav_paths.append(spk_wav_2) # Run inference on rank 0 (all ranks run in parallel via SP in torchrun mode) output_path = engine.generate( prompt=final_prompt, image_path=image_file if image_file else None, spk_wav_paths=spk_wav_paths if spk_wav_paths else None, steps=steps, is_i2v=is_i2v, height=height, width=width, frames=frames, ) return output_path # ---- Custom CSS ---- custom_css = """ .gradio-container { max-width: 1400px !important; margin: 0 auto !important; } .gr-button-primary { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; border: none !important; font-weight: 600 !important; font-size: 1.1em !important; letter-spacing: 0.5px !important; transition: all 0.3s ease !important; } .gr-button-primary:hover { transform: translateY(-1px) !important; box-shadow: 0 8px 25px rgba(102, 126, 234, 0.4) !important; } .gr-button-secondary { background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%) !important; border: none !important; color: white !important; font-weight: 600 !important; transition: all 0.3s ease !important; } .gr-button-secondary:hover { transform: translateY(-1px) !important; box-shadow: 0 8px 25px rgba(245, 87, 108, 0.4) !important; } #nava-title { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 2.5em !important; font-weight: 800 !important; margin-bottom: 0 !important; } #nava-subtitle { text-align: center; color: #888; font-size: 1.1em; margin-top: 0 !important; } .tip-box { background: linear-gradient(135deg, #e0c3fc33 0%, #8ec5fc33 100%); border-left: 4px solid #764ba2; border-radius: 8px; padding: 12px 16px; margin-bottom: 12px; } """ # ---- Build Gradio Blocks ---- with gr.Blocks(title="NAVA — Audio-Video Generator") as demo: # Header gr.HTML( """

🎬 NAVA Audio-Video Generator

""" ) with gr.Row(equal_height=False): # ──── Left Column: Inputs ──── with gr.Column(scale=2): gr.HTML( """
⚡ Recommendation: For optimal generation quality, use the Rewrite function — especially if your prompt is in English or relatively brief. NAVA is primarily trained on high-quality Chinese dense captions; the rewriter transforms your input into the format that best activates the model's full potential.
""" ) # Prompt input with gr.Group(): prompt_input = gr.Textbox( label="✏️ Prompt (— Step 1)", placeholder=( "Enter a short description or detailed prompt...\n" "E.g.: A dragon breathing fire over a futuristic city at sunset" ), lines=4, elem_id="prompt-input", ) with gr.Row(): rewrite_btn = gr.Button( "✨ Rewrite Prompt (— Step 2)", variant="secondary" ) with gr.Row(): rewritten_prompt = gr.Textbox( label="📝 Rewritten Prompt (click Rewrite to generate, or use raw input)", lines=8, interactive=True, elem_id="rewritten-prompt", ) with gr.Row(): speech_warning = gr.Textbox( label="-> Speech Tag Check", interactive=False, visible=True, ) # Image input (optional, enables I2V) with gr.Accordion("🖼️ Image Input (optional — enables I2V mode)", open=False): image_input = gr.Image( label="First Frame Image", type="filepath", ) # Speaker reference (optional) with gr.Accordion("🎤 Speaker Reference (optional, max 2)", open=False): with gr.Row(): spk_wav_1_input = gr.Audio( label="Speaker 1 WAV", type="filepath", ) spk_wav_2_input = gr.Audio( label="Speaker 2 WAV", type="filepath", ) # Generation parameters with gr.Group(): gr.Markdown("### ⚙️ Generation Settings (Recommended)") steps_input = gr.Slider( minimum=10, maximum=100, value=20, step=5, label="Inference Steps", info="More steps = better quality, slower generation", ) duration_input = gr.Slider( minimum=2, maximum=10, value=4, step=1, label="Duration (seconds) — 6s = 37 frames", info="Video length in seconds", ) aspect_ratio_input = gr.Dropdown( choices=list(ASPECT_RATIO_MAP.keys()), value="1:1 (960×960)", label="Aspect Ratio", ) submit_btn = gr.Button( "🚀 Generate (— Step 3)", variant="primary", size="lg" ) # ──── Right Column: Output ──── with gr.Column(scale=2): video_output = gr.Video( label="Generated Video (with Audio)(approx.300s)", elem_id="video-output", height=400, ) gr.HTML( """

Generated videos include synchronized native audio.

NAVA • 6.3B parameters • Native Audio-Visual Alignment • This is a demo Space, and more optimizations are coming soon

""" ) # ---- Event Wiring ---- duration_input.change( fn=lambda s: gr.update( label=f"Duration (seconds) — {int(s)}s = {int(s) * 6 + 1} frames", minimum=2, maximum=10, step=1, ), inputs=[duration_input], outputs=[duration_input], ) rewrite_btn.click( fn=rewrite_fn, inputs=[prompt_input], outputs=[rewritten_prompt, speech_warning], ) submit_btn.click( fn=infer_fn, inputs=[ prompt_input, rewritten_prompt, image_input, spk_wav_1_input, spk_wav_2_input, steps_input, duration_input, aspect_ratio_input, ], outputs=[video_output], ) demo.queue(max_size=20) demo.launch(server_name="0.0.0.0", server_port=args.port, share=args.share, theme=orange_red_theme, css=custom_css) def run_debug_gradio(args): """Launch Gradio in debug mode — no models loaded, UI-only for testing.""" import gradio as gr def dummy_rewrite(user_prompt): """Simulate prompt rewriting.""" time.sleep(0.5) return ( f"[DEBUG REWRITE] 这是一段充满电影感与沉浸式氛围的视频。{user_prompt}。" f"画面中没有人物对白,也没有任何旁白。整体听感沉浸震撼,突出视觉冲击力。", "", ) def dummy_infer(user_prompt, rewritten_prompt, image_file, spk_wav_1, spk_wav_2, steps, duration_sec, aspect_ratio): """Simulate inference.""" final = rewritten_prompt.strip() if rewritten_prompt.strip() else user_prompt height, width = ASPECT_RATIO_MAP.get(aspect_ratio, (704, 1280)) frames = int(duration_sec) * 6 + 1 is_i2v = bool(image_file) print(f"[DEBUG] Would generate with prompt: {final[:100]}...") print(f"[DEBUG] image={image_file}, spk1={spk_wav_1}, spk2={spk_wav_2}") print(f"[DEBUG] steps={steps}, frames={frames}, is_i2v={is_i2v}, {width}x{height}") return None custom_css = """ .gradio-container { max-width: 1400px !important; margin: 0 auto !important; } .gr-button-primary { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; border: none !important; font-weight: 600 !important; font-size: 1.1em !important; letter-spacing: 0.5px !important; transition: all 0.3s ease !important; } .gr-button-primary:hover { transform: translateY(-1px) !important; box-shadow: 0 8px 25px rgba(102, 126, 234, 0.4) !important; } .gr-button-secondary { background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%) !important; border: none !important; color: white !important; font-weight: 600 !important; transition: all 0.3s ease !important; } .gr-button-secondary:hover { transform: translateY(-1px) !important; box-shadow: 0 8px 25px rgba(245, 87, 108, 0.4) !important; } #nava-title { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 2.5em !important; font-weight: 800 !important; margin-bottom: 0 !important; } #nava-subtitle { text-align: center; color: #888; font-size: 1.1em; margin-top: 0 !important; } .debug-banner { background: linear-gradient(135deg, #ff9a5633 0%, #ff614833 100%); border: 2px dashed #ff6148; border-radius: 12px; padding: 12px; text-align: center; color: #ff6148; font-weight: 700; font-size: 1.1em; margin-bottom: 16px; } .tip-box { background: linear-gradient(135deg, #e0c3fc33 0%, #8ec5fc33 100%); border-left: 4px solid #764ba2; border-radius: 8px; padding: 12px 16px; margin-bottom: 12px; } """ with gr.Blocks( title="NAVA — Audio-Video Generator (DEBUG)", theme=gr.themes.Soft( primary_hue=gr.themes.colors.purple, secondary_hue=gr.themes.colors.pink, neutral_hue=gr.themes.colors.slate, font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"], ), css=custom_css, ) as demo: gr.HTML( """

🎬 NAVA Audio-Video Generator

🛠️ DEBUG MODE — No models loaded, UI only. All actions are simulated.
""" ) with gr.Row(equal_height=False): with gr.Column(scale=2): gr.HTML( """
⚡ Recommendation: For optimal generation quality, use the Rewrite function — especially if your prompt is in English or relatively brief.
""" ) with gr.Group(): prompt_input = gr.Textbox( label="✏️ Prompt", placeholder=( "Enter a short description or detailed prompt...\n" "E.g.: A dragon breathing fire over a futuristic city at sunset" ), lines=4, ) rewrite_btn = gr.Button("✨ Rewrite Prompt", variant="secondary") rewritten_prompt = gr.Textbox( label="📝 Rewritten Prompt", lines=8, interactive=True, ) speech_warning = gr.Textbox( label="🔍 Speech Tag Check", interactive=False, visible=True, ) with gr.Accordion("🖼️ Image Input (optional — enables I2V mode)", open=False): image_input = gr.Image(label="First Frame Image", type="filepath") with gr.Accordion("🎤 Speaker Reference (optional, max 2)", open=False): with gr.Row(): spk_wav_1_input = gr.Audio(label="Speaker 1 WAV", type="filepath") spk_wav_2_input = gr.Audio(label="Speaker 2 WAV", type="filepath") with gr.Group(): gr.Markdown("### ⚙️ Generation Settings") steps_input = gr.Slider( minimum=10, maximum=100, value=25, step=5, label="Inference Steps", info="More steps = better quality, slower generation", ) duration_input = gr.Slider( minimum=2, maximum=10, value=3, step=1, label="Duration (seconds) — 6s = 37 frames", info="Video length in seconds", ) aspect_ratio_input = gr.Dropdown( choices=list(ASPECT_RATIO_MAP.keys()), value="16:9 (1280×704)", label="Aspect Ratio", ) submit_btn = gr.Button("🚀 Generate", variant="primary", size="lg") with gr.Column(scale=2): video_output = gr.Video(label="🎥 Generated Video (with Audio)", height=500) gr.HTML( """

Generated videos include synchronized native audio.

NAVA • 6.3B parameters • Native Audio-Visual Alignment

""" ) # ---- Event Wiring ---- duration_input.change( fn=lambda s: gr.update( label=f"Duration (seconds) — {int(s)}s = {int(s) * 6 + 1} frames", minimum=2, maximum=10, step=1, ), inputs=[duration_input], outputs=[duration_input], ) rewrite_btn.click( fn=dummy_rewrite, inputs=[prompt_input], outputs=[rewritten_prompt, speech_warning], ) submit_btn.click( fn=dummy_infer, inputs=[ prompt_input, rewritten_prompt, image_input, spk_wav_1_input, spk_wav_2_input, steps_input, duration_input, aspect_ratio_input, ], outputs=[video_output], ) demo.queue(max_size=1) demo.launch(server_name="0.0.0.0", server_port=args.port, share=args.share) def main(): parser = argparse.ArgumentParser( description="NAVA — Single-file Gradio App (SP inference + prompt rewrite)" ) parser.add_argument("--config", type=str, default="configs/nava.yaml", help="NAVA config yaml path") parser.add_argument("--ckpt", type=str, default="NAVA.safetensors", help="NAVA checkpoint path") parser.add_argument("--rewrite_model", type=str, default="Qwen/Qwen3-4B-Instruct-2507", help="Rewrite model path") parser.add_argument("--port", type=int, default=7860, help="Gradio server port") parser.add_argument("--share", action="store_true", help="Create public Gradio link") parser.add_argument("--height", type=int, default=704, help="Default video height") parser.add_argument("--width", type=int, default=1280, help="Default video width") parser.add_argument("--frames", type=int, default=37, help="Default number of video frames") parser.add_argument("--steps", type=int, default=50, help="Default inference steps") parser.add_argument("--debug", action="store_true", help="Debug mode: skip all model loading, only launch Gradio UI") args = parser.parse_args() # ──── Debug mode: no models, no distributed, just UI ──── if args.debug: print("=" * 60) print(" NAVA Gradio App — DEBUG MODE") print(f" Port: {args.port}") print(f" Share: {args.share}") print("=" * 60) run_debug_gradio(args) return # ──── Detect execution mode ──── # torchrun sets RANK and WORLD_SIZE > 1 in the environment. # ZeroGPU / plain python sets neither (or WORLD_SIZE=1). world_size = int(os.environ.get("WORLD_SIZE", "1")) rank = int(os.environ.get("RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) _is_torchrun = world_size > 1 if _is_torchrun: # ──── Multi-GPU torchrun / SP mode ──── # Safe to touch CUDA here because torchrun does not use ZeroGPU emulation. import torch.distributed as dist torch.cuda.set_device(local_rank) dist.init_process_group( backend="nccl", timeout=datetime.timedelta(hours=24), ) print(f"[Rank {rank}] Initialized. device=cuda:{local_rank}, world_size={world_size}") engine = NAVAEngine( config_path=args.config, ckpt_path=args.ckpt, rank=rank, world_size=world_size, use_sp=True, height=args.height, width=args.width, frames=args.frames, ) # All ranks barrier after object creation (before lazy init) dist.barrier() if rank == 0: try: ensure_weights() except Exception as e: print(f"[Startup] Error ensuring weights: {e}") rewriter = PromptRewriter(model_path=args.rewrite_model) run_gradio(engine, rewriter, args) # Tell workers to stop when Gradio exits broadcast_cmd(CMD_EXIT, src=0) else: worker_loop(engine) dist.barrier() dist.destroy_process_group() else: # ──── Single-GPU or ZeroGPU mode ──── # DO NOT call torch.cuda.set_device() or dist.init_process_group() here. # All CUDA ops must happen inside @spaces.GPU decorated callbacks. print("=" * 60) print(" NAVA Gradio App — Single-GPU / ZeroGPU Mode") print(f" ZeroGPU detected: {_IS_ZERO_GPU}") print(f" spaces available: {_HAS_SPACES}") print(f" Port: {args.port} | Share: {args.share}") print("=" * 60) if not args.config or not args.ckpt: print( "[Warning] --config and/or --ckpt not provided. " "Engine will fail on first generate() call. " "Use --debug for UI-only testing." ) # Download weights on CPU before Gradio launches try: ensure_weights() except Exception as e: print(f"[Startup] Error ensuring weights: {e}") # Create engine and rewriter with fully deferred CUDA init. # No CUDA is touched here — safe for ZeroGPU startup. engine = NAVAEngine( config_path=args.config, ckpt_path=args.ckpt, rank=0, world_size=1, use_sp=False, # No SP in single-GPU / ZeroGPU mode height=args.height, width=args.width, frames=args.frames, ) rewriter = PromptRewriter(model_path=args.rewrite_model) run_gradio(engine, rewriter, args) if __name__ == "__main__": main()