File size: 4,677 Bytes
a1445bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor
import torch_pruning as tp
from qwen_vl_utils import process_vision_info
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLPatchMerger
from torch import nn
from typing import Sequence
import os

def prune_model(model, processor, pruning_ratio):
    """同步剪枝LM和视觉模块,确保维度对齐"""

    num_heads = {}
    for name, module in model.named_modules():
        if name.endswith("self_attn"):
            num_heads[module.q_proj] = model.config.num_attention_heads
            num_heads[module.k_proj] = model.config.num_key_value_heads
            num_heads[module.v_proj] = model.config.num_key_value_heads

    importance = tp.importance.GroupNormImportance(p=2, group_reduction='mean') #tp.importance.ActivationImportance(p=2, target_types=[torch.nn.Linear])
    
    # 处理未封装的参数
    unwrapped_parameters = []

    # 忽略最后的lm_head和LM部分的embedding 
    ignored_layers = []
    for m in model.modules():
        if isinstance(m, torch.nn.Linear) and m.out_features == 151936:
            ignored_layers.append(m)
        if isinstance(m, torch.nn.Embedding):
            ignored_layers.append(m)
    print("ignored_layers", ignored_layers)

    # 构建输入
    # example_inputs = torch.randint(0, 100000, (3, 56, 56), dtype=torch.long, device='cuda:1')
    
    text = "描述这张图片。"
    example_inputs = torch.tensor(processor.tokenizer.encode(text)).unsqueeze(0).to(model.device)
    print(example_inputs.shape)

    # 创建剪枝器
    model.config.use_cache = False
    pruner = tp.pruner.MetaPruner(
        model,
        example_inputs=example_inputs,
        importance=importance,
        global_pruning=False,
        pruning_ratio=pruning_ratio,
        ignored_layers=ignored_layers,
        num_heads=num_heads,
        prune_num_heads=False,
        prune_head_dims=False,
        head_pruning_ratio=pruning_ratio,
        round_to=4,
        unwrapped_parameters=unwrapped_parameters,
    )
    
    # 执行剪枝
    for g in pruner.step(interactive=True):
        # print(g)
        g.prune()

    model.config.hidden_size = model.lm_head.in_features
    for name, m in model.model.named_modules():
        if name.endswith("self_attn"):
            print(name)
            m.hidden_size = m.q_proj.out_features
            m.num_heads = m.hidden_size // m.head_dim
            model.config.num_attention_heads = m.num_heads
            m.num_key_value_groups = m.num_heads // m.num_key_value_heads
        elif name.endswith("mlp"):
            if hasattr(m, "gate_proj"):
                print(name)
                m.hidden_size = m.gate_proj.in_features
                model.config.intermediate_size = m.gate_proj.out_features

    return model
            
def main():
    model_path = "/home/rzhong/project/unsloth/model_pretrain_sft_20250303_125849"
    # model_path = "/home/rzhong/project/FSTSPrune/Qwen2.5-VL-3B-Instruct-LatexOCR"
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="cuda:1"
    )
    processor = AutoProcessor.from_pretrained(model_path)

    print("========= Before Pruning =========")
    print(model)
    ori_size = tp.utils.count_params(model)

    print("Starting pruning process...")
    pruned_model = prune_model(model, processor, pruning_ratio=0.5)
    print("========= After Pruning =========")
    print(pruned_model)

    print("  Params: %.2f M => %.2f M" %
          (ori_size / 1e6, tp.utils.count_params(pruned_model) / 1e6))
    
    # pruned_model.zero_grad()
    # save_path = "/home/rzhong/project/FSTSPrune/model_pretrain_sft_20250303_125849-Pruned"
    # os.makedirs(save_path, exist_ok=True)
    # # pruned_model.save_pretrained(save_path)
    # torch.save(pruned_model, os.path.join(save_path, "pytorch_model.bin"))
    # processor.save_pretrained(save_path)
    # pruned_model.config.save_pretrained(save_path)

    # pruned_model.zero_grad()
    # save_path = "/home/rzhong/project/FSTSPrune/model_pretrain_sft_20250303_125849-Pruned-hf"
    # os.makedirs(save_path, exist_ok=True)
    # pruned_model.save_pretrained(save_path)
    # # torch.save(pruned_model, os.path.join(save_path, "pytorch_model.bin"))
    # processor.save_pretrained(save_path)
    # # pruned_model.config.save_pretrained(save_path)

    # load_test_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    #     save_path,
    #     device_map="cpu"
    # )
    # print("load test pass!")

if __name__ == "__main__":
    main()