E-Guide / 3.prune.py
zrrraa's picture
upload
a1445bd
import torch
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor
import torch_pruning as tp
from qwen_vl_utils import process_vision_info
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLPatchMerger
from torch import nn
from typing import Sequence
import os
def prune_model(model, processor, pruning_ratio):
"""同步剪枝LM和视觉模块,确保维度对齐"""
num_heads = {}
for name, module in model.named_modules():
if name.endswith("self_attn"):
num_heads[module.q_proj] = model.config.num_attention_heads
num_heads[module.k_proj] = model.config.num_key_value_heads
num_heads[module.v_proj] = model.config.num_key_value_heads
importance = tp.importance.GroupNormImportance(p=2, group_reduction='mean') #tp.importance.ActivationImportance(p=2, target_types=[torch.nn.Linear])
# 处理未封装的参数
unwrapped_parameters = []
# 忽略最后的lm_head和LM部分的embedding
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 151936:
ignored_layers.append(m)
if isinstance(m, torch.nn.Embedding):
ignored_layers.append(m)
print("ignored_layers", ignored_layers)
# 构建输入
# example_inputs = torch.randint(0, 100000, (3, 56, 56), dtype=torch.long, device='cuda:1')
text = "描述这张图片。"
example_inputs = torch.tensor(processor.tokenizer.encode(text)).unsqueeze(0).to(model.device)
print(example_inputs.shape)
# 创建剪枝器
model.config.use_cache = False
pruner = tp.pruner.MetaPruner(
model,
example_inputs=example_inputs,
importance=importance,
global_pruning=False,
pruning_ratio=pruning_ratio,
ignored_layers=ignored_layers,
num_heads=num_heads,
prune_num_heads=False,
prune_head_dims=False,
head_pruning_ratio=pruning_ratio,
round_to=4,
unwrapped_parameters=unwrapped_parameters,
)
# 执行剪枝
for g in pruner.step(interactive=True):
# print(g)
g.prune()
model.config.hidden_size = model.lm_head.in_features
for name, m in model.model.named_modules():
if name.endswith("self_attn"):
print(name)
m.hidden_size = m.q_proj.out_features
m.num_heads = m.hidden_size // m.head_dim
model.config.num_attention_heads = m.num_heads
m.num_key_value_groups = m.num_heads // m.num_key_value_heads
elif name.endswith("mlp"):
if hasattr(m, "gate_proj"):
print(name)
m.hidden_size = m.gate_proj.in_features
model.config.intermediate_size = m.gate_proj.out_features
return model
def main():
model_path = "/home/rzhong/project/unsloth/model_pretrain_sft_20250303_125849"
# model_path = "/home/rzhong/project/FSTSPrune/Qwen2.5-VL-3B-Instruct-LatexOCR"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="cuda:1"
)
processor = AutoProcessor.from_pretrained(model_path)
print("========= Before Pruning =========")
print(model)
ori_size = tp.utils.count_params(model)
print("Starting pruning process...")
pruned_model = prune_model(model, processor, pruning_ratio=0.5)
print("========= After Pruning =========")
print(pruned_model)
print(" Params: %.2f M => %.2f M" %
(ori_size / 1e6, tp.utils.count_params(pruned_model) / 1e6))
# pruned_model.zero_grad()
# save_path = "/home/rzhong/project/FSTSPrune/model_pretrain_sft_20250303_125849-Pruned"
# os.makedirs(save_path, exist_ok=True)
# # pruned_model.save_pretrained(save_path)
# torch.save(pruned_model, os.path.join(save_path, "pytorch_model.bin"))
# processor.save_pretrained(save_path)
# pruned_model.config.save_pretrained(save_path)
# pruned_model.zero_grad()
# save_path = "/home/rzhong/project/FSTSPrune/model_pretrain_sft_20250303_125849-Pruned-hf"
# os.makedirs(save_path, exist_ok=True)
# pruned_model.save_pretrained(save_path)
# # torch.save(pruned_model, os.path.join(save_path, "pytorch_model.bin"))
# processor.save_pretrained(save_path)
# # pruned_model.config.save_pretrained(save_path)
# load_test_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
# save_path,
# device_map="cpu"
# )
# print("load test pass!")
if __name__ == "__main__":
main()