Spaces:
Sleeping
Sleeping
File size: 3,611 Bytes
835ecb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch
import open_clip
from typing import List
import numpy as np
class CLIPEmbeddingsHandler:
"""Handles CLIP embeddings for multimodal content."""
def __init__(self, model_name: str = "ViT-B-32", pretrained: str = "openai"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
try:
# FIXED: Correctly unpack 3 return values
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
model_name,
pretrained=pretrained,
device=self.device
)
self.tokenizer = open_clip.get_tokenizer(model_name)
self.model.eval() # Set to evaluation mode
print(f"✅ CLIP model loaded on {self.device}")
print(f" Model: {model_name}")
except Exception as e:
print(f"❌ Error loading CLIP model: {e}")
raise
def embed_text(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings for text."""
embeddings = []
with torch.no_grad():
for text in texts:
try:
tokens = self.tokenizer(text).to(self.device)
text_features = self.model.encode_text(tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
embeddings.append(text_features.cpu().numpy())
except Exception as e:
print(f"⚠️ Error embedding text: {e}")
embeddings.append(np.zeros(512))
result = np.array(embeddings).squeeze()
if len(result.shape) == 1:
result = np.expand_dims(result, axis=0)
return result
def embed_image_base64(self, image_base64: str) -> np.ndarray:
"""Generate embedding for base64 encoded image."""
import base64
import io
from PIL import Image
try:
image_data = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# Use the evaluation preprocessing
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image_tensor)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().squeeze()
except Exception as e:
print(f"❌ Error embedding image: {e}")
return np.zeros(512)
# LangChain wrapper
from langchain_core.embeddings import Embeddings
class CLIPLangChainEmbeddings(Embeddings):
"""LangChain wrapper for CLIP embeddings."""
def __init__(self, model_name: str = "ViT-B-32", pretrained: str = "openai"):
self.handler = CLIPEmbeddingsHandler(model_name, pretrained)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
embeddings = self.handler.embed_text(texts)
if len(embeddings.shape) == 1:
return [embeddings.tolist()]
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
embedding = self.handler.embed_text([text])
if len(embedding.shape) == 1:
return embedding.tolist()
return embedding[0].tolist()
|