|
|
"""Embedding service for semantic search. |
|
|
|
|
|
IMPORTANT: All public methods are async to avoid blocking the event loop. |
|
|
The sentence-transformers model is CPU-bound, so we use run_in_executor(). |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import uuid |
|
|
from typing import Any |
|
|
|
|
|
import chromadb |
|
|
import structlog |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
from src.utils.config import settings |
|
|
from src.utils.models import Evidence |
|
|
|
|
|
_shared_model: SentenceTransformer | None = None |
|
|
|
|
|
|
|
|
def _get_shared_model(model_name: str) -> SentenceTransformer: |
|
|
"""Get or create shared SentenceTransformer model instance.""" |
|
|
global _shared_model |
|
|
if _shared_model is None: |
|
|
_shared_model = SentenceTransformer(model_name) |
|
|
return _shared_model |
|
|
|
|
|
|
|
|
class EmbeddingService: |
|
|
"""Handles text embedding and vector storage using local sentence-transformers. |
|
|
|
|
|
All embedding operations run in a thread pool to avoid blocking |
|
|
the async event loop. |
|
|
|
|
|
Note: |
|
|
Uses local sentence-transformers models (no API key required). |
|
|
Model is configured via settings.local_embedding_model. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_name: str | None = None): |
|
|
self._model_name = model_name or settings.local_embedding_model |
|
|
|
|
|
self._model = _get_shared_model(self._model_name) |
|
|
self._client = chromadb.Client() |
|
|
self._collection = self._client.create_collection( |
|
|
name=f"evidence_{uuid.uuid4().hex}", metadata={"hnsw:space": "cosine"} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sync_embed(self, text: str) -> list[float]: |
|
|
"""Synchronous embedding - DO NOT call directly from async code.""" |
|
|
result: list[float] = self._model.encode(text).tolist() |
|
|
return result |
|
|
|
|
|
def _sync_batch_embed(self, texts: list[str]) -> list[list[float]]: |
|
|
"""Batch embedding for efficiency - DO NOT call directly from async code.""" |
|
|
embeddings = self._model.encode(texts) |
|
|
return [e.tolist() for e in embeddings] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def embed(self, text: str) -> list[float]: |
|
|
"""Embed a single text (async-safe). |
|
|
|
|
|
Uses run_in_executor to avoid blocking the event loop. |
|
|
""" |
|
|
loop = asyncio.get_running_loop() |
|
|
return await loop.run_in_executor(None, self._sync_embed, text) |
|
|
|
|
|
async def embed_batch(self, texts: list[str]) -> list[list[float]]: |
|
|
"""Batch embed multiple texts (async-safe, more efficient).""" |
|
|
loop = asyncio.get_running_loop() |
|
|
return await loop.run_in_executor(None, self._sync_batch_embed, texts) |
|
|
|
|
|
async def add_evidence(self, evidence_id: str, content: str, metadata: dict[str, Any]) -> None: |
|
|
"""Add evidence to vector store (async-safe).""" |
|
|
embedding = await self.embed(content) |
|
|
|
|
|
loop = asyncio.get_running_loop() |
|
|
await loop.run_in_executor( |
|
|
None, |
|
|
lambda: self._collection.add( |
|
|
ids=[evidence_id], |
|
|
embeddings=[embedding], |
|
|
metadatas=[metadata], |
|
|
documents=[content], |
|
|
), |
|
|
) |
|
|
|
|
|
async def search_similar(self, query: str, n_results: int = 5) -> list[dict[str, Any]]: |
|
|
"""Find semantically similar evidence (async-safe).""" |
|
|
query_embedding = await self.embed(query) |
|
|
|
|
|
loop = asyncio.get_running_loop() |
|
|
results = await loop.run_in_executor( |
|
|
None, |
|
|
lambda: self._collection.query( |
|
|
query_embeddings=[query_embedding], |
|
|
n_results=n_results, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
ids = results.get("ids") |
|
|
docs = results.get("documents") |
|
|
metas = results.get("metadatas") |
|
|
dists = results.get("distances") |
|
|
|
|
|
if not ids or not ids[0] or not docs or not metas or not dists: |
|
|
return [] |
|
|
|
|
|
return [ |
|
|
{"id": id, "content": doc, "metadata": meta, "distance": dist} |
|
|
for id, doc, meta, dist in zip( |
|
|
ids[0], |
|
|
docs[0], |
|
|
metas[0], |
|
|
dists[0], |
|
|
strict=False, |
|
|
) |
|
|
] |
|
|
|
|
|
async def deduplicate( |
|
|
self, new_evidence: list[Evidence], threshold: float = 0.9 |
|
|
) -> list[Evidence]: |
|
|
"""Remove semantically duplicate evidence (async-safe). |
|
|
|
|
|
Args: |
|
|
new_evidence: List of evidence items to deduplicate |
|
|
threshold: Similarity threshold (0.9 = 90% similar is duplicate). |
|
|
ChromaDB cosine distance: 0=identical, 2=opposite. |
|
|
We consider duplicate if distance < (1 - threshold). |
|
|
|
|
|
Returns: |
|
|
List of unique evidence items (not already in vector store). |
|
|
""" |
|
|
unique = [] |
|
|
for evidence in new_evidence: |
|
|
try: |
|
|
similar = await self.search_similar(evidence.content, n_results=1) |
|
|
|
|
|
|
|
|
is_duplicate = similar and similar[0]["distance"] < (1 - threshold) |
|
|
|
|
|
if not is_duplicate: |
|
|
unique.append(evidence) |
|
|
|
|
|
await self.add_evidence( |
|
|
evidence_id=evidence.citation.url, |
|
|
content=evidence.content, |
|
|
metadata={ |
|
|
"source": evidence.citation.source, |
|
|
"title": evidence.citation.title, |
|
|
"date": evidence.citation.date, |
|
|
"authors": ",".join(evidence.citation.authors or []), |
|
|
}, |
|
|
) |
|
|
except Exception as e: |
|
|
|
|
|
structlog.get_logger().warning( |
|
|
"Failed to process evidence in deduplicate", |
|
|
url=evidence.citation.url, |
|
|
error=str(e), |
|
|
) |
|
|
|
|
|
unique.append(evidence) |
|
|
|
|
|
return unique |
|
|
|
|
|
|
|
|
def get_embedding_service() -> EmbeddingService: |
|
|
"""Get a new instance of EmbeddingService.""" |
|
|
|
|
|
return EmbeddingService() |
|
|
|