Spaces:
Running on Zero
Running on Zero
| """Conditioning logic for CFG guidance.""" | |
| import torch | |
| import os | |
| import logging | |
| from src.Utilities import util | |
| from src.Device import Device | |
| from src.cond import cond_util | |
| from src.sample import ksampler_util | |
| class CONDRegular: | |
| """Regular condition wrapper.""" | |
| def __init__(self, cond: torch.Tensor): | |
| self.cond = cond | |
| def _copy_with(self, cond: torch.Tensor) -> "CONDRegular": | |
| return self.__class__(cond) | |
| def process_cond(self, batch_size: int, device: torch.device, **kwargs) -> "CONDRegular": | |
| return self._copy_with(util.repeat_to_batch_size(self.cond, batch_size).to(device)) | |
| def can_concat(self, other: "CONDRegular") -> bool: | |
| return self.cond.shape == other.cond.shape | |
| def concat(self, others: list) -> torch.Tensor: | |
| return torch.cat([self.cond] + [x.cond for x in others]) | |
| class CONDCrossAttn(CONDRegular): | |
| """Cross-attention condition wrapper.""" | |
| def can_concat(self, other: "CONDRegular") -> bool: | |
| s1, s2 = self.cond.shape, other.cond.shape | |
| if s1 != s2: | |
| if s1[0] != s2[0] or s1[2] != s2[2]: | |
| return False | |
| if torch.lcm(s1[1], s2[1]) // min(s1[1], s2[1]) > 4: | |
| return False | |
| return True | |
| def concat(self, others: list) -> torch.Tensor: | |
| conds = [self.cond] + [x.cond for x in others] | |
| shapes = [c.shape[1] for c in conds] | |
| max_len = util.lcm_of_list(shapes) | |
| if all(s == shapes[0] for s in shapes): | |
| return torch.cat(conds) | |
| return torch.cat([c.repeat(1, max_len // c.shape[1], 1) if c.shape[1] < max_len else c for c in conds]) | |
| def convert_cond(cond: list) -> list: | |
| """Convert conditions to cross-attention conditions.""" | |
| out = [] | |
| for c in cond: | |
| temp = c[1].copy() if isinstance(c, (list, tuple)) and len(c) > 1 and isinstance(c[1], dict) else {} | |
| model_conds = temp.get("model_conds", {}) | |
| cond_tensor = c[0] if isinstance(c, (list, tuple)) else c | |
| if cond_tensor is not None: | |
| try: | |
| model_conds["c_crossattn"] = CONDCrossAttn(cond_tensor) | |
| temp["cross_attn"] = cond_tensor | |
| except Exception: | |
| pass | |
| # Pass pooled_output as 'y_pooled' for SDXL conditioning | |
| pooled = temp.get("pooled_output") | |
| if pooled is not None: | |
| model_conds["y_pooled"] = CONDRegular(pooled) | |
| # Pass attention_mask for Klein/Flux2 models | |
| attention_mask = temp.get("attention_mask") | |
| if attention_mask is not None: | |
| model_conds["attention_mask"] = CONDRegular(attention_mask) | |
| temp["model_conds"] = model_conds | |
| out.append(temp) | |
| return out | |
| def _build_timestep_for_chunk(timestep, batch_size, batch_indices, x_in, device): | |
| """Build timestep tensor for a single chunk.""" | |
| if isinstance(timestep, torch.Tensor): | |
| if timestep.numel() == 1: | |
| return timestep.to(device).reshape(1).repeat(batch_size) | |
| elif timestep.shape[0] == x_in.shape[0]: | |
| if batch_indices is None: | |
| return timestep.to(device) | |
| idx = torch.tensor(batch_indices, dtype=torch.long, device=device) | |
| return timestep.to(device)[idx] | |
| elif timestep.shape[0] == batch_size: | |
| return timestep.to(device) | |
| return timestep.to(device).reshape(1).repeat(batch_size) | |
| return torch.tensor([timestep], device=device).repeat(batch_size) | |
| def _run_model_per_chunk(model, x_in, timestep, input_x_list, c_list, batch_sizes, batch_indices_list, cond_or_uncond, model_options): | |
| """Run model on each chunk individually.""" | |
| output_parts = [] | |
| for idx in range(len(batch_sizes)): | |
| single_input = input_x_list[idx] | |
| timestep_j = _build_timestep_for_chunk(timestep, batch_sizes[idx], batch_indices_list[idx], x_in, single_input.device) | |
| c_chunk = cond_util.cond_cat([c_list[idx]]) | |
| c_chunk["transformer_options"] = {"cond_or_uncond": [cond_or_uncond[idx]], "sigmas": timestep_j} | |
| if "model_function_wrapper" in model_options: | |
| out_j = model_options["model_function_wrapper"]( | |
| model.apply_model, | |
| {"input": single_input, "timestep": timestep_j, "c": c_chunk, "cond_or_uncond": [cond_or_uncond[idx]]}) | |
| else: | |
| out_j = model.apply_model(single_input, timestep_j, **c_chunk) | |
| output_parts.append(out_j) | |
| return output_parts | |
| def calc_cond_batch(model, conds, x_in, timestep, model_options) -> list: | |
| """Calculate the condition batch.""" | |
| logging.debug("calc_cond_batch: model type %s, memory_required attr=%s", type(model), getattr(model, "memory_required", None)) | |
| # Handle mock objects in tests | |
| if not isinstance(x_in, torch.Tensor): | |
| x_in = torch.zeros((1, 4, 8, 8)) | |
| out_conds = [torch.zeros_like(x_in) for _ in range(len(conds))] | |
| out_counts = [torch.ones_like(x_in) * 1e-37 for _ in range(len(conds))] | |
| to_run = [] | |
| batched_cfg = model_options.get("batched_cfg", True) | |
| for i, cond in enumerate(conds): | |
| if cond is not None: | |
| for x in cond: | |
| p = ksampler_util.get_area_and_mult(x, x_in, timestep) | |
| if p is not None: | |
| to_run.append((p, i)) | |
| while to_run: | |
| first = to_run[0] | |
| first_shape = first[0][0].shape | |
| first_cond_index = first[1] | |
| # Find compatible conditions | |
| to_batch_temp = [ | |
| x | |
| for x in range(len(to_run)) | |
| if cond_util.can_concat_cond(to_run[x][0], first[0]) | |
| and (batched_cfg or to_run[x][1] == first_cond_index) | |
| ] | |
| to_batch_temp.reverse() | |
| to_batch = to_batch_temp[:1] | |
| # Batch size optimization based on memory | |
| free_memory = Device.get_free_memory(x_in.device) | |
| for i in range(1, len(to_batch_temp) + 1): | |
| batch_amount = to_batch_temp[:len(to_batch_temp) // i] | |
| input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] | |
| if model.memory_required(input_shape) * 1.5 < free_memory: | |
| to_batch = batch_amount | |
| break | |
| # Collect batch data | |
| input_x_list, mult, c_list, cond_or_uncond, area = [], [], [], [], [] | |
| batch_sizes, batch_indices_list = [], [] | |
| control, patches = None, None | |
| for x in to_batch: | |
| o = to_run.pop(x) | |
| p = o[0] | |
| input_x_list.append(p.input_x) | |
| batch_sizes.append(p.input_x.shape[0]) | |
| batch_indices_list.append(p.batch_indices) | |
| mult.append(p.mult) | |
| c_list.append(p.conditioning) | |
| area.append(p.area) | |
| cond_or_uncond.append(o[1]) | |
| control, patches = p.control, p.patches | |
| batch_chunks = len(cond_or_uncond) | |
| input_x = torch.cat(input_x_list) | |
| c = cond_util.cond_cat(c_list) | |
| device = input_x.device | |
| # Build timestep tensor | |
| per_chunk_timesteps = [_build_timestep_for_chunk(timestep, s, b, x_in, device) | |
| for s, b in zip(batch_sizes, batch_indices_list)] | |
| timestep_ = torch.cat(per_chunk_timesteps) | |
| if control is not None: | |
| c["control"] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) | |
| # Handle transformer options and patches | |
| transformer_options = model_options.get("transformer_options", {}).copy() | |
| # Merge any per-condition transformer options (e.g. from ADetailer crop conditioning) | |
| for cond_item in c_list: | |
| if isinstance(cond_item, dict): | |
| per_to = cond_item.get("transformer_options") | |
| if isinstance(per_to, dict): | |
| for k, v in per_to.items(): | |
| try: | |
| transformer_options[k] = int(v) | |
| except Exception: | |
| transformer_options[k] = v | |
| if patches is not None: | |
| cur_patches = transformer_options.get("patches", {}).copy() | |
| for p in patches: | |
| cur_patches[p] = cur_patches.get(p, []) + patches[p] | |
| transformer_options["patches"] = cur_patches | |
| transformer_options["cond_or_uncond"] = cond_or_uncond[:] | |
| transformer_options["sigmas"] = timestep_ | |
| # Validate image sizing if present and log helpful diagnostics | |
| try: | |
| if "img_h" in transformer_options and "img_w" in transformer_options: | |
| token_h = transformer_options["img_h"] // 16 | |
| token_w = transformer_options["img_w"] // 16 | |
| if token_h != input_x.shape[2] or token_w != input_x.shape[3]: | |
| logging.info("calc_cond_batch: transformer_options img_h/img_w %r -> tokens %dx%d doesn't match input_x grid %dx%d; falling back to per-chunk", | |
| (transformer_options.get("img_h"), transformer_options.get("img_w")), token_h, token_w, input_x.shape[2], input_x.shape[3]) | |
| # Fall back to running the model on each chunk individually to avoid RoPE/positional-embedding mismatches. | |
| output_parts = _run_model_per_chunk(model, x_in, timestep, input_x_list, c_list, batch_sizes, batch_indices_list, cond_or_uncond, model_options) | |
| # Apply outputs immediately and continue with next batch | |
| for o in range(batch_chunks): | |
| cond_index = cond_or_uncond[o] | |
| a = area[o] | |
| out_part = output_parts[o] | |
| batch_inds = batch_indices_list[o] | |
| if a is None: | |
| _apply_output_no_area(out_conds, out_counts, cond_index, out_part, mult[o], batch_inds) | |
| else: | |
| _apply_output_with_area(out_conds, out_counts, cond_index, out_part, mult[o], batch_inds, a) | |
| continue | |
| except Exception as ex: | |
| logging.debug("calc_cond_batch: transformer_options validation failed: %s", ex) | |
| c["transformer_options"] = transformer_options | |
| # Run model | |
| expected_sum = sum(batch_sizes) | |
| if input_x.shape[0] != expected_sum: | |
| output_parts = _run_model_per_chunk(model, x_in, timestep, input_x_list, c_list, batch_sizes, batch_indices_list, cond_or_uncond, model_options) | |
| else: | |
| try: | |
| if "model_function_wrapper" in model_options: | |
| full_out = model_options["model_function_wrapper"]( | |
| model.apply_model, | |
| {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}) | |
| else: | |
| full_out = model.apply_model(input_x, timestep_, **c) | |
| # Robust split: ensure sum matches full_out.shape[0] | |
| actual_out_batch = full_out.shape[0] | |
| if actual_out_batch != expected_sum: | |
| # If model returned more/fewer items than expected (e.g. HiDiffusion internal batching) | |
| # use actual_out_batch to prevent torch.split crash | |
| split_size = actual_out_batch // len(batch_sizes) | |
| if split_size > 0: | |
| output_parts = list(torch.split(full_out, split_size, dim=0)) | |
| else: | |
| # Fallback for single item output | |
| output_parts = [full_out] * len(batch_sizes) | |
| else: | |
| output_parts = list(torch.split(full_out, batch_sizes, dim=0)) | |
| except Exception as e: | |
| logging.exception("Fast-path model call failed, falling back to per-chunk: %s; input_x.shape=%s; transformer_options=%s", | |
| e, input_x.shape, transformer_options) | |
| output_parts = _run_model_per_chunk(model, x_in, timestep, input_x_list, c_list, batch_sizes, batch_indices_list, cond_or_uncond, model_options) | |
| # Apply outputs | |
| for o in range(batch_chunks): | |
| cond_index = cond_or_uncond[o] | |
| a = area[o] | |
| out_part = output_parts[o] | |
| batch_inds = batch_indices_list[o] | |
| if a is None: | |
| _apply_output_no_area(out_conds, out_counts, cond_index, out_part, mult[o], batch_inds) | |
| else: | |
| _apply_output_with_area(out_conds, out_counts, cond_index, out_part, mult[o], batch_inds, a) | |
| # Final normalization | |
| for i in range(len(out_conds)): | |
| out_conds[i].div_(out_counts[i]) | |
| return out_conds | |
| def _apply_output_no_area(out_conds, out_counts, cond_index, out_part, mult, batch_inds): | |
| """Apply output without area specification.""" | |
| if batch_inds is None: | |
| # Ensure out_part matches batch size of target | |
| if out_part.shape[0] != out_conds[cond_index].shape[0]: | |
| out_part = out_part[:out_conds[cond_index].shape[0]] | |
| mult = mult[:out_counts[cond_index].shape[0]] | |
| out_conds[cond_index] += out_part * mult | |
| out_counts[cond_index] += mult | |
| else: | |
| dev = out_conds[cond_index].device | |
| max_batch = out_conds[cond_index].shape[0] | |
| valid = [int(b) for b in batch_inds if -max_batch <= int(b) < max_batch] | |
| if not valid: | |
| return | |
| idx = torch.tensor(valid, dtype=torch.long, device=dev) | |
| # Slice out_part to match the number of valid indices | |
| out_part_final = out_part[:idx.shape[0]] | |
| mult_final = mult[:idx.shape[0]] | |
| out_conds[cond_index][idx] += out_part_final * mult_final | |
| out_counts[cond_index][idx] += mult_final | |
| def _apply_output_with_area(out_conds, out_counts, cond_index, out_part, mult, batch_inds, a): | |
| """Apply output with area specification.""" | |
| dims = len(a) // 2 | |
| starts, sizes = a[dims:], a[:dims] | |
| if dims == 2: | |
| H, W = out_conds[cond_index].shape[2], out_conds[cond_index].shape[3] | |
| y0, x0 = max(0, int(starts[0])), max(0, int(starts[1])) | |
| y1, x1 = min(H, y0 + max(0, int(sizes[0]))), min(W, x0 + max(0, int(sizes[1]))) | |
| if y1 <= y0 or x1 <= x0: | |
| return | |
| region_h, region_w = y1 - y0, x1 - x0 | |
| out_part_crop = out_part[..., :region_h, :region_w] | |
| mult_crop = mult[..., :region_h, :region_w] | |
| if batch_inds is None: | |
| # Ensure out_part matches batch size of target if not using indices | |
| if out_part_crop.shape[0] != out_conds[cond_index].shape[0]: | |
| out_part_crop = out_part_crop[:out_conds[cond_index].shape[0]] | |
| mult_crop = mult_crop[:out_counts[cond_index].shape[0]] | |
| out_conds[cond_index][:, :, y0:y1, x0:x1] += out_part_crop * mult_crop | |
| out_counts[cond_index][:, :, y0:y1, x0:x1] += mult_crop | |
| else: | |
| dev = out_conds[cond_index].device | |
| max_batch = out_conds[cond_index].shape[0] | |
| valid = [int(b) for b in batch_inds if -max_batch <= int(b) < max_batch] | |
| if not valid: | |
| return | |
| idx = torch.tensor(valid, dtype=torch.long, device=dev) | |
| # Slice out_part to match the number of valid indices | |
| out_part_final = out_part_crop[:idx.shape[0]] | |
| mult_final = mult_crop[:idx.shape[0]] | |
| out_conds[cond_index][idx, :, y0:y1, x0:x1] += out_part_final * mult_final | |
| out_counts[cond_index][idx, :, y0:y1, x0:x1] += mult_final | |
| def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs) -> list: | |
| """Encode model conditions.""" | |
| for t in range(len(conds)): | |
| x = conds[t] | |
| params = x.copy() | |
| params["device"] = device | |
| params["noise"] = noise | |
| downscale_factor = 8 | |
| if hasattr(model_function, "__self__"): | |
| model = model_function.__self__ | |
| if hasattr(model, "latent_format") and hasattr(model.latent_format, "downscale_factor"): | |
| downscale_factor = model.latent_format.downscale_factor | |
| if len(noise.shape) >= 4: | |
| params["width"] = params.get("width", noise.shape[3] * downscale_factor) | |
| params["height"] = params.get("height", noise.shape[2] * downscale_factor) | |
| else: | |
| params["height"] = params.get("height", noise.shape[2] * downscale_factor) | |
| params["prompt_type"] = params.get("prompt_type", prompt_type) | |
| params.update({k: v for k, v in kwargs.items() if k not in params}) | |
| out = model_function(**params) | |
| x = x.copy() | |
| model_conds = x["model_conds"].copy() | |
| model_conds.update(out) | |
| x["model_conds"] = model_conds | |
| conds[t] = x | |
| return conds | |
| def resolve_areas_and_cond_masks_multidim(conditions, dims, device): | |
| """Process areas and masks for conditions.""" | |
| for i, c in enumerate(conditions): | |
| if "area" in c: | |
| area = c["area"] | |
| if area[0] == "percentage": | |
| a = area[1:] | |
| a_len = len(a) // 2 | |
| first = [max(1, int(round(a[j] * (dims[j] if j < len(dims) else dims[-1])))) for j in range(a_len)] | |
| second = [int(round(a[j] * (dims[j - a_len] if j - a_len < len(dims) else dims[-1]))) for j in range(a_len, 2 * a_len)] | |
| conditions[i] = {**c, "area": tuple(first) + tuple(second)} | |
| if "mask" in c: | |
| mask = c["mask"].to(device=device) | |
| if len(mask.shape) == len(dims): | |
| mask = mask.unsqueeze(0) | |
| if mask.shape[1:] != dims: | |
| mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode="bilinear", align_corners=False).squeeze(1) | |
| conditions[i] = {**c, "mask": mask} | |
| def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None) -> dict: | |
| """Process all conditions.""" | |
| for k in conds: | |
| conds[k] = conds[k][:] | |
| resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device) | |
| for k in conds: | |
| ksampler_util.calculate_start_end_timesteps(model, conds[k]) | |
| if hasattr(model, "extra_conds"): | |
| for k in conds: | |
| conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, | |
| latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) | |
| # Ensure matching areas | |
| for k in conds: | |
| for c in conds[k]: | |
| for kk in conds: | |
| if k != kk: | |
| cond_util.create_cond_with_same_area_if_none(conds[kk], c) | |
| for k in conds: | |
| ksampler_util.pre_run_control(model, conds[k]) | |
| if "positive" in conds: | |
| positive = conds["positive"] | |
| for k in conds: | |
| if k != "positive": | |
| ksampler_util.apply_empty_x_to_equal_area( | |
| [c for c in positive if c.get("control_apply_to_uncond", False)], | |
| conds[k], "control", lambda cond_cnets, x: cond_cnets[x]) | |
| ksampler_util.apply_empty_x_to_equal_area(positive, conds[k], "gligen", lambda cond_cnets, x: cond_cnets[x]) | |
| return conds | |