Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| import gc | |
| app = FastAPI() | |
| # Enable CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load model | |
| print("Loading BAAI/bge-large-en-v1.5...") | |
| model = SentenceTransformer( | |
| 'BAAI/bge-large-en-v1.5', | |
| device='cuda' if torch.cuda.is_available() else 'cpu', | |
| cache_folder='/app/cache' | |
| ) | |
| model.eval() | |
| # Enable mixed precision for faster inference | |
| if torch.cuda.is_available(): | |
| model = model.half() | |
| print(f"Model loaded on GPU with FP16 precision") | |
| else: | |
| print(f"Model loaded on CPU") | |
| class EmbedRequest(BaseModel): | |
| texts: list[str] | |
| batch_size: int = 256 # BGE handles larger batches well | |
| class EmbedResponse(BaseModel): | |
| embeddings: list[list[float]] | |
| processed: int | |
| dimension: int | |
| async def embed(request: EmbedRequest): | |
| try: | |
| # BGE recommends adding "Represent this sentence for searching relevant passages: " | |
| # prefix for better performance, but it's optional | |
| # Clear GPU cache | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| with torch.no_grad(): | |
| embeddings = model.encode( | |
| request.texts, | |
| batch_size=request.batch_size, | |
| convert_to_numpy=True, | |
| normalize_embeddings=True, # BGE embeddings should be normalized | |
| show_progress_bar=False | |
| ) | |
| # Cleanup | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return EmbedResponse( | |
| embeddings=embeddings.tolist(), | |
| processed=len(request.texts), | |
| dimension=embeddings.shape[1] | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health(): | |
| return { | |
| "status": "ready", | |
| "model": "BAAI/bge-large-en-v1.5", | |
| "cuda_available": torch.cuda.is_available(), | |
| "device": str(next(model.parameters()).device) | |
| } |