# Copyright (C) 2025 Hugging Face Team and Overworld # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import torch from torch import nn, Tensor def _bf16_u16(x: Tensor) -> Tensor: # reinterpret bf16 storage as int16 -> unsigned 0..65535 in int32 return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF class CachedDenoiseStepEmb(nn.Module): """bf16 sigma -> bf16 embedding via 64k LUT; invalid sigma => OOB index error (no silent wrong).""" def __init__(self, base: nn.Module, sigmas: list[float]): super().__init__() device = next(base.parameters()).device levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16) # [S] bits = _bf16_u16(levels) # [S] if torch.unique(bits).numel() != bits.numel(): raise ValueError( "scheduler_sigmas collide in bf16; caching would be ambiguous" ) with torch.no_grad(): table = ( base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous() ) # [S,D] lut = torch.full((65536,), -1, device=device, dtype=torch.int32) lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32) self.register_buffer("table", table, persistent=False) # [S,D] bf16 self.register_buffer("lut", lut, persistent=False) # [65536] int32 self.register_buffer( "oob", torch.tensor(bits.numel(), device=device, dtype=torch.int32), persistent=False, ) def forward(self, sigma: Tensor) -> Tensor: if sigma.dtype is not torch.bfloat16: raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16") idx = self.lut[_bf16_u16(sigma)] idx = torch.where(idx >= 0, idx, self.oob) # invalid -> S (OOB) return self.table[idx.to(torch.int64)] # [...,D] bf16 class CachedCondHead(nn.Module): """bf16 cond -> cached (s0,b0,g0,s1,b1,g1); invalid cond => OOB index error (no silent wrong).""" def __init__( self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8 ): super().__init__() table = cached_denoise_step_emb.table # [S,D] bf16 S, D = table.shape with torch.no_grad(): emb = table[:, None, :] # [S,1,D] cache = ( torch.stack([t.squeeze(1) for t in base(emb)], 0) .to(torch.bfloat16) .contiguous() ) # [6,S,D] # pick a single embedding dimension whose bf16 bits uniquely identify sigma key_dim = None for d in range(min(D, max_key_dims)): b = _bf16_u16(table[:, d]) if torch.unique(b).numel() == S: key_dim = d key_bits = b break if key_dim is None: raise ValueError( "Could not find a unique bf16 key dim for cond->sigma mapping; increase max_key_dims" ) lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32) lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32) self.key_dim = int(key_dim) self.register_buffer("cache", cache, persistent=False) # [6,S,D] bf16 self.register_buffer("lut", lut, persistent=False) # [65536] int32 self.register_buffer( "oob", torch.tensor(S, device=table.device, dtype=torch.int32), persistent=False, ) def forward(self, cond: Tensor): if cond.dtype is not torch.bfloat16: raise RuntimeError("CachedCondHead expects cond bf16") idx = self.lut[_bf16_u16(cond[..., self.key_dim])] idx = torch.where(idx >= 0, idx, self.oob) # invalid -> S (OOB) g = self.cache[:, idx.to(torch.int64)] # [6,...,D] bf16 (or errors) return tuple(g.unbind(0)) # (s0,b0,g0,s1,b1,g1)