Spaces:
Running
Running
| """ | |
| Storage utilities for ternary and role-aware small-model quantization. | |
| The legacy format stores per-row alpha/mu. | |
| The newer role-aware format stores grouped alpha/mu plus an optional sparse | |
| FP16 residual for the most sensitive weights and an optional low-rank residual. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import asdict, is_dataclass | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| from ternary_quant.quantizer import TernaryParameter | |
| from ternary_quant.ptq_families import TritPlaneParameter | |
| from ternary_quant.quantizer_small import ( | |
| GroupwiseTernaryParameter, | |
| RoleAwareQuantizationPlan, | |
| SmallModelQuantizationConfig, | |
| config_to_dict, | |
| plan_to_dict, | |
| ) | |
| def pack_ternary(codes: torch.Tensor) -> torch.Tensor: | |
| """Pack int8 ternary codes {-1, 0, +1} into uint8 with 4 values per byte.""" | |
| flat = codes.flatten().to(torch.int8) | |
| encoded = (flat + 1).to(torch.uint8) | |
| pad_len = (4 - flat.shape[0] % 4) % 4 | |
| if pad_len > 0: | |
| encoded = torch.cat([encoded, torch.ones(pad_len, dtype=torch.uint8)]) | |
| encoded = encoded.reshape(-1, 4) | |
| packed = ( | |
| encoded[:, 0] | |
| | (encoded[:, 1] << 2) | |
| | (encoded[:, 2] << 4) | |
| | (encoded[:, 3] << 6) | |
| ) | |
| return packed | |
| def unpack_ternary(packed: torch.Tensor, num_elements: int) -> torch.Tensor: | |
| """Unpack uint8 packed ternary back to int8 {-1, 0, +1}.""" | |
| v0 = packed & 0x03 | |
| v1 = (packed >> 2) & 0x03 | |
| v2 = (packed >> 4) & 0x03 | |
| v3 = (packed >> 6) & 0x03 | |
| unpacked = torch.stack([v0, v1, v2, v3], dim=1).flatten()[:num_elements] | |
| return unpacked.to(torch.int8) - 1 | |
| def save_quantized_model( | |
| ternary_params: dict, | |
| model_name: str, | |
| model_config, | |
| quant_config, | |
| output_dir: str, | |
| stats: Optional[dict] = None, | |
| summary: Optional[dict] = None, | |
| plan: Optional[RoleAwareQuantizationPlan] = None, | |
| method_name: Optional[str] = None, | |
| model_family: Optional[str] = None, | |
| ): | |
| """Save either legacy ternary params or role-aware grouped params.""" | |
| if not ternary_params: | |
| raise ValueError("No quantized parameters were provided.") | |
| first_param = next(iter(ternary_params.values())) | |
| if isinstance(first_param, GroupwiseTernaryParameter): | |
| _save_groupwise_quantized_model( | |
| quantized_params=ternary_params, | |
| model_name=model_name, | |
| model_config=model_config, | |
| quant_config=quant_config, | |
| output_dir=output_dir, | |
| stats=stats, | |
| summary=summary, | |
| plan=plan, | |
| method_name=method_name or "RAST-small", | |
| model_family=model_family, | |
| ) | |
| return | |
| if isinstance(first_param, TritPlaneParameter): | |
| _save_tritplane_quantized_model( | |
| quantized_params=ternary_params, | |
| model_name=model_name, | |
| model_config=model_config, | |
| quant_config=quant_config, | |
| output_dir=output_dir, | |
| stats=stats, | |
| summary=summary, | |
| plan=plan, | |
| method_name=method_name or "TritPlane", | |
| model_family=model_family, | |
| ) | |
| return | |
| if not isinstance(first_param, TernaryParameter): | |
| raise TypeError(f"Unsupported quantized parameter type: {type(first_param)!r}") | |
| _save_legacy_quantized_model( | |
| ternary_params=ternary_params, | |
| model_name=model_name, | |
| model_config=model_config, | |
| quant_config=quant_config, | |
| output_dir=output_dir, | |
| stats=stats, | |
| summary=summary, | |
| method_name=method_name or "legacy-ternary", | |
| model_family=model_family, | |
| ) | |
| def load_quantized_params(model_dir: str) -> tuple[dict, dict]: | |
| """Load quantized params from either the legacy or role-aware format.""" | |
| model_dir = Path(model_dir) | |
| with open(model_dir / "metadata.json") as f: | |
| metadata = json.load(f) | |
| format_family = metadata.get("format_family", "legacy") | |
| if format_family == "groupwise_small": | |
| return _load_groupwise_quantized_params(model_dir, metadata), metadata | |
| if format_family == "tritplane_small": | |
| return _load_tritplane_quantized_params(model_dir, metadata), metadata | |
| return _load_legacy_quantized_params(model_dir, metadata), metadata | |
| def _save_legacy_quantized_model( | |
| ternary_params: dict[str, TernaryParameter], | |
| model_name: str, | |
| model_config, | |
| quant_config, | |
| output_dir: str, | |
| stats: Optional[dict], | |
| summary: Optional[dict], | |
| method_name: str, | |
| model_family: Optional[str], | |
| ) -> None: | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| (output_dir / "codes").mkdir(exist_ok=True) | |
| (output_dir / "scales").mkdir(exist_ok=True) | |
| layer_info = {} | |
| total_packed_bytes = 0 | |
| total_fp16_bytes = 0 | |
| for name, tp in ternary_params.items(): | |
| safe_name = name.replace(".", "_") | |
| packed = pack_ternary(tp.ternary_codes) | |
| codes_path = output_dir / "codes" / f"{safe_name}.bin" | |
| packed.numpy().tofile(str(codes_path)) | |
| scales_path = output_dir / "scales" / f"{safe_name}.bin" | |
| scales_data = torch.cat([tp.alpha.flatten(), tp.mu.flatten()]) | |
| scales_data.numpy().tofile(str(scales_path)) | |
| num_elements = tp.original_shape[0] * tp.original_shape[1] | |
| packed_bytes = packed.numel() | |
| total_packed_bytes += packed_bytes | |
| total_fp16_bytes += num_elements * 2 | |
| layer_info[name] = { | |
| "scheme": "legacy_rowwise", | |
| "shape": list(tp.original_shape), | |
| "dtype": str(tp.original_dtype), | |
| "num_elements": num_elements, | |
| "packed_bytes": packed_bytes, | |
| "bits_per_param": tp.bits_per_param, | |
| } | |
| metadata = { | |
| "model_name": model_name, | |
| "model_config": model_config.to_dict() | |
| if hasattr(model_config, "to_dict") | |
| else {}, | |
| "quant_config": _serialize_quant_config(quant_config), | |
| "layer_info": layer_info, | |
| "stats": stats or {}, | |
| "summary": summary or {}, | |
| "method_name": method_name, | |
| "model_family": model_family or "causal_lm", | |
| "format_family": "legacy", | |
| "format_version": "1.0", | |
| "total_packed_bytes": total_packed_bytes, | |
| "total_fp16_bytes": total_fp16_bytes, | |
| "compression_ratio": total_fp16_bytes / total_packed_bytes | |
| if total_packed_bytes > 0 | |
| else 0.0, | |
| } | |
| with open(output_dir / "metadata.json", "w") as f: | |
| json.dump(metadata, f, indent=2, default=_json_default) | |
| print(f"Saved ternary model to {output_dir}") | |
| print(f" Packed size: {total_packed_bytes / 1e6:.1f} MB") | |
| print(f" FP16 size: {total_fp16_bytes / 1e6:.1f} MB") | |
| print(f" Compression: {metadata['compression_ratio']:.1f}x") | |
| def _save_groupwise_quantized_model( | |
| quantized_params: dict[str, GroupwiseTernaryParameter], | |
| model_name: str, | |
| model_config, | |
| quant_config, | |
| output_dir: str, | |
| stats: Optional[dict], | |
| summary: Optional[dict], | |
| plan: Optional[RoleAwareQuantizationPlan], | |
| method_name: str, | |
| model_family: Optional[str], | |
| ) -> None: | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| (output_dir / "codes").mkdir(exist_ok=True) | |
| (output_dir / "group_alpha").mkdir(exist_ok=True) | |
| (output_dir / "group_mu").mkdir(exist_ok=True) | |
| (output_dir / "sparse_indices").mkdir(exist_ok=True) | |
| (output_dir / "sparse_residual").mkdir(exist_ok=True) | |
| (output_dir / "low_rank_u").mkdir(exist_ok=True) | |
| (output_dir / "low_rank_v").mkdir(exist_ok=True) | |
| layer_info = {} | |
| total_stored_bytes = 0 | |
| total_fp16_bytes = 0 | |
| for name, param in quantized_params.items(): | |
| safe_name = name.replace(".", "_") | |
| packed = pack_ternary(param.ternary_codes) | |
| packed.numpy().tofile(str(output_dir / "codes" / f"{safe_name}.bin")) | |
| param.group_alpha.numpy().tofile( | |
| str(output_dir / "group_alpha" / f"{safe_name}.bin") | |
| ) | |
| param.group_mu.numpy().tofile(str(output_dir / "group_mu" / f"{safe_name}.bin")) | |
| sparse_nnz = 0 | |
| if param.sparse_indices is not None and param.sparse_residual is not None: | |
| param.sparse_indices.numpy().tofile( | |
| str(output_dir / "sparse_indices" / f"{safe_name}.bin") | |
| ) | |
| param.sparse_residual.numpy().tofile( | |
| str(output_dir / "sparse_residual" / f"{safe_name}.bin") | |
| ) | |
| sparse_nnz = param.sparse_indices.numel() | |
| lr_rank = 0 | |
| if param.lr_U is not None and param.lr_V is not None: | |
| param.lr_U.numpy().tofile(str(output_dir / "low_rank_u" / f"{safe_name}.bin")) | |
| param.lr_V.numpy().tofile(str(output_dir / "low_rank_v" / f"{safe_name}.bin")) | |
| lr_rank = int(param.lr_U.shape[1]) | |
| num_elements = param.num_params | |
| stored_bytes = ( | |
| packed.numel() | |
| + param.group_alpha.numel() * 2 | |
| + param.group_mu.numel() * 2 | |
| + sparse_nnz * 4 | |
| + sparse_nnz * 2 | |
| + (0 if param.lr_U is None else param.lr_U.numel() * 2) | |
| + (0 if param.lr_V is None else param.lr_V.numel() * 2) | |
| ) | |
| total_stored_bytes += stored_bytes | |
| total_fp16_bytes += num_elements * 2 | |
| layer_info[name] = { | |
| "scheme": "groupwise_small_v1", | |
| "shape": list(param.original_shape), | |
| "dtype": str(param.original_dtype), | |
| "num_elements": num_elements, | |
| "group_size": param.group_size, | |
| "n_groups": int(param.group_alpha.shape[1]), | |
| "packed_bytes": int(packed.numel()), | |
| "stored_bytes": int(stored_bytes), | |
| "effective_bits": param.effective_bits, | |
| "sparse_nnz": int(sparse_nnz), | |
| "lr_rank": lr_rank, | |
| } | |
| metadata = { | |
| "model_name": model_name, | |
| "model_config": model_config.to_dict() | |
| if hasattr(model_config, "to_dict") | |
| else {}, | |
| "quant_config": _serialize_quant_config(quant_config), | |
| "plan": plan_to_dict(plan) if (plan is not None and not isinstance(plan, str)) else {}, | |
| "layer_info": layer_info, | |
| "stats": stats or {}, | |
| "summary": summary or {}, | |
| "method_name": method_name, | |
| "model_family": model_family or "causal_lm", | |
| "format_family": "groupwise_small", | |
| "format_version": "2.0", | |
| "total_packed_bytes": total_stored_bytes, | |
| "total_fp16_bytes": total_fp16_bytes, | |
| "compression_ratio": total_fp16_bytes / total_stored_bytes | |
| if total_stored_bytes > 0 | |
| else 0.0, | |
| } | |
| with open(output_dir / "metadata.json", "w") as f: | |
| json.dump(metadata, f, indent=2, default=_json_default) | |
| print(f"Saved role-aware ternary model to {output_dir}") | |
| print(f" Stored size: {total_stored_bytes / 1e6:.1f} MB") | |
| print(f" FP16 size: {total_fp16_bytes / 1e6:.1f} MB") | |
| print(f" Compression: {metadata['compression_ratio']:.1f}x") | |
| def _save_tritplane_quantized_model( | |
| quantized_params: dict[str, TritPlaneParameter], | |
| model_name: str, | |
| model_config, | |
| quant_config, | |
| output_dir: str, | |
| stats: Optional[dict], | |
| summary: Optional[dict], | |
| plan: Optional[RoleAwareQuantizationPlan], | |
| method_name: str, | |
| model_family: Optional[str], | |
| ) -> None: | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| (output_dir / "tritplane").mkdir(exist_ok=True) | |
| layer_info = {} | |
| total_stored_bytes = 0 | |
| total_fp16_bytes = 0 | |
| for name, param in quantized_params.items(): | |
| safe_name = name.replace(".", "_") | |
| archive_path = output_dir / "tritplane" / f"{safe_name}.npz" | |
| arrays = {"n_planes": np.array([param.n_planes], dtype=np.int32)} | |
| stored_bytes = 0 | |
| group_sizes = [] | |
| for idx, plane in enumerate(param.planes): | |
| packed = pack_ternary(plane.ternary_codes) | |
| arrays[f"packed_{idx}"] = packed.cpu().numpy() | |
| arrays[f"group_alpha_{idx}"] = plane.group_alpha.cpu().numpy() | |
| arrays[f"group_mu_{idx}"] = plane.group_mu.cpu().numpy() | |
| arrays[f"group_size_{idx}"] = np.array([plane.group_size], dtype=np.int32) | |
| group_sizes.append(int(plane.group_size)) | |
| stored_bytes += int(packed.numel()) | |
| stored_bytes += int(plane.group_alpha.numel() * 2) | |
| stored_bytes += int(plane.group_mu.numel() * 2) | |
| rescued_rows = 0 | |
| if param.rescued_row_indices is not None and param.rescued_row_values is not None: | |
| arrays["rescued_row_indices"] = param.rescued_row_indices.cpu().numpy() | |
| arrays["rescued_row_values"] = param.rescued_row_values.cpu().numpy() | |
| rescued_rows = int(param.rescued_row_indices.numel()) | |
| stored_bytes += int(param.rescued_row_indices.numel() * 4) | |
| stored_bytes += int(param.rescued_row_values.numel() * 2) | |
| np.savez(archive_path, **arrays) | |
| num_elements = param.num_params | |
| total_stored_bytes += stored_bytes | |
| total_fp16_bytes += num_elements * 2 | |
| layer_info[name] = { | |
| "scheme": "tritplane_small_v1", | |
| "shape": list(param.original_shape), | |
| "dtype": str(param.original_dtype), | |
| "num_elements": num_elements, | |
| "n_planes": int(param.n_planes), | |
| "group_sizes": group_sizes, | |
| "rescued_rows": rescued_rows, | |
| "stored_bytes": int(stored_bytes), | |
| "effective_bits": float(param.effective_bits), | |
| } | |
| metadata = { | |
| "model_name": model_name, | |
| "model_config": model_config.to_dict() | |
| if hasattr(model_config, "to_dict") | |
| else {}, | |
| "quant_config": _serialize_quant_config(quant_config), | |
| "plan": plan_to_dict(plan) if (plan is not None and not isinstance(plan, str)) else {}, | |
| "layer_info": layer_info, | |
| "stats": stats or {}, | |
| "summary": summary or {}, | |
| "method_name": method_name, | |
| "model_family": model_family or "causal_lm", | |
| "format_family": "tritplane_small", | |
| "format_version": "1.0", | |
| "total_packed_bytes": total_stored_bytes, | |
| "total_fp16_bytes": total_fp16_bytes, | |
| "compression_ratio": total_fp16_bytes / total_stored_bytes | |
| if total_stored_bytes > 0 | |
| else 0.0, | |
| } | |
| with open(output_dir / "metadata.json", "w") as f: | |
| json.dump(metadata, f, indent=2, default=_json_default) | |
| print(f"Saved trit-plane ternary model to {output_dir}") | |
| print(f" Stored size: {total_stored_bytes / 1e6:.1f} MB") | |
| print(f" FP16 size: {total_fp16_bytes / 1e6:.1f} MB") | |
| print(f" Compression: {metadata['compression_ratio']:.1f}x") | |
| def _load_legacy_quantized_params( | |
| model_dir: Path, | |
| metadata: dict, | |
| ) -> dict[str, TernaryParameter]: | |
| ternary_params = {} | |
| for name, info in metadata["layer_info"].items(): | |
| safe_name = name.replace(".", "_") | |
| shape = tuple(info["shape"]) | |
| num_elements = info["num_elements"] | |
| dtype = _str_to_dtype(info["dtype"]) | |
| packed = torch.from_numpy( | |
| np.fromfile(str(model_dir / "codes" / f"{safe_name}.bin"), dtype=np.uint8) | |
| ) | |
| codes = unpack_ternary(packed, num_elements).reshape(shape) | |
| scales_data = torch.from_numpy( | |
| np.fromfile(str(model_dir / "scales" / f"{safe_name}.bin"), dtype=np.float16) | |
| ) | |
| out_features = shape[0] | |
| alpha = scales_data[:out_features].reshape(out_features, 1) | |
| mu = scales_data[out_features:].reshape(out_features, 1) | |
| ternary_params[name] = TernaryParameter( | |
| ternary_codes=codes, | |
| alpha=alpha, | |
| mu=mu, | |
| original_shape=shape, | |
| original_dtype=dtype, | |
| ) | |
| return ternary_params | |
| def _load_groupwise_quantized_params( | |
| model_dir: Path, | |
| metadata: dict, | |
| ) -> dict[str, GroupwiseTernaryParameter]: | |
| quantized_params = {} | |
| for name, info in metadata["layer_info"].items(): | |
| safe_name = name.replace(".", "_") | |
| shape = tuple(info["shape"]) | |
| dtype = _str_to_dtype(info["dtype"]) | |
| num_elements = info["num_elements"] | |
| group_size = int(info["group_size"]) | |
| packed = torch.from_numpy( | |
| np.fromfile(str(model_dir / "codes" / f"{safe_name}.bin"), dtype=np.uint8) | |
| ) | |
| codes = unpack_ternary(packed, num_elements).reshape(shape) | |
| group_alpha = torch.from_numpy( | |
| np.fromfile( | |
| str(model_dir / "group_alpha" / f"{safe_name}.bin"), | |
| dtype=np.float16, | |
| ) | |
| ).reshape(shape[0], int(info["n_groups"])) | |
| group_mu = torch.from_numpy( | |
| np.fromfile( | |
| str(model_dir / "group_mu" / f"{safe_name}.bin"), | |
| dtype=np.float16, | |
| ) | |
| ).reshape(shape[0], int(info["n_groups"])) | |
| sparse_indices = None | |
| sparse_residual = None | |
| lr_U = None | |
| lr_V = None | |
| if int(info.get("sparse_nnz", 0)) > 0: | |
| sparse_indices = torch.from_numpy( | |
| np.fromfile( | |
| str(model_dir / "sparse_indices" / f"{safe_name}.bin"), | |
| dtype=np.int32, | |
| ) | |
| ) | |
| sparse_residual = torch.from_numpy( | |
| np.fromfile( | |
| str(model_dir / "sparse_residual" / f"{safe_name}.bin"), | |
| dtype=np.float16, | |
| ) | |
| ) | |
| if int(info.get("lr_rank", 0)) > 0: | |
| lr_rank = int(info["lr_rank"]) | |
| lr_U = torch.from_numpy( | |
| np.fromfile( | |
| str(model_dir / "low_rank_u" / f"{safe_name}.bin"), | |
| dtype=np.float16, | |
| ) | |
| ).reshape(shape[0], lr_rank) | |
| lr_V = torch.from_numpy( | |
| np.fromfile( | |
| str(model_dir / "low_rank_v" / f"{safe_name}.bin"), | |
| dtype=np.float16, | |
| ) | |
| ).reshape(lr_rank, shape[1]) | |
| quantized_params[name] = GroupwiseTernaryParameter( | |
| ternary_codes=codes, | |
| group_alpha=group_alpha, | |
| group_mu=group_mu, | |
| group_size=group_size, | |
| sparse_indices=sparse_indices, | |
| sparse_residual=sparse_residual, | |
| lr_U=lr_U, | |
| lr_V=lr_V, | |
| original_shape=shape, | |
| original_dtype=dtype, | |
| ) | |
| return quantized_params | |
| def _load_tritplane_quantized_params( | |
| model_dir: Path, | |
| metadata: dict, | |
| ) -> dict[str, TritPlaneParameter]: | |
| quantized_params = {} | |
| for name, info in metadata["layer_info"].items(): | |
| safe_name = name.replace(".", "_") | |
| shape = tuple(info["shape"]) | |
| dtype = _str_to_dtype(info["dtype"]) | |
| num_elements = int(info["num_elements"]) | |
| archive = np.load(model_dir / "tritplane" / f"{safe_name}.npz") | |
| n_planes = int(info["n_planes"]) | |
| planes = [] | |
| for idx in range(n_planes): | |
| packed = torch.from_numpy(archive[f"packed_{idx}"]) | |
| codes = unpack_ternary(packed, num_elements).reshape(shape) | |
| group_alpha = torch.from_numpy(archive[f"group_alpha_{idx}"]).to(torch.float16) | |
| group_mu = torch.from_numpy(archive[f"group_mu_{idx}"]).to(torch.float16) | |
| group_size = int(archive[f"group_size_{idx}"][0]) | |
| planes.append( | |
| GroupwiseTernaryParameter( | |
| ternary_codes=codes, | |
| group_alpha=group_alpha, | |
| group_mu=group_mu, | |
| group_size=group_size, | |
| sparse_indices=None, | |
| sparse_residual=None, | |
| lr_U=None, | |
| lr_V=None, | |
| original_shape=shape, | |
| original_dtype=dtype, | |
| ) | |
| ) | |
| rescued_row_indices = None | |
| rescued_row_values = None | |
| if int(info.get("rescued_rows", 0)) > 0: | |
| rescued_row_indices = torch.from_numpy( | |
| archive["rescued_row_indices"] | |
| ).to(torch.int32) | |
| rescued_row_values = torch.from_numpy( | |
| archive["rescued_row_values"] | |
| ).to(torch.float16) | |
| quantized_params[name] = TritPlaneParameter( | |
| planes=planes, | |
| original_shape=shape, | |
| original_dtype=dtype, | |
| rescued_row_indices=rescued_row_indices, | |
| rescued_row_values=rescued_row_values, | |
| ) | |
| return quantized_params | |
| def _serialize_quant_config(quant_config) -> dict: | |
| if hasattr(quant_config, "preset_name") and is_dataclass(quant_config): | |
| return asdict(quant_config) | |
| if isinstance(quant_config, SmallModelQuantizationConfig): | |
| return config_to_dict(quant_config) | |
| if is_dataclass(quant_config): | |
| return asdict(quant_config) | |
| if hasattr(quant_config, "to_dict"): | |
| return quant_config.to_dict() | |
| if isinstance(quant_config, dict): | |
| return quant_config | |
| return {"value": str(quant_config)} | |
| def _json_default(obj): | |
| if hasattr(obj, "item"): | |
| return obj.item() | |
| if hasattr(obj, "tolist"): | |
| return obj.tolist() | |
| if is_dataclass(obj): | |
| return asdict(obj) | |
| raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") | |
| def _str_to_dtype(s: str) -> torch.dtype: | |
| mapping = { | |
| "torch.float16": torch.float16, | |
| "torch.bfloat16": torch.bfloat16, | |
| "torch.float32": torch.float32, | |
| } | |
| return mapping.get(s, torch.float16) | |