|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
|
|
|
|
|
|
def _bf16_u16(x: Tensor) -> Tensor: |
|
|
|
|
|
return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF |
|
|
|
|
|
|
|
|
class CachedDenoiseStepEmb(nn.Module): |
|
|
"""bf16 sigma -> bf16 embedding via 64k LUT; invalid sigma => OOB index error (no silent wrong).""" |
|
|
|
|
|
def __init__(self, base: nn.Module, sigmas: list[float]): |
|
|
super().__init__() |
|
|
device = next(base.parameters()).device |
|
|
|
|
|
levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16) |
|
|
bits = _bf16_u16(levels) |
|
|
if torch.unique(bits).numel() != bits.numel(): |
|
|
raise ValueError( |
|
|
"scheduler_sigmas collide in bf16; caching would be ambiguous" |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
table = ( |
|
|
base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous() |
|
|
) |
|
|
|
|
|
lut = torch.full((65536,), -1, device=device, dtype=torch.int32) |
|
|
lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32) |
|
|
|
|
|
self.register_buffer("table", table, persistent=False) |
|
|
self.register_buffer("lut", lut, persistent=False) |
|
|
self.register_buffer( |
|
|
"oob", |
|
|
torch.tensor(bits.numel(), device=device, dtype=torch.int32), |
|
|
persistent=False, |
|
|
) |
|
|
|
|
|
def forward(self, sigma: Tensor) -> Tensor: |
|
|
if sigma.dtype is not torch.bfloat16: |
|
|
raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16") |
|
|
idx = self.lut[_bf16_u16(sigma)] |
|
|
idx = torch.where(idx >= 0, idx, self.oob) |
|
|
return self.table[idx.to(torch.int64)] |
|
|
|
|
|
|
|
|
class CachedCondHead(nn.Module): |
|
|
"""bf16 cond -> cached (s0,b0,g0,s1,b1,g1); invalid cond => OOB index error (no silent wrong).""" |
|
|
|
|
|
def __init__( |
|
|
self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8 |
|
|
): |
|
|
super().__init__() |
|
|
table = cached_denoise_step_emb.table |
|
|
S, D = table.shape |
|
|
|
|
|
with torch.no_grad(): |
|
|
emb = table[:, None, :] |
|
|
cache = ( |
|
|
torch.stack([t.squeeze(1) for t in base(emb)], 0) |
|
|
.to(torch.bfloat16) |
|
|
.contiguous() |
|
|
) |
|
|
|
|
|
|
|
|
key_dim = None |
|
|
for d in range(min(D, max_key_dims)): |
|
|
b = _bf16_u16(table[:, d]) |
|
|
if torch.unique(b).numel() == S: |
|
|
key_dim = d |
|
|
key_bits = b |
|
|
break |
|
|
if key_dim is None: |
|
|
raise ValueError( |
|
|
"Could not find a unique bf16 key dim for cond->sigma mapping; increase max_key_dims" |
|
|
) |
|
|
|
|
|
lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32) |
|
|
lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32) |
|
|
|
|
|
self.key_dim = int(key_dim) |
|
|
self.register_buffer("cache", cache, persistent=False) |
|
|
self.register_buffer("lut", lut, persistent=False) |
|
|
self.register_buffer( |
|
|
"oob", |
|
|
torch.tensor(S, device=table.device, dtype=torch.int32), |
|
|
persistent=False, |
|
|
) |
|
|
|
|
|
def forward(self, cond: Tensor): |
|
|
if cond.dtype is not torch.bfloat16: |
|
|
raise RuntimeError("CachedCondHead expects cond bf16") |
|
|
idx = self.lut[_bf16_u16(cond[..., self.key_dim])] |
|
|
idx = torch.where(idx >= 0, idx, self.oob) |
|
|
g = self.cache[:, idx.to(torch.int64)] |
|
|
return tuple(g.unbind(0)) |
|
|
|