Spaces:
Sleeping
Sleeping
File size: 832 Bytes
e0552b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import torch
from torch import Tensor
from koja_diffuser.model import DiffusionTranslate
from torch.utils.checkpoint import checkpoint
def bridge_forward(
*,
bridge: DiffusionTranslate,
x: Tensor,
guide: Tensor,
t: Tensor,
guide_encoded: Tensor | None = None,
use_checkpoint: bool,
) -> Tensor:
t_embed = t.float().unsqueeze(-1)
if use_checkpoint and torch.is_grad_enabled():
return checkpoint(
lambda x_, guide_, t_, guide_encoded_: bridge(
x_, guide_, t_, guide_encoded_
),
x,
guide,
t_embed,
guide_encoded,
use_reentrant=False,
preserve_rng_state=False, # bridge dropout 사용시 True, 미사용시 False
)
return bridge(x, guide, t_embed, guide_encoded)
|