Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import Tensor | |
| from koja_diffuser.model import DiffusionTranslate | |
| from torch.utils.checkpoint import checkpoint | |
| def bridge_forward( | |
| *, | |
| bridge: DiffusionTranslate, | |
| x: Tensor, | |
| guide: Tensor, | |
| t: Tensor, | |
| guide_encoded: Tensor | None = None, | |
| use_checkpoint: bool, | |
| ) -> Tensor: | |
| t_embed = t.float().unsqueeze(-1) | |
| if use_checkpoint and torch.is_grad_enabled(): | |
| return checkpoint( | |
| lambda x_, guide_, t_, guide_encoded_: bridge( | |
| x_, guide_, t_, guide_encoded_ | |
| ), | |
| x, | |
| guide, | |
| t_embed, | |
| guide_encoded, | |
| use_reentrant=False, | |
| preserve_rng_state=False, # bridge dropout 사용시 True, 미사용시 False | |
| ) | |
| return bridge(x, guide, t_embed, guide_encoded) | |