File size: 2,871 Bytes
2eeb78f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Cache LTX-2 text embeddings across same-prompt rollout segments.

The continuous rollout calls the pipeline once per segment, and FastVideo
re-encodes the prompt through the ~25 GB Gemma text encoder every time — even
when the prompt is unchanged (the common case while the user isn't editing).
This monkeypatches `LTX2TextEncodingStage.forward` to memoize the encoded
embeddings by (prompt, negative_prompt) and skip Gemma on a hit.

Installed at import time so the patch is live in FastVideo's spawned worker
(spawn re-imports the app module). Disable with DREAMVERSE_TEXT_CACHE=0.
"""
from __future__ import annotations
import os

ENABLED = os.getenv("DREAMVERSE_TEXT_CACHE", "1") == "1"
_MAX = int(os.getenv("DREAMVERSE_TEXT_CACHE_MAX", "32"))


def _key(batch):
    return (repr(getattr(batch, "prompt", None)),
            repr(getattr(batch, "negative_prompt", None)))


def install():
    if not ENABLED:
        return
    try:
        from fastvideo.pipelines.basic.ltx2.stages.ltx2_text_encoding import LTX2TextEncodingStage
    except Exception as e:
        print(f"[textcache] LTX2 stage not importable here ({e}); skipping", flush=True)
        return
    if getattr(LTX2TextEncodingStage, "_textcache_patched", False):
        return
    LTX2TextEncodingStage._textcache_patched = True
    orig = LTX2TextEncodingStage.forward
    cache: dict = {}

    def patched(self, batch, fastvideo_args):
        key = _key(batch)
        hit = cache.get(key)
        if hit is not None:
            batch.prompt_embeds = list(hit["pe"])
            if hit.get("pm") is not None:
                batch.prompt_attention_mask = list(hit["pm"])
            if hit.get("ne") is not None:
                batch.negative_prompt_embeds = list(hit["ne"])
            if "audio" in hit:
                batch.extra["ltx2_audio_prompt_embeds"] = hit["audio"]
            print("[textcache] hit (skipped Gemma encode)", flush=True)
            return batch
        out = orig(self, batch, fastvideo_args)
        batch = out if out is not None else batch
        try:
            entry = {
                "pe": list(batch.prompt_embeds or []),
                "pm": list(batch.prompt_attention_mask) if getattr(batch, "prompt_attention_mask", None) else None,
                "ne": list(batch.negative_prompt_embeds) if getattr(batch, "negative_prompt_embeds", None) else None,
            }
            if "ltx2_audio_prompt_embeds" in batch.extra:
                entry["audio"] = batch.extra["ltx2_audio_prompt_embeds"]
            cache[key] = entry
            if len(cache) > _MAX:
                cache.pop(next(iter(cache)))
        except Exception as e:
            print(f"[textcache] store skipped ({e})", flush=True)
        return batch

    LTX2TextEncodingStage.forward = patched
    print("[textcache] installed Gemma text-embedding cache", flush=True)