Files changed (3) hide show
  1. README.md +11 -11
  2. modeling.py +8 -18
  3. st_quantize.py +5 -39
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- license: mit
3
  pipeline_tag: feature-extraction
4
  tags:
5
  - feature-extraction
@@ -7,8 +7,7 @@ tags:
7
  - conteb
8
  - contextual-embeddings
9
  language:
10
- - multilingual
11
- library_name: transformers
12
  ---
13
 
14
 
@@ -16,7 +15,7 @@ library_name: transformers
16
  <img src="assets/logo.svg" alt="Perplexity Logo" width="400">
17
  </p>
18
 
19
- <p align="center">pplx-embed-v1: Diffusion-Pretrained Dense and Contextual Embeddings</p>
20
 
21
  `pplx-embed-v1` and `pplx-embed-context-v1` are state-of-the-art text embedding models optimized for real-world, web-scale retrieval tasks.
22
 
@@ -52,7 +51,7 @@ curl -X POST https://api.perplexity.ai/v1/contextualizedembeddings \
52
  -H "Authorization: Bearer YOUR_API_KEY" \
53
  -H "Content-Type: application/json" \
54
  -d '{
55
- "input": [
56
  [
57
  "Curiosity begins in childhood with endless questions about the world.",
58
  "As we grow, curiosity drives us to explore new ideas and challenge assumptions.",
@@ -63,7 +62,7 @@ curl -X POST https://api.perplexity.ai/v1/contextualizedembeddings \
63
  "Each discovery on Mars sparks new questions about our place in the universe."
64
  ]
65
  ],
66
- "model": "pplx-embed-context-v1-0.6b"
67
  }'
68
  ```
69
 
@@ -254,14 +253,15 @@ batch_chunk_embeddings = [
254
 
255
  int8_embeddings = [quantize_int8_tanh(x) for x in batch_chunk_embeddings]
256
  binary_embeddings = [quantize_binary(x) for x in batch_chunk_embeddings]
257
-
258
- bits = [np.where(doc.numpy() >= 0, True, False) for doc in binary_embeddings]
259
- packed_embeddings = [np.packbits(b, axis=-1) for b in bits]
260
-
261
  ```
262
 
263
  </details>
264
 
265
  ## Technical Details
266
 
267
- For comprehensive technical details and evaluation results, see our paper on arXiv: https://arxiv.org/abs/2602.11151.
 
 
 
 
 
 
1
  ---
2
+ license: apache-2.0
3
  pipeline_tag: feature-extraction
4
  tags:
5
  - feature-extraction
 
7
  - conteb
8
  - contextual-embeddings
9
  language:
10
+ - multilingual
 
11
  ---
12
 
13
 
 
15
  <img src="assets/logo.svg" alt="Perplexity Logo" width="400">
16
  </p>
17
 
18
+ <p align="center">pplx-embed-v1: Diffusion-LM for Dense and Contextual Retrieval</p>
19
 
20
  `pplx-embed-v1` and `pplx-embed-context-v1` are state-of-the-art text embedding models optimized for real-world, web-scale retrieval tasks.
21
 
 
51
  -H "Authorization: Bearer YOUR_API_KEY" \
52
  -H "Content-Type: application/json" \
53
  -d '{
54
+ "inputs": [
55
  [
56
  "Curiosity begins in childhood with endless questions about the world.",
57
  "As we grow, curiosity drives us to explore new ideas and challenge assumptions.",
 
62
  "Each discovery on Mars sparks new questions about our place in the universe."
63
  ]
64
  ],
65
+ "model": "pplx-embed-context-v1-0.6B"
66
  }'
67
  ```
68
 
 
253
 
254
  int8_embeddings = [quantize_int8_tanh(x) for x in batch_chunk_embeddings]
255
  binary_embeddings = [quantize_binary(x) for x in batch_chunk_embeddings]
 
 
 
 
256
  ```
257
 
258
  </details>
259
 
260
  ## Technical Details
261
 
262
+ For comprehensive technical details and evaluation results, see our paper on arXiv.
263
+
264
+ ## Contact
265
+
266
+ - Website: https://perplexity.ai
267
+ - API Support: api-support@perplexity.ai
modeling.py CHANGED
@@ -142,7 +142,7 @@ class PPLXQwen3ContextualModel(PPLXQwen3Model):
142
  device: str | torch.device | None = None,
