openajaj / embedder.py
Jindrich3's picture
Super-squash branch 'main' using huggingface_hub
5eb8692
"""
embedder.py — Shared embedding model. Uses fastembed (ONNX, ~80MB RAM)
with fallback to sentence-transformers (PyTorch, ~500MB RAM).
"""
import numpy as np
MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
_model = None
_backend = None
def _init():
global _model, _backend
if _model is not None:
return
# Try fastembed first (lightweight ONNX)
try:
from fastembed import TextEmbedding
_model = TextEmbedding(MODEL_NAME)
_backend = "fastembed"
return
except Exception:
pass
# Fall back to sentence-transformers (heavy PyTorch)
try:
from sentence_transformers import SentenceTransformer
_model = SentenceTransformer(MODEL_NAME)
_backend = "sentence-transformers"
return
except Exception:
pass
raise RuntimeError("Nelze načíst embedding model. Nainstaluj fastembed nebo sentence-transformers.")
def get_backend():
_init()
return _backend
def encode(text):
"""Encode a single text string. Returns numpy array."""
_init()
if _backend == "fastembed":
return np.array(list(_model.embed([text]))[0])
else:
return _model.encode(text)
def encode_batch(texts, show_progress=True):
"""Encode a list of texts. Returns numpy array of shape (n, dim).
For large batches, prefers sentence-transformers (faster on GPU/CPU).
Falls back to fastembed if sentence-transformers unavailable.
"""
# For batch encoding, prefer sentence-transformers (much faster)
try:
from sentence_transformers import SentenceTransformer
st_model = SentenceTransformer(MODEL_NAME)
return st_model.encode(texts, show_progress_bar=show_progress, batch_size=256)
except ImportError:
pass
# Fallback to fastembed (slower for batch but works without PyTorch)
_init()
if _backend == "fastembed":
embeddings = list(_model.embed(texts, batch_size=256))
return np.array(embeddings)
else:
return _model.encode(texts, show_progress_bar=show_progress, batch_size=256)
class Embedder:
"""Wrapper class compatible with retrieve.py interface (has .encode method)."""
def encode(self, text):
return encode(text)