| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from .general import FUNC_LIST |
| |
|
| |
|
| | class HadaWeight(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, w1d, w1u, w2d, w2u, scale=torch.tensor(1)): |
| | ctx.save_for_backward(w1d, w1u, w2d, w2u, scale) |
| | diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale |
| | return diff_weight |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_out): |
| | (w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors |
| | grad_out = grad_out * scale |
| | temp = grad_out * (w2u @ w2d) |
| | grad_w1u = temp @ w1d.T |
| | grad_w1d = w1u.T @ temp |
| |
|
| | temp = grad_out * (w1u @ w1d) |
| | grad_w2u = temp @ w2d.T |
| | grad_w2d = w2u.T @ temp |
| |
|
| | del temp |
| | return grad_w1d, grad_w1u, grad_w2d, grad_w2u, None |
| |
|
| |
|
| | class HadaWeightTucker(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, t1, w1d, w1u, t2, w2d, w2u, scale=torch.tensor(1)): |
| | ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale) |
| |
|
| | rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u) |
| | rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u) |
| |
|
| | return rebuild1 * rebuild2 * scale |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_out): |
| | (t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors |
| | grad_out = grad_out * scale |
| |
|
| | temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d) |
| | rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u) |
| |
|
| | grad_w = rebuild * grad_out |
| | del rebuild |
| |
|
| | grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) |
| | grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T) |
| | del grad_w, temp |
| |
|
| | grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp) |
| | grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T) |
| | del grad_temp |
| |
|
| | temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d) |
| | rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u) |
| |
|
| | grad_w = rebuild * grad_out |
| | del rebuild |
| |
|
| | grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) |
| | grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T) |
| | del grad_w, temp |
| |
|
| | grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp) |
| | grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T) |
| | del grad_temp |
| | return grad_t1, grad_w1d, grad_w1u, grad_t2, grad_w2d, grad_w2u, None |
| |
|
| |
|
| | def make_weight(w1d, w1u, w2d, w2u, scale): |
| | return HadaWeight.apply(w1d, w1u, w2d, w2u, scale) |
| |
|
| |
|
| | def make_weight_tucker(t1, w1d, w1u, t2, w2d, w2u, scale): |
| | return HadaWeightTucker.apply(t1, w1d, w1u, t2, w2d, w2u, scale) |
| |
|
| |
|
| | def weight_gen(org_weight, rank, tucker=True): |
| | """### weight_gen |
| | |
| | Args: |
| | org_weight (torch.Tensor): the weight tensor |
| | rank (int): low rank |
| | |
| | Returns: |
| | torch.Tensor: w1d, w2d, w1u, w2u[, t1, t2] |
| | """ |
| | out_dim, in_dim, *k = org_weight.shape |
| | if k and tucker: |
| | w1d = torch.empty(rank, in_dim) |
| | w1u = torch.empty(rank, out_dim) |
| | t1 = torch.empty(rank, rank, *k) |
| | w2d = torch.empty(rank, in_dim) |
| | w2u = torch.empty(rank, out_dim) |
| | t2 = torch.empty(rank, rank, *k) |
| | nn.init.normal_(t1, std=0.1) |
| | nn.init.normal_(t2, std=0.1) |
| | else: |
| | w1d = torch.empty(rank, in_dim) |
| | w1u = torch.empty(out_dim, rank) |
| | w2d = torch.empty(rank, in_dim) |
| | w2u = torch.empty(out_dim, rank) |
| | t1 = t2 = None |
| | nn.init.normal_(w1d, std=1) |
| | nn.init.constant_(w1u, 0) |
| | nn.init.normal_(w2d, std=1) |
| | nn.init.normal_(w2u, std=0.1) |
| | return w1d, w1u, w2d, w2u, t1, t2 |
| |
|
| |
|
| | def diff_weight(*weights, gamma=1.0): |
| | """### diff_weight |
| | |
| | Get ΔW = BA, where BA is low rank decomposition |
| | |
| | Args: |
| | wegihts (tuple[torch.Tensor]): (w1d, w2d, w1u, w2u[, t1, t2]) |
| | gamma (float, optional): scale factor, normally alpha/rank here |
| | |
| | Returns: |
| | torch.Tensor: ΔW |
| | """ |
| | w1d, w1u, w2d, w2u, t1, t2 = weights |
| | if t1 is not None and t2 is not None: |
| | R, I = w1d.shape |
| | R, O = w1u.shape |
| | R, R, *k = t1.shape |
| | result = make_weight_tucker(t1, w1d, w1u, t2, w2d, w2u, gamma) |
| | else: |
| | R, I, *k = w1d.shape |
| | O, R, *_ = w1u.shape |
| | w1d = w1d.reshape(w1d.size(0), -1) |
| | w1u = w1u.reshape(-1, w1u.size(1)) |
| | w2d = w2d.reshape(w2d.size(0), -1) |
| | w2u = w2u.reshape(-1, w2u.size(1)) |
| | result = make_weight(w1d, w1u, w2d, w2u, gamma) |
| |
|
| | result = result.reshape(O, I, *k) |
| | return result |
| |
|
| |
|
| | def bypass_forward_diff(x, org_out, *weights, gamma=1.0, extra_args={}): |
| | """### bypass_forward_diff |
| | |
| | Args: |
| | x (torch.Tensor): input tensor |
| | weights (tuple[torch.Tensor]): (w1d, w2d, w1u, w2u[, t1, t2]) |
| | gamma (float, optional): scale factor, normally alpha/rank here |
| | extra_args (dict, optional): extra args for forward func, \ |
| | e.g. padding, stride for Conv1/2/3d |
| | |
| | Returns: |
| | torch.Tensor: output tensor |
| | """ |
| | w1d, w1u, w2d, w2u, t1, t2 = weights |
| | diff_w = diff_weight(w1d, w1u, w2d, w2u, t1, t2, gamma) |
| | return FUNC_LIST[w1d.dim() if t1 is None else t1.dim()](x, diff_w, **extra_args) |
| |
|