acd23's picture
Upload src/models/embedding.py
21a70bc verified
"""
CLIP Embedding Model -- AI Reel Creator Platform
=================================================
Wraps openai/clip-vit-large-patch14 for generating 768-dim
normalized embeddings for both images and text.
"""
import os
import warnings
from typing import List, Union, Optional
from pathlib import Path
import torch
import numpy as np
from PIL import Image
def _get_device() -> str:
if torch.cuda.is_available():
return "cuda"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return "mps"
return "cpu"
class CLIPEmbedder:
def __init__(
self,
model_name: str = "openai/clip-vit-large-patch14",
device: Optional[str] = None,
batch_size: int = 32,
normalize: bool = True,
):
self.model_name = model_name
self.device = device or os.environ.get("EMBEDDING_DEVICE", _get_device())
self.batch_size = batch_size
self.normalize = normalize
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
self.model = CLIPModel.from_pretrained(model_name).to(self.device)
self.processor = CLIPProcessor.from_pretrained(model_name)
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.model.eval()
def encode_image(self, image: Union[str, Path, Image.Image]) -> np.ndarray:
if isinstance(image, (str, Path)):
image = Image.open(str(image)).convert("RGB")
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
vec = self.model.get_image_features(**inputs)
vec = vec.squeeze().cpu().numpy()
if self.normalize:
vec = vec / (np.linalg.norm(vec) + 1e-12)
return vec
def batch_encode_images(self, images: List[Union[str, Path, Image.Image]], show_progress: bool = False) -> np.ndarray:
loaded = []
for img in images:
if isinstance(img, (str, Path)):
loaded.append(Image.open(str(img)).convert("RGB"))
else:
loaded.append(img)
all_vecs = []
iterator = range(0, len(loaded), self.batch_size)
if show_progress:
try:
from tqdm import tqdm
iterator = tqdm(iterator, total=len(loaded) // self.batch_size + 1, desc="Embedding images")
except ImportError:
pass
for i in iterator:
batch = loaded[i : i + self.batch_size]
inputs = self.processor(images=batch, return_tensors="pt", padding=True).to(self.device)
with torch.no_grad():
vecs = self.model.get_image_features(**inputs)
vecs = vecs.cpu().numpy()
if self.normalize:
vecs = vecs / (np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12)
all_vecs.append(vecs)
return np.vstack(all_vecs)
def encode_text(self, text: str) -> np.ndarray:
inputs = self.tokenizer(text, padding=True, truncation=True, max_length=77, return_tensors="pt").to(self.device)
with torch.no_grad():
vec = self.model.get_text_features(**inputs)
vec = vec.squeeze().cpu().numpy()
if self.normalize:
vec = vec / (np.linalg.norm(vec) + 1e-12)
return vec
def batch_encode_texts(self, texts: List[str], show_progress: bool = False) -> np.ndarray:
all_vecs = []
iterator = range(0, len(texts), self.batch_size)
if show_progress:
try:
from tqdm import tqdm
iterator = tqdm(iterator, total=len(texts) // self.batch_size + 1, desc="Embedding texts")
except ImportError:
pass
for i in iterator:
batch = texts[i : i + self.batch_size]
inputs = self.tokenizer(batch, padding=True, truncation=True, max_length=77, return_tensors="pt").to(self.device)
with torch.no_grad():
vecs = self.model.get_text_features(**inputs)
vecs = vecs.cpu().numpy()
if self.normalize:
vecs = vecs / (np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12)
all_vecs.append(vecs)
return np.vstack(all_vecs)
def encode_brochure_nodes(self, nodes: List[dict], text_field: str = "content") -> np.ndarray:
texts = []
for n in nodes:
title = n.get("title", "")
body = n.get(text_field, "")
texts.append(f"{title}\n\n{body}")
return self.batch_encode_texts(texts, show_progress=True)
def compute_similarity(self, text_vecs: np.ndarray, image_vecs: np.ndarray) -> np.ndarray:
return np.dot(text_vecs, image_vecs.T)
def main():
import argparse, json
parser = argparse.ArgumentParser(description="Embed images or texts with CLIP.")
parser.add_argument("--mode", choices=["image", "text"], required=True)
parser.add_argument("--input", required=True)
parser.add_argument("--output", default="embeddings.npy")
parser.add_argument("--model", default="openai/clip-vit-large-patch14")
parser.add_argument("--device", default=None)
parser.add_argument("--batch-size", type=int, default=32)
args = parser.parse_args()
embedder = CLIPEmbedder(model_name=args.model, device=args.device, batch_size=args.batch_size)
if args.mode == "image":
path = Path(args.input)
if path.is_dir():
images = sorted([str(p) for p in path.glob("*") if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp"}])
else:
images = [args.input]
embs = embedder.batch_encode_images(images, show_progress=True)
else:
with open(args.input, "r", encoding="utf-8") as f:
texts = [line.strip() for line in f if line.strip()]
embs = embedder.batch_encode_texts(texts, show_progress=True)
np.save(args.output, embs)
print(f"Saved {len(embs)} embeddings ({embs.shape[1]}-dim) -> {args.output}")
if __name__ == "__main__":
main()