catplusplus's picture
Upload folder using huggingface_hub
1e103b7 verified
import argparse
import torch
from transformers import Qwen2_5_VLForConditionalGeneration
from dfloat11 import compress_model
def main():
parser = argparse.ArgumentParser("Compress OmniGen2 MLLM (Qwen2.5-VL) using DFloat11")
parser.add_argument(
'--model_path',
type=str,
required=True,
help='The path to the OmniGen2 model (containing "mllm" folder) or direct path to MLLM checkpoint'
)
parser.add_argument(
'--save_path',
type=str,
default='./OmniGen2-mllm-DF11',
help='The path to save the compressed model'
)
parser.add_argument(
'--save_single_file',
action='store_true',
help='Save the compressed model as a single .safetensors file'
)
parser.add_argument(
'--check_correctness',
action='store_true',
help='Check the correctness of the compressed weights during compression'
)
parser.add_argument(
'--block_range',
type=int,
nargs=2,
default=(0, 100),
help='The range of transformer blocks to compress (for parallel compression over multiple CPU cores)'
)
args = parser.parse_args()
# Determine MLLM path
import os
mllm_path = args.model_path
if os.path.isdir(os.path.join(args.model_path, "mllm")):
mllm_path = os.path.join(args.model_path, "mllm")
print(f"Loading MLLM from: {mllm_path}")
# Load the Qwen2.5-VL model in bfloat16 precision
# Use trust_remote_code=True same as in inference.py
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
mllm_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
# Untie weights to avoid safetensors error about shared memory
# safetensors.torch.save_file dies if tensors share memory.
if hasattr(model, 'lm_head') and hasattr(model.lm_head, 'weight'):
print("Untying lm_head weights to avoid safetensors shared memory error...")
model.lm_head.weight = torch.nn.Parameter(model.lm_head.weight.clone())
# Compress the model using DFloat11 compression
# Pattern updated to match Qwen2.5-VL internal structure (model.language_model.layers...)
compress_model(
model=model,
pattern_dict={
r"model\.language_model\.layers\.\d+": (
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"mlp.gate_proj",
"mlp.up_proj",
"mlp.down_proj",
),
},
save_path=args.save_path,
save_single_file=args.save_single_file, # Force single file to use state_dict keys (model.language_model...)
check_correctness=args.check_correctness,
block_range=args.block_range,
)
if __name__ == "__main__":
main()