embeddings-api / model_service.py
Soumik Bose
ok
16530ae
import logging
import torch
from sentence_transformers import SentenceTransformer
logger = logging.getLogger("EmbedService")
class MultiEmbeddingService:
def __init__(self):
self.models = {}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_map = {
384: "./models/bge-384",
768: "./models/bge-768",
1024: "./models/bge-1024"
}
def load_all_models(self):
"""Loads all defined models into memory."""
logger.info(f"🚀 Acceleration Device: {self.device.upper()}")
for dim, path in self.model_map.items():
try:
logger.info(f"Loading {dim}-dimension model...")
model = SentenceTransformer(path, device=self.device)
model.eval()
self.models[dim] = model
except Exception as e:
logger.error(f"❌ Failed to load {dim}-dim model: {e}")
def generate_embedding(self, text, dimension):
if dimension not in self.models:
raise ValueError(f"Dimension {dimension} not supported.")
# show_progress_bar=False stops the spam
return self.models[dimension].encode(
text,
normalize_embeddings=True,
convert_to_numpy=True,
show_progress_bar=False,
batch_size=32
).tolist()