File size: 2,887 Bytes
1e103b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
import torch
from transformers import Qwen2_5_VLForConditionalGeneration
from dfloat11 import compress_model

def main():
    parser = argparse.ArgumentParser("Compress OmniGen2 MLLM (Qwen2.5-VL) using DFloat11")
    parser.add_argument(
        '--model_path',
        type=str,
        required=True,
        help='The path to the OmniGen2 model (containing "mllm" folder) or direct path to MLLM checkpoint'
    )
    parser.add_argument(
        '--save_path',
        type=str,
        default='./OmniGen2-mllm-DF11',
        help='The path to save the compressed model'
    )
    parser.add_argument(
        '--save_single_file',
        action='store_true',
        help='Save the compressed model as a single .safetensors file'
    )
    parser.add_argument(
        '--check_correctness',
        action='store_true',
        help='Check the correctness of the compressed weights during compression'
    )
    parser.add_argument(
        '--block_range',
        type=int,
        nargs=2,
        default=(0, 100),
        help='The range of transformer blocks to compress (for parallel compression over multiple CPU cores)'
    )
    args = parser.parse_args()

    # Determine MLLM path
    import os
    mllm_path = args.model_path
    if os.path.isdir(os.path.join(args.model_path, "mllm")):
        mllm_path = os.path.join(args.model_path, "mllm")
    
    print(f"Loading MLLM from: {mllm_path}")

    # Load the Qwen2.5-VL model in bfloat16 precision
    # Use trust_remote_code=True same as in inference.py
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        mllm_path, 
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    )
    
    # Untie weights to avoid safetensors error about shared memory
    # safetensors.torch.save_file dies if tensors share memory.
    if hasattr(model, 'lm_head') and hasattr(model.lm_head, 'weight'):
        print("Untying lm_head weights to avoid safetensors shared memory error...")
        model.lm_head.weight = torch.nn.Parameter(model.lm_head.weight.clone())

    # Compress the model using DFloat11 compression
    # Pattern updated to match Qwen2.5-VL internal structure (model.language_model.layers...)
    compress_model(
        model=model,
        pattern_dict={
            r"model\.language_model\.layers\.\d+": (
                "self_attn.q_proj",
                "self_attn.k_proj",
                "self_attn.v_proj",
                "self_attn.o_proj",
                "mlp.gate_proj",
                "mlp.up_proj",
                "mlp.down_proj",
            ),
        },
        save_path=args.save_path,
        save_single_file=args.save_single_file, # Force single file to use state_dict keys (model.language_model...)
        check_correctness=args.check_correctness,
        block_range=args.block_range,
    )

if __name__ == "__main__":
    main()