chipling commited on
Commit
e4961b2
·
verified ·
1 Parent(s): fe8fc9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -45
app.py CHANGED
@@ -3,59 +3,187 @@ from transformers import CLIPProcessor, CLIPModel
3
  from PIL import Image
4
  import torch
5
  import io
 
 
 
 
6
 
7
- app = FastAPI()
8
- model_id = "openai/clip-vit-large-patch14"
 
 
 
 
9
 
10
- # Check for GPU, but default to optimized CPU path
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
13
 
14
- # 1. Load with memory-efficient settings
15
- model = CLIPModel.from_pretrained(
16
- model_id,
17
- torch_dtype=dtype,
18
- low_cpu_mem_usage=True
19
- ).to(device).eval()
 
20
 
21
- # 2. COMPILE THE MODEL (The huge speed boost)
22
- # This takes 1 min to start up but makes every search 30% faster
23
- try:
24
- model = torch.compile(model)
25
- except Exception:
26
- print("Torch compile not supported on this environment, skipping...")
27
 
28
- processor = CLIPProcessor.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # 3. USE 'def' (Not 'async def') for CPU-heavy tasks
31
- # This allows FastAPI to run searches in parallel on different CPU cores
32
  @app.post("/embed-text")
33
- def embed_text(text: str):
34
- # CLIP uses max 77 tokens for text
35
- inputs = processor(
36
- text=[text],
37
- padding=True,
38
- truncation=True,
39
- return_tensors="pt"
40
- ).to(device)
41
-
42
- with torch.inference_mode(): # Faster than no_grad()
43
- outputs = model.get_text_features(**inputs)
44
- # Normalize embeddings for cosine similarity
45
- outputs = outputs / outputs.norm(dim=-1, keepdim=True)
46
-
47
- return {"vector": outputs[0].cpu().tolist()}
48
 
49
  @app.post("/embed-image")
50
- def embed_image(file: UploadFile = File(...)):
51
- # Optimized image reading
52
- image = Image.open(file.file).convert("RGB")
 
53
 
54
- inputs = processor(images=image, return_tensors="pt").to(device)
 
55
 
56
- with torch.inference_mode():
57
- outputs = model.get_image_features(**inputs)
58
- # Normalize embeddings for cosine similarity
59
- outputs = outputs / outputs.norm(dim=-1, keepdim=True)
60
-
61
- return {"vector": outputs[0].cpu().tolist()}
 
 
 
3
  from PIL import Image
4
  import torch
5
  import io
6
+ import asyncio
7
+ import time
8
+ from contextlib import asynccontextmanager
9
+ from typing import List, Tuple
10
 
11
+ # Configuration
12
+ MODEL_ID = "openai/clip-vit-large-patch14"
13
+ BATCH_SIZE = 32
14
+ BATCH_TIMEOUT = 0.05 # 50ms wait to fill batch
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
17
 
18
+ # Global State
19
+ model = None
20
+ processor = None
21
+ request_queue = asyncio.Queue()
22
 
23
+ class SmartBatcher:
24
+ """
25
+ Collects individual inference requests and processes them in optimal batches.
26
+ """
27
+ def __init__(self):
28
+ self.loop = asyncio.get_event_loop()
29
+ self.processing_task = None
30
 
31
+ def start(self):
32
+ self.processing_task = self.loop.create_task(self.process_batches())
33
+ print("🚀 Smart Batcher started.")
 
 
 
34
 
