Spaces:
Sleeping
Sleeping
File size: 2,908 Bytes
e0552b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | from koja_diffuser.model import DiffusionTranslate
from koja_diffuser.runtime.schedule import DiffusionSchedule
from koja_diffuser.runtime.bridge_utils import bridge_forward
import torch
from torch import Tensor
import dataclasses
from koja_diffuser.util import Emitter, noop_emitter, tensor_norm
@dataclasses.dataclass(frozen=True)
class DdimConfig:
start_timestep: int = 999
num_steps: int = 6
use_checkpoint: bool = False
return_trace: bool = False
def ddim_step(
*,
bridge: DiffusionTranslate,
schedule: DiffusionSchedule,
x: Tensor,
guide: Tensor,
t: Tensor,
t_next: Tensor | None,
guide_encoded: Tensor | None = None,
use_checkpoint=False,
):
eps_pred = bridge_forward(
bridge=bridge,
x=x,
guide=guide,
t=t,
guide_encoded=guide_encoded,
use_checkpoint=use_checkpoint,
)
x0_pred = schedule.predict_x0_from_eps(
x,
t,
eps_pred,
)
# 마지막 step이면 x0를 직접 반환
if t_next is None:
return x0_pred
alpha_bar_next = schedule.extract(
schedule.alpha_bars,
t_next,
x,
)
x_next = alpha_bar_next.sqrt() * x0_pred + (1.0 - alpha_bar_next).sqrt() * eps_pred
return x_next
async def ddim_sample_bridge(
*,
bridge: DiffusionTranslate,
schedule: DiffusionSchedule,
guide: Tensor,
config: DdimConfig,
generator: torch.Generator | None = None,
emit: Emitter = noop_emitter,
):
batch_size = guide.size(0)
model_size = guide.size(-1)
device = guide.device
target_latent_size = bridge.target_latent_size
steps = config.num_steps
x = torch.randn(
batch_size, target_latent_size, model_size, device=device, generator=generator
)
grad_step_ids = torch.linspace(
config.start_timestep,
0,
steps=config.num_steps,
device=device,
).long()
guide_encoded = bridge.encode_guide(guide)
await emit(
"ddim.init_noise",
{
"step": -1,
"guide_encoded": tensor_norm(guide_encoded),
"x": tensor_norm(x),
},
)
for i in range(steps):
t = grad_step_ids[i].expand(batch_size)
if i == steps - 1:
t_next = None
else:
t_next = grad_step_ids[i + 1].expand(batch_size)
x = ddim_step(
bridge=bridge,
schedule=schedule,
x=x,
guide=guide,
t=t,
t_next=t_next,
guide_encoded=guide_encoded,
use_checkpoint=config.use_checkpoint,
)
if not (i == steps - 1 and grad_step_ids[i].item() == 0):
await emit(
"ddim.step",
{
"step": i,
"x": tensor_norm(x),
},
)
return x
|