| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .general import factorization, FUNC_LIST |
|
|
|
|
| def get_r(oft_blocks, I=None, constraint=0): |
| if I is None: |
| I = torch.eye(oft_blocks.shape[-1], device=oft_blocks.device) |
| if I.ndim < oft_blocks.ndim: |
| for _ in range(oft_blocks.ndim - I.ndim): |
| I = I.unsqueeze(0) |
| |
| q = oft_blocks - oft_blocks.transpose(-1, -2) |
| normed_q = q |
| if constraint is not None and constraint > 0: |
| q_norm = torch.norm(q) + 1e-8 |
| if q_norm > constraint: |
| normed_q = q * constraint / q_norm |
| |
| r = (I + normed_q) @ (I - normed_q).float().inverse() |
| return r |
|
|
|
|
| def weight_gen(org_weight, max_block_size=-1, rescale=False): |
| """### weight_gen |
| |
| Args: |
| org_weight (torch.Tensor): the weight tensor |
| max_block_size (int): max block size |
| rescale (bool, optional): whether to rescale the weight. Defaults to False. |
| |
| Returns: |
| torch.Tensor: oft_blocks[, rescale_weight] |
| """ |
| out_dim, *rest = org_weight.shape |
| block_size, block_num = factorization(out_dim, max_block_size) |
| oft_blocks = torch.zeros(block_num, block_size, block_size) |
| if rescale: |
| return oft_blocks, torch.ones(out_dim, *[1] * len(rest)) |
| else: |
| return oft_blocks, None |
|
|
|
|
| def diff_weight(org_weight, *weights, constraint=None): |
| """### diff_weight |
| |
| Args: |
| org_weight (torch.Tensor): the weight tensor of original model |
| weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight]) |
| constraint (float, optional): constraint for oft |
| |
| Returns: |
| torch.Tensor: ΔW |
| """ |
| oft_blocks, rescale = weights |
| I = torch.eye(oft_blocks.shape[1], device=oft_blocks.device) |
| r = get_r(oft_blocks, I, constraint) |
|
|
| block_num, block_size, _ = oft_blocks.shape |
| _, *shape = org_weight.shape |
| org_weight = org_weight.to(dtype=r.dtype) |
| org_weight = org_weight.view(block_num, block_size, *shape) |
| |
| weight = torch.einsum( |
| "k n m, k n ... -> k m ...", |
| r - I, |
| org_weight, |
| ).view(-1, *shape) |
| if rescale is not None: |
| weight = rescale * weight |
| weight = weight + (rescale - 1) * org_weight |
| return weight |
|
|
|
|
| def bypass_forward_diff(x, org_out, *weights, constraint=None, need_transpose=False): |
| """### bypass_forward_diff |
| |
| Args: |
| x (torch.Tensor): the input tensor for original model |
| org_out (torch.Tensor): the output tensor from original model |
| weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight]) |
| constraint (float, optional): constraint for oft |
| need_transpose (bool, optional): |
| whether to transpose the input and output, |
| set to `True` if the original model have "dim" not in the last axis. |
| For example: Convolution layers |
| |
| Returns: |
| torch.Tensor: output tensor |
| """ |
| oft_blocks, rescale = weights |
| block_num, block_size, _ = oft_blocks.shape |
| I = torch.eye(block_size, device=oft_blocks.device) |
| r = get_r(oft_blocks, I, constraint) |
| if need_transpose: |
| org_out = org_out.transpose(1, -1) |
| org_out = org_out.to(dtype=r.dtype) |
| *shape, _ = org_out.shape |
| oft_out = torch.einsum( |
| "k n m, ... k n -> ... k m", r - I, org_out.view(*shape, block_num, block_size) |
| ) |
| out = oft_out.view(*shape, -1) |
| if rescale is not None: |
| out = rescale.transpose(-1, 0) * out |
| out = out + (rescale - 1).transpose(-1, 0) * org_out |
| if need_transpose: |
| out = out.transpose(1, -1) |
| return out |
|
|