AsadIsmail's picture
Bundle ternary_quant package directly (private repo fix)
162f86a verified
"""
Storage utilities for ternary and role-aware small-model quantization.
The legacy format stores per-row alpha/mu.
The newer role-aware format stores grouped alpha/mu plus an optional sparse
FP16 residual for the most sensitive weights and an optional low-rank residual.
"""
from __future__ import annotations
import json
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import Optional
import numpy as np
import torch
from ternary_quant.quantizer import TernaryParameter
from ternary_quant.ptq_families import TritPlaneParameter
from ternary_quant.quantizer_small import (
GroupwiseTernaryParameter,
RoleAwareQuantizationPlan,
SmallModelQuantizationConfig,
config_to_dict,
plan_to_dict,
)
def pack_ternary(codes: torch.Tensor) -> torch.Tensor:
"""Pack int8 ternary codes {-1, 0, +1} into uint8 with 4 values per byte."""
flat = codes.flatten().to(torch.int8)
encoded = (flat + 1).to(torch.uint8)
pad_len = (4 - flat.shape[0] % 4) % 4
if pad_len > 0:
encoded = torch.cat([encoded, torch.ones(pad_len, dtype=torch.uint8)])
encoded = encoded.reshape(-1, 4)
packed = (
encoded[:, 0]
| (encoded[:, 1] << 2)
| (encoded[:, 2] << 4)
| (encoded[:, 3] << 6)
)
return packed
def unpack_ternary(packed: torch.Tensor, num_elements: int) -> torch.Tensor:
"""Unpack uint8 packed ternary back to int8 {-1, 0, +1}."""
v0 = packed & 0x03
v1 = (packed >> 2) & 0x03
v2 = (packed >> 4) & 0x03
v3 = (packed >> 6) & 0x03
unpacked = torch.stack([v0, v1, v2, v3], dim=1).flatten()[:num_elements]
return unpacked.to(torch.int8) - 1
def save_quantized_model(
ternary_params: dict,
model_name: str,
model_config,
quant_config,
output_dir: str,
stats: Optional[dict] = None,
summary: Optional[dict] = None,
plan: Optional[RoleAwareQuantizationPlan] = None,
method_name: Optional[str] = None,
model_family: Optional[str] = None,
):
"""Save either legacy ternary params or role-aware grouped params."""
if not ternary_params:
raise ValueError("No quantized parameters were provided.")
first_param = next(iter(ternary_params.values()))
if isinstance(first_param, GroupwiseTernaryParameter):
_save_groupwise_quantized_model(
quantized_params=ternary_params,
model_name=model_name,
model_config=model_config,
quant_config=quant_config,
output_dir=output_dir,
stats=stats,
summary=summary,
plan=plan,
method_name=method_name or "RAST-small",
model_family=model_family,
)
return
if isinstance(first_param, TritPlaneParameter):
_save_tritplane_quantized_model(
quantized_params=ternary_params,
model_name=model_name,
model_config=model_config,
quant_config=quant_config,
output_dir=output_dir,
stats=stats,
summary=summary,
plan=plan,
method_name=method_name or "TritPlane",
model_family=model_family,
)
return
if not isinstance(first_param, TernaryParameter):
raise TypeError(f"Unsupported quantized parameter type: {type(first_param)!r}")
_save_legacy_quantized_model(
ternary_params=ternary_params,
model_name=model_name,
model_config=model_config,
quant_config=quant_config,
output_dir=output_dir,
stats=stats,
summary=summary,
method_name=method_name or "legacy-ternary",
model_family=model_family,
)
def load_quantized_params(model_dir: str) -> tuple[dict, dict]:
"""Load quantized params from either the legacy or role-aware format."""
model_dir = Path(model_dir)
with open(model_dir / "metadata.json") as f:
metadata = json.load(f)
format_family = metadata.get("format_family", "legacy")
if format_family == "groupwise_small":
return _load_groupwise_quantized_params(model_dir, metadata), metadata
if format_family == "tritplane_small":
return _load_tritplane_quantized_params(model_dir, metadata), metadata
return _load_legacy_quantized_params(model_dir, metadata), metadata
def _save_legacy_quantized_model(
ternary_params: dict[str, TernaryParameter],
model_name: str,
model_config,
quant_config,
output_dir: str,
stats: Optional[dict],
summary: Optional[dict],
method_name: str,
model_family: Optional[str],
) -> None:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
(output_dir / "codes").mkdir(exist_ok=True)
(output_dir / "scales").mkdir(exist_ok=True)
layer_info = {}
total_packed_bytes = 0
total_fp16_bytes = 0
for name, tp in ternary_params.items():
safe_name = name.replace(".", "_")
packed = pack_ternary(tp.ternary_codes)
codes_path = output_dir / "codes" / f"{safe_name}.bin"
packed.numpy().tofile(str(codes_path))
scales_path = output_dir / "scales" / f"{safe_name}.bin"
scales_data = torch.cat([tp.alpha.flatten(), tp.mu.flatten()])
scales_data.numpy().tofile(str(scales_path))
num_elements = tp.original_shape[0] * tp.original_shape[1]
packed_bytes = packed.numel()
total_packed_bytes += packed_bytes
total_fp16_bytes += num_elements * 2
layer_info[name] = {
"scheme": "legacy_rowwise",
"shape": list(tp.original_shape),
"dtype": str(tp.original_dtype),
"num_elements": num_elements,
"packed_bytes": packed_bytes,
"bits_per_param": tp.bits_per_param,
}
metadata = {
"model_name": model_name,
"model_config": model_config.to_dict()
if hasattr(model_config, "to_dict")
else {},
"quant_config": _serialize_quant_config(quant_config),
"layer_info": layer_info,
"stats": stats or {},
"summary": summary or {},
"method_name": method_name,
"model_family": model_family or "causal_lm",
"format_family": "legacy",
"format_version": "1.0",
"total_packed_bytes": total_packed_bytes,
"total_fp16_bytes": total_fp16_bytes,
"compression_ratio": total_fp16_bytes / total_packed_bytes
if total_packed_bytes > 0
else 0.0,
}
with open(output_dir / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2, default=_json_default)
print(f"Saved ternary model to {output_dir}")
print(f" Packed size: {total_packed_bytes / 1e6:.1f} MB")
print(f" FP16 size: {total_fp16_bytes / 1e6:.1f} MB")
print(f" Compression: {metadata['compression_ratio']:.1f}x")
def _save_groupwise_quantized_model(
quantized_params: dict[str, GroupwiseTernaryParameter],
model_name: str,
model_config,
quant_config,
output_dir: str,
stats: Optional[dict],
summary: Optional[dict],
plan: Optional[RoleAwareQuantizationPlan],
method_name: str,
model_family: Optional[str],
) -> None:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
(output_dir / "codes").mkdir(exist_ok=True)
(output_dir / "group_alpha").mkdir(exist_ok=True)
(output_dir / "group_mu").mkdir(exist_ok=True)
(output_dir / "sparse_indices").mkdir(exist_ok=True)
(output_dir / "sparse_residual").mkdir(exist_ok=True)
(output_dir / "low_rank_u").mkdir(exist_ok=True)
(output_dir / "low_rank_v").mkdir(exist_ok=True)
layer_info = {}
total_stored_bytes = 0
total_fp16_bytes = 0
for name, param in quantized_params.items():
safe_name = name.replace(".", "_")
packed = pack_ternary(param.ternary_codes)
packed.numpy().tofile(str(output_dir / "codes" / f"{safe_name}.bin"))
param.group_alpha.numpy().tofile(
str(output_dir / "group_alpha" / f"{safe_name}.bin")
)
param.group_mu.numpy().tofile(str(output_dir / "group_mu" / f"{safe_name}.bin"))
sparse_nnz = 0
if param.sparse_indices is not None and param.sparse_residual is not None:
param.sparse_indices.numpy().tofile(
str(output_dir / "sparse_indices" / f"{safe_name}.bin")
)
param.sparse_residual.numpy().tofile(
str(output_dir / "sparse_residual" / f"{safe_name}.bin")
)
sparse_nnz = param.sparse_indices.numel()
lr_rank = 0
if param.lr_U is not None and param.lr_V is not None:
param.lr_U.numpy().tofile(str(output_dir / "low_rank_u" / f"{safe_name}.bin"))
param.lr_V.numpy().tofile(str(output_dir / "low_rank_v" / f"{safe_name}.bin"))
lr_rank = int(param.lr_U.shape[1])
num_elements = param.num_params
stored_bytes = (
packed.numel()
+ param.group_alpha.numel() * 2
+ param.group_mu.numel() * 2
+ sparse_nnz * 4
+ sparse_nnz * 2
+ (0 if param.lr_U is None else param.lr_U.numel() * 2)
+ (0 if param.lr_V is None else param.lr_V.numel() * 2)
)
total_stored_bytes += stored_bytes
total_fp16_bytes += num_elements * 2
layer_info[name] = {
"scheme": "groupwise_small_v1",
"shape": list(param.original_shape),
"dtype": str(param.original_dtype),
"num_elements": num_elements,
"group_size": param.group_size,
"n_groups": int(param.group_alpha.shape[1]),
"packed_bytes": int(packed.numel()),
"stored_bytes": int(stored_bytes),
"effective_bits": param.effective_bits,
"sparse_nnz": int(sparse_nnz),
"lr_rank": lr_rank,
}
metadata = {
"model_name": model_name,
"model_config": model_config.to_dict()
if hasattr(model_config, "to_dict")
else {},
"quant_config": _serialize_quant_config(quant_config),
"plan": plan_to_dict(plan) if (plan is not None and not isinstance(plan, str)) else {},
"layer_info": layer_info,
"stats": stats or {},
"summary": summary or {},
"method_name": method_name,
"model_family": model_family or "causal_lm",
"format_family": "groupwise_small",
"format_version": "2.0",
"total_packed_bytes": total_stored_bytes,
"total_fp16_bytes": total_fp16_bytes,
"compression_ratio": total_fp16_bytes / total_stored_bytes
if total_stored_bytes > 0
else 0.0,
}
with open(output_dir / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2, default=_json_default)
print(f"Saved role-aware ternary model to {output_dir}")
print(f" Stored size: {total_stored_bytes / 1e6:.1f} MB")
print(f" FP16 size: {total_fp16_bytes / 1e6:.1f} MB")
print(f" Compression: {metadata['compression_ratio']:.1f}x")
def _save_tritplane_quantized_model(
quantized_params: dict[str, TritPlaneParameter],
model_name: str,
model_config,
quant_config,
output_dir: str,
stats: Optional[dict],
summary: Optional[dict],
plan: Optional[RoleAwareQuantizationPlan],
method_name: str,
model_family: Optional[str],
) -> None:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
(output_dir / "tritplane").mkdir(exist_ok=True)
layer_info = {}
total_stored_bytes = 0
total_fp16_bytes = 0
for name, param in quantized_params.items():
safe_name = name.replace(".", "_")
archive_path = output_dir / "tritplane" / f"{safe_name}.npz"
arrays = {"n_planes": np.array([param.n_planes], dtype=np.int32)}
stored_bytes = 0
group_sizes = []
for idx, plane in enumerate(param.planes):
packed = pack_ternary(plane.ternary_codes)
arrays[f"packed_{idx}"] = packed.cpu().numpy()
arrays[f"group_alpha_{idx}"] = plane.group_alpha.cpu().numpy()
arrays[f"group_mu_{idx}"] = plane.group_mu.cpu().numpy()
arrays[f"group_size_{idx}"] = np.array([plane.group_size], dtype=np.int32)
group_sizes.append(int(plane.group_size))
stored_bytes += int(packed.numel())
stored_bytes += int(plane.group_alpha.numel() * 2)
stored_bytes += int(plane.group_mu.numel() * 2)
rescued_rows = 0
if param.rescued_row_indices is not None and param.rescued_row_values is not None:
arrays["rescued_row_indices"] = param.rescued_row_indices.cpu().numpy()
arrays["rescued_row_values"] = param.rescued_row_values.cpu().numpy()
rescued_rows = int(param.rescued_row_indices.numel())
stored_bytes += int(param.rescued_row_indices.numel() * 4)
stored_bytes += int(param.rescued_row_values.numel() * 2)
np.savez(archive_path, **arrays)
num_elements = param.num_params
total_stored_bytes += stored_bytes
total_fp16_bytes += num_elements * 2
layer_info[name] = {
"scheme": "tritplane_small_v1",
"shape": list(param.original_shape),
"dtype": str(param.original_dtype),
"num_elements": num_elements,
"n_planes": int(param.n_planes),
"group_sizes": group_sizes,
"rescued_rows": rescued_rows,
"stored_bytes": int(stored_bytes),
"effective_bits": float(param.effective_bits),
}
metadata = {
"model_name": model_name,
"model_config": model_config.to_dict()
if hasattr(model_config, "to_dict")
else {},
"quant_config": _serialize_quant_config(quant_config),
"plan": plan_to_dict(plan) if (plan is not None and not isinstance(plan, str)) else {},
"layer_info": layer_info,
"stats": stats or {},
"summary": summary or {},
"method_name": method_name,
"model_family": model_family or "causal_lm",
"format_family": "tritplane_small",
"format_version": "1.0",
"total_packed_bytes": total_stored_bytes,
"total_fp16_bytes": total_fp16_bytes,
"compression_ratio": total_fp16_bytes / total_stored_bytes
if total_stored_bytes > 0
else 0.0,
}
with open(output_dir / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2, default=_json_default)
print(f"Saved trit-plane ternary model to {output_dir}")
print(f" Stored size: {total_stored_bytes / 1e6:.1f} MB")
print(f" FP16 size: {total_fp16_bytes / 1e6:.1f} MB")
print(f" Compression: {metadata['compression_ratio']:.1f}x")
def _load_legacy_quantized_params(
model_dir: Path,
metadata: dict,
) -> dict[str, TernaryParameter]:
ternary_params = {}
for name, info in metadata["layer_info"].items():
safe_name = name.replace(".", "_")
shape = tuple(info["shape"])
num_elements = info["num_elements"]
dtype = _str_to_dtype(info["dtype"])
packed = torch.from_numpy(
np.fromfile(str(model_dir / "codes" / f"{safe_name}.bin"), dtype=np.uint8)
)
codes = unpack_ternary(packed, num_elements).reshape(shape)
scales_data = torch.from_numpy(
np.fromfile(str(model_dir / "scales" / f"{safe_name}.bin"), dtype=np.float16)
)
out_features = shape[0]
alpha = scales_data[:out_features].reshape(out_features, 1)
mu = scales_data[out_features:].reshape(out_features, 1)
ternary_params[name] = TernaryParameter(
ternary_codes=codes,
alpha=alpha,
mu=mu,
original_shape=shape,
original_dtype=dtype,
)
return ternary_params
def _load_groupwise_quantized_params(
model_dir: Path,
metadata: dict,
) -> dict[str, GroupwiseTernaryParameter]:
quantized_params = {}
for name, info in metadata["layer_info"].items():
safe_name = name.replace(".", "_")
shape = tuple(info["shape"])
dtype = _str_to_dtype(info["dtype"])
num_elements = info["num_elements"]
group_size = int(info["group_size"])
packed = torch.from_numpy(
np.fromfile(str(model_dir / "codes" / f"{safe_name}.bin"), dtype=np.uint8)
)
codes = unpack_ternary(packed, num_elements).reshape(shape)
group_alpha = torch.from_numpy(
np.fromfile(
str(model_dir / "group_alpha" / f"{safe_name}.bin"),
dtype=np.float16,
)
).reshape(shape[0], int(info["n_groups"]))
group_mu = torch.from_numpy(
np.fromfile(
str(model_dir / "group_mu" / f"{safe_name}.bin"),
dtype=np.float16,
)
).reshape(shape[0], int(info["n_groups"]))
sparse_indices = None
sparse_residual = None
lr_U = None
lr_V = None
if int(info.get("sparse_nnz", 0)) > 0:
sparse_indices = torch.from_numpy(
np.fromfile(
str(model_dir / "sparse_indices" / f"{safe_name}.bin"),
dtype=np.int32,
)
)
sparse_residual = torch.from_numpy(
np.fromfile(
str(model_dir / "sparse_residual" / f"{safe_name}.bin"),
dtype=np.float16,
)
)
if int(info.get("lr_rank", 0)) > 0:
lr_rank = int(info["lr_rank"])
lr_U = torch.from_numpy(
np.fromfile(
str(model_dir / "low_rank_u" / f"{safe_name}.bin"),
dtype=np.float16,
)
).reshape(shape[0], lr_rank)
lr_V = torch.from_numpy(
np.fromfile(
str(model_dir / "low_rank_v" / f"{safe_name}.bin"),
dtype=np.float16,
)
).reshape(lr_rank, shape[1])
quantized_params[name] = GroupwiseTernaryParameter(
ternary_codes=codes,
group_alpha=group_alpha,
group_mu=group_mu,
group_size=group_size,
sparse_indices=sparse_indices,
sparse_residual=sparse_residual,
lr_U=lr_U,
lr_V=lr_V,
original_shape=shape,
original_dtype=dtype,
)
return quantized_params
def _load_tritplane_quantized_params(
model_dir: Path,
metadata: dict,
) -> dict[str, TritPlaneParameter]:
quantized_params = {}
for name, info in metadata["layer_info"].items():
safe_name = name.replace(".", "_")
shape = tuple(info["shape"])
dtype = _str_to_dtype(info["dtype"])
num_elements = int(info["num_elements"])
archive = np.load(model_dir / "tritplane" / f"{safe_name}.npz")
n_planes = int(info["n_planes"])
planes = []
for idx in range(n_planes):
packed = torch.from_numpy(archive[f"packed_{idx}"])
codes = unpack_ternary(packed, num_elements).reshape(shape)
group_alpha = torch.from_numpy(archive[f"group_alpha_{idx}"]).to(torch.float16)
group_mu = torch.from_numpy(archive[f"group_mu_{idx}"]).to(torch.float16)
group_size = int(archive[f"group_size_{idx}"][0])
planes.append(
GroupwiseTernaryParameter(
ternary_codes=codes,
group_alpha=group_alpha,
group_mu=group_mu,
group_size=group_size,
sparse_indices=None,
sparse_residual=None,
lr_U=None,
lr_V=None,
original_shape=shape,
original_dtype=dtype,
)
)
rescued_row_indices = None
rescued_row_values = None
if int(info.get("rescued_rows", 0)) > 0:
rescued_row_indices = torch.from_numpy(
archive["rescued_row_indices"]
).to(torch.int32)
rescued_row_values = torch.from_numpy(
archive["rescued_row_values"]
).to(torch.float16)
quantized_params[name] = TritPlaneParameter(
planes=planes,
original_shape=shape,
original_dtype=dtype,
rescued_row_indices=rescued_row_indices,
rescued_row_values=rescued_row_values,
)
return quantized_params
def _serialize_quant_config(quant_config) -> dict:
if hasattr(quant_config, "preset_name") and is_dataclass(quant_config):
return asdict(quant_config)
if isinstance(quant_config, SmallModelQuantizationConfig):
return config_to_dict(quant_config)
if is_dataclass(quant_config):
return asdict(quant_config)
if hasattr(quant_config, "to_dict"):
return quant_config.to_dict()
if isinstance(quant_config, dict):
return quant_config
return {"value": str(quant_config)}
def _json_default(obj):
if hasattr(obj, "item"):
return obj.item()
if hasattr(obj, "tolist"):
return obj.tolist()
if is_dataclass(obj):
return asdict(obj)
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
def _str_to_dtype(s: str) -> torch.dtype:
mapping = {
"torch.float16": torch.float16,
"torch.bfloat16": torch.bfloat16,
"torch.float32": torch.float32,
}
return mapping.get(s, torch.float16)