| | """ |
| | coding = utf-8 |
| | Copyright 2026 Rikka Botan. All rights reserved |
| | Licensed under "MIT License" |
| | Stable Static Embedding official PyTorch implementation |
| | """ |
| |
|
| | from __future__ import annotations |
| | import os |
| | from pathlib import Path |
| | from safetensors.torch import save_file as save_safetensors_file |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from typing import Dict |
| | from dataclasses import dataclass |
| | from tokenizers import Tokenizer |
| | from transformers import PreTrainedTokenizerFast |
| | from sentence_transformers.models.InputModule import InputModule |
| | from safetensors.torch import load_file |
| |
|
| |
|
| | def quantize_q4_k_m(weight: torch.Tensor): |
| | """ |
| | weight: (vocab, dim) |
| | returns: packed uint8 + scale + zero |
| | """ |
| | w = weight.detach().cpu().numpy().astype(np.float32) |
| |
|
| | scales = np.max(np.abs(w), axis=1, keepdims=True) + 1e-8 |
| | w_norm = w / scales |
| |
|
| | q = np.clip(np.round((w_norm + 1) * 7.5), 0, 15).astype(np.uint8) |
| |
|
| | |
| | packed = (q[:, 0::2] << 4) | q[:, 1::2] |
| |
|
| | return { |
| | "packed": packed, |
| | "scales": scales.astype(np.float32), |
| | } |
| |
|
| |
|
| | def dequantize_q4_k_m(packed: np.ndarray, scales: np.ndarray): |
| | hi = (packed >> 4) & 0xF |
| | lo = packed & 0xF |
| |
|
| | q = np.empty((packed.shape[0], packed.shape[1]*2), dtype=np.uint8) |
| | q[:, 0::2] = hi |
| | q[:, 1::2] = lo |
| |
|
| | w = (q.astype(np.float32) / 7.5) - 1.0 |
| | w = w * scales |
| | return torch.from_numpy(w) |
| |
|
| |
|
| | class SeparableDyT(nn.Module): |
| | def __init__( |
| | self, |
| | hidden_dim: int, |
| | alpha_init: float = 0.5 |
| | ): |
| | super().__init__() |
| | self.alpha = nn.Parameter(alpha_init*torch.ones(hidden_dim)) |
| | self.beta = nn.Parameter(torch.ones(hidden_dim)) |
| | self.bias = nn.Parameter(torch.zeros(hidden_dim)) |
| | |
| | def forward( |
| | self, |
| | x: torch.Tensor |
| | ) -> torch.Tensor: |
| | x = self.beta * F.tanh(self.alpha * x + self.bias) |
| | return x |
| |
|
| |
|
| | class SSEQ(InputModule): |
| | """ |
| | Stable Static Embedding (SSE) |
| | StaticEmbedding-compatible Sentence-Transformers module |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | tokenizer: Tokenizer | PreTrainedTokenizerFast, |
| | vocab_size: int, |
| | hidden_dim: int = 1024, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| |
|
| | if isinstance(tokenizer, PreTrainedTokenizerFast): |
| | tokenizer = tokenizer._tokenizer |
| | elif not isinstance(tokenizer, Tokenizer): |
| | raise ValueError("Tokenizer must be a fast (Rust) tokenizer") |
| |
|
| | self.tokenizer: Tokenizer = tokenizer |
| | self.tokenizer.no_padding() |
| |
|
| | self.embedding = nn.EmbeddingBag(vocab_size, hidden_dim) |
| | self.dyt = SeparableDyT(hidden_dim) |
| |
|
| | self.embedding_dim = hidden_dim |
| |
|
| | |
| | self.base_model = kwargs.get("base_model", None) |
| |
|
| | |
| | def tokenize( |
| | self, |
| | texts: list[str], |
| | **kwargs |
| | ) -> dict[str, torch.Tensor]: |
| | encodings = self.tokenizer.encode_batch(texts, add_special_tokens=False) |
| | encodings_ids = [encoding.ids for encoding in encodings] |
| |
|
| | offsets = torch.from_numpy( |
| | np.cumsum( |
| | [0] + [len(token_ids) for token_ids in encodings_ids[:-1]] |
| | ) |
| | ) |
| | input_ids = torch.tensor( |
| | [token_id for token_ids in encodings_ids for token_id in token_ids], |
| | dtype=torch.long |
| | ) |
| | return { |
| | "input_ids": input_ids, |
| | "offsets": offsets |
| | } |
| |
|
| | |
| | def forward( |
| | self, |
| | features: Dict[str, torch.Tensor], |
| | **kwargs, |
| | ) -> Dict[str, torch.Tensor]: |
| | x = self.embedding(features["input_ids"], features["offsets"]) |
| | x = self.dyt(x) |
| | features["sentence_embedding"] = x |
| | return features |
| |
|
| | |
| | def get_sentence_embedding_dimension(self) -> int: |
| | return self.embedding_dim |
| |
|
| | @property |
| | def max_seq_length(self) -> int: |
| | return torch.inf |
| | |
| | def save(self, output_path: str): |
| | os.makedirs(output_path, exist_ok=True) |
| |
|
| | state = self.state_dict() |
| |
|
| | emb = state["embedding.weight"] |
| | q = quantize_q4_k_m(emb) |
| |
|
| | del state["embedding.weight"] |
| |
|
| | save_safetensors_file( |
| | state, |
| | os.path.join(output_path, "model_rest.safetensors"), |
| | ) |
| |
|
| | with open(os.path.join(output_path, "embedding.q4_k_m.bin"), "wb") as f: |
| | f.write(q["packed"].tobytes()) |
| | f.write(q["scales"].tobytes()) |
| |
|
| | self.tokenizer.save( |
| | str(Path(output_path) / "tokenizer.json") |
| | ) |
| | |
| | @classmethod |
| | def load(cls, model_path: str): |
| |
|
| | tokenizer = Tokenizer.from_file( |
| | os.path.join(model_path, "tokenizer.json") |
| | ) |
| |
|
| | state = load_file( |
| | os.path.join(model_path, "model_rest.safetensors"), |
| | device="cpu" |
| | ) |
| |
|
| | |
| | bin_path = os.path.join(model_path, "embedding.q4_k_m.bin") |
| | with open(bin_path, "rb") as f: |
| | raw = f.read() |
| |
|
| | hidden = state["dyt.alpha"].shape[0] |
| | total_uint8 = len(raw) |
| |
|
| | bytes_per_row = hidden // 2 + 4 |
| | vocab = total_uint8 // bytes_per_row |
| |
|
| | packed_size = vocab * hidden // 2 |
| |
|
| | packed = np.frombuffer(raw[:packed_size], dtype=np.uint8) |
| | scales = np.frombuffer(raw[packed_size:], dtype=np.float32) |
| |
|
| | packed = packed.reshape(vocab, hidden // 2) |
| | scales = scales.reshape(vocab, 1) |
| |
|
| | emb = dequantize_q4_k_m(packed, scales) |
| |
|
| | |
| | model = cls( |
| | tokenizer=tokenizer, |
| | vocab_size=emb.shape[0], |
| | hidden_dim=emb.shape[1] |
| | ) |
| |
|
| | state["embedding.weight"] = emb |
| | model.load_state_dict(state) |
| |
|
| | return model |
| |
|
| |
|
| | @dataclass |
| | class SSESforzandoConfig: |
| | hidden_dim: int = 512 |
| | vocab_size: int = 30522 |
| |
|
| |
|
| | @dataclass |
| | class SSEForzandoConfig: |
| | hidden_dim: int = 384 |
| | vocab_size: int = 30522 |
| |
|