import torch from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor import torch_pruning as tp from qwen_vl_utils import process_vision_info from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLPatchMerger from torch import nn from typing import Sequence import os def prune_model(model, processor, pruning_ratio): """同步剪枝LM和视觉模块,确保维度对齐""" num_heads = {} for name, module in model.named_modules(): if name.endswith("self_attn"): num_heads[module.q_proj] = model.config.num_attention_heads num_heads[module.k_proj] = model.config.num_key_value_heads num_heads[module.v_proj] = model.config.num_key_value_heads importance = tp.importance.GroupNormImportance(p=2, group_reduction='mean') #tp.importance.ActivationImportance(p=2, target_types=[torch.nn.Linear]) # 处理未封装的参数 unwrapped_parameters = [] # 忽略最后的lm_head和LM部分的embedding ignored_layers = [] for m in model.modules(): if isinstance(m, torch.nn.Linear) and m.out_features == 151936: ignored_layers.append(m) if isinstance(m, torch.nn.Embedding): ignored_layers.append(m) print("ignored_layers", ignored_layers) # 构建输入 # example_inputs = torch.randint(0, 100000, (3, 56, 56), dtype=torch.long, device='cuda:1') text = "描述这张图片。" example_inputs = torch.tensor(processor.tokenizer.encode(text)).unsqueeze(0).to(model.device) print(example_inputs.shape) # 创建剪枝器 model.config.use_cache = False pruner = tp.pruner.MetaPruner( model, example_inputs=example_inputs, importance=importance, global_pruning=False, pruning_ratio=pruning_ratio, ignored_layers=ignored_layers, num_heads=num_heads, prune_num_heads=False, prune_head_dims=False, head_pruning_ratio=pruning_ratio, round_to=4, unwrapped_parameters=unwrapped_parameters, ) # 执行剪枝 for g in pruner.step(interactive=True): # print(g) g.prune() model.config.hidden_size = model.lm_head.in_features for name, m in model.model.named_modules(): if name.endswith("self_attn"): print(name) m.hidden_size = m.q_proj.out_features m.num_heads = m.hidden_size // m.head_dim model.config.num_attention_heads = m.num_heads m.num_key_value_groups = m.num_heads // m.num_key_value_heads elif name.endswith("mlp"): if hasattr(m, "gate_proj"): print(name) m.hidden_size = m.gate_proj.in_features model.config.intermediate_size = m.gate_proj.out_features return model def main(): model_path = "/home/rzhong/project/unsloth/model_pretrain_sft_20250303_125849" # model_path = "/home/rzhong/project/FSTSPrune/Qwen2.5-VL-3B-Instruct-LatexOCR" model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch.bfloat16, device_map="cuda:1" ) processor = AutoProcessor.from_pretrained(model_path) print("========= Before Pruning =========") print(model) ori_size = tp.utils.count_params(model) print("Starting pruning process...") pruned_model = prune_model(model, processor, pruning_ratio=0.5) print("========= After Pruning =========") print(pruned_model) print(" Params: %.2f M => %.2f M" % (ori_size / 1e6, tp.utils.count_params(pruned_model) / 1e6)) # pruned_model.zero_grad() # save_path = "/home/rzhong/project/FSTSPrune/model_pretrain_sft_20250303_125849-Pruned" # os.makedirs(save_path, exist_ok=True) # # pruned_model.save_pretrained(save_path) # torch.save(pruned_model, os.path.join(save_path, "pytorch_model.bin")) # processor.save_pretrained(save_path) # pruned_model.config.save_pretrained(save_path) # pruned_model.zero_grad() # save_path = "/home/rzhong/project/FSTSPrune/model_pretrain_sft_20250303_125849-Pruned-hf" # os.makedirs(save_path, exist_ok=True) # pruned_model.save_pretrained(save_path) # # torch.save(pruned_model, os.path.join(save_path, "pytorch_model.bin")) # processor.save_pretrained(save_path) # # pruned_model.config.save_pretrained(save_path) # load_test_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( # save_path, # device_map="cpu" # ) # print("load test pass!") if __name__ == "__main__": main()