|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class LoRALinear(nn.Module): |
|
|
""" |
|
|
LoRA layer: Low-Rank Adaptation. |
|
|
This layer consists of a low-rank decomposition of weight updates. |
|
|
""" |
|
|
def __init__(self, in_features, out_features, r=8, alpha=1.0, dropout=0.1, **kwargs): |
|
|
super(LoRALinear, self).__init__() |
|
|
|
|
|
self.use_si = False |
|
|
self.multi_domain = 0 |
|
|
if 'use_si' in kwargs.keys(): |
|
|
self.model = LoRALinearSI( |
|
|
in_features, out_features, r, alpha, **kwargs |
|
|
) |
|
|
self.use_si = True |
|
|
elif 'multi_domain' in kwargs.keys(): |
|
|
self.r = r |
|
|
self.alpha = alpha |
|
|
self.multi_domain = kwargs['multi_domain'] |
|
|
a_list, b_list, drop_list = [], [], [] |
|
|
for i in range(self.multi_domain): |
|
|
a_list.append(nn.Linear(in_features, r, bias=False)) |
|
|
b_list.append(nn.Linear(r, out_features, bias=False)) |
|
|
drop_list.append(nn.Dropout(dropout)) |
|
|
self.A = nn.ModuleList(a_list) |
|
|
self.B = nn.ModuleList(b_list) |
|
|
self.drop =nn.ModuleList(drop_list) |
|
|
self.scaling = alpha / r |
|
|
self._init_weights() |
|
|
else: |
|
|
self.r = r |
|
|
self.alpha = alpha |
|
|
|
|
|
|
|
|
self.A = nn.Linear(in_features, r, bias=False) |
|
|
self.drop = nn.Dropout(dropout) |
|
|
self.B = nn.Linear(r, out_features, bias=False) |
|
|
|
|
|
nn.init.zeros_(self.B.weight) |
|
|
nn.init.normal_(self.A.weight, std=1 / r) |
|
|
self.lora_name = "lora_layer" |
|
|
|
|
|
|
|
|
self.scaling = alpha / r |
|
|
|
|
|
def _init_weights(self): |
|
|
for layer in self.A: |
|
|
nn.init.normal_(layer.weight, std=1 / self.r) |
|
|
for layer in self.B: |
|
|
nn.init.zeros_(layer.weight) |
|
|
|
|
|
def forward(self, x, task_mask=None, i=None,task_idx=None): |
|
|
|
|
|
if self.use_si: |
|
|
return self.model(x) |
|
|
return self.scaling * self.B(self.drop(self.A(x))) |
|
|
|
|
|
def update_si_information(self): |
|
|
if self.use_si: |
|
|
self.model.update_si_information() |
|
|
|
|
|
def finalize_si_importance(self): |
|
|
if self.use_si: |
|
|
self.model.finalize_si_importance() |
|
|
|
|
|
|
|
|
class BayesianLinear(nn.Module): |
|
|
def __init__(self, in_features, out_features, r=8, prior_std=0.1, dropout=0.1, **kwargs): |
|
|
""" |
|
|
Bayesian LoRA Layer: Instead of deterministic weights, |
|
|
it learns a distribution over LoRA parameters using Bayesian inference. |
|
|
|
|
|
Args: |
|
|
in_features (int): Input dimension. |
|
|
out_features (int): Output dimension. |
|
|
rank (int): LoRA rank. |
|
|
prior_std (float): Standard deviation of the Gaussian prior. |
|
|
""" |
|
|
super(BayesianLinear, self).__init__() |
|
|
|
|
|
|
|
|
self.scaling = 1 / r |
|
|
|
|
|
self.A_mu = nn.Parameter(torch.randn(in_features, r) * (1 / r)) |
|
|
self.A_logvar = nn.Parameter(torch.randn(in_features, r) * (1 / r)) |
|
|
|
|
|
self.B_mu = nn.Parameter(torch.randn(r, out_features) * (1 / r)) |
|
|
self.B_logvar = nn.Parameter(torch.randn(r, out_features) * (1 / r)) |
|
|
|
|
|
self.drop = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.prior_std = prior_std |
|
|
|
|
|
|
|
|
def sample_weights(self): |
|
|
""" |
|
|
Reparameterization Trick: Sample weights from Gaussian distribution. |
|
|
""" |
|
|
A_std = torch.exp(0.5 * self.A_logvar) |
|
|
B_std = torch.exp(0.5 * self.B_logvar) |
|
|
|
|
|
|
|
|
B_sample = self.B_mu + B_std * torch.randn_like(B_std) |
|
|
A_sample = self.A_mu + A_std * torch.randn_like(A_std) |
|
|
|
|
|
return A_sample, B_sample |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass with Bayesian weight sampling. |
|
|
""" |
|
|
if self.training: |
|
|
A, B = self.sample_weights() |
|
|
else: |
|
|
A, B = self.A_mu, self.B_mu |
|
|
|
|
|
out = self.drop(x @ A) |
|
|
return out @ B |
|
|
|
|
|
class LoRALinearSI(nn.Module): |
|
|
def __init__(self, in_features, out_features, r=8, |
|
|
alpha=1.0, lambda_si=0.1, si_decay=0.99, dropout=0.1, |
|
|
plasticity_base=0.5, sparsity_threshold=1e-3): |
|
|
super().__init__() |
|
|
self.r = r |
|
|
self.alpha = alpha |
|
|
self.lambda_si = lambda_si |
|
|
self.si_decay = si_decay |
|
|
self.plasticity_base = plasticity_base |
|
|
self.sparsity_threshold = sparsity_threshold |
|
|
|
|
|
|
|
|
self.lora_A = nn.Parameter(torch.randn(in_features, r)) |
|
|
self.lora_B = nn.Parameter(torch.randn(r, out_features)) |
|
|
self.drop = nn.Dropout(dropout) |
|
|
|
|
|
nn.init.zeros_(self.lora_B) |
|
|
nn.init.normal_(self.lora_A, std=1 / r) |
|
|
|
|
|
|
|
|
self.register_buffer("omega_A", torch.zeros_like(self.lora_A)) |
|
|
self.register_buffer("omega_B", torch.zeros_like(self.lora_B)) |
|
|
self.register_buffer("prev_params_A", self.lora_A.clone().detach()) |
|
|
self.register_buffer("prev_params_B", self.lora_B.clone().detach()) |
|
|
self.register_buffer("trajectory_A", torch.zeros_like(self.lora_A)) |
|
|
self.register_buffer("trajectory_B", torch.zeros_like(self.lora_B)) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
adaptive_alpha = self.alpha |
|
|
lora_update = torch.matmul(x, self.lora_A) |
|
|
lora_update = self.drop(lora_update) |
|
|
lora_update = torch.matmul(lora_update, self.lora_B) |
|
|
return adaptive_alpha * lora_update |
|
|
|
|
|
def update_si_information(self): |
|
|
"""Update Synaptic Intelligence importance online.""" |
|
|
if self.lora_A.grad is not None: |
|
|
delta_theta_A = self.lora_A - self.prev_params_A |
|
|
self.trajectory_A += delta_theta_A * self.lora_A.grad |
|
|
self.prev_params_A = self.lora_A.detach().clone() |
|
|
|
|
|
if self.lora_B.grad is not None: |
|
|
delta_theta_B = self.lora_B - self.prev_params_B |
|
|
self.trajectory_B += delta_theta_B * self.lora_B.grad |
|
|
self.prev_params_B = self.lora_B.detach().clone() |
|
|
|
|
|
def compute_sparsity(self, param): |
|
|
"""Compute the sparsity score: fraction of near-zero values.""" |
|
|
return torch.mean((torch.abs(param) < self.sparsity_threshold).float()) |
|
|
|
|
|
def finalize_si_importance(self): |
|
|
"""Compute final importance after training a task and adjust plasticity.""" |
|
|
self.omega_A = self.si_decay * self.omega_A + (1 - self.si_decay) * (self.trajectory_A / (self.lora_A**2 + 1e-6)).detach() |
|
|
self.omega_B = self.si_decay * self.omega_B + (1 - self.si_decay) * (self.trajectory_B / (self.lora_B**2 + 1e-6)).detach() |
|
|
self.trajectory_A.zero_() |
|
|
self.trajectory_B.zero_() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def si_loss(self): |
|
|
"""Compute the SI loss term for both LoRA parameters.""" |
|
|
loss_A = torch.sum(self.omega_A * (self.lora_A - self.prev_params_A) ** 2) |
|
|
loss_B = torch.sum(self.omega_B * (self.lora_B - self.prev_params_B) ** 2) |
|
|
return self.lambda_si * (loss_A + loss_B) |
|
|
|
|
|
def set_plasticity(self, value: float): |
|
|
"""Manually set a global plasticity value if needed.""" |
|
|
self.plasticity.fill_(value) |
|
|
|
|
|
|
|
|
|
|
|
class MOELoRALinear(nn.Module): |
|
|
""" |
|
|
LoRA layer: Low-Rank Adaptation. |
|
|
This layer consists of a low-rank decomposition of weight updates. |
|
|
""" |
|
|
def __init__(self, in_features, out_features, r=8, alpha=1.0, dropout=0.1, num_task=3, **kwargs): |
|
|
super(MOELoRALinear, self).__init__() |
|
|
|
|
|
self.loras = nn.ModuleList([ |
|
|
LoRALinear( |
|
|
in_features, |
|
|
out_features, |
|
|
r, alpha, dropout, **kwargs) for _ in range(num_task) |
|
|
]) |
|
|
self.num_task=num_task |
|
|
|
|
|
def forward(self, x, i): |
|
|
if isinstance(i, int): |
|
|
return self.loras[i](x) |
|
|
elif i.dtype == torch.float: |
|
|
orig_shape = x.shape |
|
|
b = orig_shape[0] |
|
|
new_shape = (b//self.num_task, self.num_task) + orig_shape[1:] |
|
|
x = x.reshape(new_shape) |
|
|
mask_shape = i.shape + (1,)*len(orig_shape[1:]) |
|
|
i = i.reshape(mask_shape) |
|
|
res_list = torch.stack([ |
|
|
self.loras[t](x[:, t]) for t in range(self.num_task) |
|
|
], dim=1) |
|
|
res_list = res_list * i |
|
|
res_list = res_list.reshape(orig_shape[:-1]+(-1,)) |
|
|
return res_list |
|
|
|
|
|
res_list = torch.stack([ |
|
|
self.loras[t](x) for t in range(self.num_task) |
|
|
], dim=1) |
|
|
|
|
|
b = res_list.shape[0] |
|
|
res = res_list[torch.arange(b), i] |
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
|
|
|
class ZeroAdapter(nn.Module): |
|
|
""" |
|
|
LoRA layer: Low-Rank Adaptation. |
|
|
This layer consists of multiple LoRA mitigating catastrophic forgetting |
|
|
""" |
|
|
def __init__(self, in_features, out_feature, dropout=0.1, **kwargs): |
|
|
super(ZeroAdapter, self).__init__() |
|
|
mid_feature = in_features // 2 |
|
|
self.down_linear = nn.Linear(in_features, mid_feature) |
|
|
self.up_linear = nn.Linear(mid_feature, out_feature) |
|
|
|
|
|
nn.init.zeros_(self.down_linear.weight) |
|
|
nn.init.zeros_(self.down_linear.bias) |
|
|
|
|
|
nn.init.zeros_(self.up_linear.weight) |
|
|
nn.init.zeros_(self.up_linear.bias) |
|
|
|
|
|
self.act = nn.ReLU() |
|
|
self.drop = nn.Dropout(dropout) |
|
|
self.lora_name = "lora_layer" |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
x = self.down_linear(x) |
|
|
x = self.drop(self.act(x)) |
|
|
x = self.up_linear(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class LoRAMoECLAdapter(nn.Module): |
|
|
def __init__(self, in_features, mid_feature, out_feature, |
|
|
num_task=6, r=8, alpha=1.0, dropout=0.1, **kwargs): |
|
|
super(LoRAMoECLAdapter, self).__init__() |
|
|
self.r = r |
|
|
self.alpha = alpha |
|
|
self.num_task = num_task |
|
|
|
|
|
self.adapters = nn.ModuleList([ |
|
|
nn.Sequential( |
|
|
LoRALinear(in_features, mid_feature, r, alpha, dropout), |
|
|
nn.Dropout(dropout), |
|
|
nn.ReLU(), |
|
|
LoRALinear(mid_feature, out_feature, r, alpha, dropout), |
|
|
) |
|
|
for _ in range(num_task) |
|
|
]) |
|
|
|
|
|
self.router = nn.Linear(in_features, num_task) |
|
|
self.out_drop = nn.Dropout(dropout) |
|
|
|
|
|
self.lora_name = "lora_layer" |
|
|
|
|
|
def forward(self, x, i=None): |
|
|
outputs = [] |
|
|
logits = self.router(x) |
|
|
route_prob = logits.softmax(-1) |
|
|
|
|
|
for i in range(self.num_task): |
|
|
outputs.append(self.adapters[i](x)) |
|
|
outputs = torch.stack(outputs, dim=-2) |
|
|
outputs = torch.sum(outputs * route_prob[..., None], dim=-2) |
|
|
outputs = self.out_drop(outputs) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
class LoRACLAdapter(nn.Module): |
|
|
""" |
|
|
LoRA layer: Low-Rank Adaptation. |
|
|
This layer consists of multiple LoRA mitigating catastrophic forgetting |
|
|
""" |
|
|
def __init__(self, in_features, out_feature, |
|
|
num_task=6, r=8, alpha=1.0, dropout=0.1, **kwargs): |
|
|
super(LoRACLAdapter, self).__init__() |
|
|
self.r = r |
|
|
self.alpha = alpha |
|
|
|
|
|
self.loras = nn.ModuleList([ |
|
|
LoRALinear(in_features, out_feature, r, alpha, dropout) for _ in range(num_task) |
|
|
]) |
|
|
|
|
|
self.attn_weights = nn.ModuleList([nn.Linear(out_feature, 1) for _ in range(num_task)]) |
|
|
self.attn_drop = nn.Dropout(dropout) |
|
|
|
|
|
self.num_task = num_task |
|
|
|
|
|
|
|
|
self.scaling = alpha / r |
|
|
self.lora_name = "lora_layer" |
|
|
|
|
|
def forward(self, x, task_mask=None): |
|
|
|
|
|
|
|
|
|
|
|
assert task_mask is not None |
|
|
|
|
|
outputs = [] |
|
|
output_weights = [] |
|
|
|
|
|
for i in range(self.num_task): |
|
|
out = self.loras[i](x) |
|
|
weight_out = self.attn_weights[i](out) |
|
|
outputs.append(out) |
|
|
output_weights.append(weight_out) |
|
|
|
|
|
outputs = torch.cat(outputs, dim=1) |
|
|
output_weights = torch.cat(output_weights, dim=1) |
|
|
output_weights = output_weights.softmax(1) |
|
|
outputs = outputs * self.attn_drop(output_weights) |
|
|
|
|
|
|
|
|
task_mask = task_mask[0] |
|
|
task_mask = task_mask.unsqueeze(-1).expand(outputs.shape[0], -1, outputs.shape[2]) |
|
|
|
|
|
outputs[task_mask==0] = outputs[task_mask==0].detach() |
|
|
outputs = outputs.sum(1) |
|
|
return outputs[:, None] |
|
|
|
|
|
|
|
|
valid_lora_list = (LoRALinear, LoRACLAdapter, ZeroAdapter, LoRAMoECLAdapter, MOELoRALinear) |
|
|
|
|
|
|
|
|
def lora_wrapper( |
|
|
module, |
|
|
LoraLayer=LoRALinear, |
|
|
rank=8, alpha=1.0, dropout=0.1, |
|
|
num_task=6, |
|
|
**kwargs): |
|
|
""" |
|
|
Creates a separate LoRA module that mirrors the Linear layers in the original model. |
|
|
""" |
|
|
if isinstance(module, nn.ModuleList): |
|
|
lora_module = nn.ModuleList() |
|
|
for m in module: |
|
|
lora_module.append(lora_wrapper( |
|
|
m, LoraLayer, |
|
|
rank=rank, alpha=alpha, dropout=dropout,num_task=num_task |
|
|
)) |
|
|
return lora_module |
|
|
|
|
|
if isinstance(module, nn.ModuleDict): |
|
|
lora_module = nn.ModuleDict() |
|
|
for k,v in module.items(): |
|
|
lora_module[f'lora_{k}'] = lora_wrapper( |
|
|
v, LoraLayer, |
|
|
rank=rank, alpha=alpha, dropout=dropout,num_task=num_task |
|
|
) |
|
|
return lora_module |
|
|
|
|
|
if len(list(module.named_modules())) == 1 : |
|
|
if not isinstance(module, nn.Linear): |
|
|
print(f'Wrap non nn.Linear unit{type(module)}, skipping with Identity') |
|
|
return nn.Identity() |
|
|
lora_module = LoraLayer(module.in_features, module.out_features, |
|
|
r=rank, alpha=alpha,dropout=dropout, num_task=num_task,**kwargs) |
|
|
return lora_module |
|
|
|
|
|
|
|
|
|
|
|
lora_module = nn.Sequential() |
|
|
|
|
|
for name, child in module.named_children(): |
|
|
if isinstance(child, nn.Linear): |
|
|
lora_layer = LoraLayer(child.in_features, child.out_features, |
|
|
r=rank, alpha=alpha,dropout=dropout, num_task=num_task, **kwargs) |
|
|
lora_module.add_module(f'lora_{name}', lora_layer) |
|
|
elif isinstance(child, nn.Sequential): |
|
|
lora_module.add_module(f'lora_{name}', |
|
|
lora_wrapper(child, |
|
|
LoraLayer, |
|
|
rank=rank, alpha=alpha, dropout=dropout,num_task=num_task, |
|
|
) |
|
|
) |
|
|
else: |
|
|
lora_module.add_module(f'lora_{name}', nn.Identity()) |
|
|
|
|
|
return lora_module |
|
|
|
|
|
def single_peft_forward(x, model, lora_model, lora_only=False, idx=None): |
|
|
if lora_only: |
|
|
return lora_model(x, i=idx) |
|
|
return model(x) + lora_model(x, i=idx) |
|
|
|
|
|
|
|
|
def peft_wrapper_forward(x, model, lora_model, use_lora=True, |
|
|
layer_idx=-1, layer_name="", lora_only=False, task_idx=None): |
|
|
""" |
|
|
Custom forward function to combine original model output with LoRA output. |
|
|
layer_idx: can be specified for (nn.ModuleList) model; Default: running sequentially through whole ModuleList |
|
|
layer_name: can be specified for (nn.ModuleDict) model; Default:running sequentially through whole ModuleDict |
|
|
lora_only: if lora_only=True, forward function will only pass through the lora layer when meet with matched Linear |
|
|
""" |
|
|
if isinstance(model, nn.ModuleList): |
|
|
if layer_idx > -1: |
|
|
return single_peft_forward(x, model[layer_idx], lora_model[layer_idx], lora_only, task_idx) |
|
|
|
|
|
if isinstance(model, nn.ModuleDict): |
|
|
if layer_name != "": |
|
|
return single_peft_forward(x, model[layer_name], lora_model[layer_name], lora_only, task_idx) |
|
|
|
|
|
if len(list(model.named_modules())) == 1: |
|
|
return single_peft_forward(x, model, lora_model, lora_only, task_idx) |
|
|
|
|
|
def process_layer(orig_layer, lora_layer, x): |
|
|
""" Recursively process nested nn.Sequential layers """ |
|
|
if isinstance(orig_layer, nn.Sequential) and isinstance(lora_layer, nn.Sequential): |
|
|
for o_layer, l_layer in zip(orig_layer.children(), lora_layer.children()): |
|
|
x = process_layer(o_layer, l_layer, x) |
|
|
return x |
|
|
else: |
|
|
if use_lora and not isinstance(lora_layer, nn.Identity): |
|
|
return single_peft_forward(x, orig_layer, lora_layer, lora_only, task_idx) |
|
|
else: |
|
|
return orig_layer(x) |
|
|
|
|
|
for orig_layer, lora_layer in zip(model.children(), lora_model.children()): |
|
|
x = process_layer(orig_layer, lora_layer, x) |
|
|
|
|
|
return x |
|
|
|
|
|
def finetuning_detach(model): |
|
|
''' |
|
|
work with a detach for customed layer |
|
|
ensure if some sublayer inside containing such LoRA layer |
|
|
or adapter with "lora_name" attribute, |
|
|
also have this finetuning function and lora_name attr |
|
|
''' |
|
|
for name, module in model.named_modules(): |
|
|
if 'lora' in name: |
|
|
for param in module.parameters(): |
|
|
param.requires_grad = True |
|
|
else: |
|
|
for param in module.parameters(): |
|
|
param.requires_grad = False |
|
|
if isinstance(module, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)): |
|
|
module.eval() |
|
|
|
|
|
def frozen_grad(model): |
|
|
for param in model.parameters(): |
|
|
param.requires_grad = False |
|
|
return model |
|
|
|
|
|
|
|
|
class TestModule(nn.Module): |
|
|
def __init__(self): |
|
|
super(TestModule, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model = nn.ModuleDict() |
|
|
for i in range(3): |
|
|
self.model[f'{i}'] = nn.Linear(10,10) |
|
|
self.lora_layer = lora_wrapper( |
|
|
self.model, |
|
|
ZeroAdapter, |
|
|
rank=4, alpha=1.0) |
|
|
|
|
|
def forward(self, x): |
|
|
x = peft_wrapper_forward(x, self.model, self.lora_layer) |
|
|
return x |
|
|
|
|
|
def retreive_bayesian_lora_param(module): |
|
|
''' |
|
|
input, any nn.Module |
|
|
searching for all Bayesian Lora param |
|
|
return: lora_dict: Dict[sub_name: Dict['A_mu','B_mu','A_logvar','B_logvar']] |
|
|
''' |
|
|
lora_dict = {} |
|
|
lora_list = set(['A_mu','B_mu','A_logvar','B_logvar']) |
|
|
if isinstance(module, BayesianLinear): |
|
|
lora_dict['.'] = dict() |
|
|
for name,m in module.named_parameters(): |
|
|
lora_dict['.'][name] = m |
|
|
return lora_dict |
|
|
|
|
|
for name,m in module.named_parameters(): |
|
|
name_list = name.split('.') |
|
|
if name_list[-2] in lora_list: |
|
|
m_prefix = ".".join(name_list[:-2]) |
|
|
if m_prefix not in lora_dict: |
|
|
lora_dict[m_prefix] = dict() |
|
|
lora_dict[m_prefix][name.split('.')[-1]] = m |
|
|
return lora_dict |
|
|
|
|
|
|
|
|
|
|
|
def test_lora_si(): |
|
|
from time import time |
|
|
import numpy as np |
|
|
|
|
|
lora_model = LoRALinearSI( |
|
|
256, 256, 16 |
|
|
) |
|
|
t = [] |
|
|
for _ in range(10): |
|
|
s = time() |
|
|
x = torch.randn(2, 256) |
|
|
y = lora_model(x) |
|
|
loss = lora_model.si_loss() |
|
|
t.append(time()-s) |
|
|
print(loss, np.mean(t)) |
|
|
|
|
|
def test_kl_lora(): |
|
|
lora_layer = BayesianLinear( |
|
|
32, 32, r=8 |
|
|
) |
|
|
inputs = torch.randn(4, 10, 32) |
|
|
out = lora_layer(inputs) |
|
|
|
|
|
bayesian_params = retreive_bayesian_lora_param(lora_layer) |
|
|
loss = 0. |
|
|
for v_dict in bayesian_params.values(): |
|
|
print(v_dict.keys()) |
|
|
B_std = torch.exp(0.5 * v_dict['B_logvar']) |
|
|
A_std = torch.exp(0.5 * v_dict['A_logvar']) |
|
|
|
|
|
kl_B = (v_dict['B_mu']**2 + B_std**2 - 2 * torch.log(B_std) - 1).sum() |
|
|
kl_A = (v_dict['A_mu']**2 + A_std**2 - 2 * torch.log(A_std) - 1).sum() |
|
|
|
|
|
module_loss = 0.5 * (kl_B + kl_A) |
|
|
loss += module_loss |
|
|
|
|
|
print(out.shape, loss) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_kl_lora() |
|
|
|