| """ |
| coding = utf-8 |
| Copyright 2026 Rikka Botan. All rights reserved |
| Licensed under "MIT License" |
| Stable Static Embedding official PyTorch implementation |
| """ |
|
|
| from __future__ import annotations |
| import os |
| from pathlib import Path |
| from safetensors.torch import save_file as save_safetensors_file |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import Dict |
| from dataclasses import dataclass |
| from tokenizers import Tokenizer |
| from transformers import PreTrainedTokenizerFast |
| from sentence_transformers.models.InputModule import InputModule |
| from safetensors.torch import load_file |
|
|
|
|
| def quantize_q4_k_m(weight: torch.Tensor): |
| """ |
| weight: (vocab, dim) |
| returns: packed uint8 + scale + zero |
| """ |
| w = weight.detach().cpu().numpy().astype(np.float32) |
|
|
| scales = np.max(np.abs(w), axis=1, keepdims=True) + 1e-8 |
| w_norm = w / scales |
|
|
| q = np.clip(np.round((w_norm + 1) * 7.5), 0, 15).astype(np.uint8) |
|
|
| |
| packed = (q[:, 0::2] << 4) | q[:, 1::2] |
|
|
| return { |
| "packed": packed, |
| "scales": scales.astype(np.float32), |
| } |
|
|
|
|
| def dequantize_q4_k_m(packed: np.ndarray, scales: np.ndarray): |
| hi = (packed >> 4) & 0xF |
| lo = packed & 0xF |
|
|
| q = np.empty((packed.shape[0], packed.shape[1]*2), dtype=np.uint8) |
| q[:, 0::2] = hi |
| q[:, 1::2] = lo |
|
|
| w = (q.astype(np.float32) / 7.5) - 1.0 |
| w = w * scales |
| return torch.from_numpy(w) |
|
|
|
|
| class SeparableDyT(nn.Module): |
| def __init__( |
| self, |
| hidden_dim: int, |
| alpha_init: float = 0.5 |
| ): |
| super().__init__() |
| self.alpha = nn.Parameter(alpha_init*torch.ones(hidden_dim)) |
| self.beta = nn.Parameter(torch.ones(hidden_dim)) |
| self.bias = nn.Parameter(torch.zeros(hidden_dim)) |
| |
| def forward( |
| self, |
| x: torch.Tensor |
| ) -> torch.Tensor: |
| x = self.beta * F.tanh(self.alpha * x + self.bias) |
| return x |
|
|
|
|
| class SSEQ(InputModule): |
| """ |
| Stable Static Embedding (SSE) |
| StaticEmbedding-compatible Sentence-Transformers module |
| """ |
|
|
| def __init__( |
| self, |
| tokenizer: Tokenizer | PreTrainedTokenizerFast, |
| vocab_size: int, |
| hidden_dim: int = 1024, |
| **kwargs, |
| ): |
| super().__init__() |
|
|
| if isinstance(tokenizer, PreTrainedTokenizerFast): |
| tokenizer = tokenizer._tokenizer |
| elif not isinstance(tokenizer, Tokenizer): |
| raise ValueError("Tokenizer must be a fast (Rust) tokenizer") |
|
|
| self.tokenizer: Tokenizer = tokenizer |
| self.tokenizer.no_padding() |
|
|
| self.embedding = nn.EmbeddingBag(vocab_size, hidden_dim) |
| self.dyt = SeparableDyT(hidden_dim) |
|
|
| self.embedding_dim = hidden_dim |
|
|
| |
| self.base_model = kwargs.get("base_model", None) |
|
|
| |
| def tokenize( |
| self, |
| texts: list[str], |
| **kwargs |
| ) -> dict[str, torch.Tensor]: |
| encodings = self.tokenizer.encode_batch(texts, add_special_tokens=False) |
| encodings_ids = [encoding.ids for encoding in encodings] |
|
|
| offsets = torch.from_numpy( |
| np.cumsum( |
| [0] + [len(token_ids) for token_ids in encodings_ids[:-1]] |
| ) |
| ) |
| input_ids = torch.tensor( |
| [token_id for token_ids in encodings_ids for token_id in token_ids], |
| dtype=torch.long |
| ) |
| return { |
| "input_ids": input_ids, |
| "offsets": offsets |
| } |
|
|
| |
| def forward( |
| self, |
| features: Dict[str, torch.Tensor], |
| **kwargs, |
| ) -> Dict[str, torch.Tensor]: |
| x = self.embedding(features["input_ids"], features["offsets"]) |
| x = self.dyt(x) |
| features["sentence_embedding"] = x |
| return features |
|
|
| |
| def get_sentence_embedding_dimension(self) -> int: |
| return self.embedding_dim |
|
|
| @property |
| def max_seq_length(self) -> int: |
| return torch.inf |
| |
| def save(self, output_path: str): |
| os.makedirs(output_path, exist_ok=True) |
|
|
| state = self.state_dict() |
|
|
| emb = state["embedding.weight"] |
| q = quantize_q4_k_m(emb) |
|
|
| del state["embedding.weight"] |
|
|
| save_safetensors_file( |
| state, |
| os.path.join(output_path, "model_rest.safetensors"), |
| ) |
|
|
| with open(os.path.join(output_path, "embedding.q4_k_m.bin"), "wb") as f: |
| f.write(q["packed"].tobytes()) |
| f.write(q["scales"].tobytes()) |
|
|
| self.tokenizer.save( |
| str(Path(output_path) / "tokenizer.json") |
| ) |
| |
| @classmethod |
| def load(cls, model_path: str): |
|
|
| tokenizer = Tokenizer.from_file( |
| os.path.join(model_path, "tokenizer.json") |
| ) |
|
|
| state = load_file( |
| os.path.join(model_path, "model_rest.safetensors"), |
| device="cpu" |
| ) |
|
|
| |
| bin_path = os.path.join(model_path, "embedding.q4_k_m.bin") |
| with open(bin_path, "rb") as f: |
| raw = f.read() |
|
|
| hidden = state["dyt.alpha"].shape[0] |
| total_uint8 = len(raw) |
|
|
| bytes_per_row = hidden // 2 + 4 |
| vocab = total_uint8 // bytes_per_row |
|
|
| packed_size = vocab * hidden // 2 |
|
|
| packed = np.frombuffer(raw[:packed_size], dtype=np.uint8) |
| scales = np.frombuffer(raw[packed_size:], dtype=np.float32) |
|
|
| packed = packed.reshape(vocab, hidden // 2) |
| scales = scales.reshape(vocab, 1) |
|
|
| emb = dequantize_q4_k_m(packed, scales) |
|
|
| |
| model = cls( |
| tokenizer=tokenizer, |
| vocab_size=emb.shape[0], |
| hidden_dim=emb.shape[1] |
| ) |
|
|
| state["embedding.weight"] = emb |
| model.load_state_dict(state) |
|
|
| return model |
|
|
|
|
| @dataclass |
| class SSESforzandoConfig: |
| hidden_dim: int = 512 |
| vocab_size: int = 30522 |
|
|
|
|
| @dataclass |
| class SSEForzandoConfig: |
| hidden_dim: int = 384 |
| vocab_size: int = 30522 |
|
|