import json import math import site import struct import sys from dataclasses import dataclass from pathlib import Path from typing import Any _VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor" for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"): if _vendor_path.exists(): vendor_text = str(_vendor_path) if vendor_text not in sys.path: sys.path.insert(0, vendor_text) try: import numpy as np except ModuleNotFoundError: user_site = site.getusersitepackages() if user_site and user_site not in sys.path: sys.path.append(user_site) try: import numpy as np except ModuleNotFoundError: np = None if np is not None and not hasattr(np, "asarray"): np = None DTYPE_CODES = { "F32": ("f", 4), "F64": ("d", 8), "I32": ("i", 4), } @dataclass(slots=True) class SafeTensorFile: tensors: dict[str, Any] metadata: dict[str, str] def _read_safetensor_header(path: str | Path) -> dict[str, Any]: with Path(path).open("rb") as handle: length_bytes = handle.read(8) if len(length_bytes) < 8: raise ValueError("Invalid safetensors file: missing header length.") header_length = struct.unpack(" list[int]: if np is not None and hasattr(value, "shape"): return [int(axis) for axis in value.shape] if not isinstance(value, list): return [] if not value: return [0] first_shape = _shape_of(value[0]) for item in value[1:]: if _shape_of(item) != first_shape: raise ValueError("Safetensor writer does not support ragged tensors.") return [len(value)] + first_shape def _flatten(value: Any) -> list[Any]: if np is not None and hasattr(value, "reshape"): return value.reshape(-1).tolist() if isinstance(value, list): flattened: list[Any] = [] for item in value: flattened.extend(_flatten(item)) return flattened return [value] def _dtype_of(flat_values: list[Any]) -> str: if all(isinstance(value, int) and not isinstance(value, bool) for value in flat_values): return "I32" return "F64" def _pack_tensor(dtype: str, values: list[Any]) -> bytes: if not values: return b"" code, _ = DTYPE_CODES[dtype] cast_values = [int(value) for value in values] if dtype == "I32" else [float(value) for value in values] return struct.pack(f"<{len(cast_values)}{code}", *cast_values) def _array_payload(value: Any) -> tuple[str, list[int], Any] | None: if np is None: return None try: array = np.asarray(value) except (TypeError, ValueError): return None if array.dtype == object: return None shape = [int(axis) for axis in array.shape] if np.issubdtype(array.dtype, np.integer) and not np.issubdtype(array.dtype, np.bool_): return "I32", shape, np.ascontiguousarray(array.astype(" Any: if not shape: return values[0] if len(shape) == 1: return values[: shape[0]] chunk = math.prod(shape[1:]) return [ _reshape(values[index * chunk : (index + 1) * chunk], shape[1:]) for index in range(shape[0]) ] def write_safetensor_file( path: str | Path, tensors: dict[str, Any], *, metadata: dict[str, str] | None = None, ) -> None: tensor_header: dict[str, Any] = {} payloads: list[Any] = [] offset = 0 for name, value in tensors.items(): array_payload = _array_payload(value) if array_payload is None: flat_values = _flatten(value) dtype = _dtype_of(flat_values) shape = _shape_of(value) payload = _pack_tensor(dtype, flat_values) else: dtype, shape, payload = array_payload payload_size = int(payload.nbytes) if hasattr(payload, "nbytes") else len(payload) tensor_header[name] = { "dtype": dtype, "shape": shape, "data_offsets": [offset, offset + payload_size], } payloads.append(payload) offset += payload_size if metadata: tensor_header["__metadata__"] = metadata header_bytes = json.dumps(tensor_header, separators=(",", ":")).encode("utf-8") output_path = Path(path) output_path.parent.mkdir(parents=True, exist_ok=True) temporary_path = output_path.with_name(f"{output_path.name}.tmp") with temporary_path.open("wb") as handle: handle.write(struct.pack(" SafeTensorFile: tensor_path = Path(path) if arrays and np is not None: with tensor_path.open("rb") as handle: length_bytes = handle.read(8) if len(length_bytes) < 8: raise ValueError("Invalid safetensors file: missing header length.") header_length = struct.unpack(" dict[str, Any]: header = _read_safetensor_header(path) metadata = {str(key): str(value) for key, value in header.get("__metadata__", {}).items()} tensor_names = sorted(name for name in header if name != "__metadata__") config = json.loads(metadata["config"]) if "config" in metadata else {} effective_parameter_target = int(config.get("effective_parameter_target", 0)) if config else 0 return { "format": "safetensors", "path": str(Path(path).resolve()), "checkpoint_kind": metadata.get("checkpoint_kind", "unknown"), "schema_version": metadata.get("schema_version", "0"), "tokenizer_name": metadata.get("tokenizer_name", ""), "default_reasoning_profile": str(config.get("default_reasoning_profile", "none")) if config else "none", "lowercase": bool(config.get("lowercase", False)) if config else False, "tensor_count": len(tensor_names), "tensor_names": tensor_names, "tensor_dtypes": { name: str(header[name]["dtype"]) for name in tensor_names }, "tensor_shapes": { name: [int(axis) for axis in header[name]["shape"]] for name in tensor_names }, "tokenizer_vocab_size": int(metadata.get("tokenizer_vocab_size", "0")), "embedding_dim": int(config.get("embedding_dim", 0)) if config else 0, "state_dim": int(config.get("state_dim", 0)) if config else 0, "layout_profile": str(config.get("layout_profile", "rfm-base")) if config else "rfm-base", "effective_parameter_target": effective_parameter_target, "model_size": _format_model_size(effective_parameter_target), "model_size_kind": "structured_effective" if effective_parameter_target > 0 else "stored_tensor", "answer_fingerprint_count": ( int(header["answer_fingerprint_hashes"]["shape"][0]) if "answer_fingerprint_hashes" in header and header["answer_fingerprint_hashes"].get("shape") else 0 ), } def _format_model_size(parameter_count: int) -> str: if parameter_count <= 0: return "unknown" if parameter_count % 1_000_000_000 == 0: return f"{parameter_count // 1_000_000_000}B" if parameter_count >= 1_000_000_000: return f"{parameter_count / 1_000_000_000:.1f}B" if parameter_count % 1_000_000 == 0: return f"{parameter_count // 1_000_000}M" if parameter_count >= 1_000_000: return f"{parameter_count / 1_000_000:.1f}M" return str(parameter_count)