File size: 1,465 Bytes
aabff3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#!/usr/bin/env python
"""Cast a converted component's safetensors to fp16 or bf16. Casts every floating
point tensor (weights, biases, filters, alpha/beta) to the target dtype; leaves
integer tensors untouched. Copies config.json across if one is present. Smaller
download and possibly faster decode, at the cost of reduced precision.

Usage: cast_component.py <src_dir> <dst_dir> <fp16|bf16>
"""
import shutil
import sys
from pathlib import Path

import mlx.core as mx

src = Path(sys.argv[1])
dst = Path(sys.argv[2])
target = sys.argv[3].lower()

dtypes = {"fp16": mx.float16, "bf16": mx.bfloat16}
if target not in dtypes:
    sys.exit(f"target must be fp16 or bf16, got {sys.argv[3]!r}")
out_dtype = dtypes[target]

# mlx float dtypes we are willing to cast from. Anything else (int*, bool) passes
# through unchanged so index/mask tensors keep their semantics.
float_dtypes = {mx.float32, mx.float16, mx.bfloat16}

dst.mkdir(parents=True, exist_ok=True)
w = mx.load(str(src / "model.safetensors"))
assert isinstance(w, dict)  # safetensors load returns a name->array dict
out = {}
ncast = 0
for k, v in w.items():
    if v.dtype in float_dtypes:
        out[k] = v.astype(out_dtype)
        ncast += 1
    else:
        out[k] = v

mx.save_safetensors(str(dst / "model.safetensors"), out)
if (src / "config.json").exists():
    shutil.copy(src / "config.json", dst / "config.json")
print(f"{src.name}: cast {ncast}/{len(w)} float tensors to {target} -> {dst}")