Atelier-AI / engine.py
Priyanshiiiii's picture
Update engine.py
2061489 verified
import numpy as np
import open_clip
import torch
from PIL import Image
class FashionEncoder:
"""Encodes images, text, or fused (image + text modifier) queries."""
MODEL_NAME = "hf-hub:Marqo/marqo-fashionSigLIP"
def __init__(self):
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
self.MODEL_NAME
)
self.tokenizer = open_clip.get_tokenizer(self.MODEL_NAME)
self.model.eval()
@torch.no_grad()
def encode_image(self, image):
"""Encode a single PIL image -> normalised 768-dim vector."""
x = self.preprocess(image).unsqueeze(0)
v = self.model.encode_image(x).squeeze().numpy()
return v / np.linalg.norm(v)
@torch.no_grad()
def encode_images(self, images):
"""Encode a batch of PIL images -> (N, 768) normalised array."""
tensors = torch.stack([self.preprocess(img) for img in images])
vecs = self.model.encode_image(tensors).numpy()
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
return vecs / norms
@torch.no_grad()
def encode_text(self, text):
"""Encode a text string -> normalised 768-dim vector."""
tok = self.tokenizer([text])
v = self.model.encode_text(tok).squeeze().numpy()
return v / np.linalg.norm(v)
def encode_multimodal(self, image=None, text=None, alpha=0.7):
"""
Fuse image + text modifier into one query vector via weighted sum.
alpha = image weight; (1 - alpha) = text weight.
"""
parts = []
if image is not None:
parts.append((alpha, self.encode_image(image)))
if text:
weight = (1.0 - alpha) if image is not None else 1.0
parts.append((weight, self.encode_text(text)))
if not parts:
raise ValueError("At least one of image or text must be provided.")
fused = sum(w * v for w, v in parts)
return fused / np.linalg.norm(fused)