#!/usr/bin/env python3 import argparse import json import shutil import sys from pathlib import Path def _load_json(path: Path) -> dict: with path.open("r") as f: return json.load(f) def _write_json(path: Path, obj: dict) -> None: with path.open("w") as f: json.dump(obj, f, indent=2, sort_keys=True) f.write("\n") def main() -> None: parser = argparse.ArgumentParser( description="convert k2.5 to thinking" ) parser.add_argument("--src", required=True) parser.add_argument("--dst", required=True) parser.add_argument( "--trust_remote_code", action="store_true", default=False, ) args = parser.parse_args() src_dir = Path(args.src).resolve() dst_dir = Path(args.dst).resolve() dst_dir.mkdir(parents=True, exist_ok=True) from safetensors import safe_open from safetensors.torch import save_file src_config = _load_json(src_dir / "config.json") # text_config = src_config.get("text_config") # print(src_config) # print(f"text: {text_config}") # if text_config is None: # raise RuntimeError("Missing text_config in config.json") text_config = src_config _write_json(dst_dir / "config.json", dict(text_config)) for child in src_dir.iterdir(): if not child.is_file(): continue if child.suffix == ".safetensors": continue if child.name in {"config.json", "model.safetensors.index.json"}: continue shutil.copy2(child, dst_dir / child.name) if args.trust_remote_code: sys.path.append(str(Path(__file__).resolve().parents[1])) from example_utils import copy_custom_model_files # noqa: E402 copy_custom_model_files(str(src_dir), str(dst_dir), trust_remote_code=True) index_in = _load_json(src_dir / "model.safetensors.index.json") shard_names = sorted(set(index_in["weight_map"].values())) weight_map_out: dict[str, str] = {} metadata: dict | None = None for shard in shard_names: src_file = src_dir / shard dst_file = dst_dir / shard out_tensors = {} with safe_open(src_file, framework="pt") as f: if metadata is None: metadata = dict(f.metadata() or {}) for key in f.keys(): if not key.startswith("language_model."): continue new_key = key[len("language_model.") :] out_tensors[new_key] = f.get_tensor(key) weight_map_out[new_key] = shard if out_tensors: save_file(out_tensors, str(dst_file), metadata=metadata) print(f"Wrote {len(out_tensors)} tensors to {dst_file}") if not weight_map_out: raise RuntimeError("No tensors matched prefix 'language_model.'") _write_json( dst_dir / "model.safetensors.index.json", {"metadata": metadata or {}, "weight_map": weight_map_out}, ) print("Done.") if __name__ == "__main__": main()