File size: 2,083 Bytes
a326400 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | import torch
from transformers import AutoModelForCausalLM
from tqdm import tqdm
def copy_qwen2_5_coder_weights_to_vl(coder_model_id, vl_model_id, output_path):
"""
Copy the language model weights from Qwen2.5-Coder-3B-Instruct into
Qwen2.5-VL-3B-Instruct, preserving its vision-language components.
"""
print(f"Loading Qwen2.5-Coder-3B-Instruct model from {coder_model_id}...")
coder_model = AutoModelForCausalLM.from_pretrained(
coder_model_id,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
print(f"Loading Qwen2.5-VL-3B-Instruct model from {vl_model_id}...")
vl_model = AutoModelForCausalLM.from_pretrained(
vl_model_id,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
coder_state = coder_model.state_dict()
vl_state = vl_model.state_dict()
print("Copying language weights from Coder model to VL model...")
updated_keys = 0
skipped_keys = []
for key in coder_state.keys():
# Focus on the shared transformer block
if key.startswith("transformer."):
if key in vl_state and coder_state[key].shape == vl_state[key].shape:
vl_state[key] = coder_state[key].clone()
updated_keys += 1
else:
skipped_keys.append(key)
print(f"✅ Updated {updated_keys} keys from Coder to VL.")
if skipped_keys:
print(f"⚠️ Skipped {len(skipped_keys)} keys due to shape mismatch or missing keys.")
for key in skipped_keys[:5]:
print(f" - Skipped: {key} (showing up to 5...)")
print("Saving updated Qwen2.5-VL-3B-Instruct model...")
vl_model.load_state_dict(vl_state)
vl_model.save_pretrained(output_path, safe_serialization=True)
print(f"✅ Model saved to: {output_path}")
if __name__ == "__main__":
coder_model_id = "Qwen/Qwen2.5-Coder-3B-Instruct"
vl_model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
output_path = "./Qwen2.5-VL-3B-Instruct-CoderMerged"
copy_qwen2_5_coder_weights_to_vl(coder_model_id, vl_model_id, output_path)
|