Spaces:
Sleeping
Sleeping
File size: 2,686 Bytes
8602273 a4d5a67 8602273 a4d5a67 8602273 a4d5a67 8602273 a4d5a67 8602273 a4d5a67 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | import gradio as gr
import open_clip
import torch
from PIL import Image
MODEL_NAME = "ViT-B-32"
PRETRAINED = "laion2b_s34b_b79k"
print(f"Loading {MODEL_NAME} / {PRETRAINED}...")
model, _, preprocess = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED)
model.eval()
tokenizer = open_clip.get_tokenizer(MODEL_NAME)
print("Model loaded.")
def encode(text: str):
"""Encode a text query into a 512-d L2-normalized CLIP embedding."""
if not text or not text.strip():
return {"error": "empty text", "embedding": None}
with torch.no_grad():
tokens = tokenizer([text])
feat = model.encode_text(tokens)
feat = feat / feat.norm(dim=-1, keepdim=True)
return {
"model": MODEL_NAME,
"pretrained": PRETRAINED,
"dim": feat.shape[-1],
"embedding": feat[0].tolist(),
}
def encode_image(image):
"""Encode an image into a 512-d L2-normalized CLIP embedding.
Same vector space as encode(text), so text-image cosine similarity works.
Same vector space as the dataset features baked into the static site
(ViT-B-32 / laion2b_s34b_b79k), so user-uploaded image cues can be
compared directly to those features for compositional retrieval.
"""
if image is None:
return {"error": "no image", "embedding": None}
if not isinstance(image, Image.Image):
try:
image = Image.fromarray(image)
except Exception as e:
return {"error": f"unsupported image input: {type(image).__name__}: {e}", "embedding": None}
image = image.convert("RGB")
with torch.no_grad():
x = preprocess(image).unsqueeze(0)
feat = model.encode_image(x)
feat = feat / feat.norm(dim=-1, keepdim=True)
return {
"model": MODEL_NAME,
"pretrained": PRETRAINED,
"dim": feat.shape[-1],
"embedding": feat[0].tolist(),
}
text_iface = gr.Interface(
fn=encode,
inputs=gr.Textbox(label="Query", placeholder="a hotel bathroom with a walk-in shower"),
outputs=gr.JSON(label="CLIP text embedding"),
title="CLIP text encoder",
description="Returns a 512-d L2-normalized text embedding (ViT-B-32 / LAION-2B). API endpoint: /encode",
)
image_iface = gr.Interface(
fn=encode_image,
inputs=gr.Image(type="pil", label="Image"),
outputs=gr.JSON(label="CLIP image embedding"),
title="CLIP image encoder",
description="Returns a 512-d L2-normalized image embedding (ViT-B-32 / LAION-2B). API endpoint: /encode_image",
)
demo = gr.TabbedInterface(
[text_iface, image_iface],
["Text", "Image"],
title="CLIP encoder (ViT-B-32 / LAION-2B)",
)
demo.launch()
|