Files changed (3) hide show
  1. README.md +2 -2
  2. modeling.py +19 -8
  3. st_quantize.py +39 -5
README.md CHANGED
@@ -34,8 +34,8 @@ language:
34
  |:-----:|:----------:|:-------:|:---:|:------------:|:-----------:|:-------:|
35
  | `pplx-embed-v1-0.6B` | 1024 | 32K | Yes | INT8/BINARY | No | Mean |
36
  | `pplx-embed-v1-4B` | 2560 | 32K | Yes | INT8/BINARY | No | Mean |
37
- | `pplx-embed-context-v1-0.6B` | 1024 | 32K | Yes | INT8/BINARY | No | Mean |
38
- | `pplx-embed-context-v1-4B` | 2560 | 32K | Yes | INT8/BINARY | No | Mean |
39
 
40
  <sub>All models are built on diffusion continued pre-trained Qwen3 at Perplexity AI.</sub>
41
 
 
34
  |:-----:|:----------:|:-------:|:---:|:------------:|:-----------:|:-------:|
35
  | `pplx-embed-v1-0.6B` | 1024 | 32K | Yes | INT8/BINARY | No | Mean |
36
  | `pplx-embed-v1-4B` | 2560 | 32K | Yes | INT8/BINARY | No | Mean |
37
+ | `pplx-embed-context-v1-0.6B` | 1024 | 32K | Yes | INT8/BINARY/UBINARY | No | Mean |
38
+ | `pplx-embed-context-v1-4B` | 2560 | 32K | Yes | INT8/BINAR/UBINARY Y | No | Mean |
39
 
40
  <sub>All models are built on diffusion continued pre-trained Qwen3 at Perplexity AI.</sub>
41
 
modeling.py CHANGED
@@ -12,6 +12,7 @@ from transformers import AutoTokenizer
12
  from .st_quantize import FlexibleQuantizer
13
 
14
 
 
15
  def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable:
16
  """
17
  This creates bidirectional attention mask.
@@ -141,7 +142,7 @@ class PPLXQwen3ContextualModel(PPLXQwen3Model):
141
  device: str | torch.device | None = None,
142
  normalize_embeddings: bool = False,
143
  convert_to_numpy: bool = True,
144
- quantization: Literal["int8", "binary"] = "int8",
145
  ) -> list[np.ndarray] | list[torch.Tensor]:
146
  """
147
  Encode documents with late chunking (contextual embeddings).
@@ -167,15 +168,17 @@ class PPLXQwen3ContextualModel(PPLXQwen3Model):
167
  convert_to_numpy: If True, returns list[np.ndarray], otherwise list[torch.Tensor]
168
  quantization: Quantization type to apply. Options:
169
  - "int8": Int8 tanh quantization (default)
170
- - "binary": Binary tanh quantization
 
171
 
172
  Returns:
173
  List of numpy arrays or tensors (preserves document structure).
174
- Each element has shape (n_chunks, hidden_dim).
175
- embeddings[0].shape = (2, 1024), embeddings[1].shape = (3, 1024)
176
  Output type depends on quantization method:
177
- - Int8: int8 values in range [-128, 127]
178
- - Binary: float values -1.0 or 1.0
 
179
  """
180
 
181
  if not isinstance(documents, list) or not all(
@@ -185,13 +188,21 @@ class PPLXQwen3ContextualModel(PPLXQwen3Model):
185
  "Input 'documents' must be a list of lists of strings for contextual encoding."
186
  )
187
 
188
- if quantization not in ["int8", "binary"]:
189
  raise ValueError(
190
  f"Unsupported quantization type: '{quantization}'. "
191
- f"Supported types are: 'int8', 'binary'. "
192
  f"Got: {type(quantization).__name__} = '{quantization}'"
193
  )
194
 
 
 
 
 
 
 
 
 
195
  self.eval()
196
 
197
  if device is None:
 
12
  from .st_quantize import FlexibleQuantizer
13
 
14
 
15
+ # From modeling_t5gemma.py
16
  def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable:
17
  """
18
  This creates bidirectional attention mask.
 
142
  device: str | torch.device | None = None,
143
  normalize_embeddings: bool = False,
144
  convert_to_numpy: bool = True,
145
+ quantization: Literal["int8", "binary", "ubinary"] = "int8",
146
  ) -> list[np.ndarray] | list[torch.Tensor]:
147
  """
148
  Encode documents with late chunking (contextual embeddings).
 
168
  convert_to_numpy: If True, returns list[np.ndarray], otherwise list[torch.Tensor]
169
  quantization: Quantization type to apply. Options:
170
  - "int8": Int8 tanh quantization (default)
171
+ - "binary": Binary tanh quantization (-1.0 or 1.0)
172
+ - "ubinary": Unsigned packed binary (uint8, 8x compression)
173
 
174
  Returns:
175
  List of numpy arrays or tensors (preserves document structure).
176
+ Each element has shape (n_chunks, hidden_dim) or (n_chunks, hidden_dim // 8) for ubinary.
177
+ Example: embeddings[0].shape = (2, 1024), embeddings[1].shape = (3, 1024)
178
  Output type depends on quantization method:
179
+ - "int8": int8 dtype, values in range [-128, 127], shape (..., hidden_dim)
180
+ - "binary": float32 dtype, values -1.0 or 1.0, shape (..., hidden_dim)
181
+ - "ubinary": uint8 dtype, packed bits (8x smaller), shape (..., hidden_dim // 8)
182
  """
