File size: 832 Bytes
e0552b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from torch import Tensor
from koja_diffuser.model import DiffusionTranslate
from torch.utils.checkpoint import checkpoint


def bridge_forward(
    *,
    bridge: DiffusionTranslate,
    x: Tensor,
    guide: Tensor,
    t: Tensor,
    guide_encoded: Tensor | None = None,
    use_checkpoint: bool,
) -> Tensor:
    t_embed = t.float().unsqueeze(-1)

    if use_checkpoint and torch.is_grad_enabled():
        return checkpoint(
            lambda x_, guide_, t_, guide_encoded_: bridge(
                x_, guide_, t_, guide_encoded_
            ),
            x,
            guide,
            t_embed,
            guide_encoded,
            use_reentrant=False,
            preserve_rng_state=False,  # bridge dropout 사용시 True, 미사용시 False
        )

    return bridge(x, guide, t_embed, guide_encoded)