Spaces:
Running
Running
| from fastapi import FastAPI, UploadFile, File | |
| from transformers import CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| import torch | |
| import io | |
| import asyncio | |
| import time | |
| from contextlib import asynccontextmanager | |
| from typing import List, Tuple | |
| # Configuration | |
| MODEL_ID = "openai/clip-vit-large-patch14" | |
| BATCH_SIZE = 32 | |
| BATCH_TIMEOUT = 0.05 # 50ms wait to fill batch | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Global State | |
| model = None | |
| processor = None | |
| request_queue = asyncio.Queue() | |
| class SmartBatcher: | |
| """ | |
| Collects individual inference requests and processes them in optimal batches. | |
| """ | |
| def __init__(self): | |
| self.loop = asyncio.get_event_loop() | |
| self.processing_task = None | |
| def start(self): | |
| self.processing_task = self.loop.create_task(self.process_batches()) | |
| print("🚀 Smart Batcher started.") | |
| async def process_batches(self): | |
| while True: | |
| # 1. Collect Requests | |
| batch = [] | |
| # Wait for first item | |
| item = await request_queue.get() | |
| batch.append(item) | |
| # Try to fill batch within timeout window | |
| start_wait = time.time() | |
| while len(batch) < BATCH_SIZE: | |
| # Calculate remaining time in timeout window | |
| remaining = BATCH_TIMEOUT - (time.time() - start_wait) | |
| if remaining <= 0: | |
| break | |
| try: | |
| # Non-blocking check for more items | |
| # We use wait_for to respect the timeout window | |
| additional_item = await asyncio.wait_for(request_queue.get(), timeout=remaining) | |
| batch.append(additional_item) | |
| except asyncio.TimeoutError: | |
| break | |
| except Exception: | |
| break | |
| # 2. Process Batch | |
| if batch: | |
| await self.run_inference(batch) | |
| async def run_inference(self, batch: List[Tuple]): | |
| # Unpack batch: [(input_data, type, future), ...] | |
| text_inputs = [] | |
| image_inputs = [] | |
| # Sort indices to maintain order mapping | |
| # batch structure: (data, 'text'|'image', future) | |
| for i, (data, kind, fut) in enumerate(batch): | |
| if kind == 'text': | |
| text_inputs.append((i, data, fut)) | |
| elif kind == 'image': | |
| image_inputs.append((i, data, fut)) | |
| # Run Text Batch | |
| if text_inputs: | |
| texts = [t[1] for t in text_inputs] | |
| try: | |
| # Prepare Inputs | |
| inputs = processor( | |
| text=texts, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(DEVICE) | |
| # Inference | |
| with torch.inference_mode(): | |
| outputs = model.get_text_features(**inputs) | |
| outputs = outputs / outputs.norm(dim=-1, keepdim=True) | |
| vectors = outputs.cpu().tolist() | |
| # Distribute Results | |
| for j, vector in enumerate(vectors): | |
| original_idx, _, fut = text_inputs[j] | |
| if not fut.done(): | |
| fut.set_result(vector) | |
| except Exception as e: | |
| for _, _, fut in text_inputs: | |
| if not fut.done(): | |
| fut.set_exception(e) | |
| # Run Image Batch | |
| if image_inputs: | |
| images = [t[1] for t in image_inputs] | |
| try: | |
| # Prepare Inputs | |
| inputs = processor(images=images, return_tensors="pt").to(DEVICE) | |
| # Inference | |
| with torch.inference_mode(): | |
| outputs = model.get_image_features(**inputs) | |
| outputs = outputs / outputs.norm(dim=-1, keepdim=True) | |
| vectors = outputs.cpu().tolist() | |
| # Distribute Results | |
| for j, vector in enumerate(vectors): | |
| original_idx, _, fut = image_inputs[j] | |
| if not fut.done(): | |
| fut.set_result(vector) | |
| except Exception as e: | |
| for _, _, fut in image_inputs: | |
| if not fut.done(): | |
| fut.set_exception(e) | |
| async def lifespan(app: FastAPI): | |
| global model, processor | |
| print("🧠 Loading CLIP Model...") | |
| # Load Model | |
| model = CLIPModel.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| low_cpu_mem_usage=True | |
| ).to(DEVICE).eval() | |
| # Compile model for faster inference (Linux/CUDA mostly, graceful fallback) | |
| try: | |
| model = torch.compile(model) | |
| print("⚡ Torch Compile enabled.") | |
| except Exception: | |
| print("⚠️ Torch Compile skipped (not supported).") | |
| processor = CLIPProcessor.from_pretrained(MODEL_ID) | |
| # Start Batcher | |
| batcher = SmartBatcher() | |
| batcher.start() | |
| yield | |
| print("🛑 Shutting down.") | |
| app = FastAPI(lifespan=lifespan) | |
| async def embed_text(text: str): | |
| loop = asyncio.get_running_loop() | |
| fut = loop.create_future() | |
| await request_queue.put((text, 'text', fut)) | |
| # Wait for batch processor to set result | |
| result = await fut | |
| return {"vector": result} | |
| async def embed_image(file: UploadFile = File(...)): | |
| # Read image immediately to avoid holding file handle in queue too long | |
| content = await file.read() | |
| image = Image.open(io.BytesIO(content)).convert("RGB") | |
| loop = asyncio.get_running_loop() | |
| fut = loop.create_future() | |
| await request_queue.put((image, 'image', fut)) | |
| result = await fut | |
| return {"vector": result} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8001) | |