| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .general import rebuild_tucker, FUNC_LIST |
| from .general import factorization |
|
|
|
|
| def make_kron(w1, w2, scale): |
| for _ in range(w2.dim() - w1.dim()): |
| w1 = w1.unsqueeze(-1) |
| w2 = w2.contiguous() |
| rebuild = torch.kron(w1, w2) |
|
|
| if scale != 1: |
| rebuild = rebuild * scale |
|
|
| return rebuild |
|
|
|
|
| def weight_gen( |
| org_weight, |
| rank, |
| tucker=True, |
| factor=-1, |
| decompose_both=False, |
| full_matrix=False, |
| unbalanced_factorization=False, |
| ): |
| """### weight_gen |
| |
| Args: |
| org_weight (torch.Tensor): the weight tensor |
| rank (int): low rank |
| |
| Returns: |
| torch.Tensor | None: w1, w1a, w1b, w2, w2a, w2b, t2 |
| """ |
| out_dim, in_dim, *k = org_weight.shape |
| w1 = w1a = w1b = None |
| w2 = w2a = w2b = None |
| t2 = None |
| use_w1 = use_w2 = False |
|
|
| if k: |
| k_size = k |
| shape = (out_dim, in_dim, *k_size) |
|
|
| in_m, in_n = factorization(in_dim, factor) |
| out_l, out_k = factorization(out_dim, factor) |
| if unbalanced_factorization: |
| out_l, out_k = out_k, out_l |
| shape = ((out_l, out_k), (in_m, in_n), *k_size) |
| tucker = tucker and any(i != 1 for i in k_size) |
| if ( |
| decompose_both |
| and rank < max(shape[0][0], shape[1][0]) / 2 |
| and not full_matrix |
| ): |
| w1a = torch.empty(shape[0][0], rank) |
| w1b = torch.empty(rank, shape[1][0]) |
| else: |
| use_w1 = True |
| w1 = torch.empty(shape[0][0], shape[1][0]) |
|
|
| if rank >= max(shape[0][1], shape[1][1]) / 2 or full_matrix: |
| use_w2 = True |
| w2 = torch.empty(shape[0][1], shape[1][1], *k_size) |
| elif tucker: |
| t2 = torch.empty(rank, rank, *shape[2:]) |
| w2a = torch.empty(rank, shape[0][1]) |
| w2b = torch.empty(rank, shape[1][1]) |
| else: |
| |
| w2a = torch.empty(shape[0][1], rank) |
| w2b = torch.empty(rank, shape[1][1], *shape[2:]) |
| |
| else: |
| shape = (out_dim, in_dim) |
|
|
| in_m, in_n = factorization(in_dim, factor) |
| out_l, out_k = factorization(out_dim, factor) |
| if unbalanced_factorization: |
| out_l, out_k = out_k, out_l |
| shape = ( |
| (out_l, out_k), |
| (in_m, in_n), |
| ) |
| |
| if decompose_both and rank < max(shape[0][0], shape[1][0]) / 2: |
| w1a = torch.empty(shape[0][0], rank) |
| w1b = torch.empty(rank, shape[1][0]) |
| else: |
| use_w1 = True |
| w1 = torch.empty(shape[0][0], shape[1][0]) |
| if rank < max(shape[0][1], shape[1][1]) / 2: |
| |
| w2a = torch.empty(shape[0][1], rank) |
| w2b = torch.empty(rank, shape[1][1]) |
| |
| else: |
| use_w2 = True |
| w2 = torch.empty(shape[0][1], shape[1][1]) |
|
|
| if use_w2: |
| torch.nn.init.constant_(w2, 1) |
| else: |
| if tucker: |
| torch.nn.init.kaiming_uniform_(t2, a=math.sqrt(5)) |
| torch.nn.init.kaiming_uniform_(w2a, a=math.sqrt(5)) |
| torch.nn.init.constant_(w2b, 1) |
|
|
| if use_w1: |
| torch.nn.init.kaiming_uniform_(w1, a=math.sqrt(5)) |
| else: |
| torch.nn.init.kaiming_uniform_(w1a, a=math.sqrt(5)) |
| torch.nn.init.kaiming_uniform_(w1b, a=math.sqrt(5)) |
|
|
| return w1, w1a, w1b, w2, w2a, w2b, t2 |
|
|
|
|
| def diff_weight(*weights, gamma=1.0): |
| """### diff_weight |
| |
| Args: |
| weights (tuple[torch.Tensor]): (w1, w1a, w1b, w2, w2a, w2b, t) |
| gamma (float, optional): scale factor, normally alpha/rank here |
| |
| Returns: |
| torch.Tensor: ΔW |
| """ |
| w1, w1a, w1b, w2, w2a, w2b, t = weights |
| if w1a is not None: |
| rank = w1a.shape[1] |
| elif w2a is not None: |
| rank = w2a.shape[1] |
| else: |
| rank = gamma |
| scale = gamma / rank |
| if w1 is None: |
| w1 = w1a @ w1b |
| if w2 is None: |
| if t is None: |
| r, o, *k = w2b.shape |
| w2 = w2a @ w2b.view(r, -1) |
| w2 = w2.view(-1, o, *k) |
| else: |
| w2 = rebuild_tucker(t, w2a, w2b) |
| return make_kron(w1, w2, scale) |
|
|
|
|
| def bypass_forward_diff(h, org_out, *weights, gamma=1.0, extra_args={}): |
| """### bypass_forward_diff |
| |
| Args: |
| weights (tuple[torch.Tensor]): (w1, w1a, w1b, w2, w2a, w2b, t) |
| gamma (float, optional): scale factor, normally alpha/rank here |
| extra_args (dict, optional): extra args for forward func, \ |
| e.g. padding, stride for Conv1/2/3d |
| |
| Returns: |
| torch.Tensor: output tensor |
| """ |
| w1, w1a, w1b, w2, w2a, w2b, t = weights |
| use_w1 = w1 is not None |
| use_w2 = w2 is not None |
| tucker = t is not None |
| dim = t.dim() if tucker else w2.dim() if w2 is not None else w2b.dim() |
| rank = w1b.size(0) if not use_w1 else w2b.size(0) if not use_w2 else gamma |
| scale = gamma / rank |
| is_conv = dim > 2 |
| op = FUNC_LIST[dim] |
|
|
| if is_conv: |
| kw_dict = extra_args |
| else: |
| kw_dict = {} |
|
|
| if use_w2: |
| ba = w2 |
| else: |
| a = w2b |
| b = w2a |
|
|
| if t is not None: |
| a = a.view(*a.shape, *[1] * (dim - 2)) |
| b = b.view(*b.shape, *[1] * (dim - 2)) |
| elif is_conv: |
| b = b.view(*b.shape, *[1] * (dim - 2)) |
|
|
| if use_w1: |
| c = w1 |
| else: |
| c = w1a @ w1b |
| uq = c.size(1) |
|
|
| if is_conv: |
| |
| B, _, *rest = h.shape |
| h_in_group = h.reshape(B * uq, -1, *rest) |
| else: |
| |
| h_in_group = h.reshape(*h.shape[:-1], uq, -1) |
|
|
| if use_w2: |
| hb = op(h_in_group, ba, **kw_dict) |
| else: |
| if is_conv: |
| if tucker: |
| ha = op(h_in_group, a) |
| ht = op(ha, t, **kw_dict) |
| hb = op(ht, b) |
| else: |
| ha = op(h_in_group, a, **kw_dict) |
| hb = op(ha, b) |
| else: |
| ha = op(h_in_group, a, **kw_dict) |
| hb = op(ha, b) |
|
|
| if is_conv: |
| |
| |
| |
| hb = hb.view(B, -1, *hb.shape[1:]) |
| h_cross_group = hb.transpose(1, -1) |
| else: |
| |
| |
| h_cross_group = hb.transpose(-1, -2) |
|
|
| hc = F.linear(h_cross_group, c) |
| if is_conv: |
| |
| |
| |
| hc = hc.transpose(1, -1) |
| h = hc.reshape(B, -1, *hc.shape[3:]) |
| else: |
| |
| |
| |
| hc = hc.transpose(-1, -2) |
| h = hc.reshape(*hc.shape[:-2], -1) |
|
|
| return h * scale |
|
|