File size: 6,089 Bytes
d6a117a
fe8fc9b
d6a117a
1ef70b2
d6a117a
6ee8956
 
 
 
d6a117a
6ee8956
 
 
 
 
 
38dda6c
6ee8956
 
 
 
1ef70b2
6ee8956
 
 
 
 
 
 
98537c7
6ee8956
 
 
38dda6c
6ee8956
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6a117a
 
6ee8956
 
 
 
 
 
 
 
 
d6a117a
 
6ee8956
 
 
 
d6a117a
6ee8956
 
fbf7697
6ee8956
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from fastapi import FastAPI, UploadFile, File
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import io
import asyncio
import time
from contextlib import asynccontextmanager
from typing import List, Tuple

# Configuration
MODEL_ID = "openai/clip-vit-large-patch14"
BATCH_SIZE = 32
BATCH_TIMEOUT = 0.05  # 50ms wait to fill batch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

# Global State
model = None
processor = None
request_queue = asyncio.Queue()

class SmartBatcher:
    """
    Collects individual inference requests and processes them in optimal batches.
    """
    def __init__(self):
        self.loop = asyncio.get_event_loop()
        self.processing_task = None

    def start(self):
        self.processing_task = self.loop.create_task(self.process_batches())
        print("🚀 Smart Batcher started.")

    async def process_batches(self):
        while True:
            # 1. Collect Requests
            batch = []
            
            # Wait for first item
            item = await request_queue.get()
            batch.append(item)
            
            # Try to fill batch within timeout window
            start_wait = time.time()
            while len(batch) < BATCH_SIZE:
                # Calculate remaining time in timeout window
                remaining = BATCH_TIMEOUT - (time.time() - start_wait)
                if remaining <= 0:
                    break
                
                try:
                    # Non-blocking check for more items
                    # We use wait_for to respect the timeout window
                    additional_item = await asyncio.wait_for(request_queue.get(), timeout=remaining)
                    batch.append(additional_item)
                except asyncio.TimeoutError:
                    break
                except Exception:
                    break
            
            # 2. Process Batch
            if batch:
                await self.run_inference(batch)

    async def run_inference(self, batch: List[Tuple]):
        # Unpack batch: [(input_data, type, future), ...]
        text_inputs = []
        image_inputs = []
        
        # Sort indices to maintain order mapping
        # batch structure: (data, 'text'|'image', future)
        
        for i, (data, kind, fut) in enumerate(batch):
            if kind == 'text':
                text_inputs.append((i, data, fut))
            elif kind == 'image':
                image_inputs.append((i, data, fut))

        # Run Text Batch
        if text_inputs:
            texts = [t[1] for t in text_inputs]
            try:
                # Prepare Inputs
                inputs = processor(
                    text=texts, 
                    padding=True, 
                    truncation=True,
                    return_tensors="pt"
                ).to(DEVICE)

                # Inference
                with torch.inference_mode():
                    outputs = model.get_text_features(**inputs)
                    outputs = outputs / outputs.norm(dim=-1, keepdim=True)
                    vectors = outputs.cpu().tolist()

                # Distribute Results
                for j, vector in enumerate(vectors):
                    original_idx, _, fut = text_inputs[j]
                    if not fut.done():
                        fut.set_result(vector)
            except Exception as e:
                for _, _, fut in text_inputs:
                    if not fut.done():
                        fut.set_exception(e)

        # Run Image Batch
        if image_inputs:
            images = [t[1] for t in image_inputs]
            try:
                # Prepare Inputs
                inputs = processor(images=images, return_tensors="pt").to(DEVICE)

                # Inference
                with torch.inference_mode():
                    outputs = model.get_image_features(**inputs)
                    outputs = outputs / outputs.norm(dim=-1, keepdim=True)
                    vectors = outputs.cpu().tolist()

                # Distribute Results
                for j, vector in enumerate(vectors):
                    original_idx, _, fut = image_inputs[j]
                    if not fut.done():
                        fut.set_result(vector)
            except Exception as e:
                for _, _, fut in image_inputs:
                    if not fut.done():
                        fut.set_exception(e)


@asynccontextmanager
async def lifespan(app: FastAPI):
    global model, processor
    print("🧠 Loading CLIP Model...")
    
    # Load Model
    model = CLIPModel.from_pretrained(
        MODEL_ID, 
        torch_dtype=DTYPE, 
        low_cpu_mem_usage=True
    ).to(DEVICE).eval()
    
    # Compile model for faster inference (Linux/CUDA mostly, graceful fallback)
    try:
        model = torch.compile(model)
        print("⚡ Torch Compile enabled.")
    except Exception:
        print("⚠️ Torch Compile skipped (not supported).")

    processor = CLIPProcessor.from_pretrained(MODEL_ID)
    
    # Start Batcher
    batcher = SmartBatcher()
    batcher.start()
    
    yield
    print("🛑 Shutting down.")

app = FastAPI(lifespan=lifespan)

@app.post("/embed-text")
async def embed_text(text: str):
    loop = asyncio.get_running_loop()
    fut = loop.create_future()
    
    await request_queue.put((text, 'text', fut))
    
    # Wait for batch processor to set result
    result = await fut
    return {"vector": result}

@app.post("/embed-image")
async def embed_image(file: UploadFile = File(...)):
    # Read image immediately to avoid holding file handle in queue too long
    content = await file.read()
    image = Image.open(io.BytesIO(content)).convert("RGB")
    
    loop = asyncio.get_running_loop()
    fut = loop.create_future()
    
    await request_queue.put((image, 'image', fut))
    
    result = await fut
    return {"vector": result}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8001)