chmielvu commited on
Commit
98c2074
·
verified ·
1 Parent(s): ff57cfe

Fix memory accumulation with batch processing and periodic GC

Browse files
Files changed (1) hide show
  1. app.py +401 -0
app.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastEmbed-based Code Embedding Server
3
+ Optimized for CPU Basic (2 vCPU, 16GB RAM)
4
+
5
+ Models:
6
+ - Dense: jinaai/jina-embeddings-v2-base-code (768 dim, ~0.64GB)
7
+ - Sparse: Qdrant/bm25 (~0.01GB)
8
+ - Reranker: jinaai/jina-reranker-v1-tiny-en (~0.13GB)
9
+
10
+ Memory optimization:
11
+ - Preload all models at startup (avoid runtime loading spikes)
12
+ - Use /data for persistent cache (HF Spaces)
13
+ - Limit batch_size and parallel workers
14
+ - Periodic garbage collection
15
+ """
16
+
17
+ import gc
18
+ import os
19
+ import time
20
+ import uuid
21
+ from contextlib import asynccontextmanager
22
+ from typing import Any, Literal
23
+
24
+ import numpy as np
25
+ from fastapi import FastAPI
26
+ from pydantic import BaseModel, ConfigDict, Field
27
+
28
+ from fastembed import TextEmbedding, SparseTextEmbedding
29
+ from fastembed.rerank.cross_encoder import TextCrossEncoder
30
+
31
+ # Use /data for persistent cache in HF Spaces ( survives restarts)
32
+ # Falls back to /tmp for local development
33
+ CACHE_DIR = os.environ.get("FASTEMBED_CACHE", "/data/fastembed_cache" if os.path.exists("/data") else "/tmp/fastembed_cache")
34
+
35
+ # Model names
36
+ DENSE_MODEL = "jinaai/jina-embeddings-v2-base-code"
37
+ SPARSE_MODEL = "Qdrant/bm25"
38
+ RERANKER_MODEL = "jinaai/jina-reranker-v1-tiny-en"
39
+
40
+ # Memory-optimized settings for 2 vCPU, 16GB RAM
41
+ BATCH_SIZE = 32 # Limit batch to avoid memory spikes
42
+ PARALLEL_WORKERS = 1 # Single worker to avoid memory duplication
43
+
44
+ # Global model cache (singleton pattern)
45
+ _dense_model: TextEmbedding | None = None
46
+ _sparse_model: SparseTextEmbedding | None = None
47
+ _reranker_model: TextCrossEncoder | None = None
48
+
49
+ # Request counter for periodic GC
50
+ _request_count = 0
51
+ GC_INTERVAL = 50 # Run gc.collect() every 50 requests
52
+
53
+
54
+ def _run_periodic_gc():
55
+ """Run garbage collection periodically to free intermediate tensors."""
56
+ global _request_count
57
+ _request_count += 1
58
+ if _request_count % GC_INTERVAL == 0:
59
+ gc.collect()
60
+ print(f"GC triggered after {_request_count} requests")
61
+
62
+
63
+ def _get_dense_model() -> TextEmbedding:
64
+ """Get dense model (singleton, preloaded)."""
65
+ global _dense_model
66
+ if _dense_model is None:
67
+ _dense_model = TextEmbedding(
68
+ model_name=DENSE_MODEL,
69
+ cache_dir=CACHE_DIR,
70
+ )
71
+ return _dense_model
72
+
73
+
74
+ def _get_sparse_model() -> SparseTextEmbedding:
75
+ """Get sparse BM25 model (singleton, preloaded)."""
76
+ global _sparse_model
77
+ if _sparse_model is None:
78
+ _sparse_model = SparseTextEmbedding(
79
+ model_name=SPARSE_MODEL,
80
+ cache_dir=CACHE_DIR,
81
+ )
82
+ return _sparse_model
83
+
84
+
85
+ def _get_reranker() -> TextCrossEncoder:
86
+ """Get reranker model (singleton, preloaded)."""
87
+ global _reranker_model
88
+ if _reranker_model is None:
89
+ _reranker_model = TextCrossEncoder(
90
+ model_name=RERANKER_MODEL,
91
+ cache_dir=CACHE_DIR,
92
+ )
93
+ return _reranker_model
94
+
95
+
96
+ @asynccontextmanager
97
+ async def lifespan(app: FastAPI):
98
+ """Startup: preload ALL models to avoid runtime memory spikes."""
99
+ print("=" * 50)
100
+ print("PRELOADING ALL MODELS...")
101
+ print(f"Cache directory: {CACHE_DIR}")
102
+ print("=" * 50)
103
+
104
+ # Preload all models at startup
105
+ _get_dense_model()
106
+ print("Dense model loaded.")
107
+
108
+ _get_sparse_model()
109
+ print("Sparse model loaded.")
110
+
111
+ _get_reranker()
112
+ print("Reranker model loaded.")
113
+
114
+ print("All models ready.")
115
+ print("=" * 50)
116
+
117
+ # Initial GC to clean up any loading artifacts
118
+ gc.collect()
119
+
120
+ yield
121
+
122
+ # Cleanup on shutdown
123
+ global _dense_model, _sparse_model, _reranker_model
124
+ _dense_model = None
125
+ _sparse_model = None
126
+ _reranker_model = None
127
+ gc.collect()
128
+ print("Models cleared on shutdown.")
129
+
130
+
131
+ app = FastAPI(
132
+ title="FastEmbed Code Embeddings",
133
+ summary="CPU-optimized code embeddings with BM25 sparse and reranking",
134
+ version="2.2.0",
135
+ lifespan=lifespan,
136
+ )
137
+
138
+
139
+ # ==================== Request Models ====================
140
+
141
+
142
+ class EmbeddingRequest(BaseModel):
143
+ model_config = ConfigDict(extra="allow")
144
+
145
+ input: str | list[str]
146
+ model: str = "code-embed"
147
+ encoding_format: Literal["float", "base64"] = "float"
148
+ dimensions: int = 0 # 0 = full dimensions
149
+
150
+
151
+ class SparseEmbeddingRequest(BaseModel):
152
+ model_config = ConfigDict(extra="allow")
153
+
154
+ input: str | list[str]
155
+ model: str = "bm25"
156
+
157
+
158
+ class RerankRequest(BaseModel):
159
+ model_config = ConfigDict(extra="allow")
160
+
161
+ query: str = Field(..., max_length=8192)
162
+ documents: list[str] = Field(..., min_length=1, max_length=256)
163
+ return_documents: bool = False
164
+ raw_scores: bool = False
165
+ model: str = "code-rerank"
166
+ top_n: int | None = None
167
+
168
+
169
+ class HybridRequest(BaseModel):
170
+ """Request for hybrid search embeddings (dense + sparse)."""
171
+ model_config = ConfigDict(extra="allow")
172
+
173
+ input: str | list[str]
174
+ dense_model: str = "code-embed"
175
+ sparse_model: str = "bm25"
176
+
177
+
178
+ # ==================== Helper Functions ====================
179
+
180
+
181
+ def _now_ts() -> int:
182
+ return int(time.time())
183
+
184
+
185
+ def _make_id(prefix: str) -> str:
186
+ return f"{prefix}-{uuid.uuid4().hex}"
187
+
188
+
189
+ def _normalize_input(input: str | list[str]) -> list[str]:
190
+ if isinstance(input, str):
191
+ return [input]
192
+ return input
193
+
194
+
195
+ def _truncate_embedding(vector: np.ndarray, dimensions: int) -> np.ndarray:
196
+ if dimensions > 0 and dimensions < len(vector):
197
+ return vector[:dimensions]
198
+ return vector
199
+
200
+
201
+ def _vector_to_payload(vector: np.ndarray, encoding_format: str) -> list[float] | str:
202
+ if encoding_format == "base64":
203
+ import base64
204
+ return base64.b64encode(vector.astype(np.float32).tobytes()).decode()
205
+ return vector.tolist()
206
+
207
+
208
+ def _chunk_batch(texts: list[str], batch_size: int) -> list[list[str]]:
209
+ """Split texts into chunks to limit memory per batch."""
210
+ if len(texts) <= batch_size:
211
+ return [texts]
212
+ return [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
213
+
214
+
215
+ # ==================== API Endpoints ====================
216
+
217
+
218
+ @app.get("/health")
219
+ def health() -> dict[str, str]:
220
+ return {"status": "ok", "models": f"{DENSE_MODEL} + {SPARSE_MODEL} + {RERANKER_MODEL}"}
221
+
222
+
223
+ @app.post("/embeddings")
224
+ @app.post("/v1/embeddings")
225
+ def embeddings(request: EmbeddingRequest) -> dict[str, Any]:
226
+ """Generate dense embeddings using jina-embeddings-v2-base-code."""
227
+ texts = _normalize_input(request.input)
228
+ model = _get_dense_model()
229
+
230
+ # Process in batches to limit memory
231
+ all_embeddings = []
232
+ for chunk in _chunk_batch(texts, BATCH_SIZE):
233
+ chunk_embeddings = list(model.embed(chunk, batch_size=BATCH_SIZE, parallel=PARALLEL_WORKERS))
234
+ all_embeddings.extend(chunk_embeddings)
235
+
236
+ data = []
237
+ for idx, embedding in enumerate(all_embeddings):
238
+ embedding = _truncate_embedding(embedding, request.dimensions)
239
+ data.append({
240
+ "object": "embedding",
241
+ "embedding": _vector_to_payload(embedding, request.encoding_format),
242
+ "index": idx,
243
+ })
244
+
245
+ _run_periodic_gc()
246
+
247
+ return {
248
+ "object": "list",
249
+ "data": data,
250
+ "model": request.model,
251
+ "usage": {"prompt_tokens": sum(len(t.split()) for t in texts), "total_tokens": 0},
252
+ "id": _make_id("emb"),
253
+ "created": _now_ts(),
254
+ }
255
+
256
+
257
+ @app.post("/sparse/embeddings")
258
+ @app.post("/v1/sparse/embeddings")
259
+ def sparse_embeddings(request: SparseEmbeddingRequest) -> dict[str, Any]:
260
+ """Generate sparse BM25 embeddings."""
261
+ texts = _normalize_input(request.input)
262
+ model = _get_sparse_model()
263
+
264
+ # Process in batches
265
+ all_embeddings = []
266
+ for chunk in _chunk_batch(texts, BATCH_SIZE):
267
+ chunk_embeddings = list(model.embed(chunk, batch_size=BATCH_SIZE, parallel=PARALLEL_WORKERS))
268
+ all_embeddings.extend(chunk_embeddings)
269
+
270
+ data = []
271
+ for idx, emb in enumerate(all_embeddings):
272
+ data.append({
273
+ "object": "sparse_embedding",
274
+ "indices": emb.indices.tolist(),
275
+ "values": emb.values.tolist(),
276
+ "index": idx,
277
+ })
278
+
279
+ _run_periodic_gc()
280
+
281
+ return {
282
+ "object": "list",
283
+ "data": data,
284
+ "model": request.model,
285
+ "id": _make_id("sparse"),
286
+ "created": _now_ts(),
287
+ }
288
+
289
+
290
+ @app.post("/rerank")
291
+ @app.post("/v1/rerank")
292
+ def rerank(request: RerankRequest) -> dict[str, Any]:
293
+ """Rerank documents using cross-encoder."""
294
+ reranker = _get_reranker()
295
+
296
+ # Compute rerank scores
297
+ scores = reranker.rerank(request.query, request.documents)
298
+
299
+ results = []
300
+ for idx, score in enumerate(scores):
301
+ item = {"index": idx, "relevance_score": float(score)}
302
+ if request.return_documents:
303
+ item["document"] = request.documents[idx]
304
+ results.append(item)
305
+
306
+ # Sort by relevance
307
+ results.sort(key=lambda x: x["relevance_score"], reverse=True)
308
+
309
+ if request.top_n is not None:
310
+ results = results[:request.top_n]
311
+
312
+ _run_periodic_gc()
313
+
314
+ return {
315
+ "object": "rerank",
316
+ "results": results,
317
+ "model": request.model,
318
+ "usage": {
319
+ "prompt_tokens": len(request.query.split()),
320
+ "total_tokens": sum(len(d.split()) for d in request.documents),
321
+ },
322
+ "id": _make_id("rerank"),
323
+ "created": _now_ts(),
324
+ }
325
+
326
+
327
+ @app.post("/hybrid/embeddings")
328
+ @app.post("/v1/hybrid/embeddings")
329
+ def hybrid_embeddings(request: HybridRequest) -> dict[str, Any]:
330
+ """Generate both dense and sparse embeddings for hybrid search."""
331
+ texts = _normalize_input(request.input)
332
+
333
+ dense_model = _get_dense_model()
334
+ sparse_model = _get_sparse_model()
335
+
336
+ # Process in batches for both models
337
+ all_dense = []
338
+ all_sparse = []
339
+
340
+ for chunk in _chunk_batch(texts, BATCH_SIZE):
341
+ dense_chunk = list(dense_model.embed(chunk, batch_size=BATCH_SIZE, parallel=PARALLEL_WORKERS))
342
+ sparse_chunk = list(sparse_model.embed(chunk, batch_size=BATCH_SIZE, parallel=PARALLEL_WORKERS))
343
+ all_dense.extend(dense_chunk)
344
+ all_sparse.extend(sparse_chunk)
345
+
346
+ data = []
347
+ for idx, (dense, sparse) in enumerate(zip(all_dense, all_sparse)):
348
+ data.append({
349
+ "object": "hybrid_embedding",
350
+ "dense": {
351
+ "vector": dense.tolist(),
352
+ "dim": len(dense),
353
+ },
354
+ "sparse": {
355
+ "indices": sparse.indices.tolist(),
356
+ "values": sparse.values.tolist(),
357
+ },
358
+ "index": idx,
359
+ })
360
+
361
+ _run_periodic_gc()
362
+
363
+ return {
364
+ "object": "list",
365
+ "data": data,
366
+ "model": f"{request.dense_model} + {request.sparse_model}",
367
+ "id": _make_id("hybrid"),
368
+ "created": _now_ts(),
369
+ }
370
+
371
+
372
+ # ==================== Model Info ====================
373
+
374
+
375
+ @app.get("/models")
376
+ def list_models() -> dict[str, Any]:
377
+ """List supported models and their specs."""
378
+ return {
379
+ "dense": {
380
+ "model": DENSE_MODEL,
381
+ "dim": 768,
382
+ "size_gb": 0.64,
383
+ "type": "code-optimized",
384
+ },
385
+ "sparse": {
386
+ "model": SPARSE_MODEL,
387
+ "type": "bm25",
388
+ "size_gb": 0.01,
389
+ "requires_idf": True,
390
+ },
391
+ "reranker": {
392
+ "model": RERANKER_MODEL,
393
+ "size_gb": 0.13,
394
+ "type": "cross-encoder",
395
+ },
396
+ }
397
+
398
+
399
+ if __name__ == "__main__":
400
+ import uvicorn
401
+ uvicorn.run(app, host="0.0.0.0", port=7860)