File size: 2,714 Bytes
f9e119d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch, torch.nn as nn, torch.nn.functional as F
def _pack_int4(q):
    q_u4=(q+8).to(torch.uint8); even=q_u4[::2]; odd=q_u4[1::2]; p=(even<<4)|odd
    if q_u4.numel()%2==1: p=torch.cat([p,(q_u4[-1]<<4)])
    return p.contiguous()
def _unpack_int4(packed,total_elems,device=None):
    hi=(packed>>4)&0x0F; lo=packed&0x0F; u4=torch.stack([hi,lo],-1).flatten()[:total_elems]
    q=(u4.to(torch.int16)-8).to(torch.int8); return q.to(device) if device is not None else q
def quantize_per_outchannel_int4(weight):
    assert weight.dim()==2
    w=weight.detach().to(torch.float32)
    max_abs=w.abs().amax(1,keepdim=True).clamp(min=1e-8)
    scale=(max_abs/7.0); q=torch.round(w/scale).clamp_(-8,7).to(torch.int8)
    packed=_pack_int4(q.flatten()); return packed, scale.squeeze(1).to(torch.float32), w.shape
class Int4Linear(nn.Module):
    def __init__(self,in_features,out_features,bias=True,device=None,dtype=None):
        super().__init__()
        self.in_features=in_features; self.out_features=out_features
        self.bias=nn.Parameter(torch.zeros(out_features,device=device,dtype=dtype)) if bias else None
        self.register_buffer("packed_weight",torch.empty(0,dtype=torch.uint8),persistent=True)
        self.register_buffer("scales",torch.empty(out_features,dtype=torch.float32),persistent=True)
        self.register_buffer("orig_in_features",torch.tensor(in_features,dtype=torch.int32),persistent=True)
    @staticmethod
    def from_linear(m: nn.Linear):
        q=Int4Linear(m.in_features,m.out_features,bias=(m.bias is not None),device=m.weight.device,dtype=m.weight.dtype)
        packed,scales,shape=quantize_per_outchannel_int4(m.weight)
        q.packed_weight=packed; q.scales=scales; q.orig_in_features=torch.tensor(shape[1],dtype=torch.int32,device=m.weight.device)
        if m.bias is not None: q.bias=nn.Parameter(m.bias.detach().to(m.weight.dtype))
        return q
    def forward(self,x):
        total=int(self.out_features*int(self.orig_in_features.item()))
        q=_unpack_int4(self.packed_weight,total,device=x.device)
        w_q=q.to(torch.float32).view(self.out_features,-1)
        w=(w_q*self.scales.to(w_q.dtype).unsqueeze(1)).to(x.dtype)
        return F.linear(x,w,self.bias)
def quantize_model_to_int4(model,name_exclude_patterns=()):
    def ex(n): return any(p in n for p in name_exclude_patterns)
    rep=0
    for name,mod in list(model.named_modules()):
        for cn,ch in list(mod.named_children()):
            full=f"{name}.{cn}" if name else cn
            if isinstance(ch,nn.Linear) and not ex(full):
                setattr(mod,cn,Int4Linear.from_linear(ch)); rep+=1
    print(f"[INT4] Replaced {rep} Linear layers with Int4Linear."); return model