File size: 2,160 Bytes
a090db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from transformers import AutoModelForSequenceClassification
from loraLinear import LoRALinear

MODEL_CKPT   = "distilbert-base-uncased"
RANK         = 4
ALPHA        = 4
DEVICE       = "cpu"                  # fine for Spaces; merge is fast

# Re-create the LoRA-injected architecture
lora_model = AutoModelForSequenceClassification.from_pretrained(MODEL_CKPT)
for blk in lora_model.distilbert.transformer.layer:
    blk.attention.q_lin = LoRALinear(blk.attention.q_lin, RANK, ALPHA)
    blk.attention.v_lin = LoRALinear(blk.attention.v_lin, RANK, ALPHA)

lora_model.load_state_dict(torch.load("DISTILBERT_WITH_LORA.pth", map_location=DEVICE))
lora_model.eval()

# Collapse each adapter:  W ← W + (B @ A)·scale
for blk in lora_model.distilbert.transformer.layer:
    for name in ("q_lin", "v_lin"):
        wrap = getattr(blk.attention, name)
        with torch.no_grad():
            base_W = wrap.original_layer.weight        # (out, in)
            A      = wrap.lora.loraA.weight             # (rank, in)
            B      = wrap.lora.loraB.weight             # (out, rank)
            base_W += (B @ A) * wrap.lora.scaling       # in-place update

# Copy the merged weights into a *plain* DistilBERT (no wrappers)
plain_model = AutoModelForSequenceClassification.from_pretrained(MODEL_CKPT)
with torch.no_grad():
    for i in range(6):
        plain_blk = plain_model.distilbert.transformer.layer[i]
        lora_blk  = lora_model.distilbert.transformer.layer[i]

        for lin in ("q_lin", "v_lin"):
            pl = getattr(plain_blk.attention, lin)
            lr = getattr(lora_blk.attention, lin).original_layer
            pl.weight.copy_(lr.weight)
            pl.bias.copy_(lr.bias)

    # classification head
    plain_model.pre_classifier.weight.copy_(lora_model.pre_classifier.weight)
    plain_model.pre_classifier.bias.copy_(lora_model.pre_classifier.bias)
    plain_model.classifier.weight.copy_(lora_model.classifier.weight)
    plain_model.classifier.bias.copy_(lora_model.classifier.bias)

# Save
torch.save(plain_model.state_dict(), "DISTILBERT_MERGED.pth")
print("✅  Merged weights saved to  DISTILBERT_MERGED.pth")