143
  normalize_embeddings: bool = False,
144
  convert_to_numpy: bool = True,
145
- quantization: Literal["int8", "binary", "ubinary"] = "int8",
146
  ) -> list[np.ndarray] | list[torch.Tensor]:
147
  """
148
  Encode documents with late chunking (contextual embeddings).
@@ -168,17 +168,15 @@ class PPLXQwen3ContextualModel(PPLXQwen3Model):
168
  convert_to_numpy: If True, returns list[np.ndarray], otherwise list[torch.Tensor]
169
  quantization: Quantization type to apply. Options:
170
  - "int8": Int8 tanh quantization (default)
171
- - "binary": Binary tanh quantization (-1.0 or 1.0)
172
- - "ubinary": Unsigned packed binary (uint8, 8x compression)
173
 
174
  Returns:
175
  List of numpy arrays or tensors (preserves document structure).
176
- Each element has shape (n_chunks, hidden_dim) or (n_chunks, hidden_dim // 8) for ubinary.
177
- Example: embeddings[0].shape = (2, 1024), embeddings[1].shape = (3, 1024)
178
  Output type depends on quantization method:
179
- - "int8": int8 dtype, values in range [-128, 127], shape (..., hidden_dim)
180
- - "binary": float32 dtype, values -1.0 or 1.0, shape (..., hidden_dim)
181
- - "ubinary": uint8 dtype, packed bits (8x smaller), shape (..., hidden_dim // 8)
182
  """
183
 
184
  if not isinstance(documents, list) or not all(
@@ -188,21 +186,13 @@ class PPLXQwen3ContextualModel(PPLXQwen3Model):
188
  "Input 'documents' must be a list of lists of strings for contextual encoding."
189
  )
190
 
191
- if quantization not in ["int8", "binary", "ubinary"]:
192
  raise ValueError(
193
  f"Unsupported quantization type: '{quantization}'. "
194
- f"Supported types are: 'int8', 'binary', 'ubinary'. "
195
  f"Got: {type(quantization).__name__} = '{quantization}'"
196
  )
197
 
198
- if normalize_embeddings and quantization == "ubinary":
199
- raise ValueError(
200
- "normalize_embeddings=True is incompatible with quantization='ubinary'. "
201
- "Packed binary embeddings (uint8) cannot be normalized because each byte "
202
- "represents 8 packed bits, not a single dimension. "
203
- "Either set normalize_embeddings=False or use 'binary' quantization instead."
204
- )
205
-
206
  self.eval()
207
 
208
  if device is None:
 
142
  device: str | torch.device | None = None,
143
  normalize_embeddings: bool = False,
144
  convert_to_numpy: bool = True,
145
+ quantization: Literal["int8", "binary"] = "int8",
146
  ) -> list[np.ndarray] | list[torch.Tensor]:
147
  """
148
  Encode documents with late chunking (contextual embeddings).
 
168
  convert_to_numpy: If True, returns list[np.ndarray], otherwise list[torch.Tensor]
169
  quantization: Quantization type to apply. Options:
170
  - "int8": Int8 tanh quantization (default)
171
+ - "binary": Binary tanh quantization
 
172
 
173
  Returns:
174
  List of numpy arrays or tensors (preserves document structure).
175
+ Each element has shape (n_chunks, hidden_dim).
176
+ embeddings[0].shape = (2, 1024), embeddings[1].shape = (3, 1024)
177
  Output type depends on quantization method:
178
+ - Int8: int8 values in range [-128, 127]
179
+ - Binary: float values -1.0 or 1.0
 
180
  """
181
 
182
  if not isinstance(documents, list) or not all(
 
186
  "Input 'documents' must be a list of lists of strings for contextual encoding."
187
  )
188
 
189
+ if quantization not in ["int8", "binary"]:
190
  raise ValueError(
191
  f"Unsupported quantization type: '{quantization}'. "
192
+ f"Supported types are: 'int8', 'binary'. "
193
  f"Got: {type(quantization).__name__} = '{quantization}'"
194
  )
195
 
 
 
 
 
 
 
 
 
196
  self.eval()
197
 
198
  if device is None:
st_quantize.py CHANGED
@@ -1,5 +1,4 @@
1
  import torch
