| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from einops import rearrange |
| |
|
| | from .general import power2factorization, FUNC_LIST |
| | from .diag_oft import get_r |
| |
|
| |
|
| | def weight_gen(org_weight, max_block_size, boft_m=-1, rescale=False): |
| | """### boft_weight_gen |
| | |
| | Args: |
| | org_weight (torch.Tensor): the weight tensor |
| | max_block_size (int): max block size |
| | rescale (bool, optional): whether to rescale the weight. Defaults to False. |
| | |
| | Returns: |
| | torch.Tensor: oft_blocks[, rescale_weight] |
| | """ |
| | out_dim, *rest = org_weight.shape |
| | block_size, block_num = power2factorization(out_dim, max_block_size) |
| | max_boft_m = sum(int(i) for i in f"{block_num-1:b}") + 1 |
| | if boft_m == -1: |
| | boft_m = max_boft_m |
| | boft_m = min(boft_m, max_boft_m) |
| | oft_blocks = torch.zeros(boft_m, block_num, block_size, block_size) |
| | if rescale is not None: |
| | return oft_blocks, torch.ones(out_dim, *[1] * len(rest)) |
| | else: |
| | return oft_blocks, None |
| |
|
| |
|
| | def diff_weight(org_weight, *weights, constraint=None): |
| | """### boft_diff_weight |
| | |
| | Args: |
| | org_weight (torch.Tensor): the weight tensor of original model |
| | weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight]) |
| | constraint (float, optional): constraint for oft |
| | |
| | Returns: |
| | torch.Tensor: ΔW |
| | """ |
| | oft_blocks, rescale = weights |
| | m, num, b, _ = oft_blocks.shape |
| | r_b = b // 2 |
| | I = torch.eye(b, device=oft_blocks.device) |
| | r = get_r(oft_blocks, I, constraint) |
| | inp = org = org_weight.to(dtype=r.dtype) |
| |
|
| | for i in range(m): |
| | bi = r[i] |
| | g = 2 |
| | k = 2**i * r_b |
| | inp = ( |
| | inp.unflatten(-1, (-1, g, k)) |
| | .transpose(-2, -1) |
| | .flatten(-3) |
| | .unflatten(-1, (-1, b)) |
| | ) |
| | inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp) |
| | inp = inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) |
| |
|
| | if rescale is not None: |
| | inp = inp * rescale |
| |
|
| | return inp - org |
| |
|
| |
|
| | def bypass_forward_diff(org_out, *weights, constraint=None, need_transpose=False): |
| | """### boft_bypass_forward_diff |
| | |
| | Args: |
| | x (torch.Tensor): the input tensor for original model |
| | org_out (torch.Tensor): the output tensor from original model |
| | weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight]) |
| | constraint (float, optional): constraint for oft |
| | need_transpose (bool, optional): |
| | whether to transpose the input and output, |
| | set to `True` if the original model have "dim" not in the last axis. |
| | For example: Convolution layers |
| | |
| | Returns: |
| | torch.Tensor: output tensor |
| | """ |
| | oft_blocks, rescale = weights |
| | m, num, b, _ = oft_blocks.shape |
| | r_b = b // 2 |
| | I = torch.eye(b, device=oft_blocks.device) |
| | r = get_r(oft_blocks, I, constraint) |
| | inp = org = org_out.to(dtype=r.dtype) |
| | if need_transpose: |
| | inp = org = inp.transpose(1, -1) |
| |
|
| | for i in range(m): |
| | bi = r[i] |
| | g = 2 |
| | k = 2**i * r_b |
| | |
| | |
| | inp = ( |
| | inp.unflatten(-1, (-1, g, k)) |
| | .transpose(-2, -1) |
| | .flatten(-3) |
| | .unflatten(-1, (-1, b)) |
| | ) |
| | inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp) |
| | |
| | |
| | inp = inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) |
| |
|
| | if rescale is not None: |
| | inp = inp * rescale.transpose(0, -1) |
| |
|
| | inp = inp - org |
| | if need_transpose: |
| | inp = inp.transpose(1, -1) |
| | return inp |
| |
|