File size: 698 Bytes
ff53311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from typing import Dict, List


def merge_lora_weight(tensors_A, tensors_B):
    lora_A = torch.concat(tensors_A, dim=0)
    lora_B = torch.concat(tensors_B, dim=1)
    return lora_A, lora_B


def merge_lora(loras: List[Dict[str, torch.Tensor]], alpha=1):
    lora_merged = {}
    keys = [i for i in loras[0].keys() if ".lora_A." in i]
    for key in keys:
        tensors_A = [lora[key] for lora in loras]
        tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras]
        lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B)
        lora_merged[key] = lora_A * alpha
        lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B
    return lora_merged