File size: 5,465 Bytes
0d00bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from typing import List, Union

import torch
import torch.nn as nn

from ..base.QType import QType
from ..layers.QConv import QConv2d
from ..layers.QLinear import QLinear
from ..layers.QSLinear import QSLinear
from ..layers.SLinear import SLinear


# layer conversion functions 
def replace_linear(module: nn.Module, w_Q: Union[QType, str], in_Q: Union[QType, str, None]=None, quant_grad: bool=True, exclude_layers: List[str]=[]):
    assert isinstance(exclude_layers, list), 'Exclude layers must be list of string'
    # record module names 
    mod_dict = {}
    for n,m in module.named_modules():
        mod_dict[n] = m 
    
    for n,m in module.named_modules():
        if n in exclude_layers:
            print('(Replace Qlinear) Excluding layer:', n)
            continue
        if isinstance(m, nn.Linear):
            new_mod = QLinear(m.in_features, m.out_features, m.bias is not None)
            new_mod.transfer(m)
            new_mod.assign_qparams(w_Q)
            new_mod.set_quant_grad(quant_grad)
            # if quant_qkv and (('q_proj' in n) or ('k_proj' in n) or ('v_proj' in n)):
            #     new_mod.set_quant_output(True)

            if in_Q is not None:
                new_mod.assign_input_qparams(in_Q)

            parent_mod = mod_dict['.'.join(n.split('.')[:-1])]
            setattr(parent_mod, n.split('.')[-1], new_mod)
            
# layer conversion functions 
def replace_sparse_quant_linear(module: nn.Module, w_Q: Union[QType, str], in_Q: Union[QType, str, None]=None, quant_grad=True, calibration_dict=None):
    # record module names 
    mod_dict = {}
    for n,m in module.named_modules():
        mod_dict[n] = m 
    
    for n,m in module.named_modules():
        if isinstance(m, nn.Linear):
            sparse_ratio_n = calibration_dict[n] if calibration_dict is not None else 0.0
            new_mod = QSLinear(m.in_features, m.out_features, m.bias is not None, sparse_ratio=sparse_ratio_n)
            new_mod.transfer(m)
            new_mod.assign_qparams(w_Q)
            new_mod.set_quant_grad(quant_grad)

            if in_Q is not None:
                new_mod.assign_input_qparams(in_Q)

            parent_mod = mod_dict['.'.join(n.split('.')[:-1])]
            setattr(parent_mod, n.split('.')[-1], new_mod)


# layer conversion functions 
def replace_sparse_linear(module: nn.Module, calibration_dict=None):
    # record module names 
    mod_dict = {}
    for n,m in module.named_modules():
        mod_dict[n] = m 
    
    for n,m in module.named_modules():
        if isinstance(m, nn.Linear):
            sparse_ratio_n = calibration_dict[n] if calibration_dict is not None else 0.0
            new_mod = SLinear(m.in_features, m.out_features, m.bias is not None, sparse_ratio=sparse_ratio_n)
            new_mod.transfer(m)

            parent_mod = mod_dict['.'.join(n.split('.')[:-1])]
            setattr(parent_mod, n.split('.')[-1], new_mod)



def replace_linear_mixfp(module: nn.Module, w_Q: Union[QType, str], high_Q: Union[QType, str], ratio: float=0.0, quant_grad=True):
    high_prec_layer_names = []
    if ratio>0:
        w_quant_desc = w_Q if isinstance(w_Q, str) else w_Q.desc
        quant_err_list = torch.load(f'mix_fp/mixfp_err_{w_quant_desc}.pt')
        n_layers = int(ratio * len(quant_err_list))
        high_prec_layer_names = [i[0] for i in quant_err_list[-n_layers:]]
        print(f'{n_layers} layers will be assigned to high precision bit: {high_Q}')

    # record module names 
    mod_dict = {}
    for n,m in module.named_modules():
        mod_dict[n] = m 
    
    for n,m in module.named_modules():
        if isinstance(m, nn.Linear):
            new_mod = QLinear(m.in_features, m.out_features, m.bias is not None)
            new_mod.transfer(m)
            if n in high_prec_layer_names:
                new_mod.assign_qparams(high_Q)
                # print(f'Layer {n} will be assigned to {high_Q}')
            else:
                new_mod.assign_qparams(w_Q)
            new_mod.set_quant_grad(quant_grad)

            parent_mod = mod_dict['.'.join(n.split('.')[:-1])]
            setattr(parent_mod, n.split('.')[-1], new_mod)


def replace_conv2d(module: nn.Module, w_Q: QType, in_Q: Union[QType, None]=None, quant_grad=True):
    # record module names 
    mod_dict = {}
    for n,m in module.named_modules():
        mod_dict[n] = m 
    
    for n,m in module.named_modules():
        if isinstance(m, nn.Conv2d):
            new_mod = QConv2d(m.in_channels, m.out_channels, m.kernel_size, m.stride, m.padding, m.dilation, m.groups, m.bias is not None)  # type: ignore
            new_mod.transfer(m)
            new_mod.assign_qparams(w_Q)
            if in_Q is not None:
                new_mod.assign_input_qparams(in_Q)
            new_mod.set_quant_grad(quant_grad)

            parent_mod = mod_dict['.'.join(n.split('.')[:-1])]
            setattr(parent_mod, n.split('.')[-1], new_mod)
        
    
def assign_qparams(module: nn.Module, w_Q: Union[QType, str], in_Q: Union[QType, str, None]=None):
    for n,m in module.named_modules():
        if isinstance(m, (QConv2d, QLinear)):
            m.assign_qparams(w_Q)
            if in_Q is not None:
                m.assign_input_qparams(in_Q)


def set_fastforward(module: nn.Module, value: bool=True):
    print('Switch QLinear layers to fast_forward mode:', value)
    for n,m in module.named_modules():
        if isinstance(m, QLinear):
            m._fast_forward = value