multimodal_rag / embedding /embedder.py
mansh
Initial commit: Multimodal RAG over HuggingFace courses
9f3334a
"""BGE text embeddings and CLIP image embeddings."""
import torch
from PIL import Image
from sentence_transformers import SentenceTransformer
from transformers import CLIPModel, CLIPProcessor
from config import settings
class TextEmbedder:
"""BGE-small text embeddings for chunks and queries."""
def __init__(self):
self.model = SentenceTransformer(settings.text_embedding_model)
self.model.eval()
def embed_texts(self, texts: list[str]) -> list[list[float]]:
"""Embed a batch of texts. Returns list of float vectors."""
with torch.no_grad():
embeddings = self.model.encode(
texts,
batch_size=64,
show_progress_bar=True,
normalize_embeddings=True,
)
return embeddings.tolist()
def embed_query(self, query: str) -> list[float]:
"""Embed a single query."""
with torch.no_grad():
embedding = self.model.encode(
query,
normalize_embeddings=True,
)
return embedding.tolist()
class ImageEmbedder:
"""CLIP image embeddings for visual content."""
def __init__(self):
self.model = CLIPModel.from_pretrained(settings.clip_model)
self.processor = CLIPProcessor.from_pretrained(settings.clip_model)
self.model.eval()
def embed_images(self, images: list[Image.Image]) -> list[list[float]]:
"""Embed a batch of PIL images. Returns list of float vectors."""
with torch.no_grad():
inputs = self.processor(images=images, return_tensors="pt", padding=True)
outputs = self.model.get_image_features(**inputs)
# Normalize
outputs = outputs / outputs.norm(p=2, dim=-1, keepdim=True)
return outputs.tolist()
def embed_image(self, image: Image.Image) -> list[float]:
"""Embed a single image."""
return self.embed_images([image])[0]
def embed_text_for_image_search(self, text: str) -> list[float]:
"""Embed text using CLIP text encoder (for text-to-image search)."""
with torch.no_grad():
inputs = self.processor(text=[text], return_tensors="pt", padding=True)
outputs = self.model.get_text_features(**inputs)
outputs = outputs / outputs.norm(p=2, dim=-1, keepdim=True)
return outputs[0].tolist()
# Singletons — initialized lazily
_text_embedder: TextEmbedder | None = None
_image_embedder: ImageEmbedder | None = None
def get_text_embedder() -> TextEmbedder:
global _text_embedder
if _text_embedder is None:
_text_embedder = TextEmbedder()
return _text_embedder
def get_image_embedder() -> ImageEmbedder:
global _image_embedder
if _image_embedder is None:
_image_embedder = ImageEmbedder()
return _image_embedder