""" coding = utf-8 Copyright 2026 Rikka Botan. All rights reserved Licensed under "MIT License" Stable Static Embedding official PyTorch implementation """ from __future__ import annotations import os from pathlib import Path from safetensors.torch import save_file as save_safetensors_file import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import Dict from dataclasses import dataclass from tokenizers import Tokenizer from transformers import PreTrainedTokenizerFast from sentence_transformers.models.InputModule import InputModule from safetensors.torch import load_file def quantize_q4_k_m(weight: torch.Tensor): """ weight: (vocab, dim) returns: packed uint8 + scale + zero """ w = weight.detach().cpu().numpy().astype(np.float32) scales = np.max(np.abs(w), axis=1, keepdims=True) + 1e-8 w_norm = w / scales q = np.clip(np.round((w_norm + 1) * 7.5), 0, 15).astype(np.uint8) # pack 2x4bit -> uint8 packed = (q[:, 0::2] << 4) | q[:, 1::2] return { "packed": packed, "scales": scales.astype(np.float32), } def dequantize_q4_k_m(packed: np.ndarray, scales: np.ndarray): hi = (packed >> 4) & 0xF lo = packed & 0xF q = np.empty((packed.shape[0], packed.shape[1]*2), dtype=np.uint8) q[:, 0::2] = hi q[:, 1::2] = lo w = (q.astype(np.float32) / 7.5) - 1.0 w = w * scales return torch.from_numpy(w) class SeparableDyT(nn.Module): def __init__( self, hidden_dim: int, alpha_init: float = 0.5 ): super().__init__() self.alpha = nn.Parameter(alpha_init*torch.ones(hidden_dim)) self.beta = nn.Parameter(torch.ones(hidden_dim)) self.bias = nn.Parameter(torch.zeros(hidden_dim)) def forward( self, x: torch.Tensor ) -> torch.Tensor: x = self.beta * F.tanh(self.alpha * x + self.bias) return x class SSEQ(InputModule): """ Stable Static Embedding (SSE) StaticEmbedding-compatible Sentence-Transformers module """ def __init__( self, tokenizer: Tokenizer | PreTrainedTokenizerFast, vocab_size: int, hidden_dim: int = 1024, **kwargs, ): super().__init__() if isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = tokenizer._tokenizer elif not isinstance(tokenizer, Tokenizer): raise ValueError("Tokenizer must be a fast (Rust) tokenizer") self.tokenizer: Tokenizer = tokenizer self.tokenizer.no_padding() self.embedding = nn.EmbeddingBag(vocab_size, hidden_dim) self.dyt = SeparableDyT(hidden_dim) self.embedding_dim = hidden_dim # For model card compatibility self.base_model = kwargs.get("base_model", None) # Tokenization (StaticEmbedding-compatible) def tokenize( self, texts: list[str], **kwargs ) -> dict[str, torch.Tensor]: encodings = self.tokenizer.encode_batch(texts, add_special_tokens=False) encodings_ids = [encoding.ids for encoding in encodings] offsets = torch.from_numpy( np.cumsum( [0] + [len(token_ids) for token_ids in encodings_ids[:-1]] ) ) input_ids = torch.tensor( [token_id for token_ids in encodings_ids for token_id in token_ids], dtype=torch.long ) return { "input_ids": input_ids, "offsets": offsets } # Forward def forward( self, features: Dict[str, torch.Tensor], **kwargs, ) -> Dict[str, torch.Tensor]: x = self.embedding(features["input_ids"], features["offsets"]) x = self.dyt(x) features["sentence_embedding"] = x return features # Required APIs def get_sentence_embedding_dimension(self) -> int: return self.embedding_dim @property def max_seq_length(self) -> int: return torch.inf def save(self, output_path: str): os.makedirs(output_path, exist_ok=True) state = self.state_dict() emb = state["embedding.weight"] q = quantize_q4_k_m(emb) del state["embedding.weight"] save_safetensors_file( state, os.path.join(output_path, "model_rest.safetensors"), ) with open(os.path.join(output_path, "embedding.q4_k_m.bin"), "wb") as f: f.write(q["packed"].tobytes()) f.write(q["scales"].tobytes()) self.tokenizer.save( str(Path(output_path) / "tokenizer.json") ) @classmethod def load(cls, model_path: str): tokenizer = Tokenizer.from_file( os.path.join(model_path, "tokenizer.json") ) state = load_file( os.path.join(model_path, "model_rest.safetensors"), device="cpu" ) # read q4 binary bin_path = os.path.join(model_path, "embedding.q4_k_m.bin") with open(bin_path, "rb") as f: raw = f.read() hidden = state["dyt.alpha"].shape[0] total_uint8 = len(raw) bytes_per_row = hidden // 2 + 4 vocab = total_uint8 // bytes_per_row packed_size = vocab * hidden // 2 packed = np.frombuffer(raw[:packed_size], dtype=np.uint8) scales = np.frombuffer(raw[packed_size:], dtype=np.float32) packed = packed.reshape(vocab, hidden // 2) scales = scales.reshape(vocab, 1) emb = dequantize_q4_k_m(packed, scales) # rebuild model model = cls( tokenizer=tokenizer, vocab_size=emb.shape[0], hidden_dim=emb.shape[1] ) state["embedding.weight"] = emb model.load_state_dict(state) return model @dataclass class SSESforzandoConfig: hidden_dim: int = 512 vocab_size: int = 30522 @dataclass class SSEForzandoConfig: hidden_dim: int = 384 vocab_size: int = 30522