| | """Sentence embedding generation utilities.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Iterable, List |
| |
|
| | import numpy as np |
| | from sentence_transformers import SentenceTransformer |
| | from tqdm.auto import tqdm |
| |
|
| |
|
| | @dataclass |
| | class EmbeddingGenerator: |
| | model_name: str = "sentence-transformers/all-MiniLM-L6-v2" |
| | batch_size: int = 64 |
| | normalize: bool = True |
| |
|
| | def __post_init__(self) -> None: |
| | self.model = SentenceTransformer(self.model_name) |
| |
|
| | def encode(self, texts: Iterable[str]) -> np.ndarray: |
| | embeddings: List[np.ndarray] = [] |
| | batch: List[str] = [] |
| | for text in texts: |
| | batch.append(text) |
| | if len(batch) == self.batch_size: |
| | embeddings.append(self.model.encode(batch, normalize_embeddings=self.normalize)) |
| | batch = [] |
| | if batch: |
| | embeddings.append(self.model.encode(batch, normalize_embeddings=self.normalize)) |
| | return np.vstack(embeddings) |
| |
|
| | def save(self, embeddings: np.ndarray, path: Path) -> None: |
| | path.parent.mkdir(parents=True, exist_ok=True) |
| | np.save(path, embeddings) |
| |
|
| |
|
| | __all__ = ["EmbeddingGenerator"] |
| |
|