clip-embedder / app.py
akashD22's picture
Update app.py
a38cf9c verified
import os
# Redirect all HF & Transformers cache to a writable folder
cache_dir = "/tmp/hf_cache"
os.environ["XDG_CACHE_HOME"] = cache_dir
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.makedirs(cache_dir, exist_ok=True)
from fastapi import FastAPI, Request
from transformers import CLIPModel, CLIPProcessor
import torch, base64, io
from PIL import Image
app = FastAPI()
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
proc = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
@app.get("/")
async def root():
return {"status": "CLIP embedder is running!"}
@app.post("/embed/text")
async def embed_text(req: Request):
data = await req.json()
texts = data.get("texts", [])
inputs = proc(text=texts, return_tensors="pt", padding=True, truncation=True)
feats = model.get_text_features(**inputs)
feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
return {"embeddings": feats.detach().cpu().tolist()}
@app.post("/embed/image")
async def embed_image(req: Request):
data = await req.json()
b64_list = data.get("images", [])
out = []
for b in b64_list:
img = Image.open(io.BytesIO(base64.b64decode(b))).convert("RGB")
inputs = proc(images=img, return_tensors="pt")
feats = model.get_image_features(**inputs)
feats = feats / feats.norm(p=2, dim=-1, keepdim=True)
out.append(feats.detach().cpu().tolist()[0])
return {"embeddings": out}