Spaces:
Running
Running
File size: 5,388 Bytes
fea62df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from loguru import logger
from typing import Dict, List, Optional, Any
from sentence_transformers import SentenceTransformer
from sentence_transformers import SparseEncoder
class ModelConfig:
def __init__(self, model_id: str, config: Dict[str, Any]):
self.id = model_id
self.name = config["name"]
self.type = config["type"] # "embedding" or "sparse"
self.dimension = int(config["dimension"])
self.max_tokens = int(config["max_tokens"])
self.description = config["description"]
self.language = config["language"]
self.repository = config["repository"]
class EmbeddingModel:
"""
Embedding model wrapper for dense embeddings.
attributes:
config: ModelConfig instance
model: SentenceTransformer instance
_loaded: Flag indicating if the model is loaded
"""
def __init__(self, config: ModelConfig):
self.config = config
self.model: Optional[SentenceTransformer] = None
self._loaded = False
def load(self) -> None:
"""Load the embedding model."""
if self._loaded:
return
logger.info(f"Loading embedding model: {self.config.name}")
try:
self.model = SentenceTransformer(self.config.name, device="cpu", trust_remote_code=True)
self._loaded = True
logger.success(f"Loaded embedding model: {self.config.id}")
except Exception as e:
logger.error(f"Failed to load embedding model {self.config.id}: {e}")
raise
def embed(self, texts: List[str], prompt: Optional[str] = None) -> List[List[float]]:
"""
method to generate embeddings for a list of texts.
Args:
texts: List of input texts
prompt: Optional prompt for instruction-based models
Returns:
List of embedding vectors
"""
if not self._loaded:
self.load()
try:
embeddings = self.model.encode(texts, prompt=prompt)
return [embedding.tolist() for embedding in embeddings]
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
raise
class SparseEmbeddingModel:
"""
Sparse embedding model wrapper.
Attributes:
config: ModelConfig instance
model: SparseEncoder instance
_loaded: Flag indicating if the model is loaded
"""
def __init__(self, config: ModelConfig):
self.config = config
self.model: Optional[SparseEncoder] = None
self._loaded = False
def _format_values(self, values: List[float]) -> List[float]:
"""Format float values to a fixed precision."""
return [round(float(v), 7) for v in values]
def load(self) -> None:
"""Load the sparse embedding model."""
if self._loaded:
return
logger.info(f"Loading sparse model: {self.config.name}")
try:
self.model = SparseEncoder(self.config.name)
self._loaded = True
logger.success(f"Loaded sparse model: {self.config.id}")
except Exception as e:
logger.error(f"Failed to load sparse model {self.config.id}: {e}")
raise
def embed(self, text: str, prompt: Optional[str] = None) -> Dict[Any, Any]:
"""
Generate a sparse embedding for a single text.
Args:
text: Input text
prompt: Optional prompt for instruction-based models
Returns:
Sparse embedding as a dictionary with 'indices' and 'values' keys.
"""
try:
tensor = self.model.encode([text])
values = tensor[0].coalesce().values().tolist()
indices = tensor[0].coalesce().indices()[0].tolist()
return {
"indices": indices,
"values": self._format_values(values)
}
except Exception as e:
logger.error(f"Embedding error: {e}")
raise
def embed_batch(self, texts: List[str], prompt: Optional[str] = None) -> List[Dict[str, Any]]:
"""
Generate sparse embeddings for a batch of texts.
Args:
texts: List of input texts
prompt: Optional prompt for instruction-based models
Returns:
List of sparse embeddings as dictionaries with 'text' and 'sparse_embedding' keys.
"""
if not self._loaded:
self.load()
try:
tensors = self.model.encode(texts)
results = []
for i, tensor in enumerate(tensors):
values = tensor.coalesce().values().tolist()
indices = tensor.coalesce().indices()[0].tolist()
results.append({
"text": texts[i],
"sparse_embedding": {
"indices": indices,
"values": self._format_values(values)
}
})
return results
except Exception as e:
logger.error(f"Sparse embedding generation failed: {e}")
raise
|