| import torch |
|
|
| def _identity(x: torch.Tensor, f_x: torch.Tensor) -> torch.Tensor: |
| """ |
| linear residual connection |
| x : residual stream |
| f_x : attention/MLP/conv(if channel-wise) output |
| """ |
| return f_x |
|
|
| def _orthogonal_channel(x: torch.Tensor, f_x: torch.Tensor, dim: int, eps: torch.Tensor) -> torch.Tensor: |
| """ |
| orthogonal residual connection |
| x : residual stream |
| f_x : attention/MLP/conv(if channel-wise) output |
| """ |
| dot = (x * f_x).sum(dim, keepdim=True) |
| norm_x2 = (x * x ).sum(dim, keepdim=True) + eps |
| scale = (dot / norm_x2).to(x.dtype) |
| proj_out = scale * x |
| return f_x - proj_out |
|
|
| def _orthogonal_global(x: torch.Tensor, f_x: torch.Tensor, dim: int, eps: torch.Tensor) -> torch.Tensor: |
| """ |
| orthogonal residual connection |
| x : residual stream |
| f_x : conv output |
| """ |
| original_shape = x.shape |
| positive_dim = dim if dim >= 0 else len(original_shape) + dim |
|
|
| x_view = x.flatten(dim) |
| f_view = f_x.flatten(dim) |
| dot = (x_view * f_view).sum(dim=dim, keepdim=True) |
| norm_sq = (x_view * x_view).sum(dim=dim, keepdim=True) + eps |
|
|
| scale = (dot / norm_sq).to(x.dtype) |
| unsqueeze_times = len(original_shape) - positive_dim - 1 |
| for _ in range(unsqueeze_times): |
| scale = scale.unsqueeze(-1) |
| proj_out = scale * x |
| return f_x - proj_out |
|
|
| def connect(x, f_x, *, |
| method="linear", orthogonal_method="global", |
| dim=-1, eps=1e-6, perturbation=None): |
| if perturbation is not None: |
| raise NotImplementedError("perturbation is not implemented yet") |
| if method == "linear": |
| return x + _identity(x, f_x) |
| elif method == "orthogonal": |
| if orthogonal_method == "global": |
| return x + _orthogonal_global(x, f_x, dim, eps) |
| elif orthogonal_method == "channel": |
| return x + _orthogonal_channel(x, f_x, dim, eps) |
| else: |
| raise ValueError(f"unknown orthogonal method: {method}") |
| else: |
| raise ValueError(f"unknown connect method: {method}") |
|
|
|
|