183
 
184
  if not isinstance(documents, list) or not all(
 
188
  "Input 'documents' must be a list of lists of strings for contextual encoding."
189
  )
190
 
191
+ if quantization not in ["int8", "binary", "ubinary"]:
192
  raise ValueError(
193
  f"Unsupported quantization type: '{quantization}'. "
194
+ f"Supported types are: 'int8', 'binary', 'ubinary'. "
195
  f"Got: {type(quantization).__name__} = '{quantization}'"
196
  )
197
 
198
+ if normalize_embeddings and quantization == "ubinary":
199
+ raise ValueError(
200
+ "normalize_embeddings=True is incompatible with quantization='ubinary'. "
201
+ "Packed binary embeddings (uint8) cannot be normalized because each byte "
202
+ "represents 8 packed bits, not a single dimension. "
203
+ "Either set normalize_embeddings=False or use 'binary' quantization instead."
204
+ )
205
+
206
  self.eval()
207
 
208
  if device is None:
st_quantize.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from typing import Literal
3
  from sentence_transformers.models import Module
4
 
@@ -66,17 +67,46 @@ class BinaryTanhQuantizer(Quantizer):
66
  return torch.where(x >= 0, 1.0, -1.0)
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  class FlexibleQuantizer(Module):
70
  def __init__(self):
71
  super().__init__()
72
  self._int8_quantizer = Int8TanhQuantizer()
73
  self._binary_quantizer = BinaryTanhQuantizer()
 
74
 
75
  def forward(
76
  self,
77
  features: dict[str, torch.Tensor],
78
- quantization: Literal["binary", "int8"] = "int8",
79
- **kwargs
80
  ) -> dict[str, torch.Tensor]:
81
  if quantization == "int8":
82
  features["sentence_embedding"] = self._int8_quantizer(
@@ -86,9 +116,13 @@ class FlexibleQuantizer(Module):
86
  features["sentence_embedding"] = self._binary_quantizer(
87
  features["sentence_embedding"]
88
  )
 
 
 
 
89
  else:
90
  raise ValueError(
91
- f"Invalid quantization type: {quantization}. Must be 'binary' or 'int8'."
92
  )
93
  return features
94
 
@@ -104,6 +138,6 @@ class FlexibleQuantizer(Module):
104
  **kwargs,
105
  ):
106
  return cls()
107
-
108
- def save(self, output_path: str, *args, **kwargs) -> None:
109
  return
 
1
  import torch
2
+ import numpy as np
3
  from typing import Literal
4
  from sentence_transformers.models import Module
5
 
 
67
  return torch.where(x >= 0, 1.0, -1.0)
68
 
69
 
70
+ class PackedBinaryQuantizer:
71
+ """
72
+ Packs binary embeddings into uint8 format for efficient storage.
73
+
74
+ This quantizer applies a binary threshold (x >= 0) and packs 8 consecutive
75
+ bits into a single uint8 byte using numpy.packbits. This reduces memory
76
+ usage by 8x compared to float32 and by 4x compared to int8.
77
+
78
+ IMPORTANT: This is an inference-only quantizer - it is not differentiable
79
+ and should only be used for encoding/inference, not during training.
80
+
81
+ Args:
82
+ x: Input tensor of any float dtype, shape (..., embedding_dim)
83
+
84
+ Returns:
85
+ Packed binary tensor of dtype uint8, shape (..., embedding_dim // 8)
86
+
87
+ Example:
88
+ >>> quantizer = PackedBinaryQuantizer()
89
+ >>> embeddings = torch.randn(2, 1024) # float32
90
+ >>> packed = quantizer(embeddings) # uint8, shape (2, 128)
91
+ """
92
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
93
+ bits = np.where(x.cpu().numpy() >= 0, True, False)
94
+ packed = np.packbits(bits, axis=-1)
95
+ return torch.from_numpy(packed).to(x.device)
96
+
97
+
98
  class FlexibleQuantizer(Module):
99
  def __init__(self):
100
  super().__init__()
101
  self._int8_quantizer = Int8TanhQuantizer()
102
  self._binary_quantizer = BinaryTanhQuantizer()
103
+ self._packed_binary_quantizer = PackedBinaryQuantizer()
104
 
105
  def forward(
106
  self,
107
  features: dict[str, torch.Tensor],
108
+ quantization: Literal["int8", "binary", "ubinary"] = "int8",
109
+ **kwargs,
110
  ) -> dict[str, torch.Tensor]:
111
  if quantization == "int8":
112
  features["sentence_embedding"] = self._int8_quantizer(
 
116
  features["sentence_embedding"] = self._binary_quantizer(
117
  features["sentence_embedding"]
118
  )
119
+ elif quantization == "ubinary":
120
+ features["sentence_embedding"] = self._packed_binary_quantizer(
121
+ features["sentence_embedding"]
122
+ )
123
  else:
124
  raise ValueError(
125
+ f"Invalid quantization type: {quantization}. Must be 'binary', 'ubinary', or 'int8'."
126
  )
127
  return features
128
 
 
138
  **kwargs,
139
  ):
140
  return cls()
141
+
142
+ def save(self, output_path: str, *args, **kwargs) -> None:
143
  return