""" PackedConv1d — nibble-packed веса + батчевый параллельный unpack. Архитектура: PackedConv1d / PackedBatchNorm1d — хранят nibbles в RAM (~3 MB) BatchedUnpackEngine — один numba prange вызов распаковывает ВСЕ слои параллельно перед forward() build_packed_model() — строит модель + движок Три бэкенда (выбирается автоматически): numba-batched — ОДИН вызов numba prange, ~0.1 ms overhead numba-per-layer— 69 вызовов numba, ~16 ms overhead torch — 69 вызовов torch, ~34 ms overhead """ from __future__ import annotations import sys, time from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F sys.path.insert(0, str(Path(__file__).parent)) sys.stdout.reconfigure(encoding="utf-8") TEACHER_REPO = "yangwang825/ecapa-tdnn-vox2" TEACHER_DIR = Path("models/ecapa_vox2") PACKED_PATH = Path("models/ecapa_onnx/ecapa_qat_packed.pt") N_WARM = 10 N_BENCH = 80 # ── Numba JIT функции ───────────────────────────────────────────────────────── def _try_init_numba(): """Компилирует numba JIT функции. Возвращает (ok: bool, funcs: dict).""" try: import numba as nb # Батчевый unpack 4-bit: ВСЕ nibble-слои за ОДИН вызов с prange @nb.njit(cache=True, fastmath=True, parallel=True, nogil=True) def nb_unpack_4bit_batch(all_packed, centroids_mat, n_origs, pack_offsets, out_offsets, out_flat): """ all_packed : uint8 [total_packed_bytes] — все nibble данные подряд centroids_mat: float32 [n_layers, 16] — центроиды каждого слоя n_origs : int64 [n_layers] — кол-во весов на слой pack_offsets : int64 [n_layers] — смещения в all_packed out_offsets : int64 [n_layers] — смещения в out_flat out_flat : float32[total_weights] — pre-allocated выход """ for i in nb.prange(len(n_origs)): n = n_origs[i] po = pack_offsets[i] oo = out_offsets[i] c = centroids_mat[i] half = n >> 1 for j in range(half): b = all_packed[po + j] out_flat[oo + 2*j ] = c[b & np.uint8(0x0F)] out_flat[oo + 2*j+1] = c[b >> np.uint8(4)] if n & 1: out_flat[oo + n - 1] = c[all_packed[po + half] & np.uint8(0x0F)] # Батчевый unpack 8-bit @nb.njit(cache=True, fastmath=True, parallel=True, nogil=True) def nb_unpack_8bit_batch(all_packed, centroids_mat, n_origs, pack_offsets, out_offsets, out_flat): for i in nb.prange(len(n_origs)): n = n_origs[i] po = pack_offsets[i] oo = out_offsets[i] c = centroids_mat[i] for j in range(n): out_flat[oo + j] = c[all_packed[po + j]] # Per-layer unpack (для сравнения) @nb.njit(cache=True, fastmath=True, nogil=True) def nb_unpack_nibbles(packed, centroids, n_orig): out = np.empty(n_orig, dtype=np.float32) half = n_orig >> 1 for j in range(half): b = packed[j] out[2*j ] = centroids[b & np.uint8(0x0F)] out[2*j+1] = centroids[b >> np.uint8(4)] if n_orig & 1: out[n_orig-1] = centroids[packed[half] & np.uint8(0x0F)] return out @nb.njit(cache=True, fastmath=True, nogil=True) def nb_unpack_bytes(packed, centroids, n_orig): out = np.empty(n_orig, dtype=np.float32) for i in range(n_orig): out[i] = centroids[packed[i]] return out # Прогрев компилятора _dp = np.zeros(4, dtype=np.uint8) _dc16 = np.zeros((1, 16), dtype=np.float32) _dc256 = np.zeros((1, 256), dtype=np.float32) _dn = np.array([8], dtype=np.int64) _do = np.array([0], dtype=np.int64) _out = np.zeros(8, dtype=np.float32) nb_unpack_4bit_batch(_dp, _dc16, _dn, _do, _do, _out) nb_unpack_nibbles(_dp, _dc16[0], 8) nb_unpack_bytes(_dp[:4], _dc256[0], 4) return True, { "batch_4": nb_unpack_4bit_batch, "batch_8": nb_unpack_8bit_batch, "single_4": nb_unpack_nibbles, "single_8": nb_unpack_bytes, } except ImportError: return False, {} NUMBA_OK, _NB = _try_init_numba() # ── PackedConv1d ────────────────────────────────────────────────────────────── class PackedConv1d(nn.Module): """Conv1d с nibble-packed весами. forward() использует FP32 тензор, который заполняется BatchedUnpackEngine перед каждым encode_batch().""" def __init__(self, weight_meta: dict, bias: torch.Tensor | None, in_channels: int, out_channels: int, kernel_size, stride, padding, dilation, groups: int): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self._shape = tuple(weight_meta["shape"]) self._n_orig = int(weight_meta["n_orig"]) self._n_bits = int(weight_meta["n_bits"]) self.register_buffer("_w_packed", weight_meta["packed"]) self.register_buffer("_w_centroids", weight_meta["centroids"]) if bias is not None: self.register_buffer("_bias", bias) self._has_bias = True else: self._has_bias = False # Заполняется BatchedUnpackEngine; None → inline unpack (fallback) self._w_ready: torch.Tensor | None = None def _unpack_inline(self) -> torch.Tensor: """Inline fallback: используется без BatchedUnpackEngine.""" p = self._w_packed.numpy() c = self._w_centroids.numpy() if NUMBA_OK: if self._n_bits == 4: arr = _NB["single_4"](p, c, self._n_orig) else: arr = _NB["single_8"](p, c, self._n_orig) return torch.from_numpy(arr).reshape(self._shape) # torch fallback lo = (self._w_packed & 0x0F).long() hi = ((self._w_packed >> 4) & 0x0F).long() idx = torch.empty(lo.numel() + hi.numel(), dtype=torch.long) idx[0::2] = lo; idx[1::2] = hi return self._w_centroids[idx[:self._n_orig]].reshape(self._shape) def forward(self, x: torch.Tensor) -> torch.Tensor: w = self._w_ready if self._w_ready is not None else self._unpack_inline() b = self._bias if self._has_bias else None return F.conv1d(x, w, b, self.stride, self.padding, self.dilation, self.groups) @property def weight_bytes(self) -> int: return self._w_packed.nbytes + self._w_centroids.nbytes class PackedBatchNorm1d(nn.Module): """BatchNorm1d с nibble-packed γ и β.""" def __init__(self, weight_meta: dict | None, bias_meta: dict | None, running_mean: torch.Tensor, running_var: torch.Tensor, eps: float, num_features: int): super().__init__() self.eps = eps; self.num_features = num_features self.register_buffer("running_mean", running_mean) self.register_buffer("running_var", running_var) def _reg(meta, prefix): if meta is not None: self.register_buffer(f"_{prefix}_packed", meta["packed"]) self.register_buffer(f"_{prefix}_centroids", meta["centroids"]) setattr(self, f"_{prefix}_shape", tuple(meta["shape"])) setattr(self, f"_{prefix}_n_orig", int(meta["n_orig"])) setattr(self, f"_{prefix}_n_bits", int(meta["n_bits"])) setattr(self, f"_has_{prefix}", True) setattr(self, f"_{prefix}_ready", None) else: setattr(self, f"_has_{prefix}", False) _reg(weight_meta, "w") _reg(bias_meta, "b") def _unpack_one(self, prefix: str) -> torch.Tensor: p = getattr(self, f"_{prefix}_packed").numpy() c = getattr(self, f"_{prefix}_centroids").numpy() n = getattr(self, f"_{prefix}_n_orig") sh = getattr(self, f"_{prefix}_shape") nb_bits = getattr(self, f"_{prefix}_n_bits") if NUMBA_OK: fn = _NB["single_4"] if nb_bits == 4 else _NB["single_8"] return torch.from_numpy(fn(p, c, n)).reshape(sh) pk = getattr(self, f"_{prefix}_packed") ct = getattr(self, f"_{prefix}_centroids") lo = (pk & 0x0F).long(); hi = ((pk >> 4) & 0x0F).long() idx = torch.empty(lo.numel()+hi.numel(), dtype=torch.long) idx[0::2]=lo; idx[1::2]=hi return ct[idx[:n]].reshape(sh) def forward(self, x: torch.Tensor) -> torch.Tensor: w = (self._w_ready if (self._has_w and self._w_ready is not None) else (self._unpack_one("w") if self._has_w else None)) b = (self._b_ready if (self._has_b and self._b_ready is not None) else (self._unpack_one("b") if self._has_b else None)) return F.batch_norm(x, self.running_mean, self.running_var, w, b, False, 0.0, self.eps) @property def weight_bytes(self) -> int: total = self.running_mean.nbytes + self.running_var.nbytes for prefix in ("w", "b"): if getattr(self, f"_has_{prefix}", False): total += (getattr(self, f"_{prefix}_packed").nbytes + getattr(self, f"_{prefix}_centroids").nbytes) return total # ── BatchedUnpackEngine ─────────────────────────────────────────────────────── class BatchedUnpackEngine: """ Один numba prange вызов распаковывает ВСЕ nibble-слои параллельно. Схема: init: сканирует модель, конкатенирует all_packed, строит смещения, pre-allocates out_flat, создаёт numpy views для каждого слоя. unpack_all(): ОДИН nb_unpack_4bit_batch + ОДИН nb_unpack_8bit_batch вызов. Заполняет _w_ready / _b_ready на каждом packed-модуле. """ def __init__(self, enc: nn.Module): self.enc = enc self._entries: list[dict] = [] # [{"module", "attr", "n_bits", "meta_key"}] # Собираем все тензоры для упаковки for name, m in enc.named_modules(): if isinstance(m, PackedConv1d): self._entries.append({ "module": m, "attr": "_w_ready", "packed": m._w_packed.numpy().copy(), "centroids": m._w_centroids.numpy().copy(), "n_orig": m._n_orig, "shape": m._shape, "n_bits": m._n_bits, }) elif isinstance(m, PackedBatchNorm1d): for prefix, attr in (("w", "_w_ready"), ("b", "_b_ready")): if getattr(m, f"_has_{prefix}", False): self._entries.append({ "module": m, "attr": attr, "packed": getattr(m, f"_{prefix}_packed").numpy().copy(), "centroids": getattr(m, f"_{prefix}_centroids").numpy().copy(), "n_orig": getattr(m, f"_{prefix}_n_orig"), "shape": getattr(m, f"_{prefix}_shape"), "n_bits": getattr(m, f"_{prefix}_n_bits"), }) # Строим батчи по n_bits self._batches: dict[int, dict] = {} for nb_bits in (4, 8): entries = [e for e in self._entries if e["n_bits"] == nb_bits] if not entries: continue all_packed = np.concatenate([e["packed"] for e in entries]) centroids_mat = np.stack([e["centroids"] for e in entries]) n_origs = np.array([e["n_orig"] for e in entries], dtype=np.int64) pack_sizes = np.array([len(e["packed"]) for e in entries], dtype=np.int64) pack_offsets = np.concatenate([[0], np.cumsum(pack_sizes[:-1])]).astype(np.int64) out_offsets = np.concatenate([[0], np.cumsum(n_origs[:-1])]).astype(np.int64) total_weights = int(n_origs.sum()) # ── Ключевая оптимизация ────────────────────────────────────────── # out_flat аллоцируется как torch-тензор (правильное выравнивание # для MKL/BLAS). Получаем numpy-view той же памяти для numba. out_flat_t = torch.empty(total_weights, dtype=torch.float32) out_flat_np = out_flat_t.numpy() # zero-copy view # _w_ready / _b_ready устанавливаются ОДИН РАЗ здесь (не в unpack_all). # Numba пишет в out_flat_np → out_flat_t обновляется автоматически. for i, e in enumerate(entries): o = int(out_offsets[i]) # torch-view той же памяти — persistent, без копирования t_view = out_flat_t[o : o + e["n_orig"]].reshape(e["shape"]) setattr(e["module"], e["attr"], t_view) self._batches[nb_bits] = { "entries": entries, "all_packed": all_packed, "centroids_mat": centroids_mat, "n_origs": n_origs, "pack_offsets": pack_offsets, "out_offsets": out_offsets, "out_flat_np": out_flat_np, # numba пишет сюда "out_flat_t": out_flat_t, # torch читает отсюда } def unpack_all(self): """Один (или два) numba-вызова, никаких Python-операций после.""" for nb_bits, bat in self._batches.items(): fn = _NB["batch_4"] if nb_bits == 4 else _NB["batch_8"] fn(bat["all_packed"], bat["centroids_mat"], bat["n_origs"], bat["pack_offsets"], bat["out_offsets"], bat["out_flat_np"]) # Тензоры уже указывают на out_flat_t (та же память) — ничего больше не нужно def encode_batch(self, wav: torch.Tensor) -> torch.Tensor: """Батчевый unpack + forward модели.""" self.unpack_all() return self.enc.encode_batch(wav) def ram_nibbles_mb(self) -> float: total = sum(m.weight_bytes for m in self.enc.modules() if isinstance(m, (PackedConv1d, PackedBatchNorm1d))) return total / 1024**2 def ram_total_mb(self) -> float: return (sum(p.nbytes for p in self.enc.parameters()) + sum(b.nbytes for b in self.enc.buffers())) / 1024**2 # ── Построение модели ───────────────────────────────────────────────────────── def build_packed_model(packed_path: Path) -> tuple[nn.Module, BatchedUnpackEngine]: from speechbrain.inference.classifiers import EncoderClassifier from pack_qat_weights import (QATConv1d, QATBatchNorm1d, replace_with_qat, unpack_tensor) enc = EncoderClassifier.from_hparams( source=TEACHER_REPO, savedir=str(TEACHER_DIR), run_opts={"device": "cpu"}) replace_with_qat(enc, n_bits=4) ckpt = torch.load(packed_path, map_location="cpu", weights_only=False) packed_layers = ckpt["packed_layers"] fp32_params = ckpt["fp32_params"] all_p = dict(enc.named_parameters()); all_b = dict(enc.named_buffers()) with torch.no_grad(): for k, v in fp32_params.items(): if k in all_p: all_p[k].data.copy_(v) elif k in all_b: all_b[k].data.copy_(v) mods = dict(enc.named_modules()) for name, layer in packed_layers.items(): m = mods.get(name) if m is None: continue parts = name.rsplit(".", 1) parent = mods[parts[0]] if len(parts) == 2 else enc attr = parts[-1] if isinstance(m, QATConv1d): bias_t = None if layer["bias"] is not None: bias_t = (unpack_tensor(layer["bias"]) if isinstance(layer["bias"], dict) else layer["bias"].clone()) new_m = PackedConv1d( weight_meta=layer["weight"], bias=bias_t, in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation, groups=m.groups) setattr(parent, attr, new_m) elif isinstance(m, QATBatchNorm1d): bias_meta = layer["bias"] if isinstance(layer.get("bias"), dict) else None new_m = PackedBatchNorm1d( weight_meta=layer["weight"], bias_meta=bias_meta, running_mean=layer["running_mean"], running_var=layer["running_var"], eps=layer["eps"], num_features=m.num_features) setattr(parent, attr, new_m) enc.eval() engine = BatchedUnpackEngine(enc) return enc, engine # ── Бенчмарк ───────────────────────────────────────────────────────────────── def bench_model(encode_fn, n_warm=N_WARM, n_bench=N_BENCH, secs=3) -> dict: w1 = torch.randn(1, 16_000 * secs) w8 = torch.randn(8, 16_000 * secs) with torch.no_grad(): for _ in range(n_warm): encode_fn(w1) t1, t8 = [], [] with torch.no_grad(): for _ in range(n_bench): t = time.perf_counter(); encode_fn(w1) t1.append((time.perf_counter()-t)*1000) for _ in range(n_bench): t = time.perf_counter(); encode_fn(w8) t8.append((time.perf_counter()-t)*1000) return {"mean_b1": float(np.mean(t1)), "median_b1": float(np.median(t1)), "p95_b1": float(np.percentile(t1, 95)), "mean_b8": float(np.mean(t8))} def model_ram(enc) -> float: return (sum(p.nbytes for p in enc.parameters()) + sum(b.nbytes for b in enc.buffers())) / 1024**2 def main(): from speechbrain.inference.classifiers import EncoderClassifier from pack_qat_weights import (QATConv1d, QATBatchNorm1d, replace_with_qat, unpack_tensor) bar = "=" * 72 print(bar) print(" PackedConv1d — сравнение бэкендов unpack") print(f" Numba: {'OK ' + __import__('numba').__version__ if NUMBA_OK else 'НЕТ (torch fallback)'}") print(bar) results = {} # [1] FP32 baseline print("\n[1] FP32 baseline ...") enc_fp32 = EncoderClassifier.from_hparams( source=TEACHER_REPO, savedir=str(TEACHER_DIR), run_opts={"device": "cpu"}) enc_fp32.eval() results["fp32"] = { "label": "FP32 оригинал", "ram": model_ram(enc_fp32), "nibbles": None, "inf": bench_model(enc_fp32.encode_batch)} print(f" RAM: {results['fp32']['ram']:.2f} MB | " f"median b=1: {results['fp32']['inf']['median_b1']:.1f} ms") # [2] Unpack-at-load (для сравнения RAM) print("\n[2] Unpack-at-load (FP32 в памяти после распаковки) ...") enc_ul = EncoderClassifier.from_hparams( source=TEACHER_REPO, savedir=str(TEACHER_DIR), run_opts={"device": "cpu"}) replace_with_qat(enc_ul, n_bits=4) ckpt = torch.load(PACKED_PATH, map_location="cpu", weights_only=False) all_p = dict(enc_ul.named_parameters()); all_b_d = dict(enc_ul.named_buffers()) for k, v in ckpt["fp32_params"].items(): if k in all_p: all_p[k].data.copy_(v) elif k in all_b_d: all_b_d[k].data.copy_(v) ms = dict(enc_ul.named_modules()) with torch.no_grad(): for nm, lyr in ckpt["packed_layers"].items(): m = ms.get(nm) if isinstance(m, QATConv1d): m.weight.data = unpack_tensor(lyr["weight"]) if lyr["bias"] is not None: m.bias.data = (unpack_tensor(lyr["bias"]) if isinstance(lyr["bias"], dict) else lyr["bias"].clone()) elif isinstance(m, QATBatchNorm1d): if lyr["weight"] is not None: m.weight.data = unpack_tensor(lyr["weight"]) if lyr["bias"] is not None: m.bias.data = unpack_tensor(lyr["bias"]) m.running_mean.data = lyr["running_mean"].clone() m.running_var.data = lyr["running_var"].clone() enc_ul.eval() results["unload"] = { "label": "Unpack-at-load", "ram": model_ram(enc_ul), "nibbles": None, "inf": bench_model(enc_ul.encode_batch)} print(f" RAM: {results['unload']['ram']:.2f} MB | " f"median b=1: {results['unload']['inf']['median_b1']:.1f} ms") # [3] PackedConv1d per-layer torch print("\n[3] PackedConv1d per-layer [torch] ...") enc_pk, engine_pk = build_packed_model(PACKED_PATH) # Отключаем batched engine: _w_ready = None → inline torch unpack for m in enc_pk.modules(): if isinstance(m, (PackedConv1d, PackedBatchNorm1d)): for attr in ("_w_ready", "_b_ready"): if hasattr(m, attr): setattr(m, attr, None) # Форсируем torch fallback saved_numba = NUMBA_OK import packed_inference as _self; _self.NUMBA_OK = False results["torch_inline"] = { "label": "PackedConv1d [torch inline]", "ram": model_ram(enc_pk), "nibbles": engine_pk.ram_nibbles_mb(), "inf": bench_model(enc_pk.encode_batch)} _self.NUMBA_OK = saved_numba print(f" RAM: {results['torch_inline']['ram']:.2f} MB " f"(nibbles: {results['torch_inline']['nibbles']:.2f} MB) | " f"median b=1: {results['torch_inline']['inf']['median_b1']:.1f} ms") if NUMBA_OK: # [4] PackedConv1d per-layer numba print("\n[4] PackedConv1d per-layer [numba] ...") enc_nb, engine_nb = build_packed_model(PACKED_PATH) for m in enc_nb.modules(): if isinstance(m, (PackedConv1d, PackedBatchNorm1d)): for attr in ("_w_ready", "_b_ready"): if hasattr(m, attr): setattr(m, attr, None) # JIT прогрев with torch.no_grad(): enc_nb.encode_batch(torch.randn(1, 48_000)) results["numba_inline"] = { "label": "PackedConv1d [numba inline]", "ram": model_ram(enc_nb), "nibbles": engine_nb.ram_nibbles_mb(), "inf": bench_model(enc_nb.encode_batch)} print(f" RAM: {results['numba_inline']['ram']:.2f} MB " f"(nibbles: {results['numba_inline']['nibbles']:.2f} MB) | " f"median b=1: {results['numba_inline']['inf']['median_b1']:.1f} ms") # [5] BatchedUnpackEngine + numba prange (ОСНОВНОЙ РЕЗУЛЬТАТ) print("\n[5] BatchedUnpackEngine [numba prange] ...") enc_bt, engine_bt = build_packed_model(PACKED_PATH) # Прогрев prange JIT print(" prange JIT прогрев...", end=" ", flush=True) with torch.no_grad(): engine_bt.encode_batch(torch.randn(1, 48_000)) print("готово") results["batched"] = { "label": "Batched [numba prange]", "ram": engine_bt.ram_total_mb(), "nibbles": engine_bt.ram_nibbles_mb(), "inf": bench_model(engine_bt.encode_batch)} print(f" RAM: {results['batched']['ram']:.2f} MB " f"(nibbles: {results['batched']['nibbles']:.2f} MB) | " f"median b=1: {results['batched']['inf']['median_b1']:.1f} ms") # ── Сводка ─────────────────────────────────────────────────────────────── ref_b1 = results["fp32"]["inf"]["median_b1"] ref_ram = results["fp32"]["ram"] print(f"\n{bar}") print(f" ИТОГОВАЯ СВОДКА (CPU, 3s audio, n={N_BENCH}, median latency)") print(bar) print(f" {'Схема':<32} {'RAM':>8} {'Nibbles':>9} {'b=1':>9} {'b=8':>9} {'vs FP32':>9}") print(f" {'-'*72}") for r in results.values(): sp = ref_b1 / r["inf"]["median_b1"] pstr = f"{r['nibbles']:.2f} MB" if r["nibbles"] else " —" mark = " ◀" if r.get("label","").startswith("Batched") else "" print(f" {r['label']:<32} {r['ram']:>6.2f} MB {pstr:>9}" f" {r['inf']['median_b1']:>8.1f}ms {r['inf']['mean_b8']:>8.1f}ms" f" {sp:>7.2f}x{mark}") print(f"\n {'-'*72}") fp32_ram = results["fp32"]["ram"] if "batched" in results: bt = results["batched"] oh = bt["inf"]["median_b1"] - ref_b1 print(f" RAM (nibbles): {bt['nibbles']:.2f} MB vs FP32: {fp32_ram:.2f} MB " f"({fp32_ram/bt['nibbles']:.1f}x меньше весов)") print(f" Overhead batched: {oh:+.1f} ms vs FP32") nb_oh = results.get("numba_inline",{}).get("inf",{}).get("median_b1", 0) - ref_b1 if nb_oh: print(f" Overhead numba inline: {nb_oh:+.1f} ms " f"Ускорение батча: {(ref_b1+nb_oh)/(ref_b1+oh):.1f}x над numba inline") print(bar) if __name__ == "__main__": main()