Spaces:
Sleeping
Sleeping
| from koja_diffuser.model import DiffusionTranslate | |
| from koja_diffuser.runtime.schedule import DiffusionSchedule | |
| from koja_diffuser.runtime.bridge_utils import bridge_forward | |
| import torch | |
| from torch import Tensor | |
| import dataclasses | |
| from koja_diffuser.util import Emitter, noop_emitter, tensor_norm | |
| class DdimConfig: | |
| start_timestep: int = 999 | |
| num_steps: int = 6 | |
| use_checkpoint: bool = False | |
| return_trace: bool = False | |
| def ddim_step( | |
| *, | |
| bridge: DiffusionTranslate, | |
| schedule: DiffusionSchedule, | |
| x: Tensor, | |
| guide: Tensor, | |
| t: Tensor, | |
| t_next: Tensor | None, | |
| guide_encoded: Tensor | None = None, | |
| use_checkpoint=False, | |
| ): | |
| eps_pred = bridge_forward( | |
| bridge=bridge, | |
| x=x, | |
| guide=guide, | |
| t=t, | |
| guide_encoded=guide_encoded, | |
| use_checkpoint=use_checkpoint, | |
| ) | |
| x0_pred = schedule.predict_x0_from_eps( | |
| x, | |
| t, | |
| eps_pred, | |
| ) | |
| # 마지막 step이면 x0를 직접 반환 | |
| if t_next is None: | |
| return x0_pred | |
| alpha_bar_next = schedule.extract( | |
| schedule.alpha_bars, | |
| t_next, | |
| x, | |
| ) | |
| x_next = alpha_bar_next.sqrt() * x0_pred + (1.0 - alpha_bar_next).sqrt() * eps_pred | |
| return x_next | |
| async def ddim_sample_bridge( | |
| *, | |
| bridge: DiffusionTranslate, | |
| schedule: DiffusionSchedule, | |
| guide: Tensor, | |
| config: DdimConfig, | |
| generator: torch.Generator | None = None, | |
| emit: Emitter = noop_emitter, | |
| ): | |
| batch_size = guide.size(0) | |
| model_size = guide.size(-1) | |
| device = guide.device | |
| target_latent_size = bridge.target_latent_size | |
| steps = config.num_steps | |
| x = torch.randn( | |
| batch_size, target_latent_size, model_size, device=device, generator=generator | |
| ) | |
| grad_step_ids = torch.linspace( | |
| config.start_timestep, | |
| 0, | |
| steps=config.num_steps, | |
| device=device, | |
| ).long() | |
| guide_encoded = bridge.encode_guide(guide) | |
| await emit( | |
| "ddim.init_noise", | |
| { | |
| "step": -1, | |
| "guide_encoded": tensor_norm(guide_encoded), | |
| "x": tensor_norm(x), | |
| }, | |
| ) | |
| for i in range(steps): | |
| t = grad_step_ids[i].expand(batch_size) | |
| if i == steps - 1: | |
| t_next = None | |
| else: | |
| t_next = grad_step_ids[i + 1].expand(batch_size) | |
| x = ddim_step( | |
| bridge=bridge, | |
| schedule=schedule, | |
| x=x, | |
| guide=guide, | |
| t=t, | |
| t_next=t_next, | |
| guide_encoded=guide_encoded, | |
| use_checkpoint=config.use_checkpoint, | |
| ) | |
| if not (i == steps - 1 and grad_step_ids[i].item() == 0): | |
| await emit( | |
| "ddim.step", | |
| { | |
| "step": i, | |
| "x": tensor_norm(x), | |
| }, | |
| ) | |
| return x | |