Wan2GP / shared /utils /self_refiner.py
Egnalkram's picture
Upload folder using huggingface_hub
4689c2b verified
import torch
import copy
from diffusers.utils.torch_utils import randn_tensor
def is_int_string(s: str) -> bool:
try:
int(s)
return True
except ValueError:
return False
def _normalize_single_self_refiner_plan_from_str(plan_str):
entries = []
if not plan_str.strip():
return [], ""
for chunk in plan_str.split(","):
chunk = chunk.strip()
if not chunk:
continue
if ":" not in chunk:
return [], f"Invalid format in '{chunk}'. Entries must be in 'start-end:steps' format."
range_part, steps_part = chunk.split(":", 1)
range_part = range_part.strip()
steps_part = steps_part.strip()
if not steps_part:
return [], f"Missing step count in '{chunk}'."
if "-" in range_part:
start_s, end_s = range_part.split("-", 1)
else:
start_s = end_s = range_part
start_s = start_s.strip()
end_s = end_s.strip()
if not is_int_string(start_s) or not is_int_string(end_s):
return [], f"Range '{range_part}' must contain integers."
if not is_int_string(steps_part):
return [], f"Steps '{steps_part}' must be an integer."
entries.append({
"start": int(start_s),
"end": int(end_s),
"steps": int(steps_part),
})
entries.sort(key=lambda x: x["start"])
return entries, ""
def convert_refiner_list_to_string(rules_list):
parts = []
for r in rules_list:
if isinstance(r, dict):
start = r.get("start")
end = r.get("end")
steps = r.get("steps")
if start == end:
parts.append(f"{start}:{steps}")
else:
parts.append(f"{start}-{end}:{steps}")
return ",".join(parts)
def normalize_self_refiner_plan(plan_input, max_plans: int = 1):
if plan_input is None:
return [[]], ""
if isinstance(plan_input, list):
cleaned_plan = []
for rule in plan_input:
if isinstance(rule, dict) and 'start' in rule and 'end' in rule:
cleaned_plan.append(rule)
return [cleaned_plan], ""
plan_str = str(plan_input).strip()
if not plan_str:
return [[]], ""
segments = [seg.strip() for seg in plan_str.split(";")]
if max_plans > 0 and len(segments) > max_plans:
pass
plans = []
for seg in segments:
if not seg:
plans.append([])
continue
plan_rules, error = _normalize_single_self_refiner_plan_from_str(seg)
if error:
return [], error
plans.append(plan_rules)
return plans, ""
def ensure_refiner_list(plan_data):
if isinstance(plan_data, list):
return plan_data
if isinstance(plan_data, str):
plans, _ = normalize_self_refiner_plan(plan_data)
if plans and len(plans) > 0:
return plans[0]
return []
def add_refiner_rule(current_rules, range_val, steps_val):
current_rules = ensure_refiner_list(current_rules)
if isinstance(range_val, str):
raw_range = range_val.strip().replace(",", "-").replace(":", "-")
if "-" in raw_range:
start_s, end_s = raw_range.split("-", 1)
else:
start_s = end_s = raw_range
new_start, new_end = int(start_s.strip()), int(end_s.strip())
else:
new_start, new_end = int(range_val[0]), int(range_val[1])
if new_start > new_end:
new_start, new_end = new_end, new_start
for rule in current_rules:
if new_start <= rule['end'] and new_end >= rule['start']:
from gradio import Info
Info(f"Overlap detected! Steps {new_start}-{new_end} conflict with existing rule {rule['start']}-{rule['end']}.")
return current_rules
new_rule = {
"start": new_start,
"end": new_end,
"steps": int(steps_val)
}
updated_list = current_rules + [new_rule]
return sorted(updated_list, key=lambda x: x['start'])
def remove_refiner_rule(current_rules, index):
current_rules = ensure_refiner_list(current_rules)
if 0 <= index < len(current_rules):
current_rules.pop(index)
return current_rules
class PnPHandler:
def __init__(self, stochastic_plan, ths_uncertainty=0.0, p_norm=1, certain_percentage=0.999, channel_dim: int = 1):
self.stochastic_step_map = self._build_stochastic_step_map(stochastic_plan)
self.ths_uncertainty = ths_uncertainty
self.p_norm = p_norm
self.certain_percentage = certain_percentage
self.channel_dim = channel_dim
self.buffer = [None]
self.certain_flag = False
def _build_stochastic_step_map(self, plan):
step_map = {}
if not plan:
return step_map
for entry in plan:
if isinstance(entry, dict):
start = entry.get("start", entry.get("begin"))
end = entry.get("end", entry.get("stop"))
steps = entry.get("steps", entry.get("anneal", entry.get("num_anneal_steps", 1)))
elif isinstance(entry, (list, tuple)):
start, end, steps = entry[0], entry[1], entry[2]
else:
continue
start_i = int(start)
end_i = int(end)
steps_i = int(steps)
if steps_i > 0:
for idx in range(start_i, end_i + 1):
step_map[idx] = steps_i
return step_map
def get_anneal_steps(self, step_index):
return self.stochastic_step_map.get(step_index, 0)
def reset_buffer(self):
self.buffer = [None]
self.certain_flag = False
def process_step(self, latents, noise_pred, sigma, sigma_next, generator=None, device=None, latents_next=None, pred_original_sample=None):
if pred_original_sample is None:
pred_original_sample = latents - sigma * noise_pred
if latents_next is None:
latents_next = latents + (sigma_next - sigma) * noise_pred
if self.buffer[-1] is not None:
diff = pred_original_sample - self.buffer[-1][1]
channel_dim = self.channel_dim
if channel_dim < 0:
channel_dim += latents.ndim
uncertainty = torch.norm(diff, p=self.p_norm, dim=channel_dim) / latents.shape[channel_dim]
certain_mask = uncertainty < self.ths_uncertainty
if self.buffer[-1][0] is not None:
certain_mask = certain_mask | self.buffer[-1][0]
if certain_mask.sum() / certain_mask.numel() > self.certain_percentage:
self.certain_flag = True
certain_mask_float = certain_mask.to(latents.dtype).unsqueeze(channel_dim)
latents_next = certain_mask_float * self.buffer[-1][2] + (1.0 - certain_mask_float) * latents_next
pred_original_sample = certain_mask_float * self.buffer[-1][1] + (1.0 - certain_mask_float) * pred_original_sample
certain_mask_stored = certain_mask
else:
certain_mask_stored = None
self.buffer.append([certain_mask_stored, pred_original_sample, latents_next])
return latents_next
def perturb_latents(self, latents, buffer_latent, sigma, generator=None, device=None, noise_mask=None):
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if noise_mask is None:
return (1.0 - sigma) * buffer_latent + sigma * noise
sigma_t = (noise_mask.to(latents.dtype) * sigma)
return (1.0 - sigma_t) * buffer_latent + sigma_t * noise
def run_refinement_loop(self, latents, noise_pred, current_sigma, next_sigma, m_steps, denoise_func, step_func, clone_func=None, restore_func=None, generator=None, device=None, noise_mask=None):
if noise_pred is None:
return None
scheduler_state = None
if clone_func:
scheduler_state = clone_func()
latents_next_0, pred_original_sample_0 = step_func(noise_pred, latents)
if latents_next_0 is None or pred_original_sample_0 is None:
return None
latents_next = self.process_step(
latents, noise_pred, current_sigma, next_sigma,
latents_next=latents_next_0, pred_original_sample=pred_original_sample_0
)
if self.certain_flag:
return latents_next
for ii in range(1, m_steps):
if restore_func and scheduler_state is not None:
restore_func(scheduler_state)
latents_perturbed = self.perturb_latents(
latents,
self.buffer[-1][1],
current_sigma,
generator=generator,
device=device,
noise_mask=noise_mask,
)
n_pred = denoise_func(latents_perturbed)
if n_pred is None:
return None
latents_next_loop, pred_original_sample_loop = step_func(n_pred, latents_perturbed)
if latents_next_loop is None or pred_original_sample_loop is None:
return None
latents_next = self.process_step(
latents_perturbed, n_pred, current_sigma, next_sigma,
latents_next=latents_next_loop, pred_original_sample=pred_original_sample_loop
)
if self.certain_flag:
break
return latents_next
def step(self, step_index, latents, noise_pred, t, timesteps, target_shape, seed_g, sample_scheduler, scheduler_kwargs, denoise_func):
if noise_pred is None:
return None, sample_scheduler
self.reset_buffer()
current_sigma = t.item() / 1000.0
next_sigma = (0. if step_index == len(timesteps)-1 else timesteps[step_index+1].item()) / 1000.0
m_steps = self.get_anneal_steps(step_index)
if m_steps > 1 and not self.certain_flag:
def _get_prev_sample(step_out):
if hasattr(step_out, "prev_sample"):
return step_out.prev_sample
if isinstance(step_out, (tuple, list)):
return step_out[0]
return step_out
def _get_pred_original_sample(step_out, latents_in, n_pred_sliced):
if hasattr(step_out, "pred_original_sample"):
return step_out.pred_original_sample
t_val = t.item() if torch.is_tensor(t) else float(t)
return latents_in - (t_val / 1000.0) * n_pred_sliced
def step_func(n_pred_in, latents_in):
n_pred_sliced = n_pred_in[:, :latents_in.shape[1], :target_shape[1]]
nonlocal sample_scheduler
step_out = sample_scheduler.step(n_pred_sliced, t, latents_in, **scheduler_kwargs)
latents_next_out = _get_prev_sample(step_out)
pred_original_sample_out = _get_pred_original_sample(step_out, latents_in, n_pred_sliced)
return latents_next_out, pred_original_sample_out
def clone_func():
if sample_scheduler is None:
return None
if getattr(sample_scheduler, "is_stateful", True):
return copy.deepcopy(sample_scheduler)
return None
def restore_func(saved_state):
nonlocal sample_scheduler
if saved_state:
sample_scheduler = copy.deepcopy(saved_state)
latents = self.run_refinement_loop(
latents=latents,
noise_pred=noise_pred,
current_sigma=current_sigma,
next_sigma=next_sigma,
m_steps=m_steps,
denoise_func=denoise_func,
step_func=step_func,
clone_func=clone_func,
restore_func=restore_func,
generator=seed_g,
device=latents.device
)
if latents is None:
return None, sample_scheduler
else:
n_pred_sliced = noise_pred[:, :latents.shape[1], :target_shape[1]]
step_out = sample_scheduler.step( n_pred_sliced, t, latents, **scheduler_kwargs)
if hasattr(step_out, "prev_sample"):
latents = step_out.prev_sample
elif isinstance(step_out, (tuple, list)):
latents = step_out[0]
else:
latents = step_out
return latents, sample_scheduler
def create_self_refiner_handler(pnp_plan, pnp_f_uncertainty, pnp_p_norm, pnp_certain_percentage, channel_dim: int = 1):
plans, _ = normalize_self_refiner_plan(pnp_plan, max_plans=2)
stochastic_plan = None
if plans and len(plans) > 0:
stochastic_plan = plans[0]
if not stochastic_plan:
stochastic_plan = [
{"start": 1, "end": 5, "steps": 3},
{"start": 6, "end": 13, "steps": 1},
]
return PnPHandler(
stochastic_plan,
ths_uncertainty=pnp_f_uncertainty,
p_norm=pnp_p_norm,
certain_percentage=pnp_certain_percentage,
channel_dim=channel_dim,
)
def run_refinement_loop_multi(
handlers,
latents_list,
noise_pred_list,
current_sigma,
next_sigma,
m_steps,
denoise_func,
step_func,
generators=None,
devices=None,
noise_masks=None,
stop_when: str = "all",
):
if m_steps <= 1:
return latents_list
if noise_pred_list is None:
return None
if not isinstance(noise_pred_list, (list, tuple)) or any(pred is None for pred in noise_pred_list):
return None
def _should_stop():
if stop_when == "any":
return any(handler.certain_flag for handler in handlers)
return all(handler.certain_flag for handler in handlers)
latents_next_list, pred_original_list = step_func(noise_pred_list, latents_list)
if latents_next_list is None or pred_original_list is None:
return None
if len(latents_next_list) != len(handlers) or len(pred_original_list) != len(handlers):
return None
refined_latents_list = []
for handler, latents, latents_next, pred_original in zip(
handlers, latents_list, latents_next_list, pred_original_list
):
refined_latents_list.append(
handler.process_step(
latents,
None,
current_sigma,
next_sigma,
latents_next=latents_next,
pred_original_sample=pred_original,
)
)
if _should_stop():
return refined_latents_list
for _ in range(1, m_steps):
perturbed_list = []
for idx, (handler, latents) in enumerate(zip(handlers, latents_list)):
generator = generators[idx] if generators is not None else None
device = devices[idx] if devices is not None else latents.device
noise_mask = noise_masks[idx] if noise_masks is not None else None
perturbed_list.append(
handler.perturb_latents(
latents,
handler.buffer[-1][1],
current_sigma,
generator=generator,
device=device,
noise_mask=noise_mask,
)
)
noise_pred_list = denoise_func(perturbed_list)
if noise_pred_list is None:
return None
latents_next_list, pred_original_list = step_func(noise_pred_list, perturbed_list)
if latents_next_list is None or pred_original_list is None:
return None
refined_latents_list = []
for handler, latents, latents_next, pred_original in zip(
handlers, perturbed_list, latents_next_list, pred_original_list
):
refined_latents_list.append(
handler.process_step(
latents,
None,
current_sigma,
next_sigma,
latents_next=latents_next,
pred_original_sample=pred_original,
)
)
if _should_stop():
break
return refined_latents_list