File size: 1,502 Bytes
3d92313
 
 
 
 
 
 
 
 
 
0c464c6
 
 
 
 
 
 
 
 
a38cf9c
 
 
 
 
0c464c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os

# Redirect all HF & Transformers cache to a writable folder
cache_dir = "/tmp/hf_cache"
os.environ["XDG_CACHE_HOME"]    = cache_dir
os.environ["HF_HOME"]           = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.makedirs(cache_dir, exist_ok=True)


from fastapi import FastAPI, Request
from transformers import CLIPModel, CLIPProcessor
import torch, base64, io
from PIL import Image

app = FastAPI()
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
proc  = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

@app.get("/")
async def root():
    return {"status": "CLIP embedder is running!"}


@app.post("/embed/text")
async def embed_text(req: Request):
    data = await req.json()
    texts = data.get("texts", [])
    inputs = proc(text=texts, return_tensors="pt", padding=True, truncation=True)
    feats = model.get_text_features(**inputs)
    feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
    return {"embeddings": feats.detach().cpu().tolist()}

@app.post("/embed/image")
async def embed_image(req: Request):
    data = await req.json()
    b64_list = data.get("images", [])
    out = []
    for b in b64_list:
        img = Image.open(io.BytesIO(base64.b64decode(b))).convert("RGB")
        inputs = proc(images=img, return_tensors="pt")
        feats = model.get_image_features(**inputs)
        feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
        out.append(feats.detach().cpu().tolist()[0])
    return {"embeddings": out}