File size: 4,509 Bytes
57eef5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import torch
from torch import nn, Tensor


def _bf16_u16(x: Tensor) -> Tensor:
    # reinterpret bf16 storage as int16 -> unsigned 0..65535 in int32
    return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF


class CachedDenoiseStepEmb(nn.Module):
    """bf16 sigma -> bf16 embedding via 64k LUT; invalid sigma => OOB index error (no silent wrong)."""

    def __init__(self, base: nn.Module, sigmas: list[float]):
        super().__init__()
        device = next(base.parameters()).device

        levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16)  # [S]
        bits = _bf16_u16(levels)  # [S]
        if torch.unique(bits).numel() != bits.numel():
            raise ValueError(
                "scheduler_sigmas collide in bf16; caching would be ambiguous"
            )

        with torch.no_grad():
            table = (
                base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous()
            )  # [S,D]

        lut = torch.full((65536,), -1, device=device, dtype=torch.int32)
        lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32)

        self.register_buffer("table", table, persistent=False)  # [S,D] bf16
        self.register_buffer("lut", lut, persistent=False)  # [65536] int32
        self.register_buffer(
            "oob",
            torch.tensor(bits.numel(), device=device, dtype=torch.int32),
            persistent=False,
        )

    def forward(self, sigma: Tensor) -> Tensor:
        if sigma.dtype is not torch.bfloat16:
            raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16")
        idx = self.lut[_bf16_u16(sigma)]
        idx = torch.where(idx >= 0, idx, self.oob)  # invalid -> S (OOB)
        return self.table[idx.to(torch.int64)]  # [...,D] bf16


class CachedCondHead(nn.Module):
    """bf16 cond -> cached (s0,b0,g0,s1,b1,g1); invalid cond => OOB index error (no silent wrong)."""

    def __init__(
        self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8
    ):
        super().__init__()
        table = cached_denoise_step_emb.table  # [S,D] bf16
        S, D = table.shape

        with torch.no_grad():
            emb = table[:, None, :]  # [S,1,D]
            cache = (
                torch.stack([t.squeeze(1) for t in base(emb)], 0)
                .to(torch.bfloat16)
                .contiguous()
            )  # [6,S,D]

        # pick a single embedding dimension whose bf16 bits uniquely identify sigma
        key_dim = None
        for d in range(min(D, max_key_dims)):
            b = _bf16_u16(table[:, d])
            if torch.unique(b).numel() == S:
                key_dim = d
                key_bits = b
                break
        if key_dim is None:
            raise ValueError(
                "Could not find a unique bf16 key dim for cond->sigma mapping; increase max_key_dims"
            )

        lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32)
        lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32)

        self.key_dim = int(key_dim)
        self.register_buffer("cache", cache, persistent=False)  # [6,S,D] bf16
        self.register_buffer("lut", lut, persistent=False)  # [65536] int32
        self.register_buffer(
            "oob",
            torch.tensor(S, device=table.device, dtype=torch.int32),
            persistent=False,
        )

    def forward(self, cond: Tensor):
        if cond.dtype is not torch.bfloat16:
            raise RuntimeError("CachedCondHead expects cond bf16")
        idx = self.lut[_bf16_u16(cond[..., self.key_dim])]
        idx = torch.where(idx >= 0, idx, self.oob)  # invalid -> S (OOB)
        g = self.cache[:, idx.to(torch.int64)]  # [6,...,D] bf16 (or errors)
        return tuple(g.unbind(0))  # (s0,b0,g0,s1,b1,g1)