| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Callable, List, Optional |
| | import torch |
| | from einops import rearrange |
| | from torch import nn |
| |
|
| | from common.cache import Cache |
| | from common.distributed.ops import slice_inputs |
| |
|
| | |
| | ada_layer_type = Callable[[int, int], nn.Module] |
| |
|
| |
|
| | def get_ada_layer(ada_layer: str) -> ada_layer_type: |
| | if ada_layer == "single": |
| | return AdaSingle |
| | raise NotImplementedError(f"{ada_layer} is not supported") |
| |
|
| |
|
| | def expand_dims(x: torch.Tensor, dim: int, ndim: int): |
| | """ |
| | Expand tensor "x" to "ndim" by adding empty dims at "dim". |
| | Example: x is (b d), target ndim is 5, add dim at 1, return (b 1 1 1 d). |
| | """ |
| | shape = x.shape |
| | shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] |
| | return x.reshape(shape) |
| |
|
| |
|
| | class AdaSingle(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | emb_dim: int, |
| | layers: List[str], |
| | ): |
| | assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" |
| | super().__init__() |
| | self.dim = dim |
| | self.emb_dim = emb_dim |
| | self.layers = layers |
| | for l in layers: |
| | self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) |
| | self.register_parameter(f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1)) |
| | self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) |
| |
|
| | def forward( |
| | self, |
| | hid: torch.FloatTensor, |
| | emb: torch.FloatTensor, |
| | layer: str, |
| | mode: str, |
| | cache: Cache = Cache(disable=True), |
| | branch_tag: str = "", |
| | hid_len: Optional[torch.LongTensor] = None, |
| | ) -> torch.FloatTensor: |
| | idx = self.layers.index(layer) |
| | emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] |
| | emb = expand_dims(emb, 1, hid.ndim + 1) |
| |
|
| | if hid_len is not None: |
| | emb = cache( |
| | f"emb_repeat_{idx}_{branch_tag}", |
| | lambda: slice_inputs( |
| | torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]), |
| | dim=0, |
| | ), |
| | ) |
| |
|
| | shiftA, scaleA, gateA = emb.unbind(-1) |
| | shiftB, scaleB, gateB = ( |
| | getattr(self, f"{layer}_shift"), |
| | getattr(self, f"{layer}_scale"), |
| | getattr(self, f"{layer}_gate"), |
| | ) |
| |
|
| | if mode == "in": |
| | return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) |
| | if mode == "out": |
| | return hid.mul_(gateA + gateB) |
| | raise NotImplementedError |
| |
|
| | def extra_repr(self) -> str: |
| | return f"dim={self.dim}, emb_dim={self.emb_dim}, layers={self.layers}" |