Zero-Shot Image Classification
Transformers
Safetensors
siglip
vision
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ from PIL import Image
5
+ import requests
6
+ from transformers import AutoProcessor, SiglipModel
7
+
8
+ # 1. 初始化 FastAPI
9
+ app = FastAPI(title="SigLIP 2 Embedding API")
10
+
11
+ # 2. 自動載入您複製的 SigLIP 2 模型 (只會在啟動時載入一次)
12
+ model_id = "google/siglip2-base-patch16-224"
13
+ print("正在載入 SigLIP 2 模型...")
14
+ processor = AutoProcessor.from_pretrained(model_id)
15
+ model = SiglipModel.from_pretrained(model_id)
16
+ print("模型載入完成!")
17
+
18
+ # 定義資料格式:API 接收一個包含圖片網址的 JSON
19
+ class ImageInput(BaseModel):
20
+ url: str
21
+
22
+ # 3. 建立網頁 API 接口 /embed
23
+ @app.post("/embed")
24
+ def get_embedding(data: ImageInput):
25
+ try:
26
+ # 下載 n8n 傳過來的圖片網址
27
+ image = Image.open(requests.get(data.url, stream=True).raw)
28
+
29
+ # 使用模型提取特徵
30
+ inputs = processor(images=image, return_tensors="pt")
31
+ with torch.no_grad():
32
+ # 提取 768 維度圖片向量
33
+ image_features = model.get_image_features(**inputs)
34
+ # 進行歸一化 (L2 Normalization),這對向量搜尋非常重要
35
+ image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
36
+
37
+ # 將 Tensor 轉換為 Python 的標準陣列 (List)
38
+ embedding_list = image_features.squeeze().tolist()
39
+
40
+ return {
41
+ "status": "success",
42
+ "dimension": len(embedding_list),
43
+ "embedding": embedding_list
44
+ }
45
+ except Exception as e:
46
+ return {"status": "error", "message": str(e)}