chipling commited on
Commit
38dda6c
·
verified ·
1 Parent(s): fbf7697

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -6,21 +6,33 @@ import io
6
 
7
  app = FastAPI()
8
  model_id = "google/siglip2-so400m-patch14-384"
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
10
 
11
- # Optimized loading based on Feb 2025 Docs
12
  model = AutoModel.from_pretrained(
13
  model_id,
14
- torch_dtype=torch.float32,
15
  low_cpu_mem_usage=True,
16
- attn_implementation="sdpa" # Scaled Dot Product Attention for speed
17
  ).to(device).eval()
18
 
 
 
 
 
 
 
 
19
  processor = AutoProcessor.from_pretrained(model_id)
20
 
 
 
21
  @app.post("/embed-text")
22
  def embed_text(text: str):
23
- # Docs specify max_length=64 for the Gemma-based tokenizer in SigLIP 2
24
  inputs = processor(
25
  text=[text],
26
  padding="max_length",
@@ -28,17 +40,19 @@ def embed_text(text: str):
28
  return_tensors="pt"
29
  ).to(device)
30
 
31
- with torch.inference_mode():
32
  outputs = model.get_text_features(**inputs)
 
33
  return {"vector": outputs[0].cpu().tolist()}
34
 
35
  @app.post("/embed-image")
36
  def embed_image(file: UploadFile = File(...)):
37
- image = Image.open(io.BytesIO(file.file.read())).convert("RGB")
 
38
 
39
- # NaFlex logic is handled automatically by the processor
40
  inputs = processor(images=image, return_tensors="pt").to(device)
41
 
42
  with torch.inference_mode():
43
  outputs = model.get_image_features(**inputs)
 
44
  return {"vector": outputs[0].cpu().tolist()}
 
6
 
7
  app = FastAPI()
8
  model_id = "google/siglip2-so400m-patch14-384"
9
+
10
+ # Check for GPU, but default to optimized CPU path
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
13
 
14
+ # 1. Load with memory-efficient settings
15
  model = AutoModel.from_pretrained(
16
  model_id,
17
+ torch_dtype=dtype,
18
  low_cpu_mem_usage=True,
19
+ attn_implementation="sdpa" # Use Scaled Dot Product Attention
20
  ).to(device).eval()
21
 
22
+ # 2. COMPILE THE MODEL (The huge speed boost)
23
+ # This takes 1 min to start up but makes every search 30% faster
24
+ try:
25
+ model = torch.compile(model)
26
+ except Exception:
27
+ print("Torch compile not supported on this environment, skipping...")
28
+
29
  processor = AutoProcessor.from_pretrained(model_id)
30
 
31
+ # 3. USE 'def' (Not 'async def') for CPU-heavy tasks
32
+ # This allows FastAPI to run searches in parallel on different CPU cores
33
  @app.post("/embed-text")
34
  def embed_text(text: str):
35
+ # GEMMA FIX: max_length=64 is required for SigLIP 2
36
  inputs = processor(
37
  text=[text],
38
  padding="max_length",
 
40
  return_tensors="pt"
41
  ).to(device)
42
 
43
+ with torch.inference_mode(): # Faster than no_grad()
44
  outputs = model.get_text_features(**inputs)
45
+
46
  return {"vector": outputs[0].cpu().tolist()}
47
 
48
  @app.post("/embed-image")
49
  def embed_image(file: UploadFile = File(...)):
50
+ # Optimized image reading
51
+ image = Image.open(file.file).convert("RGB")
52
 
 
53
  inputs = processor(images=image, return_tensors="pt").to(device)
54
 
55
  with torch.inference_mode():
56
  outputs = model.get_image_features(**inputs)
57
+
58
  return {"vector": outputs[0].cpu().tolist()}