Spaces:
Running
Running
File size: 2,454 Bytes
fa16bad fea62df fa16bad fea62df 0231daa fea62df 0231daa fea62df 0231daa fea62df 0231daa fea62df 0231daa fea62df 0231daa fea62df fa16bad fea62df fa16bad 0231daa fea62df fa16bad fea62df 0231daa fa16bad fea62df 0231daa fea62df af36df4 fea62df 0231daa fea62df fa16bad 0231daa fea62df 0231daa fa16bad fea62df 0231daa fea62df fa16bad fea62df fa16bad fea62df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from loguru import logger
from ..src.core.config import ModelConfig
class EmbeddingModel:
"""
Embedding model wrapper for dense embeddings.
attributes:
config: ModelConfig instance
model: SentenceTransformer instance
_loaded: Flag indicating if the model is loaded
"""
def __init__(self, config: ModelConfig):
self.config = config
self.model: Optional[SentenceTransformer] = None
self._loaded = False
def load(self) -> None:
"""Load the embedding model."""
if self._loaded:
return
logger.info(f"Loading embedding model: {self.config.name}")
try:
self.model = SentenceTransformer(
self.config.name, device="cpu", trust_remote_code=True
)
self._loaded = True
logger.success(f"Loaded embedding model: {self.config.id}")
except Exception as e:
logger.error(f"Failed to load embedding model {self.config.id}: {e}")
raise
def query_embed(self, text: List[str], prompt: Optional[str] = None) -> List[float]:
"""
method to generate embedding for a single text.
Args:
text: Input text
prompt: Optional prompt for instruction-based models
Returns:
Embedding vector
"""
if not self._loaded:
self.load()
try:
embeddings = self.model.encode_query(text, prompt=prompt)
return [embedding.tolist() for embedding in embeddings]
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
raise
def embed_documents(
self, texts: List[str], prompt: Optional[str] = None
) -> List[List[float]]:
"""
method to generate embeddings for a list of texts.
Args:
texts: List of input texts
prompt: Optional prompt for instruction-based models
Returns:
List of embedding vectors
"""
if not self._loaded:
self.load()
try:
embeddings = self.model.encode_document(texts, prompt=prompt)
return [embedding.tolist() for embedding in embeddings]
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
raise
|