| import torch | |
| import copy | |
| from diffusers.utils.torch_utils import randn_tensor | |
| def is_int_string(s: str) -> bool: | |
| try: | |
| int(s) | |
| return True | |
| except ValueError: | |
| return False | |
| def _normalize_single_self_refiner_plan_from_str(plan_str): | |
| entries = [] | |
| if not plan_str.strip(): | |
| return [], "" | |
| for chunk in plan_str.split(","): | |
| chunk = chunk.strip() | |
| if not chunk: | |
| continue | |
| if ":" not in chunk: | |
| return [], f"Invalid format in '{chunk}'. Entries must be in 'start-end:steps' format." | |
| range_part, steps_part = chunk.split(":", 1) | |
| range_part = range_part.strip() | |
| steps_part = steps_part.strip() | |
| if not steps_part: | |
| return [], f"Missing step count in '{chunk}'." | |
| if "-" in range_part: | |
| start_s, end_s = range_part.split("-", 1) | |
| else: | |
| start_s = end_s = range_part | |
| start_s = start_s.strip() | |
| end_s = end_s.strip() | |
| if not is_int_string(start_s) or not is_int_string(end_s): | |
| return [], f"Range '{range_part}' must contain integers." | |
| if not is_int_string(steps_part): | |
| return [], f"Steps '{steps_part}' must be an integer." | |
| entries.append({ | |
| "start": int(start_s), | |
| "end": int(end_s), | |
| "steps": int(steps_part), | |
| }) | |
| entries.sort(key=lambda x: x["start"]) | |
| return entries, "" | |
| def convert_refiner_list_to_string(rules_list): | |
| parts = [] | |
| for r in rules_list: | |
| if isinstance(r, dict): | |
| start = r.get("start") | |
| end = r.get("end") | |
| steps = r.get("steps") | |
| if start == end: | |
| parts.append(f"{start}:{steps}") | |
| else: | |
| parts.append(f"{start}-{end}:{steps}") | |
| return ",".join(parts) | |
| def normalize_self_refiner_plan(plan_input, max_plans: int = 1): | |
| if plan_input is None: | |
| return [[]], "" | |
| if isinstance(plan_input, list): | |
| cleaned_plan = [] | |
| for rule in plan_input: | |
| if isinstance(rule, dict) and 'start' in rule and 'end' in rule: | |
| cleaned_plan.append(rule) | |
| return [cleaned_plan], "" | |
| plan_str = str(plan_input).strip() | |
| if not plan_str: | |
| return [[]], "" | |
| segments = [seg.strip() for seg in plan_str.split(";")] | |
| if max_plans > 0 and len(segments) > max_plans: | |
| pass | |
| plans = [] | |
| for seg in segments: | |
| if not seg: | |
| plans.append([]) | |
| continue | |
| plan_rules, error = _normalize_single_self_refiner_plan_from_str(seg) | |
| if error: | |
| return [], error | |
| plans.append(plan_rules) | |
| return plans, "" | |
| def ensure_refiner_list(plan_data): | |
| if isinstance(plan_data, list): | |
| return plan_data | |
| if isinstance(plan_data, str): | |
| plans, _ = normalize_self_refiner_plan(plan_data) | |
| if plans and len(plans) > 0: | |
| return plans[0] | |
| return [] | |
| def add_refiner_rule(current_rules, range_val, steps_val): | |
| current_rules = ensure_refiner_list(current_rules) | |
| if isinstance(range_val, str): | |
| raw_range = range_val.strip().replace(",", "-").replace(":", "-") | |
| if "-" in raw_range: | |
| start_s, end_s = raw_range.split("-", 1) | |
| else: | |
| start_s = end_s = raw_range | |
| new_start, new_end = int(start_s.strip()), int(end_s.strip()) | |
| else: | |
| new_start, new_end = int(range_val[0]), int(range_val[1]) | |
| if new_start > new_end: | |
| new_start, new_end = new_end, new_start | |
| for rule in current_rules: | |
| if new_start <= rule['end'] and new_end >= rule['start']: | |
| from gradio import Info | |
| Info(f"Overlap detected! Steps {new_start}-{new_end} conflict with existing rule {rule['start']}-{rule['end']}.") | |
| return current_rules | |
| new_rule = { | |
| "start": new_start, | |
| "end": new_end, | |
| "steps": int(steps_val) | |
| } | |
| updated_list = current_rules + [new_rule] | |
| return sorted(updated_list, key=lambda x: x['start']) | |
| def remove_refiner_rule(current_rules, index): | |
| current_rules = ensure_refiner_list(current_rules) | |
| if 0 <= index < len(current_rules): | |
| current_rules.pop(index) | |
| return current_rules | |
| class PnPHandler: | |
| def __init__(self, stochastic_plan, ths_uncertainty=0.0, p_norm=1, certain_percentage=0.999, channel_dim: int = 1): | |
| self.stochastic_step_map = self._build_stochastic_step_map(stochastic_plan) | |
| self.ths_uncertainty = ths_uncertainty | |
| self.p_norm = p_norm | |
| self.certain_percentage = certain_percentage | |
| self.channel_dim = channel_dim | |
| self.buffer = [None] | |
| self.certain_flag = False | |
| def _build_stochastic_step_map(self, plan): | |
| step_map = {} | |
| if not plan: | |
| return step_map | |
| for entry in plan: | |
| if isinstance(entry, dict): | |
| start = entry.get("start", entry.get("begin")) | |
| end = entry.get("end", entry.get("stop")) | |
| steps = entry.get("steps", entry.get("anneal", entry.get("num_anneal_steps", 1))) | |
| elif isinstance(entry, (list, tuple)): | |
| start, end, steps = entry[0], entry[1], entry[2] | |
| else: | |
| continue | |
| start_i = int(start) | |
| end_i = int(end) | |
| steps_i = int(steps) | |
| if steps_i > 0: | |
| for idx in range(start_i, end_i + 1): | |
| step_map[idx] = steps_i | |
| return step_map | |
| def get_anneal_steps(self, step_index): | |
| return self.stochastic_step_map.get(step_index, 0) | |
| def reset_buffer(self): | |
| self.buffer = [None] | |
| self.certain_flag = False | |
| def process_step(self, latents, noise_pred, sigma, sigma_next, generator=None, device=None, latents_next=None, pred_original_sample=None): | |
| if pred_original_sample is None: | |
| pred_original_sample = latents - sigma * noise_pred | |
| if latents_next is None: | |
| latents_next = latents + (sigma_next - sigma) * noise_pred | |
| if self.buffer[-1] is not None: | |
| diff = pred_original_sample - self.buffer[-1][1] | |
| channel_dim = self.channel_dim | |
| if channel_dim < 0: | |
| channel_dim += latents.ndim | |
| uncertainty = torch.norm(diff, p=self.p_norm, dim=channel_dim) / latents.shape[channel_dim] | |
| certain_mask = uncertainty < self.ths_uncertainty | |
| if self.buffer[-1][0] is not None: | |
| certain_mask = certain_mask | self.buffer[-1][0] | |
| if certain_mask.sum() / certain_mask.numel() > self.certain_percentage: | |
| self.certain_flag = True | |
| certain_mask_float = certain_mask.to(latents.dtype).unsqueeze(channel_dim) | |
| latents_next = certain_mask_float * self.buffer[-1][2] + (1.0 - certain_mask_float) * latents_next | |
| pred_original_sample = certain_mask_float * self.buffer[-1][1] + (1.0 - certain_mask_float) * pred_original_sample | |
| certain_mask_stored = certain_mask | |
| else: | |
| certain_mask_stored = None | |
| self.buffer.append([certain_mask_stored, pred_original_sample, latents_next]) | |
| return latents_next | |
| def perturb_latents(self, latents, buffer_latent, sigma, generator=None, device=None, noise_mask=None): | |
| noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) | |
| if noise_mask is None: | |
| return (1.0 - sigma) * buffer_latent + sigma * noise | |
| sigma_t = (noise_mask.to(latents.dtype) * sigma) | |
| return (1.0 - sigma_t) * buffer_latent + sigma_t * noise | |
| def run_refinement_loop(self, latents, noise_pred, current_sigma, next_sigma, m_steps, denoise_func, step_func, clone_func=None, restore_func=None, generator=None, device=None, noise_mask=None): | |
| if noise_pred is None: | |
| return None | |
| scheduler_state = None | |
| if clone_func: | |
| scheduler_state = clone_func() | |
| latents_next_0, pred_original_sample_0 = step_func(noise_pred, latents) | |
| if latents_next_0 is None or pred_original_sample_0 is None: | |
| return None | |
| latents_next = self.process_step( | |
| latents, noise_pred, current_sigma, next_sigma, | |
| latents_next=latents_next_0, pred_original_sample=pred_original_sample_0 | |
| ) | |
| if self.certain_flag: | |
| return latents_next | |
| for ii in range(1, m_steps): | |
| if restore_func and scheduler_state is not None: | |
| restore_func(scheduler_state) | |
| latents_perturbed = self.perturb_latents( | |
| latents, | |
| self.buffer[-1][1], | |
| current_sigma, | |
| generator=generator, | |
| device=device, | |
| noise_mask=noise_mask, | |
| ) | |
| n_pred = denoise_func(latents_perturbed) | |
| if n_pred is None: | |
| return None | |
| latents_next_loop, pred_original_sample_loop = step_func(n_pred, latents_perturbed) | |
| if latents_next_loop is None or pred_original_sample_loop is None: | |
| return None | |
| latents_next = self.process_step( | |
| latents_perturbed, n_pred, current_sigma, next_sigma, | |
| latents_next=latents_next_loop, pred_original_sample=pred_original_sample_loop | |
| ) | |
| if self.certain_flag: | |
| break | |
| return latents_next | |
| def step(self, step_index, latents, noise_pred, t, timesteps, target_shape, seed_g, sample_scheduler, scheduler_kwargs, denoise_func): | |
| if noise_pred is None: | |
| return None, sample_scheduler | |
| self.reset_buffer() | |
| current_sigma = t.item() / 1000.0 | |
| next_sigma = (0. if step_index == len(timesteps)-1 else timesteps[step_index+1].item()) / 1000.0 | |
| m_steps = self.get_anneal_steps(step_index) | |
| if m_steps > 1 and not self.certain_flag: | |
| def _get_prev_sample(step_out): | |
| if hasattr(step_out, "prev_sample"): | |
| return step_out.prev_sample | |
| if isinstance(step_out, (tuple, list)): | |
| return step_out[0] | |
| return step_out | |
| def _get_pred_original_sample(step_out, latents_in, n_pred_sliced): | |
| if hasattr(step_out, "pred_original_sample"): | |
| return step_out.pred_original_sample | |
| t_val = t.item() if torch.is_tensor(t) else float(t) | |
| return latents_in - (t_val / 1000.0) * n_pred_sliced | |
| def step_func(n_pred_in, latents_in): | |
| n_pred_sliced = n_pred_in[:, :latents_in.shape[1], :target_shape[1]] | |
| nonlocal sample_scheduler | |
| step_out = sample_scheduler.step(n_pred_sliced, t, latents_in, **scheduler_kwargs) | |
| latents_next_out = _get_prev_sample(step_out) | |
| pred_original_sample_out = _get_pred_original_sample(step_out, latents_in, n_pred_sliced) | |
| return latents_next_out, pred_original_sample_out | |
| def clone_func(): | |
| if sample_scheduler is None: | |
| return None | |
| if getattr(sample_scheduler, "is_stateful", True): | |
| return copy.deepcopy(sample_scheduler) | |
| return None | |
| def restore_func(saved_state): | |
| nonlocal sample_scheduler | |
| if saved_state: | |
| sample_scheduler = copy.deepcopy(saved_state) | |
| latents = self.run_refinement_loop( | |
| latents=latents, | |
| noise_pred=noise_pred, | |
| current_sigma=current_sigma, | |
| next_sigma=next_sigma, | |
| m_steps=m_steps, | |
| denoise_func=denoise_func, | |
| step_func=step_func, | |
| clone_func=clone_func, | |
| restore_func=restore_func, | |
| generator=seed_g, | |
| device=latents.device | |
| ) | |
| if latents is None: | |
| return None, sample_scheduler | |
| else: | |
| n_pred_sliced = noise_pred[:, :latents.shape[1], :target_shape[1]] | |
| step_out = sample_scheduler.step( n_pred_sliced, t, latents, **scheduler_kwargs) | |
| if hasattr(step_out, "prev_sample"): | |
| latents = step_out.prev_sample | |
| elif isinstance(step_out, (tuple, list)): | |
| latents = step_out[0] | |
| else: | |
| latents = step_out | |
| return latents, sample_scheduler | |
| def create_self_refiner_handler(pnp_plan, pnp_f_uncertainty, pnp_p_norm, pnp_certain_percentage, channel_dim: int = 1): | |
| plans, _ = normalize_self_refiner_plan(pnp_plan, max_plans=2) | |
| stochastic_plan = None | |
| if plans and len(plans) > 0: | |
| stochastic_plan = plans[0] | |
| if not stochastic_plan: | |
| stochastic_plan = [ | |
| {"start": 1, "end": 5, "steps": 3}, | |
| {"start": 6, "end": 13, "steps": 1}, | |
| ] | |
| return PnPHandler( | |
| stochastic_plan, | |
| ths_uncertainty=pnp_f_uncertainty, | |
| p_norm=pnp_p_norm, | |
| certain_percentage=pnp_certain_percentage, | |
| channel_dim=channel_dim, | |
| ) | |
| def run_refinement_loop_multi( | |
| handlers, | |
| latents_list, | |
| noise_pred_list, | |
| current_sigma, | |
| next_sigma, | |
| m_steps, | |
| denoise_func, | |
| step_func, | |
| generators=None, | |
| devices=None, | |
| noise_masks=None, | |
| stop_when: str = "all", | |
| ): | |
| if m_steps <= 1: | |
| return latents_list | |
| if noise_pred_list is None: | |
| return None | |
| if not isinstance(noise_pred_list, (list, tuple)) or any(pred is None for pred in noise_pred_list): | |
| return None | |
| def _should_stop(): | |
| if stop_when == "any": | |
| return any(handler.certain_flag for handler in handlers) | |
| return all(handler.certain_flag for handler in handlers) | |
| latents_next_list, pred_original_list = step_func(noise_pred_list, latents_list) | |
| if latents_next_list is None or pred_original_list is None: | |
| return None | |
| if len(latents_next_list) != len(handlers) or len(pred_original_list) != len(handlers): | |
| return None | |
| refined_latents_list = [] | |
| for handler, latents, latents_next, pred_original in zip( | |
| handlers, latents_list, latents_next_list, pred_original_list | |
| ): | |
| refined_latents_list.append( | |
| handler.process_step( | |
| latents, | |
| None, | |
| current_sigma, | |
| next_sigma, | |
| latents_next=latents_next, | |
| pred_original_sample=pred_original, | |
| ) | |
| ) | |
| if _should_stop(): | |
| return refined_latents_list | |
| for _ in range(1, m_steps): | |
| perturbed_list = [] | |
| for idx, (handler, latents) in enumerate(zip(handlers, latents_list)): | |
| generator = generators[idx] if generators is not None else None | |
| device = devices[idx] if devices is not None else latents.device | |
| noise_mask = noise_masks[idx] if noise_masks is not None else None | |
| perturbed_list.append( | |
| handler.perturb_latents( | |
| latents, | |
| handler.buffer[-1][1], | |
| current_sigma, | |
| generator=generator, | |
| device=device, | |
| noise_mask=noise_mask, | |
| ) | |
| ) | |
| noise_pred_list = denoise_func(perturbed_list) | |
| if noise_pred_list is None: | |
| return None | |
| latents_next_list, pred_original_list = step_func(noise_pred_list, perturbed_list) | |
| if latents_next_list is None or pred_original_list is None: | |
| return None | |
| refined_latents_list = [] | |
| for handler, latents, latents_next, pred_original in zip( | |
| handlers, perturbed_list, latents_next_list, pred_original_list | |
| ): | |
| refined_latents_list.append( | |
| handler.process_step( | |
| latents, | |
| None, | |
| current_sigma, | |
| next_sigma, | |
| latents_next=latents_next, | |
| pred_original_sample=pred_original, | |
| ) | |
| ) | |
| if _should_stop(): | |
| break | |
| return refined_latents_list | |