mkrimmel-pplx commited on
Commit
8d01c68
·
1 Parent(s): ff0893c

feat: updated context model

Browse files
Files changed (2) hide show
  1. modeling.py +27 -12
  2. st_quantize.py +12 -9
modeling.py CHANGED
@@ -15,7 +15,7 @@
15
  # See the License for the specific language governing permissions and
16
  # limitations under the License.
17
 
18
- from typing import Optional, Tuple
19
 
20
  import numpy as np
21
  import torch
@@ -26,7 +26,7 @@ from transformers.modeling_utils import PreTrainedModel
26
  from transformers.modeling_outputs import BaseModelOutputWithPast
27
 
28
  from .configuration import PPLXQwen3Config
29
- from .st_quantize import Int8TanhQuantizer
30
 
31
 
32
  # Activation functions mapping
@@ -553,7 +553,7 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
553
  super().__init__(config)
554
  self.model = PPLXQwen3Model(config)
555
  self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
556
- self.quantizer = Int8TanhQuantizer(hard=True)
557
  self.post_init()
558
 
559
  def forward(
@@ -594,6 +594,7 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
594
  device: str | torch.device | None = None,
595
  normalize_embeddings: bool = False,
596
  convert_to_numpy: bool = True,
 
597
  ) -> list[np.ndarray] | list[torch.Tensor]:
598
  """
599
  Encode documents with late chunking (contextual embeddings).
@@ -605,8 +606,9 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
605
  1. Concatenate chunks with separator tokens
606
  2. Run forward pass to get token embeddings
607
  3. Extract and pool individual chunk embeddings (late chunking)
608
- 4. Apply quantization (Int8 tanh quantization)
609
- 5. Convert to numpy or return as tensors
 
610
 
611
  Args:
612
  documents: List of documents, where each document is a list of text chunks.
@@ -614,14 +616,19 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
614
  batch_size: Batch size for encoding
615
  show_progress_bar: Show progress bar during encoding
616
  device: Device to use for computation (defaults to model's device)
617
- normalize_embeddings: Normalize embeddings to unit length (applied before quantization)
618
  convert_to_numpy: If True, returns list[np.ndarray], otherwise list[torch.Tensor]
 
 
 
619
 
620
  Returns:
621
  List of numpy arrays or tensors (preserves document structure).
622
  Each element has shape (n_chunks, hidden_dim).
623
  embeddings[0].shape = (2, 1024), embeddings[1].shape = (3, 1024)
624
- With quantization, embeddings are int8 values in range [-128, 127].
 
 
625
  """
626
 
627
  if not isinstance(documents, list) or not all(
@@ -631,6 +638,13 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
631
  "Input 'documents' must be a list of lists of strings for contextual encoding."
632
  )
633
 
 
 
 
 
 
 
 
634
  self.eval()
635
 
636
  if device is None:
@@ -676,10 +690,12 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
676
  for doc_chunks in batch_chunk_embeddings
677
  ]
678
 
679
- if self.quantizer is not None:
680
- batch_chunk_embeddings = [
681
- self.quantizer(emb) for emb in batch_chunk_embeddings
682
- ]
 
 
683
 
684
  if normalize_embeddings:
685
  batch_chunk_embeddings = [
@@ -691,7 +707,6 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
691
 
692
  all_embeddings.extend(batch_chunk_embeddings)
693
 
694
- # Convert to numpy if requested
695
  if convert_to_numpy:
696
  all_embeddings = [emb.numpy() for emb in all_embeddings]
697
 
 
15
  # See the License for the specific language governing permissions and
16
  # limitations under the License.
17
 
18
+ from typing import Optional, Tuple, Literal
19
 
20
  import numpy as np
21
  import torch
 
26
  from transformers.modeling_outputs import BaseModelOutputWithPast
27
 
28
  from .configuration import PPLXQwen3Config
29
+ from .st_quantize import FlexibleQuantizer
30
 
31
 
32
  # Activation functions mapping
 
553
  super().__init__(config)
554
  self.model = PPLXQwen3Model(config)
555
  self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
556
+ self._flexible_quantizer = FlexibleQuantizer()
557
  self.post_init()
558
 
559
  def forward(
 
594
  device: str | torch.device | None = None,
595
  normalize_embeddings: bool = False,
596
  convert_to_numpy: bool = True,
597
+ quantization: Literal["int8", "binary"] = "int8",
598
  ) -> list[np.ndarray] | list[torch.Tensor]:
599
  """
600
  Encode documents with late chunking (contextual embeddings).
 
606
  1. Concatenate chunks with separator tokens
607
  2. Run forward pass to get token embeddings
608
  3. Extract and pool individual chunk embeddings (late chunking)
609
+ 4. Apply quantization (Int8 or binary, always enabled)
610
+ 5. Normalize embeddings if requested (applied after quantization)
611
+ 6. Convert to numpy or return as tensors
612
 
613
  Args:
614
  documents: List of documents, where each document is a list of text chunks.
 
616
  batch_size: Batch size for encoding
617
  show_progress_bar: Show progress bar during encoding
618
  device: Device to use for computation (defaults to model's device)
619
+ normalize_embeddings: Normalize embeddings to unit length (applied after quantization)
620
  convert_to_numpy: If True, returns list[np.ndarray], otherwise list[torch.Tensor]
621
+ quantization: Quantization type to apply. Options:
622
+ - "int8": Int8 tanh quantization (default)
623
+ - "binary": Binary tanh quantization
624
 
625
  Returns:
626
  List of numpy arrays or tensors (preserves document structure).
627
  Each element has shape (n_chunks, hidden_dim).
628
  embeddings[0].shape = (2, 1024), embeddings[1].shape = (3, 1024)
629
+ Output type depends on quantization method:
630
+ - Int8: int8 values in range [-128, 127]
631
+ - Binary: float values -1.0 or 1.0
632
  """
633
 
634
  if not isinstance(documents, list) or not all(
 
638
  "Input 'documents' must be a list of lists of strings for contextual encoding."
639
  )
640
 
641
+ if quantization not in ["int8", "binary"]:
642
+ raise ValueError(
643
+ f"Unsupported quantization type: '{quantization}'. "
644
+ f"Supported types are: 'int8', 'binary'. "
645
+ f"Got: {type(quantization).__name__} = '{quantization}'"
646
+ )
647
+
648
  self.eval()
649
 
650
  if device is None:
 
690
  for doc_chunks in batch_chunk_embeddings
691
  ]
