mhnakif's picture
Upload folder using huggingface_hub
18cc273 verified
"""Shared sampling utilities for LCS intervention hooks."""
import comfy.utils
import torch
import torch.nn.functional as F
def find_step_index(sigma, sigmas):
"""Find the step index for a given sigma value in the sigma schedule.
Uses torch.isclose for robust matching across dtype differences (e.g. bfloat16
sigma vs float32 sample_sigmas), with argmin fallback for edge cases.
"""
sigma_val = sigma.flatten()[0].float()
sigmas_f = sigmas.float()
matched = torch.isclose(sigmas_f, sigma_val, rtol=1e-3, atol=1e-5).nonzero()
if len(matched) > 0:
return matched[0].item()
return (sigmas_f - sigma_val).abs().argmin().item()
def denoised_to_raw(denoised, model):
"""Convert denoised tensor from process_in space to raw VAE space.
Uses the model's latent_format.process_out (inverse of process_in).
Works for any model: FLUX (scale+shift), LTXV (identity), SD (scale), etc.
"""
return model.latent_format.process_out(denoised)
def raw_to_denoised(raw, model):
"""Convert raw VAE space tensor back to process_in space.
Uses the model's latent_format.process_in.
"""
return model.latent_format.process_in(raw)
def unpack_video_if_needed(denoised, args):
"""Unpack LTXAV-style packed latents if detected.
LTXAV packs video [B,128,F,H,W] + audio [B,ch,T,freq] into [B,1,flat].
Returns (tensor_to_process, pack_info) where pack_info is None for
non-packed formats or a dict for repacking.
"""
# Detect packed format: shape [B, 1, flat] with very large last dim
if denoised.ndim == 3 and denoised.shape[1] == 1:
# Try to find latent_shapes from cond data
cond = args.get("cond")
latent_shapes = _extract_latent_shapes(cond)
if latent_shapes is not None and len(latent_shapes) > 1:
tensors = comfy.utils.unpack_latents(denoised, latent_shapes)
# tensors[0] = video [B, 128, F, H, W], tensors[1] = audio [B, ch, T, freq]
return tensors[0], {"other_tensors": tensors[1:]}
return denoised, None
def repack_video_if_needed(modified, pack_info):
"""Repack video tensor back into LTXAV packed format if it was unpacked.
modified: the video tensor after intervention [B, 128, F, H, W]
pack_info: from unpack_video_if_needed
"""
if pack_info is None:
return modified
all_tensors = [modified] + pack_info["other_tensors"]
packed, _ = comfy.utils.pack_latents(all_tensors)
return packed
def downsample_mask(mask, h_len, w_len, device, dtype):
"""Downsample a mask to patch grid and flatten to [1, L, 1]."""
mask_dev = mask.to(device=device, dtype=dtype)
if mask_dev.ndim == 3:
mask_dev = mask_dev[:1]
if mask_dev.ndim == 2:
mask_4d = mask_dev.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
elif mask_dev.ndim == 3:
mask_4d = mask_dev.unsqueeze(1) # [B, 1, H, W]
else:
mask_4d = mask_dev
mask_resized = F.interpolate(
mask_4d, size=(h_len, w_len), mode="bilinear", align_corners=False
)
return mask_resized.reshape(1, -1, 1) # [1, L, 1]
def _extract_latent_shapes(cond):
"""Try to extract latent_shapes from conditioning data.
After convert_cond, cond is a list of dicts with 'model_conds' containing
CONDConstant-wrapped values like 'latent_shapes'.
"""
if cond is None:
return None
for c in cond:
if isinstance(c, dict):
model_conds = c.get('model_conds', {})
if 'latent_shapes' in model_conds:
ls = model_conds['latent_shapes']
# CONDConstant wraps the value in .cond
if hasattr(ls, 'cond'):
return ls.cond
return ls
return None