"""Cache LTX-2 text embeddings across same-prompt rollout segments. The continuous rollout calls the pipeline once per segment, and FastVideo re-encodes the prompt through the ~25 GB Gemma text encoder every time — even when the prompt is unchanged (the common case while the user isn't editing). This monkeypatches `LTX2TextEncodingStage.forward` to memoize the encoded embeddings by (prompt, negative_prompt) and skip Gemma on a hit. Installed at import time so the patch is live in FastVideo's spawned worker (spawn re-imports the app module). Disable with DREAMVERSE_TEXT_CACHE=0. """ from __future__ import annotations import os ENABLED = os.getenv("DREAMVERSE_TEXT_CACHE", "1") == "1" _MAX = int(os.getenv("DREAMVERSE_TEXT_CACHE_MAX", "32")) def _key(batch): return (repr(getattr(batch, "prompt", None)), repr(getattr(batch, "negative_prompt", None))) def install(): if not ENABLED: return try: from fastvideo.pipelines.basic.ltx2.stages.ltx2_text_encoding import LTX2TextEncodingStage except Exception as e: print(f"[textcache] LTX2 stage not importable here ({e}); skipping", flush=True) return if getattr(LTX2TextEncodingStage, "_textcache_patched", False): return LTX2TextEncodingStage._textcache_patched = True orig = LTX2TextEncodingStage.forward cache: dict = {} def patched(self, batch, fastvideo_args): key = _key(batch) hit = cache.get(key) if hit is not None: batch.prompt_embeds = list(hit["pe"]) if hit.get("pm") is not None: batch.prompt_attention_mask = list(hit["pm"]) if hit.get("ne") is not None: batch.negative_prompt_embeds = list(hit["ne"]) if "audio" in hit: batch.extra["ltx2_audio_prompt_embeds"] = hit["audio"] print("[textcache] hit (skipped Gemma encode)", flush=True) return batch out = orig(self, batch, fastvideo_args) batch = out if out is not None else batch try: entry = { "pe": list(batch.prompt_embeds or []), "pm": list(batch.prompt_attention_mask) if getattr(batch, "prompt_attention_mask", None) else None, "ne": list(batch.negative_prompt_embeds) if getattr(batch, "negative_prompt_embeds", None) else None, } if "ltx2_audio_prompt_embeds" in batch.extra: entry["audio"] = batch.extra["ltx2_audio_prompt_embeds"] cache[key] = entry if len(cache) > _MAX: cache.pop(next(iter(cache))) except Exception as e: print(f"[textcache] store skipped ({e})", flush=True) return batch LTX2TextEncodingStage.forward = patched print("[textcache] installed Gemma text-embedding cache", flush=True)