""" FastEmbed-based Code Embedding Server Optimized for CPU Basic (2 vCPU, 16GB RAM) Models: - Dense: jinaai/jina-embeddings-v2-base-code (768 dim, ~0.64GB) - Sparse: Qdrant/bm25 (~0.01GB) - Reranker: jinaai/jina-reranker-v1-tiny-en (~0.13GB) Memory optimization: - Preload all models at startup (avoid runtime loading spikes) - Use /data for persistent cache (HF Spaces) - Limit batch_size and parallel workers - Periodic garbage collection """ import gc import os import time import uuid from contextlib import asynccontextmanager from typing import Any, Literal import numpy as np from fastapi import FastAPI from pydantic import BaseModel, ConfigDict, Field from fastembed import TextEmbedding, SparseTextEmbedding from fastembed.rerank.cross_encoder import TextCrossEncoder # Use /data for persistent cache in HF Spaces ( survives restarts) # Falls back to /tmp for local development CACHE_DIR = os.environ.get("FASTEMBED_CACHE", "/data/fastembed_cache" if os.path.exists("/data") else "/tmp/fastembed_cache") # Model names DENSE_MODEL = "jinaai/jina-embeddings-v2-base-code" SPARSE_MODEL = "Qdrant/bm25" RERANKER_MODEL = "jinaai/jina-reranker-v1-tiny-en" # Memory-optimized settings for 2 vCPU, 16GB RAM BATCH_SIZE = 32 # Limit batch to avoid memory spikes PARALLEL_WORKERS = 1 # Single worker to avoid memory duplication # Global model cache (singleton pattern) _dense_model: TextEmbedding | None = None _sparse_model: SparseTextEmbedding | None = None _reranker_model: TextCrossEncoder | None = None # Request counter for periodic GC _request_count = 0 GC_INTERVAL = 50 # Run gc.collect() every 50 requests def _run_periodic_gc(): """Run garbage collection periodically to free intermediate tensors.""" global _request_count _request_count += 1 if _request_count % GC_INTERVAL == 0: gc.collect() print(f"GC triggered after {_request_count} requests") def _get_dense_model() -> TextEmbedding: """Get dense model (singleton, preloaded).""" global _dense_model if _dense_model is None: _dense_model = TextEmbedding( model_name=DENSE_MODEL, cache_dir=CACHE_DIR, ) return _dense_model def _get_sparse_model() -> SparseTextEmbedding: """Get sparse BM25 model (singleton, preloaded).""" global _sparse_model if _sparse_model is None: _sparse_model = SparseTextEmbedding( model_name=SPARSE_MODEL, cache_dir=CACHE_DIR, ) return _sparse_model def _get_reranker() -> TextCrossEncoder: """Get reranker model (singleton, preloaded).""" global _reranker_model if _reranker_model is None: _reranker_model = TextCrossEncoder( model_name=RERANKER_MODEL, cache_dir=CACHE_DIR, ) return _reranker_model @asynccontextmanager async def lifespan(app: FastAPI): """Startup: preload ALL models to avoid runtime memory spikes.""" print("=" * 50) print("PRELOADING ALL MODELS...") print(f"Cache directory: {CACHE_DIR}") print("=" * 50) # Preload all models at startup _get_dense_model() print("Dense model loaded.") _get_sparse_model() print("Sparse model loaded.") _get_reranker() print("Reranker model loaded.") print("All models ready.") print("=" * 50) # Initial GC to clean up any loading artifacts gc.collect() yield # Cleanup on shutdown global _dense_model, _sparse_model, _reranker_model _dense_model = None _sparse_model = None _reranker_model = None gc.collect() print("Models cleared on shutdown.") app = FastAPI( title="FastEmbed Code Embeddings", summary="CPU-optimized code embeddings with BM25 sparse and reranking", version="2.2.0", lifespan=lifespan, ) # ==================== Request Models ==================== class EmbeddingRequest(BaseModel): model_config = ConfigDict(extra="allow") input: str | list[str] model: str = "code-embed" encoding_format: Literal["float", "base64"] = "float" dimensions: int = 0 # 0 = full dimensions class SparseEmbeddingRequest(BaseModel): model_config = ConfigDict(extra="allow") input: str | list[str] model: str = "bm25" class RerankRequest(BaseModel): model_config = ConfigDict(extra="allow") query: str = Field(..., max_length=8192) documents: list[str] = Field(..., min_length=1, max_length=256) return_documents: bool = False raw_scores: bool = False model: str = "code-rerank" top_n: int | None = None class HybridRequest(BaseModel): """Request for hybrid search embeddings (dense + sparse).""" model_config = ConfigDict(extra="allow") input: str | list[str] dense_model: str = "code-embed" sparse_model: str = "bm25" # ==================== Helper Functions ==================== def _now_ts() -> int: return int(time.time()) def _make_id(prefix: str) -> str: return f"{prefix}-{uuid.uuid4().hex}" def _normalize_input(input: str | list[str]) -> list[str]: if isinstance(input, str): return [input] return input def _truncate_embedding(vector: np.ndarray, dimensions: int) -> np.ndarray: if dimensions > 0 and dimensions < len(vector): return vector[:dimensions] return vector def _vector_to_payload(vector: np.ndarray, encoding_format: str) -> list[float] | str: if encoding_format == "base64": import base64 return base64.b64encode(vector.astype(np.float32).tobytes()).decode() return vector.tolist() def _chunk_batch(texts: list[str], batch_size: int) -> list[list[str]]: """Split texts into chunks to limit memory per batch.""" if len(texts) <= batch_size: return [texts] return [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)] # ==================== API Endpoints ==================== @app.get("/health") def health() -> dict[str, str]: return {"status": "ok", "models": f"{DENSE_MODEL} + {SPARSE_MODEL} + {RERANKER_MODEL}"} @app.post("/embeddings") @app.post("/v1/embeddings") def embeddings(request: EmbeddingRequest) -> dict[str, Any]: """Generate dense embeddings using jina-embeddings-v2-base-code.""" texts = _normalize_input(request.input) model = _get_dense_model() # Process in batches to limit memory all_embeddings = [] for chunk in _chunk_batch(texts, BATCH_SIZE): chunk_embeddings = list(model.embed(chunk, batch_size=BATCH_SIZE, parallel=PARALLEL_WORKERS)) all_embeddings.extend(chunk_embeddings) data = [] for idx, embedding in enumerate(all_embeddings): embedding = _truncate_embedding(embedding, request.dimensions) data.append({ "object": "embedding", "embedding": _vector_to_payload(embedding, request.encoding_format), "index": idx, }) _run_periodic_gc() return { "object": "list", "data": data, "model": request.model, "usage": {"prompt_tokens": sum(len(t.split()) for t in texts), "total_tokens": 0}, "id": _make_id("emb"), "created": _now_ts(), } @app.post("/sparse/embeddings") @app.post("/v1/sparse/embeddings") def sparse_embeddings(request: SparseEmbeddingRequest) -> dict[str, Any]: """Generate sparse BM25 embeddings.""" texts = _normalize_input(request.input) model = _get_sparse_model() # Process in batches all_embeddings = [] for chunk in _chunk_batch(texts, BATCH_SIZE): chunk_embeddings = list(model.embed(chunk, batch_size=BATCH_SIZE, parallel=PARALLEL_WORKERS)) all_embeddings.extend(chunk_embeddings) data = [] for idx, emb in enumerate(all_embeddings): data.append({ "object": "sparse_embedding", "indices": emb.indices.tolist(), "values": emb.values.tolist(), "index": idx, }) _run_periodic_gc() return { "object": "list", "data": data, "model": request.model, "id": _make_id("sparse"), "created": _now_ts(), } @app.post("/rerank") @app.post("/v1/rerank") def rerank(request: RerankRequest) -> dict[str, Any]: """Rerank documents using cross-encoder.""" reranker = _get_reranker() # Compute rerank scores scores = reranker.rerank(request.query, request.documents) results = [] for idx, score in enumerate(scores): item = {"index": idx, "relevance_score": float(score)} if request.return_documents: item["document"] = request.documents[idx] results.append(item) # Sort by relevance results.sort(key=lambda x: x["relevance_score"], reverse=True) if request.top_n is not None: results = results[:request.top_n] _run_periodic_gc() return { "object": "rerank", "results": results, "model": request.model, "usage": { "prompt_tokens": len(request.query.split()), "total_tokens": sum(len(d.split()) for d in request.documents), }, "id": _make_id("rerank"), "created": _now_ts(), } @app.post("/hybrid/embeddings") @app.post("/v1/hybrid/embeddings") def hybrid_embeddings(request: HybridRequest) -> dict[str, Any]: """Generate both dense and sparse embeddings for hybrid search.""" texts = _normalize_input(request.input) dense_model = _get_dense_model() sparse_model = _get_sparse_model() # Process in batches for both models all_dense = [] all_sparse = [] for chunk in _chunk_batch(texts, BATCH_SIZE): dense_chunk = list(dense_model.embed(chunk, batch_size=BATCH_SIZE, parallel=PARALLEL_WORKERS)) sparse_chunk = list(sparse_model.embed(chunk, batch_size=BATCH_SIZE, parallel=PARALLEL_WORKERS)) all_dense.extend(dense_chunk) all_sparse.extend(sparse_chunk) data = [] for idx, (dense, sparse) in enumerate(zip(all_dense, all_sparse)): data.append({ "object": "hybrid_embedding", "dense": { "vector": dense.tolist(), "dim": len(dense), }, "sparse": { "indices": sparse.indices.tolist(), "values": sparse.values.tolist(), }, "index": idx, }) _run_periodic_gc() return { "object": "list", "data": data, "model": f"{request.dense_model} + {request.sparse_model}", "id": _make_id("hybrid"), "created": _now_ts(), } # ==================== Model Info ==================== @app.get("/models") def list_models() -> dict[str, Any]: """List supported models and their specs.""" return { "dense": { "model": DENSE_MODEL, "dim": 768, "size_gb": 0.64, "type": "code-optimized", }, "sparse": { "model": SPARSE_MODEL, "type": "bm25", "size_gb": 0.01, "requires_idf": True, }, "reranker": { "model": RERANKER_MODEL, "size_gb": 0.13, "type": "cross-encoder", }, } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)