Zero-Shot Image Classification
Transformers
Safetensors
siglip
vision
MOCI2001's picture
Create app.py
ed64727 verified
Raw
History Blame
1.64 kB
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from PIL import Image
import requests
from transformers import AutoProcessor, SiglipModel
# 1. 初始化 FastAPI
app = FastAPI(title="SigLIP 2 Embedding API")
# 2. 自動載入您複製的 SigLIP 2 模型 (只會在啟動時載入一次)
model_id = "google/siglip2-base-patch16-224"
print("正在載入 SigLIP 2 模型...")
processor = AutoProcessor.from_pretrained(model_id)
model = SiglipModel.from_pretrained(model_id)
print("模型載入完成!")
# 定義資料格式:API 接收一個包含圖片網址的 JSON
class ImageInput(BaseModel):
url: str
# 3. 建立網頁 API 接口 /embed
@app.post("/embed")
def get_embedding(data: ImageInput):
try:
# 下載 n8n 傳過來的圖片網址
image = Image.open(requests.get(data.url, stream=True).raw)
# 使用模型提取特徵
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
# 提取 768 維度圖片向量
image_features = model.get_image_features(**inputs)
# 進行歸一化 (L2 Normalization),這對向量搜尋非常重要
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
# 將 Tensor 轉換為 Python 的標準陣列 (List)
embedding_list = image_features.squeeze().tolist()
return {
"status": "success",
"dimension": len(embedding_list),
"embedding": embedding_list
}
except Exception as e:
return {"status": "error", "message": str(e)}