dn6's picture
dn6 HF Staff
Add diffusers support
57eef5f verified
raw
history blame
4.51 kB
# Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import torch
from torch import nn, Tensor
def _bf16_u16(x: Tensor) -> Tensor:
# reinterpret bf16 storage as int16 -> unsigned 0..65535 in int32
return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF
class CachedDenoiseStepEmb(nn.Module):
"""bf16 sigma -> bf16 embedding via 64k LUT; invalid sigma => OOB index error (no silent wrong)."""
def __init__(self, base: nn.Module, sigmas: list[float]):
super().__init__()
device = next(base.parameters()).device
levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16) # [S]
bits = _bf16_u16(levels) # [S]
if torch.unique(bits).numel() != bits.numel():
raise ValueError(
"scheduler_sigmas collide in bf16; caching would be ambiguous"
)
with torch.no_grad():
table = (
base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous()
) # [S,D]
lut = torch.full((65536,), -1, device=device, dtype=torch.int32)
lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32)
self.register_buffer("table", table, persistent=False) # [S,D] bf16
self.register_buffer("lut", lut, persistent=False) # [65536] int32
self.register_buffer(
"oob",
torch.tensor(bits.numel(), device=device, dtype=torch.int32),
persistent=False,
)
def forward(self, sigma: Tensor) -> Tensor:
if sigma.dtype is not torch.bfloat16:
raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16")
idx = self.lut[_bf16_u16(sigma)]
idx = torch.where(idx >= 0, idx, self.oob) # invalid -> S (OOB)
return self.table[idx.to(torch.int64)] # [...,D] bf16
class CachedCondHead(nn.Module):
"""bf16 cond -> cached (s0,b0,g0,s1,b1,g1); invalid cond => OOB index error (no silent wrong)."""
def __init__(
self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8
):
super().__init__()
table = cached_denoise_step_emb.table # [S,D] bf16
S, D = table.shape
with torch.no_grad():
emb = table[:, None, :] # [S,1,D]
cache = (
torch.stack([t.squeeze(1) for t in base(emb)], 0)
.to(torch.bfloat16)
.contiguous()
) # [6,S,D]
# pick a single embedding dimension whose bf16 bits uniquely identify sigma
key_dim = None
for d in range(min(D, max_key_dims)):
b = _bf16_u16(table[:, d])
if torch.unique(b).numel() == S:
key_dim = d
key_bits = b
break
if key_dim is None:
raise ValueError(
"Could not find a unique bf16 key dim for cond->sigma mapping; increase max_key_dims"
)
lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32)
lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32)
self.key_dim = int(key_dim)
self.register_buffer("cache", cache, persistent=False) # [6,S,D] bf16
self.register_buffer("lut", lut, persistent=False) # [65536] int32
self.register_buffer(
"oob",
torch.tensor(S, device=table.device, dtype=torch.int32),
persistent=False,
)
def forward(self, cond: Tensor):
if cond.dtype is not torch.bfloat16:
raise RuntimeError("CachedCondHead expects cond bf16")
idx = self.lut[_bf16_u16(cond[..., self.key_dim])]
idx = torch.where(idx >= 0, idx, self.oob) # invalid -> S (OOB)
g = self.cache[:, idx.to(torch.int64)] # [6,...,D] bf16 (or errors)
return tuple(g.unbind(0)) # (s0,b0,g0,s1,b1,g1)