Babu Pallam
Add local embedding model wrapper for RAG retrieval
4c612c8
Raw
History Blame Contribute Delete
1.92 kB
# ============================================================
# FILE: src/embeddings.py
# ============================================================
# PURPOSE:
# Convert text into embedding vectors.
#
# WHAT IS AN EMBEDDING?
# An embedding is a list of numbers that represents the meaning of text.
#
# Example:
# "refund policy" -> [0.12, -0.04, 0.88, ...]
#
# Similar text should have similar vectors.
#
# The same embedding model must be used for:
# - indexing documents
# - querying documents
#
# If you change the embedding model, rebuild the vector database.
# ============================================================
from typing import List
from sentence_transformers import SentenceTransformer
class EmbeddingModel:
"""
Wrapper around SentenceTransformer.
This keeps embedding logic separate from the vector database logic.
"""
def __init__(self, model_name: str, device: str = "cpu") -> None:
"""
Load the embedding model.
device='cpu':
- best for your current local machine
- no GPU required
"""
self.model_name = model_name
self.device = device
self.model = SentenceTransformer(model_name, device=device)
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""
Convert a list of texts into embedding vectors.
normalize_embeddings=True:
- useful for cosine similarity style retrieval
- makes vector comparison more stable
"""
if not texts:
return []
embeddings = self.model.encode(
texts,
normalize_embeddings=True,
show_progress_bar=True,
)
return embeddings.tolist()
def embed_query(self, query: str) -> List[float]:
"""
Convert one user question into one embedding vector.
"""
return self.embed_texts([query])[0]