Commit
·
8d01c68
1
Parent(s):
ff0893c
feat: updated context model
Browse files- modeling.py +27 -12
- st_quantize.py +12 -9
modeling.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
| 15 |
# See the License for the specific language governing permissions and
|
| 16 |
# limitations under the License.
|
| 17 |
|
| 18 |
-
from typing import Optional, Tuple
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
import torch
|
|
@@ -26,7 +26,7 @@ from transformers.modeling_utils import PreTrainedModel
|
|
| 26 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 27 |
|
| 28 |
from .configuration import PPLXQwen3Config
|
| 29 |
-
from .st_quantize import
|
| 30 |
|
| 31 |
|
| 32 |
# Activation functions mapping
|
|
@@ -553,7 +553,7 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
|
|
| 553 |
super().__init__(config)
|
| 554 |
self.model = PPLXQwen3Model(config)
|
| 555 |
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
|
| 556 |
-
self.
|
| 557 |
self.post_init()
|
| 558 |
|
| 559 |
def forward(
|
|
@@ -594,6 +594,7 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
|
|
| 594 |
device: str | torch.device | None = None,
|
| 595 |
normalize_embeddings: bool = False,
|
| 596 |
convert_to_numpy: bool = True,
|
|
|
|
| 597 |
) -> list[np.ndarray] | list[torch.Tensor]:
|
| 598 |
"""
|
| 599 |
Encode documents with late chunking (contextual embeddings).
|
|
@@ -605,8 +606,9 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
|
|
| 605 |
1. Concatenate chunks with separator tokens
|
| 606 |
2. Run forward pass to get token embeddings
|
| 607 |
3. Extract and pool individual chunk embeddings (late chunking)
|
| 608 |
-
4. Apply quantization (Int8
|
| 609 |
-
5.
|
|
|
|
| 610 |
|
| 611 |
Args:
|
| 612 |
documents: List of documents, where each document is a list of text chunks.
|
|
@@ -614,14 +616,19 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
|
|
| 614 |
batch_size: Batch size for encoding
|
| 615 |
show_progress_bar: Show progress bar during encoding
|
| 616 |
device: Device to use for computation (defaults to model's device)
|
| 617 |
-
normalize_embeddings: Normalize embeddings to unit length (applied
|
| 618 |
convert_to_numpy: If True, returns list[np.ndarray], otherwise list[torch.Tensor]
|
|
|
|
|
|
|
|
|
|
| 619 |
|
| 620 |
Returns:
|
| 621 |
List of numpy arrays or tensors (preserves document structure).
|
| 622 |
Each element has shape (n_chunks, hidden_dim).
|
| 623 |
embeddings[0].shape = (2, 1024), embeddings[1].shape = (3, 1024)
|
| 624 |
-
|
|
|
|
|
|
|
| 625 |
"""
|
| 626 |
|
| 627 |
if not isinstance(documents, list) or not all(
|
|
@@ -631,6 +638,13 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
|
|
| 631 |
"Input 'documents' must be a list of lists of strings for contextual encoding."
|
| 632 |
)
|
| 633 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
self.eval()
|
| 635 |
|
| 636 |
if device is None:
|
|
@@ -676,10 +690,12 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
|
|
| 676 |
for doc_chunks in batch_chunk_embeddings
|
| 677 |
]
|
| 678 |
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
]
|
|
|
|
|
|
|
| 683 |
|
| 684 |
if normalize_embeddings:
|
| 685 |
batch_chunk_embeddings = [
|
|
@@ -691,7 +707,6 @@ class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
|
|
| 691 |
|
| 692 |
all_embeddings.extend(batch_chunk_embeddings)
|
| 693 |
|
| 694 |
-
# Convert to numpy if requested
|
| 695 |
if convert_to_numpy:
|
| 696 |
all_embeddings = [emb.numpy() for emb in all_embeddings]
|
| 697 |
|
|
|
|
| 15 |
# See the License for the specific language governing permissions and
|
| 16 |
# limitations under the License.
|
| 17 |
|
| 18 |
+
from typing import Optional, Tuple, Literal
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
import torch
|
|
|
|
| 26 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 27 |
|
| 28 |
from .configuration import PPLXQwen3Config
|
| 29 |
+
from .st_quantize import FlexibleQuantizer
|
| 30 |
|
| 31 |
|
| 32 |
# Activation functions mapping
|
|
|
|
| 553 |
super().__init__(config)
|
| 554 |
self.model = PPLXQwen3Model(config)
|
| 555 |
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
|
| 556 |
+
self._flexible_quantizer = FlexibleQuantizer()
|
| 557 |
self.post_init()
|
| 558 |
|
| 559 |
def forward(
|
|
|
|
| 594 |
device: str | torch.device | None = None,
|
| 595 |
normalize_embeddings: bool = False,
|
| 596 |
convert_to_numpy: bool = True,
|
| 597 |
+
quantization: Literal["int8", "binary"] = "int8",
|
| 598 |
) -> list[np.ndarray] | list[torch.Tensor]:
|
| 599 |
"""
|
| 600 |
Encode documents with late chunking (contextual embeddings).
|
|
|
|
| 606 |
1. Concatenate chunks with separator tokens
|
| 607 |
2. Run forward pass to get token embeddings
|
| 608 |
3. Extract and pool individual chunk embeddings (late chunking)
|
| 609 |
+
4. Apply quantization (Int8 or binary, always enabled)
|
| 610 |
+
5. Normalize embeddings if requested (applied after quantization)
|
| 611 |
+
6. Convert to numpy or return as tensors
|
| 612 |
|
| 613 |
Args:
|
| 614 |
documents: List of documents, where each document is a list of text chunks.
|
|
|
|
| 616 |
batch_size: Batch size for encoding
|
| 617 |
show_progress_bar: Show progress bar during encoding
|
| 618 |
device: Device to use for computation (defaults to model's device)
|
| 619 |
+
normalize_embeddings: Normalize embeddings to unit length (applied after quantization)
|
| 620 |
convert_to_numpy: If True, returns list[np.ndarray], otherwise list[torch.Tensor]
|
| 621 |
+
quantization: Quantization type to apply. Options:
|
| 622 |
+
- "int8": Int8 tanh quantization (default)
|
| 623 |
+
- "binary": Binary tanh quantization
|
| 624 |
|
| 625 |
Returns:
|
| 626 |
List of numpy arrays or tensors (preserves document structure).
|
| 627 |
Each element has shape (n_chunks, hidden_dim).
|
| 628 |
embeddings[0].shape = (2, 1024), embeddings[1].shape = (3, 1024)
|
| 629 |
+
Output type depends on quantization method:
|
| 630 |
+
- Int8: int8 values in range [-128, 127]
|
| 631 |
+
- Binary: float values -1.0 or 1.0
|
| 632 |
"""
|
| 633 |
|
| 634 |
if not isinstance(documents, list) or not all(
|
|
|
|
| 638 |
"Input 'documents' must be a list of lists of strings for contextual encoding."
|
| 639 |
)
|
| 640 |
|
| 641 |
+
if quantization not in ["int8", "binary"]:
|
| 642 |
+
raise ValueError(
|
| 643 |
+
f"Unsupported quantization type: '{quantization}'. "
|
| 644 |
+
f"Supported types are: 'int8', 'binary'. "
|
| 645 |
+
f"Got: {type(quantization).__name__} = '{quantization}'"
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
self.eval()
|
| 649 |
|
| 650 |
if device is None:
|
|
|
|
| 690 |
for doc_chunks in batch_chunk_embeddings
|
| 691 |
]
|
| 692 |
|
| 693 |
+
batch_chunk_embeddings = [
|
| 694 |
+
self._flexible_quantizer(
|
| 695 |
+
{"sentence_embedding": emb}, quantization=quantization
|
| 696 |
+
)["sentence_embedding"]
|
| 697 |
+
for emb in batch_chunk_embeddings
|
| 698 |
+
]
|
| 699 |
|
| 700 |
if normalize_embeddings:
|
| 701 |
batch_chunk_embeddings = [
|
|
|
|
| 707 |
|
| 708 |
all_embeddings.extend(batch_chunk_embeddings)
|
| 709 |
|
|
|
|
| 710 |
if convert_to_numpy:
|
| 711 |
all_embeddings = [emb.numpy() for emb in all_embeddings]
|
| 712 |
|
st_quantize.py
CHANGED
|
@@ -24,9 +24,7 @@ class Quantizer(torch.nn.Module):
|
|
| 24 |
result = soft
|
| 25 |
else:
|
| 26 |
result = (
|
| 27 |
-
self._hard_quantize(x, *args, **kwargs).detach()
|
| 28 |
-
+ soft
|
| 29 |
-
- soft.detach()
|
| 30 |
)
|
| 31 |
|
| 32 |
return result
|
|
@@ -53,13 +51,13 @@ class Int8TanhQuantizer(Quantizer):
|
|
| 53 |
|
| 54 |
class BinaryTanhQuantizer(Quantizer):
|
| 55 |
def __init__(
|
| 56 |
-
self,
|
| 57 |
hard: bool = True,
|
| 58 |
scale: float = 1.0,
|
| 59 |
):
|
| 60 |
super().__init__(hard)
|
| 61 |
self._scale = scale
|
| 62 |
-
|
| 63 |
def _soft_quantize(self, x, *args, **kwargs):
|
| 64 |
return torch.tanh(self._scale * x)
|
| 65 |
|
|
@@ -73,7 +71,11 @@ class FlexibleQuantizer(torch.nn.Module):
|
|
| 73 |
self._int8_quantizer = Int8TanhQuantizer()
|
| 74 |
self._binary_quantizer = BinaryTanhQuantizer()
|
| 75 |
|
| 76 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
if quantization == "int8":
|
| 78 |
features["sentence_embedding"] = self._int8_quantizer(
|
| 79 |
features["sentence_embedding"]
|
|
@@ -83,10 +85,11 @@ class FlexibleQuantizer(torch.nn.Module):
|
|
| 83 |
features["sentence_embedding"]
|
| 84 |
)
|
| 85 |
else:
|
| 86 |
-
raise ValueError(
|
|
|
|
|
|
|
| 87 |
return features
|
| 88 |
-
|
| 89 |
@classmethod
|
| 90 |
def load(cls, input_path: str):
|
| 91 |
return cls()
|
| 92 |
-
|
|
|
|
| 24 |
result = soft
|
| 25 |
else:
|
| 26 |
result = (
|
| 27 |
+
self._hard_quantize(x, *args, **kwargs).detach() + soft - soft.detach()
|
|
|
|
|
|
|
| 28 |
)
|
| 29 |
|
| 30 |
return result
|
|
|
|
| 51 |
|
| 52 |
class BinaryTanhQuantizer(Quantizer):
|
| 53 |
def __init__(
|
| 54 |
+
self,
|
| 55 |
hard: bool = True,
|
| 56 |
scale: float = 1.0,
|
| 57 |
):
|
| 58 |
super().__init__(hard)
|
| 59 |
self._scale = scale
|
| 60 |
+
|
| 61 |
def _soft_quantize(self, x, *args, **kwargs):
|
| 62 |
return torch.tanh(self._scale * x)
|
| 63 |
|
|
|
|
| 71 |
self._int8_quantizer = Int8TanhQuantizer()
|
| 72 |
self._binary_quantizer = BinaryTanhQuantizer()
|
| 73 |
|
| 74 |
+
def forward(
|
| 75 |
+
self,
|
| 76 |
+
features: dict[str, torch.Tensor],
|
| 77 |
+
quantization: Literal["binary", "int8"] = "int8",
|
| 78 |
+
) -> dict[str, torch.Tensor]:
|
| 79 |
if quantization == "int8":
|
| 80 |
features["sentence_embedding"] = self._int8_quantizer(
|
| 81 |
features["sentence_embedding"]
|
|
|
|
| 85 |
features["sentence_embedding"]
|
| 86 |
)
|
| 87 |
else:
|
| 88 |
+
raise ValueError(
|
| 89 |
+
f"Invalid quantization type: {quantization}. Must be 'binary' or 'int8'."
|
| 90 |
+
)
|
| 91 |
return features
|
| 92 |
+
|
| 93 |
@classmethod
|
| 94 |
def load(cls, input_path: str):
|
| 95 |
return cls()
|
|
|