final_project / src /embeddings_handler.py
dnj0's picture
Upload 7 files
835ecb4 verified
import torch
import open_clip
from typing import List
import numpy as np
class CLIPEmbeddingsHandler:
"""Handles CLIP embeddings for multimodal content."""
def __init__(self, model_name: str = "ViT-B-32", pretrained: str = "openai"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
try:
# FIXED: Correctly unpack 3 return values
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
model_name,
pretrained=pretrained,
device=self.device
)
self.tokenizer = open_clip.get_tokenizer(model_name)
self.model.eval() # Set to evaluation mode
print(f"✅ CLIP model loaded on {self.device}")
print(f" Model: {model_name}")
except Exception as e:
print(f"❌ Error loading CLIP model: {e}")
raise
def embed_text(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings for text."""
embeddings = []
with torch.no_grad():
for text in texts:
try:
tokens = self.tokenizer(text).to(self.device)
text_features = self.model.encode_text(tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
embeddings.append(text_features.cpu().numpy())
except Exception as e:
print(f"⚠️ Error embedding text: {e}")
embeddings.append(np.zeros(512))
result = np.array(embeddings).squeeze()
if len(result.shape) == 1:
result = np.expand_dims(result, axis=0)
return result
def embed_image_base64(self, image_base64: str) -> np.ndarray:
"""Generate embedding for base64 encoded image."""
import base64
import io
from PIL import Image
try:
image_data = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# Use the evaluation preprocessing
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image_tensor)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().squeeze()
except Exception as e:
print(f"❌ Error embedding image: {e}")
return np.zeros(512)
# LangChain wrapper
from langchain_core.embeddings import Embeddings
class CLIPLangChainEmbeddings(Embeddings):
"""LangChain wrapper for CLIP embeddings."""
def __init__(self, model_name: str = "ViT-B-32", pretrained: str = "openai"):
self.handler = CLIPEmbeddingsHandler(model_name, pretrained)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
embeddings = self.handler.embed_text(texts)
if len(embeddings.shape) == 1:
return [embeddings.tolist()]
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
embedding = self.handler.embed_text([text])
if len(embedding.shape) == 1:
return embedding.tolist()
return embedding[0].tolist()