from fastapi import FastAPI, UploadFile, File from transformers import CLIPProcessor, CLIPModel from PIL import Image import torch import io import asyncio import time from contextlib import asynccontextmanager from typing import List, Tuple # Configuration MODEL_ID = "openai/clip-vit-large-patch14" BATCH_SIZE = 32 BATCH_TIMEOUT = 0.05 # 50ms wait to fill batch DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 # Global State model = None processor = None request_queue = asyncio.Queue() class SmartBatcher: """ Collects individual inference requests and processes them in optimal batches. """ def __init__(self): self.loop = asyncio.get_event_loop() self.processing_task = None def start(self): self.processing_task = self.loop.create_task(self.process_batches()) print("🚀 Smart Batcher started.") async def process_batches(self): while True: # 1. Collect Requests batch = [] # Wait for first item item = await request_queue.get() batch.append(item) # Try to fill batch within timeout window start_wait = time.time() while len(batch) < BATCH_SIZE: # Calculate remaining time in timeout window remaining = BATCH_TIMEOUT - (time.time() - start_wait) if remaining <= 0: break try: # Non-blocking check for more items # We use wait_for to respect the timeout window additional_item = await asyncio.wait_for(request_queue.get(), timeout=remaining) batch.append(additional_item) except asyncio.TimeoutError: break except Exception: break # 2. Process Batch if batch: await self.run_inference(batch) async def run_inference(self, batch: List[Tuple]): # Unpack batch: [(input_data, type, future), ...] text_inputs = [] image_inputs = [] # Sort indices to maintain order mapping # batch structure: (data, 'text'|'image', future) for i, (data, kind, fut) in enumerate(batch): if kind == 'text': text_inputs.append((i, data, fut)) elif kind == 'image': image_inputs.append((i, data, fut)) # Run Text Batch if text_inputs: texts = [t[1] for t in text_inputs] try: # Prepare Inputs inputs = processor( text=texts, padding=True, truncation=True, return_tensors="pt" ).to(DEVICE) # Inference with torch.inference_mode(): outputs = model.get_text_features(**inputs) outputs = outputs / outputs.norm(dim=-1, keepdim=True) vectors = outputs.cpu().tolist() # Distribute Results for j, vector in enumerate(vectors): original_idx, _, fut = text_inputs[j] if not fut.done(): fut.set_result(vector) except Exception as e: for _, _, fut in text_inputs: if not fut.done(): fut.set_exception(e) # Run Image Batch if image_inputs: images = [t[1] for t in image_inputs] try: # Prepare Inputs inputs = processor(images=images, return_tensors="pt").to(DEVICE) # Inference with torch.inference_mode(): outputs = model.get_image_features(**inputs) outputs = outputs / outputs.norm(dim=-1, keepdim=True) vectors = outputs.cpu().tolist() # Distribute Results for j, vector in enumerate(vectors): original_idx, _, fut = image_inputs[j] if not fut.done(): fut.set_result(vector) except Exception as e: for _, _, fut in image_inputs: if not fut.done(): fut.set_exception(e) @asynccontextmanager async def lifespan(app: FastAPI): global model, processor print("🧠 Loading CLIP Model...") # Load Model model = CLIPModel.from_pretrained( MODEL_ID, torch_dtype=DTYPE, low_cpu_mem_usage=True ).to(DEVICE).eval() # Compile model for faster inference (Linux/CUDA mostly, graceful fallback) try: model = torch.compile(model) print("⚡ Torch Compile enabled.") except Exception: print("⚠️ Torch Compile skipped (not supported).") processor = CLIPProcessor.from_pretrained(MODEL_ID) # Start Batcher batcher = SmartBatcher() batcher.start() yield print("🛑 Shutting down.") app = FastAPI(lifespan=lifespan) @app.post("/embed-text") async def embed_text(text: str): loop = asyncio.get_running_loop() fut = loop.create_future() await request_queue.put((text, 'text', fut)) # Wait for batch processor to set result result = await fut return {"vector": result} @app.post("/embed-image") async def embed_image(file: UploadFile = File(...)): # Read image immediately to avoid holding file handle in queue too long content = await file.read() image = Image.open(io.BytesIO(content)).convert("RGB") loop = asyncio.get_running_loop() fut = loop.create_future() await request_queue.put((image, 'image', fut)) result = await fut return {"vector": result} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8001)