| |
|
|
| import argparse |
| import json |
| import shutil |
| import sys |
| from pathlib import Path |
|
|
|
|
| def _load_json(path: Path) -> dict: |
| with path.open("r") as f: |
| return json.load(f) |
|
|
|
|
| def _write_json(path: Path, obj: dict) -> None: |
| with path.open("w") as f: |
| json.dump(obj, f, indent=2, sort_keys=True) |
| f.write("\n") |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="convert k2.5 to thinking" |
| ) |
| parser.add_argument("--src", required=True) |
| parser.add_argument("--dst", required=True) |
| parser.add_argument( |
| "--trust_remote_code", |
| action="store_true", |
| default=False, |
| ) |
| args = parser.parse_args() |
|
|
| src_dir = Path(args.src).resolve() |
| dst_dir = Path(args.dst).resolve() |
| dst_dir.mkdir(parents=True, exist_ok=True) |
|
|
| from safetensors import safe_open |
| from safetensors.torch import save_file |
|
|
| src_config = _load_json(src_dir / "config.json") |
| |
| |
| |
| |
| |
| text_config = src_config |
| _write_json(dst_dir / "config.json", dict(text_config)) |
|
|
| for child in src_dir.iterdir(): |
| if not child.is_file(): |
| continue |
| if child.suffix == ".safetensors": |
| continue |
| if child.name in {"config.json", "model.safetensors.index.json"}: |
| continue |
| shutil.copy2(child, dst_dir / child.name) |
|
|
| if args.trust_remote_code: |
| sys.path.append(str(Path(__file__).resolve().parents[1])) |
| from example_utils import copy_custom_model_files |
|
|
| copy_custom_model_files(str(src_dir), str(dst_dir), trust_remote_code=True) |
|
|
| index_in = _load_json(src_dir / "model.safetensors.index.json") |
| shard_names = sorted(set(index_in["weight_map"].values())) |
| weight_map_out: dict[str, str] = {} |
| metadata: dict | None = None |
| for shard in shard_names: |
| src_file = src_dir / shard |
| dst_file = dst_dir / shard |
| out_tensors = {} |
| with safe_open(src_file, framework="pt") as f: |
| if metadata is None: |
| metadata = dict(f.metadata() or {}) |
| for key in f.keys(): |
| if not key.startswith("language_model."): |
| continue |
| new_key = key[len("language_model.") :] |
| out_tensors[new_key] = f.get_tensor(key) |
| weight_map_out[new_key] = shard |
| if out_tensors: |
| save_file(out_tensors, str(dst_file), metadata=metadata) |
| print(f"Wrote {len(out_tensors)} tensors to {dst_file}") |
|
|
| if not weight_map_out: |
| raise RuntimeError("No tensors matched prefix 'language_model.'") |
|
|
| _write_json( |
| dst_dir / "model.safetensors.index.json", |
| {"metadata": metadata or {}, "weight_map": weight_map_out}, |
| ) |
|
|
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|