import numpy as np import torch def _timestep_transform(t, shift=5.0, num_timesteps=1000): t = t / num_timesteps new_t = shift * t / (1 + (shift - 1) * t) return new_t * num_timesteps class EulerSchedulerOutput: def __init__(self, prev_sample, pred_original_sample=None): self.prev_sample = prev_sample if pred_original_sample is not None: self.pred_original_sample = pred_original_sample def __getitem__(self, index): if index == 0: return self.prev_sample raise IndexError("EulerSchedulerOutput only supports index 0.") def __iter__(self): yield self.prev_sample class EulerScheduler: is_stateful = False def __init__(self, num_train_timesteps=1000, use_timestep_transform=True): self.num_train_timesteps = num_train_timesteps self.use_timestep_transform = use_timestep_transform self.timesteps = None self.num_inference_steps = None def set_timesteps(self, num_inference_steps, device=None, shift=5.0): self.num_inference_steps = num_inference_steps timesteps = list( np.linspace(self.num_train_timesteps, 1, num_inference_steps, dtype=np.float32) ) timesteps.append(0.0) if device is None: timesteps = [torch.tensor([t]) for t in timesteps] else: timesteps = [torch.tensor([t], device=device) for t in timesteps] if self.use_timestep_transform: timesteps = [ _timestep_transform(t, shift=shift, num_timesteps=self.num_train_timesteps) for t in timesteps ][:-1] self.timesteps = torch.tensor(timesteps) return self.timesteps def _timestep_to_index(self, timestep): if self.timesteps is None: raise ValueError("Timesteps are not set. Call set_timesteps first.") if torch.is_tensor(timestep): if timestep.numel() != 1: t_val = timestep.flatten()[0].item() else: t_val = timestep.item() else: t_val = float(timestep) diff = (self.timesteps - t_val).abs() idx = int(torch.argmin(diff).item()) return idx, t_val def step(self, model_output, timestep, sample, return_dict=True, **kwargs): if self.timesteps is None: raise ValueError("Timesteps are not set. Call set_timesteps first.") idx, t_val = self._timestep_to_index(timestep) if idx + 1 < len(self.timesteps): dt_raw = self.timesteps[idx] - self.timesteps[idx + 1] else: dt_raw = self.timesteps[idx] dt = dt_raw.item() / self.num_train_timesteps prev_sample = sample - model_output * dt pred_original_sample = sample - (t_val / self.num_train_timesteps) * model_output if not return_dict: return (prev_sample,) return EulerSchedulerOutput( prev_sample=prev_sample, pred_original_sample=pred_original_sample, ) def scale_model_input(self, sample, *args, **kwargs): return sample