692
 
693
+ batch_chunk_embeddings = [
694
+ self._flexible_quantizer(
695
+ {"sentence_embedding": emb}, quantization=quantization
696
+ )["sentence_embedding"]
697
+ for emb in batch_chunk_embeddings
698
+ ]
699
 
700
  if normalize_embeddings:
701
  batch_chunk_embeddings = [
 
707
 
708
  all_embeddings.extend(batch_chunk_embeddings)
709
 
 
710
  if convert_to_numpy:
711
  all_embeddings = [emb.numpy() for emb in all_embeddings]
712
 
st_quantize.py CHANGED
@@ -24,9 +24,7 @@ class Quantizer(torch.nn.Module):
24
  result = soft
25
  else:
26
  result = (
27
- self._hard_quantize(x, *args, **kwargs).detach()
28
- + soft
29
- - soft.detach()
30
  )
31
 
32
  return result
@@ -53,13 +51,13 @@ class Int8TanhQuantizer(Quantizer):
53
 
54
  class BinaryTanhQuantizer(Quantizer):
55
  def __init__(
56
- self,
57
  hard: bool = True,
58
  scale: float = 1.0,
59
  ):
60
  super().__init__(hard)
61
  self._scale = scale
62
-
63
  def _soft_quantize(self, x, *args, **kwargs):
64
  return torch.tanh(self._scale * x)
65
 
@@ -73,7 +71,11 @@ class FlexibleQuantizer(torch.nn.Module):
73
  self._int8_quantizer = Int8TanhQuantizer()
74
  self._binary_quantizer = BinaryTanhQuantizer()
75
 
76
- def forward(self, features: dict[str, torch.Tensor], quantization: Literal["binary", "int8"] = "int8") -> dict[str, torch.Tensor]:
 
 
 
 
77
  if quantization == "int8":
78
  features["sentence_embedding"] = self._int8_quantizer(
79
  features["sentence_embedding"]
@@ -83,10 +85,11 @@ class FlexibleQuantizer(torch.nn.Module):
83
  features["sentence_embedding"]
84
  )
85
  else:
86
- raise ValueError(f"Invalid quantization type: {quantization}. Must be 'binary' or 'int8'.")
 
 
87
  return features
88
-
89
  @classmethod
90
  def load(cls, input_path: str):
91
  return cls()
92
-
 
24
  result = soft
25
  else:
26
  result = (
27
+ self._hard_quantize(x, *args, **kwargs).detach() + soft - soft.detach()
 
 
28
  )
29
 
30
  return result
 
51
 
52
  class BinaryTanhQuantizer(Quantizer):
53
  def __init__(
54
+ self,
55
  hard: bool = True,
56
  scale: float = 1.0,
57
  ):
58
  super().__init__(hard)
59
  self._scale = scale
60
+
61
  def _soft_quantize(self, x, *args, **kwargs):
62
  return torch.tanh(self._scale * x)
63
 
 
71
  self._int8_quantizer = Int8TanhQuantizer()
72
  self._binary_quantizer = BinaryTanhQuantizer()
73
 
74
+ def forward(
75
+ self,
76
+ features: dict[str, torch.Tensor],
77
+ quantization: Literal["binary", "int8"] = "int8",
78
+ ) -> dict[str, torch.Tensor]:
79
  if quantization == "int8":
80
  features["sentence_embedding"] = self._int8_quantizer(
81
  features["sentence_embedding"]
 
85
  features["sentence_embedding"]
86
  )
87
  else:
88
+ raise ValueError(
89
+ f"Invalid quantization type: {quantization}. Must be 'binary' or 'int8'."
90
+ )
91
  return features
92
+
93
  @classmethod
94
  def load(cls, input_path: str):
95
  return cls()