Spaces:
Running on RTX PRO 6000
Running on RTX PRO 6000
multimodalart HF Staff
Speed: text-embed cache + lighter 384x640 profile + PyAV (VAE compile opt-in)
2eeb78f | """Cache LTX-2 text embeddings across same-prompt rollout segments. | |
| The continuous rollout calls the pipeline once per segment, and FastVideo | |
| re-encodes the prompt through the ~25 GB Gemma text encoder every time — even | |
| when the prompt is unchanged (the common case while the user isn't editing). | |
| This monkeypatches `LTX2TextEncodingStage.forward` to memoize the encoded | |
| embeddings by (prompt, negative_prompt) and skip Gemma on a hit. | |
| Installed at import time so the patch is live in FastVideo's spawned worker | |
| (spawn re-imports the app module). Disable with DREAMVERSE_TEXT_CACHE=0. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| ENABLED = os.getenv("DREAMVERSE_TEXT_CACHE", "1") == "1" | |
| _MAX = int(os.getenv("DREAMVERSE_TEXT_CACHE_MAX", "32")) | |
| def _key(batch): | |
| return (repr(getattr(batch, "prompt", None)), | |
| repr(getattr(batch, "negative_prompt", None))) | |
| def install(): | |
| if not ENABLED: | |
| return | |
| try: | |
| from fastvideo.pipelines.basic.ltx2.stages.ltx2_text_encoding import LTX2TextEncodingStage | |
| except Exception as e: | |
| print(f"[textcache] LTX2 stage not importable here ({e}); skipping", flush=True) | |
| return | |
| if getattr(LTX2TextEncodingStage, "_textcache_patched", False): | |
| return | |
| LTX2TextEncodingStage._textcache_patched = True | |
| orig = LTX2TextEncodingStage.forward | |
| cache: dict = {} | |
| def patched(self, batch, fastvideo_args): | |
| key = _key(batch) | |
| hit = cache.get(key) | |
| if hit is not None: | |
| batch.prompt_embeds = list(hit["pe"]) | |
| if hit.get("pm") is not None: | |
| batch.prompt_attention_mask = list(hit["pm"]) | |
| if hit.get("ne") is not None: | |
| batch.negative_prompt_embeds = list(hit["ne"]) | |
| if "audio" in hit: | |
| batch.extra["ltx2_audio_prompt_embeds"] = hit["audio"] | |
| print("[textcache] hit (skipped Gemma encode)", flush=True) | |
| return batch | |
| out = orig(self, batch, fastvideo_args) | |
| batch = out if out is not None else batch | |
| try: | |
| entry = { | |
| "pe": list(batch.prompt_embeds or []), | |
| "pm": list(batch.prompt_attention_mask) if getattr(batch, "prompt_attention_mask", None) else None, | |
| "ne": list(batch.negative_prompt_embeds) if getattr(batch, "negative_prompt_embeds", None) else None, | |
| } | |
| if "ltx2_audio_prompt_embeds" in batch.extra: | |
| entry["audio"] = batch.extra["ltx2_audio_prompt_embeds"] | |
| cache[key] = entry | |
| if len(cache) > _MAX: | |
| cache.pop(next(iter(cache))) | |
| except Exception as e: | |
| print(f"[textcache] store skipped ({e})", flush=True) | |
| return batch | |
| LTX2TextEncodingStage.forward = patched | |
| print("[textcache] installed Gemma text-embedding cache", flush=True) | |