"""Conditioning logic for CFG guidance.""" import torch import os import logging from src.Utilities import util from src.Device import Device from src.cond import cond_util from src.sample import ksampler_util class CONDRegular: """Regular condition wrapper.""" def __init__(self, cond: torch.Tensor): self.cond = cond def _copy_with(self, cond: torch.Tensor) -> "CONDRegular": return self.__class__(cond) def process_cond(self, batch_size: int, device: torch.device, **kwargs) -> "CONDRegular": return self._copy_with(util.repeat_to_batch_size(self.cond, batch_size).to(device)) def can_concat(self, other: "CONDRegular") -> bool: return self.cond.shape == other.cond.shape def concat(self, others: list) -> torch.Tensor: return torch.cat([self.cond] + [x.cond for x in others]) class CONDCrossAttn(CONDRegular): """Cross-attention condition wrapper.""" def can_concat(self, other: "CONDRegular") -> bool: s1, s2 = self.cond.shape, other.cond.shape if s1 != s2: if s1[0] != s2[0] or s1[2] != s2[2]: return False if torch.lcm(s1[1], s2[1]) // min(s1[1], s2[1]) > 4: return False return True def concat(self, others: list) -> torch.Tensor: conds = [self.cond] + [x.cond for x in others] shapes = [c.shape[1] for c in conds] max_len = util.lcm_of_list(shapes) if all(s == shapes[0] for s in shapes): return torch.cat(conds) return torch.cat([c.repeat(1, max_len // c.shape[1], 1) if c.shape[1] < max_len else c for c in conds]) def convert_cond(cond: list) -> list: """Convert conditions to cross-attention conditions.""" out = [] for c in cond: temp = c[1].copy() if isinstance(c, (list, tuple)) and len(c) > 1 and isinstance(c[1], dict) else {} model_conds = temp.get("model_conds", {}) cond_tensor = c[0] if isinstance(c, (list, tuple)) else c if cond_tensor is not None: try: model_conds["c_crossattn"] = CONDCrossAttn(cond_tensor) temp["cross_attn"] = cond_tensor except Exception: pass # Pass pooled_output as 'y_pooled' for SDXL conditioning pooled = temp.get("pooled_output") if pooled is not None: model_conds["y_pooled"] = CONDRegular(pooled) # Pass attention_mask for Klein/Flux2 models attention_mask = temp.get("attention_mask") if attention_mask is not None: model_conds["attention_mask"] = CONDRegular(attention_mask) temp["model_conds"] = model_conds out.append(temp) return out def _build_timestep_for_chunk(timestep, batch_size, batch_indices, x_in, device): """Build timestep tensor for a single chunk.""" if isinstance(timestep, torch.Tensor): if timestep.numel() == 1: return timestep.to(device).reshape(1).repeat(batch_size) elif timestep.shape[0] == x_in.shape[0]: if batch_indices is None: return timestep.to(device) idx = torch.tensor(batch_indices, dtype=torch.long, device=device) return timestep.to(device)[idx] elif timestep.shape[0] == batch_size: return timestep.to(device) return timestep.to(device).reshape(1).repeat(batch_size) return torch.tensor([timestep], device=device).repeat(batch_size) def _run_model_per_chunk(model, x_in, timestep, input_x_list, c_list, batch_sizes, batch_indices_list, cond_or_uncond, model_options): """Run model on each chunk individually.""" output_parts = [] for idx in range(len(batch_sizes)): single_input = input_x_list[idx] timestep_j = _build_timestep_for_chunk(timestep, batch_sizes[idx], batch_indices_list[idx], x_in, single_input.device) c_chunk = cond_util.cond_cat([c_list[idx]]) c_chunk["transformer_options"] = {"cond_or_uncond": [cond_or_uncond[idx]], "sigmas": timestep_j} if "model_function_wrapper" in model_options: out_j = model_options["model_function_wrapper"]( model.apply_model, {"input": single_input, "timestep": timestep_j, "c": c_chunk, "cond_or_uncond": [cond_or_uncond[idx]]}) else: out_j = model.apply_model(single_input, timestep_j, **c_chunk) output_parts.append(out_j) return output_parts def calc_cond_batch(model, conds, x_in, timestep, model_options) -> list: """Calculate the condition batch.""" logging.debug("calc_cond_batch: model type %s, memory_required attr=%s", type(model), getattr(model, "memory_required", None)) # Handle mock objects in tests if not isinstance(x_in, torch.Tensor): x_in = torch.zeros((1, 4, 8, 8)) out_conds = [torch.zeros_like(x_in) for _ in range(len(conds))] out_counts = [torch.ones_like(x_in) * 1e-37 for _ in range(len(conds))] to_run = [] batched_cfg = model_options.get("batched_cfg", True) for i, cond in enumerate(conds): if cond is not None: for x in cond: p = ksampler_util.get_area_and_mult(x, x_in, timestep) if p is not None: to_run.append((p, i)) while to_run: first = to_run[0] first_shape = first[0][0].shape first_cond_index = first[1] # Find compatible conditions to_batch_temp = [ x for x in range(len(to_run)) if cond_util.can_concat_cond(to_run[x][0], first[0]) and (batched_cfg or to_run[x][1] == first_cond_index) ] to_batch_temp.reverse() to_batch = to_batch_temp[:1] # Batch size optimization based on memory free_memory = Device.get_free_memory(x_in.device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp) // i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] if model.memory_required(input_shape) * 1.5 < free_memory: to_batch = batch_amount break # Collect batch data input_x_list, mult, c_list, cond_or_uncond, area = [], [], [], [], [] batch_sizes, batch_indices_list = [], [] control, patches = None, None for x in to_batch: o = to_run.pop(x) p = o[0] input_x_list.append(p.input_x) batch_sizes.append(p.input_x.shape[0]) batch_indices_list.append(p.batch_indices) mult.append(p.mult) c_list.append(p.conditioning) area.append(p.area) cond_or_uncond.append(o[1]) control, patches = p.control, p.patches batch_chunks = len(cond_or_uncond) input_x = torch.cat(input_x_list) c = cond_util.cond_cat(c_list) device = input_x.device # Build timestep tensor per_chunk_timesteps = [_build_timestep_for_chunk(timestep, s, b, x_in, device) for s, b in zip(batch_sizes, batch_indices_list)] timestep_ = torch.cat(per_chunk_timesteps) if control is not None: c["control"] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) # Handle transformer options and patches transformer_options = model_options.get("transformer_options", {}).copy() # Merge any per-condition transformer options (e.g. from ADetailer crop conditioning) for cond_item in c_list: if isinstance(cond_item, dict): per_to = cond_item.get("transformer_options") if isinstance(per_to, dict): for k, v in per_to.items(): try: transformer_options[k] = int(v) except Exception: transformer_options[k] = v if patches is not None: cur_patches = transformer_options.get("patches", {}).copy() for p in patches: cur_patches[p] = cur_patches.get(p, []) + patches[p] transformer_options["patches"] = cur_patches transformer_options["cond_or_uncond"] = cond_or_uncond[:] transformer_options["sigmas"] = timestep_ # Validate image sizing if present and log helpful diagnostics try: if "img_h" in transformer_options and "img_w" in transformer_options: token_h = transformer_options["img_h"] // 16 token_w = transformer_options["img_w"] // 16 if token_h != input_x.shape[2] or token_w != input_x.shape[3]: logging.info("calc_cond_batch: transformer_options img_h/img_w %r -> tokens %dx%d doesn't match input_x grid %dx%d; falling back to per-chunk", (transformer_options.get("img_h"), transformer_options.get("img_w")), token_h, token_w, input_x.shape[2], input_x.shape[3]) # Fall back to running the model on each chunk individually to avoid RoPE/positional-embedding mismatches. output_parts = _run_model_per_chunk(model, x_in, timestep, input_x_list, c_list, batch_sizes, batch_indices_list, cond_or_uncond, model_options) # Apply outputs immediately and continue with next batch for o in range(batch_chunks): cond_index = cond_or_uncond[o] a = area[o] out_part = output_parts[o] batch_inds = batch_indices_list[o] if a is None: _apply_output_no_area(out_conds, out_counts, cond_index, out_part, mult[o], batch_inds) else: _apply_output_with_area(out_conds, out_counts, cond_index, out_part, mult[o], batch_inds, a) continue except Exception as ex: logging.debug("calc_cond_batch: transformer_options validation failed: %s", ex) c["transformer_options"] = transformer_options # Run model expected_sum = sum(batch_sizes) if input_x.shape[0] != expected_sum: output_parts = _run_model_per_chunk(model, x_in, timestep, input_x_list, c_list, batch_sizes, batch_indices_list, cond_or_uncond, model_options) else: try: if "model_function_wrapper" in model_options: full_out = model_options["model_function_wrapper"]( model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}) else: full_out = model.apply_model(input_x, timestep_, **c) # Robust split: ensure sum matches full_out.shape[0] actual_out_batch = full_out.shape[0] if actual_out_batch != expected_sum: # If model returned more/fewer items than expected (e.g. HiDiffusion internal batching) # use actual_out_batch to prevent torch.split crash split_size = actual_out_batch // len(batch_sizes) if split_size > 0: output_parts = list(torch.split(full_out, split_size, dim=0)) else: # Fallback for single item output output_parts = [full_out] * len(batch_sizes) else: output_parts = list(torch.split(full_out, batch_sizes, dim=0)) except Exception as e: logging.exception("Fast-path model call failed, falling back to per-chunk: %s; input_x.shape=%s; transformer_options=%s", e, input_x.shape, transformer_options) output_parts = _run_model_per_chunk(model, x_in, timestep, input_x_list, c_list, batch_sizes, batch_indices_list, cond_or_uncond, model_options) # Apply outputs for o in range(batch_chunks): cond_index = cond_or_uncond[o] a = area[o] out_part = output_parts[o] batch_inds = batch_indices_list[o] if a is None: _apply_output_no_area(out_conds, out_counts, cond_index, out_part, mult[o], batch_inds) else: _apply_output_with_area(out_conds, out_counts, cond_index, out_part, mult[o], batch_inds, a) # Final normalization for i in range(len(out_conds)): out_conds[i].div_(out_counts[i]) return out_conds def _apply_output_no_area(out_conds, out_counts, cond_index, out_part, mult, batch_inds): """Apply output without area specification.""" if batch_inds is None: # Ensure out_part matches batch size of target if out_part.shape[0] != out_conds[cond_index].shape[0]: out_part = out_part[:out_conds[cond_index].shape[0]] mult = mult[:out_counts[cond_index].shape[0]] out_conds[cond_index] += out_part * mult out_counts[cond_index] += mult else: dev = out_conds[cond_index].device max_batch = out_conds[cond_index].shape[0] valid = [int(b) for b in batch_inds if -max_batch <= int(b) < max_batch] if not valid: return idx = torch.tensor(valid, dtype=torch.long, device=dev) # Slice out_part to match the number of valid indices out_part_final = out_part[:idx.shape[0]] mult_final = mult[:idx.shape[0]] out_conds[cond_index][idx] += out_part_final * mult_final out_counts[cond_index][idx] += mult_final def _apply_output_with_area(out_conds, out_counts, cond_index, out_part, mult, batch_inds, a): """Apply output with area specification.""" dims = len(a) // 2 starts, sizes = a[dims:], a[:dims] if dims == 2: H, W = out_conds[cond_index].shape[2], out_conds[cond_index].shape[3] y0, x0 = max(0, int(starts[0])), max(0, int(starts[1])) y1, x1 = min(H, y0 + max(0, int(sizes[0]))), min(W, x0 + max(0, int(sizes[1]))) if y1 <= y0 or x1 <= x0: return region_h, region_w = y1 - y0, x1 - x0 out_part_crop = out_part[..., :region_h, :region_w] mult_crop = mult[..., :region_h, :region_w] if batch_inds is None: # Ensure out_part matches batch size of target if not using indices if out_part_crop.shape[0] != out_conds[cond_index].shape[0]: out_part_crop = out_part_crop[:out_conds[cond_index].shape[0]] mult_crop = mult_crop[:out_counts[cond_index].shape[0]] out_conds[cond_index][:, :, y0:y1, x0:x1] += out_part_crop * mult_crop out_counts[cond_index][:, :, y0:y1, x0:x1] += mult_crop else: dev = out_conds[cond_index].device max_batch = out_conds[cond_index].shape[0] valid = [int(b) for b in batch_inds if -max_batch <= int(b) < max_batch] if not valid: return idx = torch.tensor(valid, dtype=torch.long, device=dev) # Slice out_part to match the number of valid indices out_part_final = out_part_crop[:idx.shape[0]] mult_final = mult_crop[:idx.shape[0]] out_conds[cond_index][idx, :, y0:y1, x0:x1] += out_part_final * mult_final out_counts[cond_index][idx, :, y0:y1, x0:x1] += mult_final def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs) -> list: """Encode model conditions.""" for t in range(len(conds)): x = conds[t] params = x.copy() params["device"] = device params["noise"] = noise downscale_factor = 8 if hasattr(model_function, "__self__"): model = model_function.__self__ if hasattr(model, "latent_format") and hasattr(model.latent_format, "downscale_factor"): downscale_factor = model.latent_format.downscale_factor if len(noise.shape) >= 4: params["width"] = params.get("width", noise.shape[3] * downscale_factor) params["height"] = params.get("height", noise.shape[2] * downscale_factor) else: params["height"] = params.get("height", noise.shape[2] * downscale_factor) params["prompt_type"] = params.get("prompt_type", prompt_type) params.update({k: v for k, v in kwargs.items() if k not in params}) out = model_function(**params) x = x.copy() model_conds = x["model_conds"].copy() model_conds.update(out) x["model_conds"] = model_conds conds[t] = x return conds def resolve_areas_and_cond_masks_multidim(conditions, dims, device): """Process areas and masks for conditions.""" for i, c in enumerate(conditions): if "area" in c: area = c["area"] if area[0] == "percentage": a = area[1:] a_len = len(a) // 2 first = [max(1, int(round(a[j] * (dims[j] if j < len(dims) else dims[-1])))) for j in range(a_len)] second = [int(round(a[j] * (dims[j - a_len] if j - a_len < len(dims) else dims[-1]))) for j in range(a_len, 2 * a_len)] conditions[i] = {**c, "area": tuple(first) + tuple(second)} if "mask" in c: mask = c["mask"].to(device=device) if len(mask.shape) == len(dims): mask = mask.unsqueeze(0) if mask.shape[1:] != dims: mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode="bilinear", align_corners=False).squeeze(1) conditions[i] = {**c, "mask": mask} def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None) -> dict: """Process all conditions.""" for k in conds: conds[k] = conds[k][:] resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device) for k in conds: ksampler_util.calculate_start_end_timesteps(model, conds[k]) if hasattr(model, "extra_conds"): for k in conds: conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) # Ensure matching areas for k in conds: for c in conds[k]: for kk in conds: if k != kk: cond_util.create_cond_with_same_area_if_none(conds[kk], c) for k in conds: ksampler_util.pre_run_control(model, conds[k]) if "positive" in conds: positive = conds["positive"] for k in conds: if k != "positive": ksampler_util.apply_empty_x_to_equal_area( [c for c in positive if c.get("control_apply_to_uncond", False)], conds[k], "control", lambda cond_cnets, x: cond_cnets[x]) ksampler_util.apply_empty_x_to_equal_area(positive, conds[k], "gligen", lambda cond_cnets, x: cond_cnets[x]) return conds