2
- import numpy as np
3
  from typing import Literal
4
  from sentence_transformers.models import Module
5
 
@@ -67,46 +66,17 @@ class BinaryTanhQuantizer(Quantizer):
67
  return torch.where(x >= 0, 1.0, -1.0)
68
 
69
 
70
- class PackedBinaryQuantizer:
71
- """
72
- Packs binary embeddings into uint8 format for efficient storage.
73
-
74
- This quantizer applies a binary threshold (x >= 0) and packs 8 consecutive
75
- bits into a single uint8 byte using numpy.packbits. This reduces memory
76
- usage by 8x compared to float32 and by 4x compared to int8.
77
-
78
- IMPORTANT: This is an inference-only quantizer - it is not differentiable
79
- and should only be used for encoding/inference, not during training.
80
-
81
- Args:
82
- x: Input tensor of any float dtype, shape (..., embedding_dim)
83
-
84
- Returns:
85
- Packed binary tensor of dtype uint8, shape (..., embedding_dim // 8)
86
-
87
- Example:
88
- >>> quantizer = PackedBinaryQuantizer()
89
- >>> embeddings = torch.randn(2, 1024) # float32
90
- >>> packed = quantizer(embeddings) # uint8, shape (2, 128)
91
- """
92
- def __call__(self, x: torch.Tensor) -> torch.Tensor:
93
- bits = np.where(x.cpu().numpy() >= 0, True, False)
94
- packed = np.packbits(bits, axis=-1)
95
- return torch.from_numpy(packed).to(x.device)
96
-
97
-
98
  class FlexibleQuantizer(Module):
99
  def __init__(self):
100
  super().__init__()
101
  self._int8_quantizer = Int8TanhQuantizer()
102
  self._binary_quantizer = BinaryTanhQuantizer()
103
- self._packed_binary_quantizer = PackedBinaryQuantizer()
104
 
105
  def forward(
106
  self,
107
  features: dict[str, torch.Tensor],
108
- quantization: Literal["int8", "binary", "ubinary"] = "int8",
109
- **kwargs,
110
  ) -> dict[str, torch.Tensor]:
111
  if quantization == "int8":
112
  features["sentence_embedding"] = self._int8_quantizer(
@@ -116,13 +86,9 @@ class FlexibleQuantizer(Module):
116
  features["sentence_embedding"] = self._binary_quantizer(
117
  features["sentence_embedding"]
118
  )
119
- elif quantization == "ubinary":
120
- features["sentence_embedding"] = self._packed_binary_quantizer(
121
- features["sentence_embedding"]
122
- )
123
  else:
124
  raise ValueError(
125
- f"Invalid quantization type: {quantization}. Must be 'binary', 'ubinary', or 'int8'."
126
  )
127
  return features
128
 
@@ -138,6 +104,6 @@ class FlexibleQuantizer(Module):
138
  **kwargs,
139
  ):
140
  return cls()
141
-
142
- def save(self, output_path: str, *args, **kwargs) -> None:
143
  return
 
1
  import torch
 
2
  from typing import Literal
3
  from sentence_transformers.models import Module
4
 
 
66
  return torch.where(x >= 0, 1.0, -1.0)
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  class FlexibleQuantizer(Module):
70
  def __init__(self):
71
  super().__init__()
72
  self._int8_quantizer = Int8TanhQuantizer()
73
  self._binary_quantizer = BinaryTanhQuantizer()
 
74
 
75
  def forward(
76
  self,
77
  features: dict[str, torch.Tensor],
78
+ quantization: Literal["binary", "int8"] = "int8",
79
+ **kwargs
80
  ) -> dict[str, torch.Tensor]:
81
  if quantization == "int8":
82
  features["sentence_embedding"] = self._int8_quantizer(
 
86
  features["sentence_embedding"] = self._binary_quantizer(
87
  features["sentence_embedding"]
88
  )
 
 
 
 
89
  else:
90
  raise ValueError(
91
+ f"Invalid quantization type: {quantization}. Must be 'binary' or 'int8'."
92
  )
93
  return features
94
 
 
104
  **kwargs,
105
  ):
106
  return cls()
107
+
108
+ def save(self, output_path: str, *args, **kwargs) -> None:
109
  return