File size: 3,036 Bytes
68821b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python3

import argparse
import json
import shutil
import sys
from pathlib import Path


def _load_json(path: Path) -> dict:
    with path.open("r") as f:
        return json.load(f)


def _write_json(path: Path, obj: dict) -> None:
    with path.open("w") as f:
        json.dump(obj, f, indent=2, sort_keys=True)
        f.write("\n")


def main() -> None:
    parser = argparse.ArgumentParser(
        description="convert k2.5 to thinking"
    )
    parser.add_argument("--src", required=True)
    parser.add_argument("--dst", required=True)
    parser.add_argument(
        "--trust_remote_code",
        action="store_true",
        default=False,
    )
    args = parser.parse_args()

    src_dir = Path(args.src).resolve()
    dst_dir = Path(args.dst).resolve()
    dst_dir.mkdir(parents=True, exist_ok=True)

    from safetensors import safe_open
    from safetensors.torch import save_file

    src_config = _load_json(src_dir / "config.json")
    # text_config = src_config.get("text_config")
    # print(src_config)
    # print(f"text: {text_config}")
    # if text_config is None:
    #    raise RuntimeError("Missing text_config in config.json")
    text_config = src_config
    _write_json(dst_dir / "config.json", dict(text_config))

    for child in src_dir.iterdir():
        if not child.is_file():
            continue
        if child.suffix == ".safetensors":
            continue
        if child.name in {"config.json", "model.safetensors.index.json"}:
            continue
        shutil.copy2(child, dst_dir / child.name)

    if args.trust_remote_code:
        sys.path.append(str(Path(__file__).resolve().parents[1]))
        from example_utils import copy_custom_model_files  # noqa: E402

        copy_custom_model_files(str(src_dir), str(dst_dir), trust_remote_code=True)

    index_in = _load_json(src_dir / "model.safetensors.index.json")
    shard_names = sorted(set(index_in["weight_map"].values()))
    weight_map_out: dict[str, str] = {}
    metadata: dict | None = None
    for shard in shard_names:
        src_file = src_dir / shard
        dst_file = dst_dir / shard
        out_tensors = {}
        with safe_open(src_file, framework="pt") as f:
            if metadata is None:
                metadata = dict(f.metadata() or {})
            for key in f.keys():
                if not key.startswith("language_model."):
                    continue
                new_key = key[len("language_model.") :]
                out_tensors[new_key] = f.get_tensor(key)
                weight_map_out[new_key] = shard
        if out_tensors:
            save_file(out_tensors, str(dst_file), metadata=metadata)
            print(f"Wrote {len(out_tensors)} tensors to {dst_file}")

    if not weight_map_out:
        raise RuntimeError("No tensors matched prefix 'language_model.'")

    _write_json(
        dst_dir / "model.safetensors.index.json",
        {"metadata": metadata or {}, "weight_map": weight_map_out},
    )

    print("Done.")


if __name__ == "__main__":
    main()