Reframr-RFM-v2-Base / reframr /checkpoint.py
OkeyMeta's picture
Add Reframr-RFM-v2-Base release files
52da7b7 verified
import json
import math
import site
import struct
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
_VENDOR_ROOT = Path(__file__).resolve().parent.parent / ".vendor"
for _vendor_path in (_VENDOR_ROOT / "python", _VENDOR_ROOT / "sitepkgs"):
if _vendor_path.exists():
vendor_text = str(_vendor_path)
if vendor_text not in sys.path:
sys.path.insert(0, vendor_text)
try:
import numpy as np
except ModuleNotFoundError:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.append(user_site)
try:
import numpy as np
except ModuleNotFoundError:
np = None
if np is not None and not hasattr(np, "asarray"):
np = None
DTYPE_CODES = {
"F32": ("f", 4),
"F64": ("d", 8),
"I32": ("i", 4),
}
@dataclass(slots=True)
class SafeTensorFile:
tensors: dict[str, Any]
metadata: dict[str, str]
def _read_safetensor_header(path: str | Path) -> dict[str, Any]:
with Path(path).open("rb") as handle:
length_bytes = handle.read(8)
if len(length_bytes) < 8:
raise ValueError("Invalid safetensors file: missing header length.")
header_length = struct.unpack("<Q", length_bytes)[0]
header_bytes = handle.read(header_length)
if len(header_bytes) != header_length:
raise ValueError("Invalid safetensors file: truncated header.")
return json.loads(header_bytes.decode("utf-8"))
def _shape_of(value: Any) -> list[int]:
if np is not None and hasattr(value, "shape"):
return [int(axis) for axis in value.shape]
if not isinstance(value, list):
return []
if not value:
return [0]
first_shape = _shape_of(value[0])
for item in value[1:]:
if _shape_of(item) != first_shape:
raise ValueError("Safetensor writer does not support ragged tensors.")
return [len(value)] + first_shape
def _flatten(value: Any) -> list[Any]:
if np is not None and hasattr(value, "reshape"):
return value.reshape(-1).tolist()
if isinstance(value, list):
flattened: list[Any] = []
for item in value:
flattened.extend(_flatten(item))
return flattened
return [value]
def _dtype_of(flat_values: list[Any]) -> str:
if all(isinstance(value, int) and not isinstance(value, bool) for value in flat_values):
return "I32"
return "F64"
def _pack_tensor(dtype: str, values: list[Any]) -> bytes:
if not values:
return b""
code, _ = DTYPE_CODES[dtype]
cast_values = [int(value) for value in values] if dtype == "I32" else [float(value) for value in values]
return struct.pack(f"<{len(cast_values)}{code}", *cast_values)
def _array_payload(value: Any) -> tuple[str, list[int], Any] | None:
if np is None:
return None
try:
array = np.asarray(value)
except (TypeError, ValueError):
return None
if array.dtype == object:
return None
shape = [int(axis) for axis in array.shape]
if np.issubdtype(array.dtype, np.integer) and not np.issubdtype(array.dtype, np.bool_):
return "I32", shape, np.ascontiguousarray(array.astype("<i4", copy=False))
if np.issubdtype(array.dtype, np.floating):
if array.dtype == np.float32:
return "F32", shape, np.ascontiguousarray(array.astype("<f4", copy=False))
return "F64", shape, np.ascontiguousarray(array.astype("<f8", copy=False))
return "F64", shape, np.ascontiguousarray(array.astype("<f8", copy=False))
def _reshape(values: list[Any], shape: list[int]) -> Any:
if not shape:
return values[0]
if len(shape) == 1:
return values[: shape[0]]
chunk = math.prod(shape[1:])
return [
_reshape(values[index * chunk : (index + 1) * chunk], shape[1:])
for index in range(shape[0])
]
def write_safetensor_file(
path: str | Path,
tensors: dict[str, Any],
*,
metadata: dict[str, str] | None = None,
) -> None:
tensor_header: dict[str, Any] = {}
payloads: list[Any] = []
offset = 0
for name, value in tensors.items():
array_payload = _array_payload(value)
if array_payload is None:
flat_values = _flatten(value)
dtype = _dtype_of(flat_values)
shape = _shape_of(value)
payload = _pack_tensor(dtype, flat_values)
else:
dtype, shape, payload = array_payload
payload_size = int(payload.nbytes) if hasattr(payload, "nbytes") else len(payload)
tensor_header[name] = {
"dtype": dtype,
"shape": shape,
"data_offsets": [offset, offset + payload_size],
}
payloads.append(payload)
offset += payload_size
if metadata:
tensor_header["__metadata__"] = metadata
header_bytes = json.dumps(tensor_header, separators=(",", ":")).encode("utf-8")
output_path = Path(path)
output_path.parent.mkdir(parents=True, exist_ok=True)
temporary_path = output_path.with_name(f"{output_path.name}.tmp")
with temporary_path.open("wb") as handle:
handle.write(struct.pack("<Q", len(header_bytes)))
handle.write(header_bytes)
for payload in payloads:
if hasattr(payload, "nbytes"):
if payload.nbytes:
handle.write(memoryview(payload).cast("B"))
else:
handle.write(payload)
handle.flush()
temporary_path.replace(output_path)
def read_safetensor_file(path: str | Path, *, arrays: bool = False) -> SafeTensorFile:
tensor_path = Path(path)
if arrays and np is not None:
with tensor_path.open("rb") as handle:
length_bytes = handle.read(8)
if len(length_bytes) < 8:
raise ValueError("Invalid safetensors file: missing header length.")
header_length = struct.unpack("<Q", length_bytes)[0]
header_bytes = handle.read(header_length)
if len(header_bytes) != header_length:
raise ValueError("Invalid safetensors file: truncated header.")
header = json.loads(header_bytes.decode("utf-8"))
data_start = 8 + header_length
metadata = {str(key): str(value) for key, value in header.get("__metadata__", {}).items()}
tensors: dict[str, Any] = {}
for name, spec in header.items():
if name == "__metadata__":
continue
start, end = spec["data_offsets"]
dtype = str(spec["dtype"])
shape = [int(value) for value in spec["shape"]]
_, width = DTYPE_CODES[dtype]
payload_width = end - start
element_count = payload_width // width if width else 0
if payload_width <= 0:
tensors[name] = np.asarray([], dtype={"I32": "<i4", "F32": "<f4", "F64": "<f8"}[dtype])
continue
array_dtype = {"I32": "<i4", "F32": "<f4", "F64": "<f8"}[dtype]
mapped_shape = tuple(shape) if shape else (element_count,)
try:
mapped = np.memmap(
tensor_path,
dtype=array_dtype,
mode="r",
offset=data_start + start,
shape=mapped_shape,
order="C",
)
tensors[name] = mapped if shape else mapped[0]
except OSError:
with tensor_path.open("rb") as handle:
handle.seek(data_start + start)
values = np.fromfile(handle, dtype=array_dtype, count=element_count)
if values.size != element_count:
raise ValueError(
f"Invalid safetensors file: tensor {name!r} payload is truncated."
)
copied = values.reshape(shape).copy() if shape else values.copy()
tensors[name] = copied if shape else copied[0]
return SafeTensorFile(tensors=tensors, metadata=metadata)
raw = tensor_path.read_bytes()
if len(raw) < 8:
raise ValueError("Invalid safetensors file: missing header length.")
header_length = struct.unpack("<Q", raw[:8])[0]
header = json.loads(raw[8 : 8 + header_length].decode("utf-8"))
data_buffer = raw[8 + header_length :]
metadata = {str(key): str(value) for key, value in header.get("__metadata__", {}).items()}
tensors: dict[str, Any] = {}
for name, spec in header.items():
if name == "__metadata__":
continue
start, end = spec["data_offsets"]
dtype = str(spec["dtype"])
shape = [int(value) for value in spec["shape"]]
code, width = DTYPE_CODES[dtype]
payload = data_buffer[start:end]
element_count = len(payload) // width if width else 0
if np is not None and payload:
array_dtype = {"I32": "<i4", "F32": "<f4", "F64": "<f8"}[dtype]
values = np.frombuffer(payload, dtype=array_dtype, count=element_count)
reshaped = values.reshape(shape) if shape else values
if arrays:
tensors[name] = reshaped.copy() if shape else values.copy()[0]
else:
tensors[name] = reshaped.tolist() if shape else values.tolist()[0]
else:
values = list(struct.unpack(f"<{element_count}{code}", payload)) if payload else []
tensors[name] = _reshape(values, shape)
return SafeTensorFile(tensors=tensors, metadata=metadata)
def inspect_checkpoint(path: str | Path) -> dict[str, Any]:
header = _read_safetensor_header(path)
metadata = {str(key): str(value) for key, value in header.get("__metadata__", {}).items()}
tensor_names = sorted(name for name in header if name != "__metadata__")
config = json.loads(metadata["config"]) if "config" in metadata else {}
effective_parameter_target = int(config.get("effective_parameter_target", 0)) if config else 0
return {
"format": "safetensors",
"path": str(Path(path).resolve()),
"checkpoint_kind": metadata.get("checkpoint_kind", "unknown"),
"schema_version": metadata.get("schema_version", "0"),
"tokenizer_name": metadata.get("tokenizer_name", ""),
"default_reasoning_profile": str(config.get("default_reasoning_profile", "none")) if config else "none",
"lowercase": bool(config.get("lowercase", False)) if config else False,
"tensor_count": len(tensor_names),
"tensor_names": tensor_names,
"tensor_dtypes": {
name: str(header[name]["dtype"])
for name in tensor_names
},
"tensor_shapes": {
name: [int(axis) for axis in header[name]["shape"]]
for name in tensor_names
},
"tokenizer_vocab_size": int(metadata.get("tokenizer_vocab_size", "0")),
"embedding_dim": int(config.get("embedding_dim", 0)) if config else 0,
"state_dim": int(config.get("state_dim", 0)) if config else 0,
"layout_profile": str(config.get("layout_profile", "rfm-base")) if config else "rfm-base",
"effective_parameter_target": effective_parameter_target,
"model_size": _format_model_size(effective_parameter_target),
"model_size_kind": "structured_effective" if effective_parameter_target > 0 else "stored_tensor",
"answer_fingerprint_count": (
int(header["answer_fingerprint_hashes"]["shape"][0])
if "answer_fingerprint_hashes" in header
and header["answer_fingerprint_hashes"].get("shape")
else 0
),
}
def _format_model_size(parameter_count: int) -> str:
if parameter_count <= 0:
return "unknown"
if parameter_count % 1_000_000_000 == 0:
return f"{parameter_count // 1_000_000_000}B"
if parameter_count >= 1_000_000_000:
return f"{parameter_count / 1_000_000_000:.1f}B"
if parameter_count % 1_000_000 == 0:
return f"{parameter_count // 1_000_000}M"
if parameter_count >= 1_000_000:
return f"{parameter_count / 1_000_000:.1f}M"
return str(parameter_count)