chmielvu's picture
Fix memory accumulation with batch processing and periodic GC
98c2074 verified
"""
FastEmbed-based Code Embedding Server
Optimized for CPU Basic (2 vCPU, 16GB RAM)
Models:
- Dense: jinaai/jina-embeddings-v2-base-code (768 dim, ~0.64GB)
- Sparse: Qdrant/bm25 (~0.01GB)
- Reranker: jinaai/jina-reranker-v1-tiny-en (~0.13GB)
Memory optimization:
- Preload all models at startup (avoid runtime loading spikes)
- Use /data for persistent cache (HF Spaces)
- Limit batch_size and parallel workers
- Periodic garbage collection
"""
import gc
import os
import time
import uuid
from contextlib import asynccontextmanager
from typing import Any, Literal
import numpy as np
from fastapi import FastAPI
from pydantic import BaseModel, ConfigDict, Field
from fastembed import TextEmbedding, SparseTextEmbedding
from fastembed.rerank.cross_encoder import TextCrossEncoder
# Use /data for persistent cache in HF Spaces ( survives restarts)
# Falls back to /tmp for local development
CACHE_DIR = os.environ.get("FASTEMBED_CACHE", "/data/fastembed_cache" if os.path.exists("/data") else "/tmp/fastembed_cache")
# Model names
DENSE_MODEL = "jinaai/jina-embeddings-v2-base-code"
SPARSE_MODEL = "Qdrant/bm25"
RERANKER_MODEL = "jinaai/jina-reranker-v1-tiny-en"
# Memory-optimized settings for 2 vCPU, 16GB RAM
BATCH_SIZE = 32 # Limit batch to avoid memory spikes
PARALLEL_WORKERS = 1 # Single worker to avoid memory duplication
# Global model cache (singleton pattern)
_dense_model: TextEmbedding | None = None
_sparse_model: SparseTextEmbedding | None = None
_reranker_model: TextCrossEncoder | None = None
# Request counter for periodic GC
_request_count = 0
GC_INTERVAL = 50 # Run gc.collect() every 50 requests
def _run_periodic_gc():
"""Run garbage collection periodically to free intermediate tensors."""
global _request_count
_request_count += 1
if _request_count % GC_INTERVAL == 0:
gc.collect()
print(f"GC triggered after {_request_count} requests")
def _get_dense_model() -> TextEmbedding:
"""Get dense model (singleton, preloaded)."""
global _dense_model
if _dense_model is None:
_dense_model = TextEmbedding(
model_name=DENSE_MODEL,
cache_dir=CACHE_DIR,
)
return _dense_model
def _get_sparse_model() -> SparseTextEmbedding:
"""Get sparse BM25 model (singleton, preloaded)."""
global _sparse_model
if _sparse_model is None:
_sparse_model = SparseTextEmbedding(
model_name=SPARSE_MODEL,
cache_dir=CACHE_DIR,
)
return _sparse_model
def _get_reranker() -> TextCrossEncoder:
"""Get reranker model (singleton, preloaded)."""
global _reranker_model
if _reranker_model is None:
_reranker_model = TextCrossEncoder(
model_name=RERANKER_MODEL,
cache_dir=CACHE_DIR,
)
return _reranker_model
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup: preload ALL models to avoid runtime memory spikes."""
print("=" * 50)
print("PRELOADING ALL MODELS...")
print(f"Cache directory: {CACHE_DIR}")
print("=" * 50)
# Preload all models at startup
_get_dense_model()
print("Dense model loaded.")
_get_sparse_model()
print("Sparse model loaded.")
_get_reranker()
print("Reranker model loaded.")
print("All models ready.")
print("=" * 50)
# Initial GC to clean up any loading artifacts
gc.collect()
yield
# Cleanup on shutdown
global _dense_model, _sparse_model, _reranker_model
_dense_model = None
_sparse_model = None
_reranker_model = None
gc.collect()
print("Models cleared on shutdown.")
app = FastAPI(
title="FastEmbed Code Embeddings",
summary="CPU-optimized code embeddings with BM25 sparse and reranking",
version="2.2.0",
lifespan=lifespan,
)
# ==================== Request Models ====================
class EmbeddingRequest(BaseModel):
model_config = ConfigDict(extra="allow")
input: str | list[str]
model: str = "code-embed"
encoding_format: Literal["float", "base64"] = "float"
dimensions: int = 0 # 0 = full dimensions
class SparseEmbeddingRequest(BaseModel):
model_config = ConfigDict(extra="allow")
input: str | list[str]
model: str = "bm25"
class RerankRequest(BaseModel):
model_config = ConfigDict(extra="allow")
query: str = Field(..., max_length=8192)
documents: list[str] = Field(..., min_length=1, max_length=256)
return_documents: bool = False
raw_scores: bool = False
model: str = "code-rerank"
top_n: int | None = None
class HybridRequest(BaseModel):
"""Request for hybrid search embeddings (dense + sparse)."""
model_config = ConfigDict(extra="allow")
input: str | list[str]
dense_model: str = "code-embed"
sparse_model: str = "bm25"
# ==================== Helper Functions ====================
def _now_ts() -> int:
return int(time.time())
def _make_id(prefix: str) -> str:
return f"{prefix}-{uuid.uuid4().hex}"
def _normalize_input(input: str | list[str]) -> list[str]:
if isinstance(input, str):
return [input]
return input
def _truncate_embedding(vector: np.ndarray, dimensions: int) -> np.ndarray:
if dimensions > 0 and dimensions < len(vector):
return vector[:dimensions]
return vector
def _vector_to_payload(vector: np.ndarray, encoding_format: str) -> list[float] | str:
if encoding_format == "base64":
import base64
return base64.b64encode(vector.astype(np.float32).tobytes()).decode()
return vector.tolist()
def _chunk_batch(texts: list[str], batch_size: int) -> list[list[str]]:
"""Split texts into chunks to limit memory per batch."""
if len(texts) <= batch_size:
return [texts]
return [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
# ==================== API Endpoints ====================
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok", "models": f"{DENSE_MODEL} + {SPARSE_MODEL} + {RERANKER_MODEL}"}
@app.post("/embeddings")
@app.post("/v1/embeddings")
def embeddings(request: EmbeddingRequest) -> dict[str, Any]:
"""Generate dense embeddings using jina-embeddings-v2-base-code."""
texts = _normalize_input(request.input)
model = _get_dense_model()
# Process in batches to limit memory
all_embeddings = []
for chunk in _chunk_batch(texts, BATCH_SIZE):
chunk_embeddings = list(model.embed(chunk, batch_size=BATCH_SIZE, parallel=PARALLEL_WORKERS))
all_embeddings.extend(chunk_embeddings)
data = []
for idx, embedding in enumerate(all_embeddings):
embedding = _truncate_embedding(embedding, request.dimensions)
data.append({
"object": "embedding",
"embedding": _vector_to_payload(embedding, request.encoding_format),
"index": idx,
})
_run_periodic_gc()
return {
"object": "list",
"data": data,
"model": request.model,
"usage": {"prompt_tokens": sum(len(t.split()) for t in texts), "total_tokens": 0},
"id": _make_id("emb"),
"created": _now_ts(),
}
@app.post("/sparse/embeddings")
@app.post("/v1/sparse/embeddings")
def sparse_embeddings(request: SparseEmbeddingRequest) -> dict[str, Any]:
"""Generate sparse BM25 embeddings."""
texts = _normalize_input(request.input)
model = _get_sparse_model()
# Process in batches
all_embeddings = []
for chunk in _chunk_batch(texts, BATCH_SIZE):
chunk_embeddings = list(model.embed(chunk, batch_size=BATCH_SIZE, parallel=PARALLEL_WORKERS))
all_embeddings.extend(chunk_embeddings)
data = []
for idx, emb in enumerate(all_embeddings):
data.append({
"object": "sparse_embedding",
"indices": emb.indices.tolist(),
"values": emb.values.tolist(),
"index": idx,
})
_run_periodic_gc()
return {
"object": "list",
"data": data,
"model": request.model,
"id": _make_id("sparse"),
"created": _now_ts(),
}
@app.post("/rerank")
@app.post("/v1/rerank")
def rerank(request: RerankRequest) -> dict[str, Any]:
"""Rerank documents using cross-encoder."""
reranker = _get_reranker()
# Compute rerank scores
scores = reranker.rerank(request.query, request.documents)
results = []
for idx, score in enumerate(scores):
item = {"index": idx, "relevance_score": float(score)}
if request.return_documents:
item["document"] = request.documents[idx]
results.append(item)
# Sort by relevance
results.sort(key=lambda x: x["relevance_score"], reverse=True)
if request.top_n is not None:
results = results[:request.top_n]
_run_periodic_gc()
return {
"object": "rerank",
"results": results,
"model": request.model,
"usage": {
"prompt_tokens": len(request.query.split()),
"total_tokens": sum(len(d.split()) for d in request.documents),
},
"id": _make_id("rerank"),
"created": _now_ts(),
}
@app.post("/hybrid/embeddings")
@app.post("/v1/hybrid/embeddings")
def hybrid_embeddings(request: HybridRequest) -> dict[str, Any]:
"""Generate both dense and sparse embeddings for hybrid search."""
texts = _normalize_input(request.input)
dense_model = _get_dense_model()
sparse_model = _get_sparse_model()
# Process in batches for both models
all_dense = []
all_sparse = []
for chunk in _chunk_batch(texts, BATCH_SIZE):
dense_chunk = list(dense_model.embed(chunk, batch_size=BATCH_SIZE, parallel=PARALLEL_WORKERS))
sparse_chunk = list(sparse_model.embed(chunk, batch_size=BATCH_SIZE, parallel=PARALLEL_WORKERS))
all_dense.extend(dense_chunk)
all_sparse.extend(sparse_chunk)
data = []
for idx, (dense, sparse) in enumerate(zip(all_dense, all_sparse)):
data.append({
"object": "hybrid_embedding",
"dense": {
"vector": dense.tolist(),
"dim": len(dense),
},
"sparse": {
"indices": sparse.indices.tolist(),
"values": sparse.values.tolist(),
},
"index": idx,
})
_run_periodic_gc()
return {
"object": "list",
"data": data,
"model": f"{request.dense_model} + {request.sparse_model}",
"id": _make_id("hybrid"),
"created": _now_ts(),
}
# ==================== Model Info ====================
@app.get("/models")
def list_models() -> dict[str, Any]:
"""List supported models and their specs."""
return {
"dense": {
"model": DENSE_MODEL,
"dim": 768,
"size_gb": 0.64,
"type": "code-optimized",
},
"sparse": {
"model": SPARSE_MODEL,
"type": "bm25",
"size_gb": 0.01,
"requires_idf": True,
},
"reranker": {
"model": RERANKER_MODEL,
"size_gb": 0.13,
"type": "cross-encoder",
},
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)