rtx6000test / text_cache.py
multimodalart's picture
multimodalart HF Staff
Speed: text-embed cache + lighter 384x640 profile + PyAV (VAE compile opt-in)
2eeb78f
Raw
History Blame Contribute Delete
2.87 kB
"""Cache LTX-2 text embeddings across same-prompt rollout segments.
The continuous rollout calls the pipeline once per segment, and FastVideo
re-encodes the prompt through the ~25 GB Gemma text encoder every time — even
when the prompt is unchanged (the common case while the user isn't editing).
This monkeypatches `LTX2TextEncodingStage.forward` to memoize the encoded
embeddings by (prompt, negative_prompt) and skip Gemma on a hit.
Installed at import time so the patch is live in FastVideo's spawned worker
(spawn re-imports the app module). Disable with DREAMVERSE_TEXT_CACHE=0.
"""
from __future__ import annotations
import os
ENABLED = os.getenv("DREAMVERSE_TEXT_CACHE", "1") == "1"
_MAX = int(os.getenv("DREAMVERSE_TEXT_CACHE_MAX", "32"))
def _key(batch):
return (repr(getattr(batch, "prompt", None)),
repr(getattr(batch, "negative_prompt", None)))
def install():
if not ENABLED:
return
try:
from fastvideo.pipelines.basic.ltx2.stages.ltx2_text_encoding import LTX2TextEncodingStage
except Exception as e:
print(f"[textcache] LTX2 stage not importable here ({e}); skipping", flush=True)
return
if getattr(LTX2TextEncodingStage, "_textcache_patched", False):
return
LTX2TextEncodingStage._textcache_patched = True
orig = LTX2TextEncodingStage.forward
cache: dict = {}
def patched(self, batch, fastvideo_args):
key = _key(batch)
hit = cache.get(key)
if hit is not None:
batch.prompt_embeds = list(hit["pe"])
if hit.get("pm") is not None:
batch.prompt_attention_mask = list(hit["pm"])
if hit.get("ne") is not None:
batch.negative_prompt_embeds = list(hit["ne"])
if "audio" in hit:
batch.extra["ltx2_audio_prompt_embeds"] = hit["audio"]
print("[textcache] hit (skipped Gemma encode)", flush=True)
return batch
out = orig(self, batch, fastvideo_args)
batch = out if out is not None else batch
try:
entry = {
"pe": list(batch.prompt_embeds or []),
"pm": list(batch.prompt_attention_mask) if getattr(batch, "prompt_attention_mask", None) else None,
"ne": list(batch.negative_prompt_embeds) if getattr(batch, "negative_prompt_embeds", None) else None,
}
if "ltx2_audio_prompt_embeds" in batch.extra:
entry["audio"] = batch.extra["ltx2_audio_prompt_embeds"]
cache[key] = entry
if len(cache) > _MAX:
cache.pop(next(iter(cache)))
except Exception as e:
print(f"[textcache] store skipped ({e})", flush=True)
return batch
LTX2TextEncodingStage.forward = patched
print("[textcache] installed Gemma text-embedding cache", flush=True)