k25-text-only / convert.py
aspctu's picture
Upload folder using huggingface_hub
68821b2 verified
#!/usr/bin/env python3
import argparse
import json
import shutil
import sys
from pathlib import Path
def _load_json(path: Path) -> dict:
with path.open("r") as f:
return json.load(f)
def _write_json(path: Path, obj: dict) -> None:
with path.open("w") as f:
json.dump(obj, f, indent=2, sort_keys=True)
f.write("\n")
def main() -> None:
parser = argparse.ArgumentParser(
description="convert k2.5 to thinking"
)
parser.add_argument("--src", required=True)
parser.add_argument("--dst", required=True)
parser.add_argument(
"--trust_remote_code",
action="store_true",
default=False,
)
args = parser.parse_args()
src_dir = Path(args.src).resolve()
dst_dir = Path(args.dst).resolve()
dst_dir.mkdir(parents=True, exist_ok=True)
from safetensors import safe_open
from safetensors.torch import save_file
src_config = _load_json(src_dir / "config.json")
# text_config = src_config.get("text_config")
# print(src_config)
# print(f"text: {text_config}")
# if text_config is None:
# raise RuntimeError("Missing text_config in config.json")
text_config = src_config
_write_json(dst_dir / "config.json", dict(text_config))
for child in src_dir.iterdir():
if not child.is_file():
continue
if child.suffix == ".safetensors":
continue
if child.name in {"config.json", "model.safetensors.index.json"}:
continue
shutil.copy2(child, dst_dir / child.name)
if args.trust_remote_code:
sys.path.append(str(Path(__file__).resolve().parents[1]))
from example_utils import copy_custom_model_files # noqa: E402
copy_custom_model_files(str(src_dir), str(dst_dir), trust_remote_code=True)
index_in = _load_json(src_dir / "model.safetensors.index.json")
shard_names = sorted(set(index_in["weight_map"].values()))
weight_map_out: dict[str, str] = {}
metadata: dict | None = None
for shard in shard_names:
src_file = src_dir / shard
dst_file = dst_dir / shard
out_tensors = {}
with safe_open(src_file, framework="pt") as f:
if metadata is None:
metadata = dict(f.metadata() or {})
for key in f.keys():
if not key.startswith("language_model."):
continue
new_key = key[len("language_model.") :]
out_tensors[new_key] = f.get_tensor(key)
weight_map_out[new_key] = shard
if out_tensors:
save_file(out_tensors, str(dst_file), metadata=metadata)
print(f"Wrote {len(out_tensors)} tensors to {dst_file}")
if not weight_map_out:
raise RuntimeError("No tensors matched prefix 'language_model.'")
_write_json(
dst_dir / "model.safetensors.index.json",
{"metadata": metadata or {}, "weight_map": weight_map_out},
)
print("Done.")
if __name__ == "__main__":
main()