|
|
import os |
|
|
import json |
|
|
import torch |
|
|
from safetensors.torch import load_file, save_file |
|
|
|
|
|
|
|
|
src_model = "model_tune.safetensors" |
|
|
ref_index = "model_ref.safetensors.index.json" |
|
|
|
|
|
|
|
|
|
|
|
out_dir = os.path.dirname(os.path.abspath(src_model)) |
|
|
if out_dir == "": |
|
|
out_dir = "." |
|
|
|
|
|
print(f"Loading tuned model: {src_model}") |
|
|
src_tensors = load_file(src_model) |
|
|
|
|
|
print(f"Loading reference index: {ref_index}") |
|
|
with open(ref_index, "r") as f: |
|
|
ref_data = json.load(f) |
|
|
ref_weight_map = ref_data["weight_map"] |
|
|
|
|
|
|
|
|
remapped = {} |
|
|
|
|
|
for k, v in src_tensors.items(): |
|
|
|
|
|
new_k = k[len("model."):] if k.startswith("model.") else k |
|
|
|
|
|
if new_k.endswith(".gate_up_proj.weight"): |
|
|
|
|
|
half = v.shape[0] // 2 |
|
|
base = new_k.replace(".gate_up_proj.weight", "") |
|
|
remapped[base + ".gate_proj.weight"] = v[:half, :] |
|
|
remapped[base + ".up_proj.weight"] = v[half:, :] |
|
|
else: |
|
|
remapped[new_k] = v |
|
|
|
|
|
print(f"Remapped tensors: {len(remapped)}") |
|
|
|
|
|
|
|
|
shards = {} |
|
|
for name, shard_file in ref_weight_map.items(): |
|
|
if name not in remapped: |
|
|
raise KeyError(f"❌ Missing tensor: {name} (expected by reference)") |
|
|
if shard_file not in shards: |
|
|
shards[shard_file] = {} |
|
|
shards[shard_file][name] = remapped[name] |
|
|
|
|
|
for shard_file, tensors in shards.items(): |
|
|
out_path = os.path.join(out_dir, shard_file) |
|
|
print(f"Saving {len(tensors)} tensors to {out_path} ...") |
|
|
save_file(tensors, out_path) |
|
|
|
|
|
|
|
|
total_size = sum(os.path.getsize(os.path.join(out_dir, f)) for f in shards.keys()) |
|
|
new_index = { |
|
|
"metadata": {"total_size": total_size}, |
|
|
"weight_map": ref_weight_map, |
|
|
} |
|
|
out_index = os.path.join(out_dir, "model.safetensors.index.json") |
|
|
with open(out_index, "w") as f: |
|
|
json.dump(new_index, f, indent=2) |
|
|
|
|
|
print("\n✅ Done.") |
|
|
print(f" - {len(shards)} shard files in {out_dir}") |
|
|
print(f" - {out_index}") |
|
|
|