File size: 2,686 Bytes
8602273
 
 
a4d5a67
8602273
 
 
 
 
a4d5a67
8602273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4d5a67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8602273
 
 
a4d5a67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8602273
 
a4d5a67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import open_clip
import torch
from PIL import Image

MODEL_NAME = "ViT-B-32"
PRETRAINED = "laion2b_s34b_b79k"

print(f"Loading {MODEL_NAME} / {PRETRAINED}...")
model, _, preprocess = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED)
model.eval()
tokenizer = open_clip.get_tokenizer(MODEL_NAME)
print("Model loaded.")


def encode(text: str):
    """Encode a text query into a 512-d L2-normalized CLIP embedding."""
    if not text or not text.strip():
        return {"error": "empty text", "embedding": None}
    with torch.no_grad():
        tokens = tokenizer([text])
        feat = model.encode_text(tokens)
        feat = feat / feat.norm(dim=-1, keepdim=True)
    return {
        "model": MODEL_NAME,
        "pretrained": PRETRAINED,
        "dim": feat.shape[-1],
        "embedding": feat[0].tolist(),
    }


def encode_image(image):
    """Encode an image into a 512-d L2-normalized CLIP embedding.

    Same vector space as encode(text), so text-image cosine similarity works.
    Same vector space as the dataset features baked into the static site
    (ViT-B-32 / laion2b_s34b_b79k), so user-uploaded image cues can be
    compared directly to those features for compositional retrieval.
    """
    if image is None:
        return {"error": "no image", "embedding": None}
    if not isinstance(image, Image.Image):
        try:
            image = Image.fromarray(image)
        except Exception as e:
            return {"error": f"unsupported image input: {type(image).__name__}: {e}", "embedding": None}
    image = image.convert("RGB")
    with torch.no_grad():
        x = preprocess(image).unsqueeze(0)
        feat = model.encode_image(x)
        feat = feat / feat.norm(dim=-1, keepdim=True)
    return {
        "model": MODEL_NAME,
        "pretrained": PRETRAINED,
        "dim": feat.shape[-1],
        "embedding": feat[0].tolist(),
    }


text_iface = gr.Interface(
    fn=encode,
    inputs=gr.Textbox(label="Query", placeholder="a hotel bathroom with a walk-in shower"),
    outputs=gr.JSON(label="CLIP text embedding"),
    title="CLIP text encoder",
    description="Returns a 512-d L2-normalized text embedding (ViT-B-32 / LAION-2B). API endpoint: /encode",
)

image_iface = gr.Interface(
    fn=encode_image,
    inputs=gr.Image(type="pil", label="Image"),
    outputs=gr.JSON(label="CLIP image embedding"),
    title="CLIP image encoder",
    description="Returns a 512-d L2-normalized image embedding (ViT-B-32 / LAION-2B). API endpoint: /encode_image",
)

demo = gr.TabbedInterface(
    [text_iface, image_iface],
    ["Text", "Image"],
    title="CLIP encoder (ViT-B-32 / LAION-2B)",
)

demo.launch()