import os from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from sentence_transformers import SentenceTransformer import torch import gc app = FastAPI() # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Load model print("Loading BAAI/bge-large-en-v1.5...") model = SentenceTransformer( 'BAAI/bge-large-en-v1.5', device='cuda' if torch.cuda.is_available() else 'cpu', cache_folder='/app/cache' ) model.eval() # Enable mixed precision for faster inference if torch.cuda.is_available(): model = model.half() print(f"Model loaded on GPU with FP16 precision") else: print(f"Model loaded on CPU") class EmbedRequest(BaseModel): texts: list[str] batch_size: int = 256 # BGE handles larger batches well class EmbedResponse(BaseModel): embeddings: list[list[float]] processed: int dimension: int @app.post("/embed", response_model=EmbedResponse) async def embed(request: EmbedRequest): try: # BGE recommends adding "Represent this sentence for searching relevant passages: " # prefix for better performance, but it's optional # Clear GPU cache if torch.cuda.is_available(): torch.cuda.empty_cache() with torch.no_grad(): embeddings = model.encode( request.texts, batch_size=request.batch_size, convert_to_numpy=True, normalize_embeddings=True, # BGE embeddings should be normalized show_progress_bar=False ) # Cleanup gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return EmbedResponse( embeddings=embeddings.tolist(), processed=len(request.texts), dimension=embeddings.shape[1] ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health(): return { "status": "ready", "model": "BAAI/bge-large-en-v1.5", "cuda_available": torch.cuda.is_available(), "device": str(next(model.parameters()).device) }