Spaces:
Running
Running
File size: 2,487 Bytes
fa16bad fea62df fa16bad fea62df fa16bad fea62df fa16bad fea62df fa16bad fea62df fa16bad fea62df fa16bad fea62df af36df4 fea62df fa16bad fea62df fa16bad fea62df fa16bad fea62df fa16bad fea62df fa16bad fea62df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from loguru import logger
from .config import ModelConfig
class EmbeddingModel:
"""
Embedding model wrapper for dense embeddings.
attributes:
config: ModelConfig instance
model: SentenceTransformer instance
_loaded: Flag indicating if the model is loaded
"""
def __init__(self, config: ModelConfig):
self.config = config
self.model: Optional[SentenceTransformer] = None
self._loaded = False
def load(self) -> None:
"""Load the embedding model."""
if self._loaded:
return
logger.info(f"Loading embedding model: {self.config.name}")
try:
self.model = SentenceTransformer(self.config.name, device="cpu", trust_remote_code=True)
self._loaded = True
logger.success(f"Loaded embedding model: {self.config.id}")
except Exception as e:
logger.error(f"Failed to load embedding model {self.config.id}: {e}")
raise
def query_embed(self, text: List[str], prompt: Optional[str] = None) -> List[float]:
"""
method to generate embedding for a single text.
Args:
text: Input text
prompt: Optional prompt for instruction-based models
Returns:
Embedding vector
"""
if not self._loaded:
self.load()
try:
embeddings = self.model.encode_query(text, prompt=prompt)
return [embedding.tolist() for embedding in embeddings]
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
raise
def embed_documents(self, texts: List[str], prompt: Optional[str] = None) -> List[List[float]]:
"""
method to generate embeddings for a list of texts.
Args:
texts: List of input texts
prompt: Optional prompt for instruction-based models
Returns:
List of embedding vectors
"""
if not self._loaded:
self.load()
try:
embeddings = self.model.encode_document(texts, prompt=prompt)
return [embedding.tolist() for embedding in embeddings]
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
raise
|