#!/usr/bin/env python """Cast a converted component's safetensors to fp16 or bf16. Casts every floating point tensor (weights, biases, filters, alpha/beta) to the target dtype; leaves integer tensors untouched. Copies config.json across if one is present. Smaller download and possibly faster decode, at the cost of reduced precision. Usage: cast_component.py """ import shutil import sys from pathlib import Path import mlx.core as mx src = Path(sys.argv[1]) dst = Path(sys.argv[2]) target = sys.argv[3].lower() dtypes = {"fp16": mx.float16, "bf16": mx.bfloat16} if target not in dtypes: sys.exit(f"target must be fp16 or bf16, got {sys.argv[3]!r}") out_dtype = dtypes[target] # mlx float dtypes we are willing to cast from. Anything else (int*, bool) passes # through unchanged so index/mask tensors keep their semantics. float_dtypes = {mx.float32, mx.float16, mx.bfloat16} dst.mkdir(parents=True, exist_ok=True) w = mx.load(str(src / "model.safetensors")) assert isinstance(w, dict) # safetensors load returns a name->array dict out = {} ncast = 0 for k, v in w.items(): if v.dtype in float_dtypes: out[k] = v.astype(out_dtype) ncast += 1 else: out[k] = v mx.save_safetensors(str(dst / "model.safetensors"), out) if (src / "config.json").exists(): shutil.copy(src / "config.json", dst / "config.json") print(f"{src.name}: cast {ncast}/{len(w)} float tensors to {target} -> {dst}")