Spaces:
Sleeping
Sleeping
| import torch | |
| import open_clip | |
| from typing import List | |
| import numpy as np | |
| class CLIPEmbeddingsHandler: | |
| """Handles CLIP embeddings for multimodal content.""" | |
| def __init__(self, model_name: str = "ViT-B-32", pretrained: str = "openai"): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| # FIXED: Correctly unpack 3 return values | |
| self.model, _, self.preprocess = open_clip.create_model_and_transforms( | |
| model_name, | |
| pretrained=pretrained, | |
| device=self.device | |
| ) | |
| self.tokenizer = open_clip.get_tokenizer(model_name) | |
| self.model.eval() # Set to evaluation mode | |
| print(f"✅ CLIP model loaded on {self.device}") | |
| print(f" Model: {model_name}") | |
| except Exception as e: | |
| print(f"❌ Error loading CLIP model: {e}") | |
| raise | |
| def embed_text(self, texts: List[str]) -> np.ndarray: | |
| """Generate embeddings for text.""" | |
| embeddings = [] | |
| with torch.no_grad(): | |
| for text in texts: | |
| try: | |
| tokens = self.tokenizer(text).to(self.device) | |
| text_features = self.model.encode_text(tokens) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| embeddings.append(text_features.cpu().numpy()) | |
| except Exception as e: | |
| print(f"⚠️ Error embedding text: {e}") | |
| embeddings.append(np.zeros(512)) | |
| result = np.array(embeddings).squeeze() | |
| if len(result.shape) == 1: | |
| result = np.expand_dims(result, axis=0) | |
| return result | |
| def embed_image_base64(self, image_base64: str) -> np.ndarray: | |
| """Generate embedding for base64 encoded image.""" | |
| import base64 | |
| import io | |
| from PIL import Image | |
| try: | |
| image_data = base64.b64decode(image_base64) | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| # Use the evaluation preprocessing | |
| image_tensor = self.preprocess(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| image_features = self.model.encode_image(image_tensor) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| return image_features.cpu().numpy().squeeze() | |
| except Exception as e: | |
| print(f"❌ Error embedding image: {e}") | |
| return np.zeros(512) | |
| # LangChain wrapper | |
| from langchain_core.embeddings import Embeddings | |
| class CLIPLangChainEmbeddings(Embeddings): | |
| """LangChain wrapper for CLIP embeddings.""" | |
| def __init__(self, model_name: str = "ViT-B-32", pretrained: str = "openai"): | |
| self.handler = CLIPEmbeddingsHandler(model_name, pretrained) | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| """Embed search docs.""" | |
| embeddings = self.handler.embed_text(texts) | |
| if len(embeddings.shape) == 1: | |
| return [embeddings.tolist()] | |
| return embeddings.tolist() | |
| def embed_query(self, text: str) -> List[float]: | |
| """Embed query text.""" | |
| embedding = self.handler.embed_text([text]) | |
| if len(embedding.shape) == 1: | |
| return embedding.tolist() | |
| return embedding[0].tolist() | |