Spaces:
Running
Running
| import logging | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| logger = logging.getLogger("EmbedService") | |
| class MultiEmbeddingService: | |
| def __init__(self): | |
| self.models = {} | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model_map = { | |
| 384: "./models/bge-384", | |
| 768: "./models/bge-768", | |
| 1024: "./models/bge-1024" | |
| } | |
| def load_all_models(self): | |
| """Loads all defined models into memory.""" | |
| logger.info(f"🚀 Acceleration Device: {self.device.upper()}") | |
| for dim, path in self.model_map.items(): | |
| try: | |
| logger.info(f"Loading {dim}-dimension model...") | |
| model = SentenceTransformer(path, device=self.device) | |
| model.eval() | |
| self.models[dim] = model | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load {dim}-dim model: {e}") | |
| def generate_embedding(self, text, dimension): | |
| if dimension not in self.models: | |
| raise ValueError(f"Dimension {dimension} not supported.") | |
| # show_progress_bar=False stops the spam | |
| return self.models[dimension].encode( | |
| text, | |
| normalize_embeddings=True, | |
| convert_to_numpy=True, | |
| show_progress_bar=False, | |
| batch_size=32 | |
| ).tolist() |