35
+ async def process_batches(self):
36
+ while True:
37
+ # 1. Collect Requests
38
+ batch = []
39
+
40
+ # Wait for first item
41
+ item = await request_queue.get()
42
+ batch.append(item)
43
+
44
+ # Try to fill batch within timeout window
45
+ start_wait = time.time()
46
+ while len(batch) < BATCH_SIZE:
47
+ # Calculate remaining time in timeout window
48
+ remaining = BATCH_TIMEOUT - (time.time() - start_wait)
49
+ if remaining <= 0:
50
+ break
51
+
52
+ try:
53
+ # Non-blocking check for more items
54
+ # We use wait_for to respect the timeout window
55
+ additional_item = await asyncio.wait_for(request_queue.get(), timeout=remaining)
56
+ batch.append(additional_item)
57
+ except asyncio.TimeoutError:
58
+ break
59
+ except Exception:
60
+ break
61
+
62
+ # 2. Process Batch
63
+ if batch:
64
+ await self.run_inference(batch)
65
+
66
+ async def run_inference(self, batch: List[Tuple]):
67
+ # Unpack batch: [(input_data, type, future), ...]
68
+ text_inputs = []
69
+ image_inputs = []
70
+
71
+ # Sort indices to maintain order mapping
72
+ # batch structure: (data, 'text'|'image', future)
73
+
74
+ for i, (data, kind, fut) in enumerate(batch):
75
+ if kind == 'text':
76
+ text_inputs.append((i, data, fut))
77
+ elif kind == 'image':
78
+ image_inputs.append((i, data, fut))
79
+
80
+ # Run Text Batch
81
+ if text_inputs:
82
+ texts = [t[1] for t in text_inputs]
83
+ try:
84
+ # Prepare Inputs
85
+ inputs = processor(
86
+ text=texts,
87
+ padding=True,
88
+ truncation=True,
89
+ return_tensors="pt"
90
+ ).to(DEVICE)
91
+
92
+ # Inference
93
+ with torch.inference_mode():
94
+ outputs = model.get_text_features(**inputs)
95
+ outputs = outputs / outputs.norm(dim=-1, keepdim=True)
96
+ vectors = outputs.cpu().tolist()
97
+
98
+ # Distribute Results
99
+ for j, vector in enumerate(vectors):
100
+ original_idx, _, fut = text_inputs[j]
101
+ if not fut.done():
102
+ fut.set_result(vector)
103
+ except Exception as e:
104
+ for _, _, fut in text_inputs:
105
+ if not fut.done():
106
+ fut.set_exception(e)
107
+
108
+ # Run Image Batch
109
+ if image_inputs:
110
+ images = [t[1] for t in image_inputs]
111
+ try:
112
+ # Prepare Inputs
113
+ inputs = processor(images=images, return_tensors="pt").to(DEVICE)
114
+
115
+ # Inference
116
+ with torch.inference_mode():
117
+ outputs = model.get_image_features(**inputs)
118
+ outputs = outputs / outputs.norm(dim=-1, keepdim=True)
119
+ vectors = outputs.cpu().tolist()
120
+
121
+ # Distribute Results
122
+ for j, vector in enumerate(vectors):
123
+ original_idx, _, fut = image_inputs[j]
124
+ if not fut.done():
125
+ fut.set_result(vector)
126
+ except Exception as e:
127
+ for _, _, fut in image_inputs:
128
+ if not fut.done():
129
+ fut.set_exception(e)
130
+
131
+
132
+ @asynccontextmanager
133
+ async def lifespan(app: FastAPI):
134
+ global model, processor
135
+ print("🧠 Loading CLIP Model...")
136
+
137
+ # Load Model
138
+ model = CLIPModel.from_pretrained(
139
+ MODEL_ID,
140
+ torch_dtype=DTYPE,
141
+ low_cpu_mem_usage=True
142
+ ).to(DEVICE).eval()
143
+
144
+ # Compile model for faster inference (Linux/CUDA mostly, graceful fallback)
145
+ try:
146
+ model = torch.compile(model)
147
+ print("⚡ Torch Compile enabled.")
148
+ except Exception:
149
+ print("⚠️ Torch Compile skipped (not supported).")
150
+
151
+ processor = CLIPProcessor.from_pretrained(MODEL_ID)
152
+
153
+ # Start Batcher
154
+ batcher = SmartBatcher()
155
+ batcher.start()
156
+
157
+ yield
158
+ print("🛑 Shutting down.")
159
+
160
+ app = FastAPI(lifespan=lifespan)
161
 
 
 
162
  @app.post("/embed-text")
163
+ async def embed_text(text: str):
164
+ loop = asyncio.get_running_loop()
165
+ fut = loop.create_future()
166
+
167
+ await request_queue.put((text, 'text', fut))
168
+
169
+ # Wait for batch processor to set result
170
+ result = await fut
171
+ return {"vector": result}
 
 
 
 
 
 
172
 
173
  @app.post("/embed-image")
174
+ async def embed_image(file: UploadFile = File(...)):
175
+ # Read image immediately to avoid holding file handle in queue too long
176
+ content = await file.read()
177
+ image = Image.open(io.BytesIO(content)).convert("RGB")
178
 
179
+ loop = asyncio.get_running_loop()
180
+ fut = loop.create_future()
181
 
182
+ await request_queue.put((image, 'image', fut))
183
+
184
+ result = await fut
185
+ return {"vector": result}
186
+
187
+ if __name__ == "__main__":
188
+ import uvicorn
189
+ uvicorn.run(app, host="0.0.0.0", port=8001)