KoJaDiffuser / koja_diffuser /runtime /bridge_utils.py
github-actions[bot]
Sync from GitHub 33c12db74322f3d28409b5dc0a8c441914c9178b
e0552b0
raw
history blame contribute delete
832 Bytes
import torch
from torch import Tensor
from koja_diffuser.model import DiffusionTranslate
from torch.utils.checkpoint import checkpoint
def bridge_forward(
*,
bridge: DiffusionTranslate,
x: Tensor,
guide: Tensor,
t: Tensor,
guide_encoded: Tensor | None = None,
use_checkpoint: bool,
) -> Tensor:
t_embed = t.float().unsqueeze(-1)
if use_checkpoint and torch.is_grad_enabled():
return checkpoint(
lambda x_, guide_, t_, guide_encoded_: bridge(
x_, guide_, t_, guide_encoded_
),
x,
guide,
t_embed,
guide_encoded,
use_reentrant=False,
preserve_rng_state=False, # bridge dropout 사용시 True, 미사용시 False
)
return bridge(x, guide, t_embed, guide_encoded)