embedder2 / app.py
dickreuter's picture
update
d774fc1
import os
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import torch
import gc
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load model
print("Loading BAAI/bge-large-en-v1.5...")
model = SentenceTransformer(
'BAAI/bge-large-en-v1.5',
device='cuda' if torch.cuda.is_available() else 'cpu',
cache_folder='/app/cache'
)
model.eval()
# Enable mixed precision for faster inference
if torch.cuda.is_available():
model = model.half()
print(f"Model loaded on GPU with FP16 precision")
else:
print(f"Model loaded on CPU")
class EmbedRequest(BaseModel):
texts: list[str]
batch_size: int = 256 # BGE handles larger batches well
class EmbedResponse(BaseModel):
embeddings: list[list[float]]
processed: int
dimension: int
@app.post("/embed", response_model=EmbedResponse)
async def embed(request: EmbedRequest):
try:
# BGE recommends adding "Represent this sentence for searching relevant passages: "
# prefix for better performance, but it's optional
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
with torch.no_grad():
embeddings = model.encode(
request.texts,
batch_size=request.batch_size,
convert_to_numpy=True,
normalize_embeddings=True, # BGE embeddings should be normalized
show_progress_bar=False
)
# Cleanup
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return EmbedResponse(
embeddings=embeddings.tolist(),
processed=len(request.texts),
dimension=embeddings.shape[1]
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
return {
"status": "ready",
"model": "BAAI/bge-large-en-v1.5",
"cuda_available": torch.cuda.is_available(),
"device": str(next(model.parameters()).device)
}