chipling commited on
Commit
844d9d8
·
verified ·
1 Parent(s): 98537c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -4,49 +4,49 @@ from transformers import AutoProcessor, AutoModel
4
  from PIL import Image
5
  import torch
6
  import io
 
 
 
 
7
 
8
  app = FastAPI()
9
 
10
  model_id = "google/siglip2-so400m-patch14-384"
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- # OPTIMIZATION 1: Load with half-precision if on GPU, or optimized CPU settings
 
14
  model = AutoModel.from_pretrained(
15
  model_id,
16
- torch_dtype=torch.float32, # Use float16 if you have a GPU
17
  low_cpu_mem_usage=True
18
  ).to(device).eval()
19
 
20
- # OPTIMIZATION 2: Compile the model (Requires PyTorch 2.0+)
 
 
 
21
  try:
22
  model = torch.compile(model)
23
- except Exception as e:
24
- print(f"Compilation skipped: {e}")
25
-
26
- processor = AutoProcessor.from_pretrained(model_id)
27
 
28
  class TextRequest(BaseModel):
29
  text: str
30
 
31
- # OPTIMIZATION 3: Remove 'async' so FastAPI uses thread pools for CPU work
32
  @app.post("/embed-text")
33
  def embed_text(request: TextRequest):
34
  inputs = processor(text=[request.text], padding="max_length", return_tensors="pt").to(device)
35
- with torch.no_grad():
36
- # OPTIMIZATION 4: Use Inference Mode (faster than no_grad)
37
- with torch.inference_mode():
38
- text_outputs = model.get_text_features(**inputs)
39
-
40
  return {"vector": text_outputs[0].cpu().tolist(), "dim": 1152}
41
 
42
  @app.post("/embed-image")
43
  def embed_image(file: UploadFile = File(...)):
44
- # Reading file is still async-friendly
45
  image_data = file.file.read()
46
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
47
 
48
  inputs = processor(images=image, return_tensors="pt").to(device)
49
  with torch.inference_mode():
50
  image_outputs = model.get_image_features(**inputs)
51
-
52
  return {"vector": image_outputs[0].cpu().tolist(), "dim": 1152}
 
4
  from PIL import Image
5
  import torch
6
  import io
7
+ import os
8
+
9
+ # Set higher timeout for model downloading
10
+ os.environ["HF_HUB_READ_TIMEOUT"] = "60"
11
 
12
  app = FastAPI()
13
 
14
  model_id = "google/siglip2-so400m-patch14-384"
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
+ # FIX 1: Use torch_dtype directly (deprecation fix)
18
+ # FIX 2: use low_cpu_mem_usage to prevent RAM spikes on 16GB
19
  model = AutoModel.from_pretrained(
20
  model_id,
21
+ torch_dtype=torch.float32,
22
  low_cpu_mem_usage=True
23
  ).to(device).eval()
24
 
25
+ # FIX 3: Explicitly set use_fast=True to avoid the processor warning
26
+ processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
27
+
28
+ # OPTIMIZATION: Faster inference
29
  try:
30
  model = torch.compile(model)
31
+ except:
32
+ pass
 
 
33
 
34
  class TextRequest(BaseModel):
35
  text: str
36
 
 
37
  @app.post("/embed-text")
38
  def embed_text(request: TextRequest):
39
  inputs = processor(text=[request.text], padding="max_length", return_tensors="pt").to(device)
40
+ with torch.inference_mode():
41
+ text_outputs = model.get_text_features(**inputs)
 
 
 
42
  return {"vector": text_outputs[0].cpu().tolist(), "dim": 1152}
43
 
44
  @app.post("/embed-image")
45
  def embed_image(file: UploadFile = File(...)):
 
46
  image_data = file.file.read()
47
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
48
 
49
  inputs = processor(images=image, return_tensors="pt").to(device)
50
  with torch.inference_mode():
51
  image_outputs = model.get_image_features(**inputs)
 
52
  return {"vector": image_outputs[0].cpu().tolist(), "dim": 1152}