chipling commited on
Commit
fbf7697
·
verified ·
1 Parent(s): 844d9d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -28
app.py CHANGED
@@ -1,52 +1,44 @@
1
  from fastapi import FastAPI, UploadFile, File
2
- from pydantic import BaseModel
3
  from transformers import AutoProcessor, AutoModel
4
  from PIL import Image
5
  import torch
6
  import io
7
- import os
8
-
9
- # Set higher timeout for model downloading
10
- os.environ["HF_HUB_READ_TIMEOUT"] = "60"
11
 
12
  app = FastAPI()
13
-
14
  model_id = "google/siglip2-so400m-patch14-384"
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- # FIX 1: Use torch_dtype directly (deprecation fix)
18
- # FIX 2: use low_cpu_mem_usage to prevent RAM spikes on 16GB
19
  model = AutoModel.from_pretrained(
20
  model_id,
21
  torch_dtype=torch.float32,
22
- low_cpu_mem_usage=True
 
23
  ).to(device).eval()
24
 
25
- # FIX 3: Explicitly set use_fast=True to avoid the processor warning
26
- processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
27
-
28
- # OPTIMIZATION: Faster inference
29
- try:
30
- model = torch.compile(model)
31
- except:
32
- pass
33
-
34
- class TextRequest(BaseModel):
35
- text: str
36
 
37
  @app.post("/embed-text")
38
- def embed_text(request: TextRequest):
39
- inputs = processor(text=[request.text], padding="max_length", return_tensors="pt").to(device)
 
 
 
 
 
 
 
40
  with torch.inference_mode():
41
- text_outputs = model.get_text_features(**inputs)
42
- return {"vector": text_outputs[0].cpu().tolist(), "dim": 1152}
43
 
44
  @app.post("/embed-image")
45
  def embed_image(file: UploadFile = File(...)):
46
- image_data = file.file.read()
47
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
48
 
 
49
  inputs = processor(images=image, return_tensors="pt").to(device)
 
50
  with torch.inference_mode():
51
- image_outputs = model.get_image_features(**inputs)
52
- return {"vector": image_outputs[0].cpu().tolist(), "dim": 1152}
 
1
  from fastapi import FastAPI, UploadFile, File
 
2
  from transformers import AutoProcessor, AutoModel
3
  from PIL import Image
4
  import torch
5
  import io
 
 
 
 
6
 
7
  app = FastAPI()
 
8
  model_id = "google/siglip2-so400m-patch14-384"
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ # Optimized loading based on Feb 2025 Docs
 
12
  model = AutoModel.from_pretrained(
13
  model_id,
14
  torch_dtype=torch.float32,
15
+ low_cpu_mem_usage=True,
16
+ attn_implementation="sdpa" # Scaled Dot Product Attention for speed
17
  ).to(device).eval()
18
 
19
+ processor = AutoProcessor.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
20
 
21
  @app.post("/embed-text")
22
+ def embed_text(text: str):
23
+ # Docs specify max_length=64 for the Gemma-based tokenizer in SigLIP 2
24
+ inputs = processor(
25
+ text=[text],
26
+ padding="max_length",
27
+ max_length=64,
28
+ return_tensors="pt"
29
+ ).to(device)
30
+
31
  with torch.inference_mode():
32
+ outputs = model.get_text_features(**inputs)
33
+ return {"vector": outputs[0].cpu().tolist()}
34
 
35
  @app.post("/embed-image")
36
  def embed_image(file: UploadFile = File(...)):
37
+ image = Image.open(io.BytesIO(file.file.read())).convert("RGB")
 
38
 
39
+ # NaFlex logic is handled automatically by the processor
40
  inputs = processor(images=image, return_tensors="pt").to(device)
41
+
42
  with torch.inference_mode():
43
+ outputs = model.get_image_features(**inputs)
44
+ return {"vector": outputs[0].cpu().tolist()}