import os import json import torch from safetensors.torch import load_file, save_file # -------- CONFIG -------- src_model = "model_tune.safetensors" # your fine-tuned safetensors ref_index = "model_ref.safetensors.index.json" # reference index.json # ------------------------ # Output directory = same as input file out_dir = os.path.dirname(os.path.abspath(src_model)) if out_dir == "": out_dir = "." print(f"Loading tuned model: {src_model}") src_tensors = load_file(src_model) print(f"Loading reference index: {ref_index}") with open(ref_index, "r") as f: ref_data = json.load(f) ref_weight_map = ref_data["weight_map"] # -------- Remap + split gate_up_proj -------- remapped = {} for k, v in src_tensors.items(): # Drop "model." prefix new_k = k[len("model."):] if k.startswith("model.") else k if new_k.endswith(".gate_up_proj.weight"): # Split into gate_proj + up_proj half = v.shape[0] // 2 base = new_k.replace(".gate_up_proj.weight", "") remapped[base + ".gate_proj.weight"] = v[:half, :] remapped[base + ".up_proj.weight"] = v[half:, :] else: remapped[new_k] = v print(f"Remapped tensors: {len(remapped)}") # -------- Save shards -------- shards = {} for name, shard_file in ref_weight_map.items(): if name not in remapped: raise KeyError(f"āŒ Missing tensor: {name} (expected by reference)") if shard_file not in shards: shards[shard_file] = {} shards[shard_file][name] = remapped[name] for shard_file, tensors in shards.items(): out_path = os.path.join(out_dir, shard_file) print(f"Saving {len(tensors)} tensors to {out_path} ...") save_file(tensors, out_path) # -------- Write new index.json -------- total_size = sum(os.path.getsize(os.path.join(out_dir, f)) for f in shards.keys()) new_index = { "metadata": {"total_size": total_size}, "weight_map": ref_weight_map, } out_index = os.path.join(out_dir, "model.safetensors.index.json") with open(out_index, "w") as f: json.dump(new_index, f, indent=2) print("\nāœ… Done.") print(f" - {len(shards)} shard files in {out_dir}") print(f" - {out_index}")