File size: 2,171 Bytes
7b3f0ec
 
 
 
 
 
754b5f9
 
7b3f0ec
 
754b5f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b3f0ec
754b5f9
7b3f0ec
754b5f9
7b3f0ec
754b5f9
 
 
 
 
 
 
 
7b3f0ec
754b5f9
 
 
 
 
 
 
 
 
 
 
 
 
 
7b3f0ec
754b5f9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import json
import torch
from safetensors.torch import load_file, save_file

# -------- CONFIG --------
src_model = "model_tune.safetensors"                 # your fine-tuned safetensors
ref_index = "model_ref.safetensors.index.json" # reference index.json
# ------------------------

# Output directory = same as input file
out_dir = os.path.dirname(os.path.abspath(src_model))
if out_dir == "":
    out_dir = "."

print(f"Loading tuned model: {src_model}")
src_tensors = load_file(src_model)

print(f"Loading reference index: {ref_index}")
with open(ref_index, "r") as f:
    ref_data = json.load(f)
ref_weight_map = ref_data["weight_map"]

# -------- Remap + split gate_up_proj --------
remapped = {}

for k, v in src_tensors.items():
    # Drop "model." prefix
    new_k = k[len("model."):] if k.startswith("model.") else k

    if new_k.endswith(".gate_up_proj.weight"):
        # Split into gate_proj + up_proj
        half = v.shape[0] // 2
        base = new_k.replace(".gate_up_proj.weight", "")
        remapped[base + ".gate_proj.weight"] = v[:half, :]
        remapped[base + ".up_proj.weight"] = v[half:, :]
    else:
        remapped[new_k] = v

print(f"Remapped tensors: {len(remapped)}")

# -------- Save shards --------
shards = {}
for name, shard_file in ref_weight_map.items():
    if name not in remapped:
        raise KeyError(f"❌ Missing tensor: {name} (expected by reference)")
    if shard_file not in shards:
        shards[shard_file] = {}
    shards[shard_file][name] = remapped[name]

for shard_file, tensors in shards.items():
    out_path = os.path.join(out_dir, shard_file)
    print(f"Saving {len(tensors)} tensors to {out_path} ...")
    save_file(tensors, out_path)

# -------- Write new index.json --------
total_size = sum(os.path.getsize(os.path.join(out_dir, f)) for f in shards.keys())
new_index = {
    "metadata": {"total_size": total_size},
    "weight_map": ref_weight_map,
}
out_index = os.path.join(out_dir, "model.safetensors.index.json")
with open(out_index, "w") as f:
    json.dump(new_index, f, indent=2)

print("\n✅ Done.")
print(f"   - {len(shards)} shard files in {out_dir}")
print(f"   - {out_index}")