chipling commited on
Commit
1ef70b2
·
verified ·
1 Parent(s): 35daee1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -1,29 +1,42 @@
1
  from fastapi import FastAPI, UploadFile, File
2
  from pydantic import BaseModel
3
- from sentence_transformers import SentenceTransformer
4
  from PIL import Image
 
5
  import io
6
 
7
  app = FastAPI()
8
 
9
- # Load model into memory (do this globally so it only happens once)
10
- model = SentenceTransformer('google/siglip-so400m-patch14-384')
 
 
 
 
 
11
 
12
  class TextRequest(BaseModel):
13
  text: str
14
 
15
  @app.post("/embed-text")
16
  async def embed_text(request: TextRequest):
17
- # Convert text to vector
18
- vector = model.encode(request.text).tolist()
19
- return {"vector": vector}
 
 
 
 
20
 
21
  @app.post("/embed-image")
22
  async def embed_image(file: UploadFile = File(...)):
23
- # Read uploaded image
24
  image_data = await file.read()
25
- image = Image.open(io.BytesIO(image_data))
26
 
27
- # Convert image to vector
28
- vector = model.encode(image).tolist()
29
- return {"vector": vector}
 
 
 
 
 
1
  from fastapi import FastAPI, UploadFile, File
2
  from pydantic import BaseModel
3
+ from transformers import AutoProcessor, AutoModel
4
  from PIL import Image
5
+ import torch
6
  import io
7
 
8
  app = FastAPI()
9
 
10
+ # Load SigLIP 2
11
+ model_id = "google/siglip2-so400m-patch14-384"
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Load Model and Processor
15
+ model = AutoModel.from_pretrained(model_id).to(device).eval()
16
+ processor = AutoProcessor.from_pretrained(model_id)
17
 
18
  class TextRequest(BaseModel):
19
  text: str
20
 
21
  @app.post("/embed-text")
22
  async def embed_text(request: TextRequest):
23
+ inputs = processor(text=[request.text], padding="max_length", return_tensors="pt").to(device)
24
+ with torch.no_grad():
25
+ # Get the text embeddings
26
+ text_outputs = model.get_text_features(**inputs)
27
+
28
+ vector = text_outputs[0].cpu().tolist()
29
+ return {"vector": vector, "dim": len(vector)}
30
 
31
  @app.post("/embed-image")
32
  async def embed_image(file: UploadFile = File(...)):
 
33
  image_data = await file.read()
34
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
35
 
36
+ inputs = processor(images=image, return_tensors="pt").to(device)
37
+ with torch.no_grad():
38
+ # Get the image embeddings
39
+ image_outputs = model.get_image_features(**inputs)
40
+
41
+ vector = image_outputs[0].cpu().tolist()
42
+ return {"vector": vector, "dim": len(vector)}