Spaces:
Sleeping
Sleeping
File size: 2,318 Bytes
f8e0141 d774fc1 f8e0141 d774fc1 b7594da f8e0141 d774fc1 f8e0141 b7594da f8e0141 d774fc1 f8e0141 d774fc1 f8e0141 d774fc1 f8e0141 d774fc1 f8e0141 d774fc1 f8e0141 d774fc1 f8e0141 d774fc1 f8e0141 b7594da f8e0141 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | import os
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import torch
import gc
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load model
print("Loading BAAI/bge-large-en-v1.5...")
model = SentenceTransformer(
'BAAI/bge-large-en-v1.5',
device='cuda' if torch.cuda.is_available() else 'cpu',
cache_folder='/app/cache'
)
model.eval()
# Enable mixed precision for faster inference
if torch.cuda.is_available():
model = model.half()
print(f"Model loaded on GPU with FP16 precision")
else:
print(f"Model loaded on CPU")
class EmbedRequest(BaseModel):
texts: list[str]
batch_size: int = 256 # BGE handles larger batches well
class EmbedResponse(BaseModel):
embeddings: list[list[float]]
processed: int
dimension: int
@app.post("/embed", response_model=EmbedResponse)
async def embed(request: EmbedRequest):
try:
# BGE recommends adding "Represent this sentence for searching relevant passages: "
# prefix for better performance, but it's optional
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
with torch.no_grad():
embeddings = model.encode(
request.texts,
batch_size=request.batch_size,
convert_to_numpy=True,
normalize_embeddings=True, # BGE embeddings should be normalized
show_progress_bar=False
)
# Cleanup
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return EmbedResponse(
embeddings=embeddings.tolist(),
processed=len(request.texts),
dimension=embeddings.shape[1]
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
return {
"status": "ready",
"model": "BAAI/bge-large-en-v1.5",
"cuda_available": torch.cuda.is_available(),
"device": str(next(model.parameters()).device)
} |