Soumik Bose commited on
Commit
16530ae
·
1 Parent(s): 4f9495d
Files changed (5) hide show
  1. Dockerfile +14 -16
  2. main.py +92 -36
  3. model_service.py +8 -14
  4. requirements.txt +4 -1
  5. vector_store.py +116 -0
Dockerfile CHANGED
@@ -1,43 +1,41 @@
1
- # Use the official Python 3.11 slim image
2
  FROM python:3.11-slim
3
 
4
  # Set environment variables
5
  ENV PYTHONDONTWRITEBYTECODE=1 \
6
  PYTHONUNBUFFERED=1 \
7
  PYTHONIOENCODING=UTF-8 \
8
- HF_HOME=/app/cache
9
 
10
- # Install system dependencies
11
- RUN apt-get update && apt-get install -y --no-install-recommends curl \
12
- && rm -rf /var/lib/apt/lists/* \
13
- && useradd -m -u 1000 user
14
 
 
 
15
  WORKDIR /app
16
 
17
- # --- LAYER 1: Dependencies ---
18
  COPY --chown=user:user requirements.txt .
19
  RUN pip install --no-cache-dir -r requirements.txt
20
 
21
  # --- LAYER 2: Download Models (Cached) ---
22
- # We download models for 384, 768, and 1024 dimensions.
23
- # 384 dim: BAAI/bge-small-en-v1.5
24
- # 768 dim: BAAI/bge-base-en-v1.5
25
- # 1024 dim: BAAI/bge-large-en-v1.5
26
  RUN python3 -c "from huggingface_hub import snapshot_download; \
27
  snapshot_download(repo_id='BAAI/bge-small-en-v1.5', local_dir='./models/bge-384'); \
28
  snapshot_download(repo_id='BAAI/bge-base-en-v1.5', local_dir='./models/bge-768'); \
29
  snapshot_download(repo_id='BAAI/bge-large-en-v1.5', local_dir='./models/bge-1024')"
30
 
31
- # --- LAYER 3: Application Code ---
32
  COPY --chown=user:user . .
33
 
34
- # Ensure permissions
35
- RUN mkdir -p $HF_HOME && chown -R user:user /app/cache && chown -R user:user /app/models
 
36
 
37
- # Switch user
38
  USER user
39
 
40
- # Expose port
41
  EXPOSE 7860
42
 
43
  # Start script
 
1
+ # Use Python 3.11 Slim
2
  FROM python:3.11-slim
3
 
4
  # Set environment variables
5
  ENV PYTHONDONTWRITEBYTECODE=1 \
6
  PYTHONUNBUFFERED=1 \
7
  PYTHONIOENCODING=UTF-8 \
8
+ HF_HOME=/app/cache
9
 
10
+ # Install basic tools
11
+ RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
 
 
12
 
13
+ # Create user to avoid permission issues
14
+ RUN useradd -m -u 1000 user
15
  WORKDIR /app
16
 
17
+ # --- LAYER 1: Install Dependencies ---
18
  COPY --chown=user:user requirements.txt .
19
  RUN pip install --no-cache-dir -r requirements.txt
20
 
21
  # --- LAYER 2: Download Models (Cached) ---
22
+ # This ensures models are baked into the image and load instantly
 
 
 
23
  RUN python3 -c "from huggingface_hub import snapshot_download; \
24
  snapshot_download(repo_id='BAAI/bge-small-en-v1.5', local_dir='./models/bge-384'); \
25
  snapshot_download(repo_id='BAAI/bge-base-en-v1.5', local_dir='./models/bge-768'); \
26
  snapshot_download(repo_id='BAAI/bge-large-en-v1.5', local_dir='./models/bge-1024')"
27
 
28
+ # --- LAYER 3: App Code ---
29
  COPY --chown=user:user . .
30
 
31
+ # Create storage directory for the database and set permissions
32
+ RUN mkdir -p /app/storage && chown -R user:user /app/storage && chmod 777 /app/storage
33
+ RUN mkdir -p $HF_HOME && chown -R user:user $HF_HOME
34
 
35
+ # Switch to non-root user
36
  USER user
37
 
38
+ # Expose Port
39
  EXPOSE 7860
40
 
41
  # Start script
main.py CHANGED
@@ -11,8 +11,9 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel, Field
13
 
14
- # Import the new MultiEmbeddingService
15
  from model_service import MultiEmbeddingService
 
16
 
17
  # ============================================================================
18
  # LOGGING
@@ -32,34 +33,57 @@ ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS', '*').split(',')
32
  # Global context container
33
  ml_context = {
34
  "service": None,
35
- "executor": None
 
36
  }
37
 
 
 
 
 
 
 
 
 
 
 
 
38
  # ============================================================================
39
  # LIFESPAN MANAGER
40
  # ============================================================================
41
  @asynccontextmanager
42
  async def lifespan(app: FastAPI):
43
- """Lifecycle manager: Loads models and thread pool."""
44
- # --- Startup ---
45
- logger.info("Initializing Multi-Dimensional Embedding Service...")
46
 
47
  # 1. Thread Pool
48
  cpu_count = multiprocessing.cpu_count()
49
  max_workers = cpu_count * 2
50
  executor = ThreadPoolExecutor(max_workers=max_workers)
51
  ml_context["executor"] = executor
52
- logger.info(f"Thread pool ready: {max_workers} workers")
53
 
54
  # 2. Load Models
55
  try:
56
  service = MultiEmbeddingService()
57
- service.load_all_models() # Loads 384, 768, 1024 models
58
  ml_context["service"] = service
59
  except Exception as e:
60
- logger.critical(f"Critical error loading models: {e}", exc_info=True)
61
  raise e
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if AUTH_TOKEN:
64
  logger.info("🔒 Auth enabled.")
65
 
@@ -67,6 +91,7 @@ async def lifespan(app: FastAPI):
67
 
68
  # --- Shutdown ---
69
  logger.info("Shutting down...")
 
70
  if ml_context["executor"]:
71
  ml_context["executor"].shutdown(wait=True)
72
  ml_context.clear()
@@ -75,8 +100,8 @@ async def lifespan(app: FastAPI):
75
  # APP SETUP
76
  # ============================================================================
77
  app = FastAPI(
78
- title="Multi-Dim Embedding API",
79
- version="3.0.0",
80
  lifespan=lifespan
81
  )
82
 
@@ -98,7 +123,7 @@ async def verify_token(credentials: Optional[HTTPAuthorizationCredentials] = Sec
98
  return True
99
 
100
  # ============================================================================
101
- # MODELS
102
  # ============================================================================
103
  class EmbedRequest(BaseModel):
104
  data: Union[str, List[str]] = Field(..., description="Text string or list of strings")
@@ -114,12 +139,14 @@ class EmbedRequest(BaseModel):
114
  }
115
 
116
  class EmbedResponse(BaseModel):
 
117
  embeddings: Union[List[float], List[List[float]]] = Field(...)
118
  dimension: int
119
  count: int
120
 
121
  class DeEmbedRequest(BaseModel):
122
  vector: List[float] = Field(..., description="The embedding vector to decode")
 
123
 
124
  # ============================================================================
125
  # ENDPOINTS
@@ -138,62 +165,91 @@ async def health_check():
138
  @app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(verify_token)])
139
  async def create_embeddings(request: EmbedRequest):
140
  """
141
- Generate embeddings for specific dimensions.
142
- Supported dimensions: 384, 768, 1024.
143
  """
144
  service = ml_context.get("service")
145
  executor = ml_context.get("executor")
 
146
 
147
  if not service or not executor:
148
  raise HTTPException(status_code=503, detail="Service unavailable")
149
 
 
150
  if request.dimension not in service.models:
151
- raise HTTPException(
152
- status_code=400,
153
- detail=f"Dimension {request.dimension} not supported. Use 384, 768, or 1024."
154
- )
155
 
156
  try:
 
157
  is_single = isinstance(request.data, str)
158
- count = 1 if is_single else len(request.data)
 
159
 
 
160
  loop = asyncio.get_running_loop()
161
  embeddings = await loop.run_in_executor(
162
  executor,
163
  service.generate_embedding,
164
- request.data,
165
  request.dimension
166
  )
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  return EmbedResponse(
 
169
  embeddings=embeddings,
170
  dimension=request.dimension,
171
  count=count
172
  )
173
 
174
  except Exception as e:
175
- logger.error(f"Inference error: {e}")
176
  raise HTTPException(status_code=500, detail=str(e))
177
 
178
  @app.post("/deembed", dependencies=[Depends(verify_token)])
179
  async def de_embed_vector(request: DeEmbedRequest):
180
  """
181
- Experimental: Reverse vector to text.
182
-
183
- NOTE: Mathematically, standard embedding models (BERT, BGE) are NOT reversible
184
- because they are lossy compression algorithms.
185
-
186
- To retrieve text from a vector, you must use a Vector Database (retrieval),
187
- not a direct model inversion.
188
  """
189
- # In a real scenario, this would look like:
190
- # result = vector_db.search(vector=request.vector, top_k=1)
191
- # return {"text": result.text}
192
 
193
- raise HTTPException(
194
- status_code=501,
195
- detail=(
196
- "De-embedding (Vector-to-Text) is not possible with standalone embedding models. "
197
- "This endpoint requires a connected Vector Database to perform a similarity search."
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
199
- )
 
 
 
 
 
 
 
 
 
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel, Field
13
 
14
+ # Import Custom Modules
15
  from model_service import MultiEmbeddingService
16
+ from vector_store import SmartVectorStore # <--- MUST HAVE THIS FILE
17
 
18
  # ============================================================================
19
  # LOGGING
 
33
  # Global context container
34
  ml_context = {
35
  "service": None,
36
+ "executor": None,
37
+ "store": None
38
  }
39
 
40
+ # ============================================================================
41
+ # BACKGROUND TASKS
42
+ # ============================================================================
43
+ async def background_cleanup_task():
44
+ """Runs continuously to clean up data older than 24 hours."""
45
+ while True:
46
+ await asyncio.sleep(3600) # Sleep for 1 hour
47
+ if ml_context["store"]:
48
+ logger.info("⏰ Running scheduled storage cleanup...")
49
+ ml_context["store"].prune_expired()
50
+
51
  # ============================================================================
52
  # LIFESPAN MANAGER
53
  # ============================================================================
54
  @asynccontextmanager
55
  async def lifespan(app: FastAPI):
56
+ """Lifecycle manager: Loads models, DB, and thread pool."""
57
+ logger.info("🚀 Initializing Multi-Dimensional Embedding Service...")
 
58
 
59
  # 1. Thread Pool
60
  cpu_count = multiprocessing.cpu_count()
61
  max_workers = cpu_count * 2
62
  executor = ThreadPoolExecutor(max_workers=max_workers)
63
  ml_context["executor"] = executor
 
64
 
65
  # 2. Load Models
66
  try:
67
  service = MultiEmbeddingService()
68
+ service.load_all_models()
69
  ml_context["service"] = service
70
  except Exception as e:
71
+ logger.critical(f"Critical error loading models: {e}", exc_info=True)
72
  raise e
73
 
74
+ # 3. Load Vector Store (Database)
75
+ try:
76
+ # ttl_hours=24 ensures data is deleted after 24 hours
77
+ store = SmartVectorStore(storage_path="./storage", ttl_hours=24)
78
+ ml_context["store"] = store
79
+ logger.info("✅ Vector Store loaded with 24h retention policy.")
80
+ except Exception as e:
81
+ logger.critical(f"❌ Critical error loading Vector Store: {e}")
82
+ raise e
83
+
84
+ # 4. Start Cleanup Task
85
+ cleanup_task = asyncio.create_task(background_cleanup_task())
86
+
87
  if AUTH_TOKEN:
88
  logger.info("🔒 Auth enabled.")
89
 
 
91
 
92
  # --- Shutdown ---
93
  logger.info("Shutting down...")
94
+ cleanup_task.cancel()
95
  if ml_context["executor"]:
96
  ml_context["executor"].shutdown(wait=True)
97
  ml_context.clear()
 
100
  # APP SETUP
101
  # ============================================================================
102
  app = FastAPI(
103
+ title="Multi-Dim Embedding & Retrieval API",
104
+ version="3.1.0",
105
  lifespan=lifespan
106
  )
107
 
 
123
  return True
124
 
125
  # ============================================================================
126
+ # Pydantic MODELS
127
  # ============================================================================
128
  class EmbedRequest(BaseModel):
129
  data: Union[str, List[str]] = Field(..., description="Text string or list of strings")
 
139
  }
140
 
141
  class EmbedResponse(BaseModel):
142
+ id: Union[int, List[int]] = Field(..., description="Unique ID(s) for retrieval")
143
  embeddings: Union[List[float], List[List[float]]] = Field(...)
144
  dimension: int
145
  count: int
146
 
147
  class DeEmbedRequest(BaseModel):
148
  vector: List[float] = Field(..., description="The embedding vector to decode")
149
+ dimension: int = Field(768, description="The dimension of the vector")
150
 
151
  # ============================================================================
152
  # ENDPOINTS
 
165
  @app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(verify_token)])
166
  async def create_embeddings(request: EmbedRequest):
167
  """
168
+ Generate embeddings AND store them for later retrieval.
169
+ Accepts Single String OR Array of Strings.
170
  """
171
  service = ml_context.get("service")
172
  executor = ml_context.get("executor")
173
+ store = ml_context.get("store")
174
 
175
  if not service or not executor:
176
  raise HTTPException(status_code=503, detail="Service unavailable")
177
 
178
+ # Validate Dimension
179
  if request.dimension not in service.models:
180
+ raise HTTPException(status_code=400, detail=f"Dimension {request.dimension} not supported.")
 
 
 
181
 
182
  try:
183
+ # 1. Normalize Input
184
  is_single = isinstance(request.data, str)
185
+ inputs = [request.data] if is_single else request.data
186
+ count = len(inputs)
187
 
188
+ # 2. Generate Embeddings (CPU Thread Pool)
189
  loop = asyncio.get_running_loop()
190
  embeddings = await loop.run_in_executor(
191
  executor,
192
  service.generate_embedding,
193
+ inputs, # Pass list for batch processing
194
  request.dimension
195
  )
196
 
197
+ # 3. Store in Vector DB (Get Unique IDs)
198
+ stored_ids = []
199
+
200
+ # If batch processing, embeddings is a list of lists.
201
+ # If single, it might be a list of floats, so we wrap it to iterate consistently.
202
+ vectors_to_process = [embeddings] if is_single else embeddings
203
+
204
+ for text, vec in zip(inputs, vectors_to_process):
205
+ # store.add generates a UNIQUE ID (does not overwrite old data)
206
+ new_id = store.add(text, vec, request.dimension)
207
+ stored_ids.append(new_id)
208
+
209
+ # 4. Return Response
210
  return EmbedResponse(
211
+ id=stored_ids[0] if is_single else stored_ids,
212
  embeddings=embeddings,
213
  dimension=request.dimension,
214
  count=count
215
  )
216
 
217
  except Exception as e:
218
+ logger.error(f"Inference error: {e}", exc_info=True)
219
  raise HTTPException(status_code=500, detail=str(e))
220
 
221
  @app.post("/deembed", dependencies=[Depends(verify_token)])
222
  async def de_embed_vector(request: DeEmbedRequest):
223
  """
224
+ Retrieve original text using the vector.
225
+ Lookups are done via FAISS (Exact/Nearest Neighbor).
 
 
 
 
 
226
  """
227
+ store = ml_context.get("store")
 
 
228
 
229
+ if not store:
230
+ raise HTTPException(status_code=503, detail="Store unavailable")
231
+
232
+ # Search in the store
233
+ result = store.search(request.vector, request.dimension)
234
+
235
+ if result:
236
+ return {
237
+ "found": True,
238
+ "text": result["text"],
239
+ "created_at_timestamp": result["created_at"],
240
+ "note": "Data expires 24h after creation."
241
+ }
242
+ else:
243
+ raise HTTPException(
244
+ status_code=404,
245
+ detail="Vector not found. It may have expired (24h limit) or was never stored."
246
  )
247
+
248
+ @app.get("/check_id/{dimension}/{uid}")
249
+ async def check_by_id(dimension: int, uid: int):
250
+ """Debug endpoint: Check if an ID exists without the vector."""
251
+ store = ml_context.get("store")
252
+ data = store.get_by_id(uid, dimension)
253
+ if data:
254
+ return data
255
+ raise HTTPException(status_code=404, detail="ID not found")
model_service.py CHANGED
@@ -1,16 +1,14 @@
1
  import logging
2
- from sentence_transformers import SentenceTransformer
3
  import torch
 
4
 
5
  logger = logging.getLogger("EmbedService")
6
 
7
  class MultiEmbeddingService:
8
  def __init__(self):
9
  self.models = {}
10
- # Auto-detect GPU, otherwise use CPU
11
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- # Map dimensions to local folders (downloaded in Dockerfile)
14
  self.model_map = {
15
  384: "./models/bge-384",
16
  768: "./models/bge-768",
@@ -18,31 +16,27 @@ class MultiEmbeddingService:
18
  }
19
 
20
  def load_all_models(self):
21
- """Loads all defined models into memory ONCE at startup."""
22
  logger.info(f"🚀 Acceleration Device: {self.device.upper()}")
23
 
24
  for dim, path in self.model_map.items():
25
  try:
26
- logger.info(f"Loading {dim}-dimension model from {path}...")
27
  model = SentenceTransformer(path, device=self.device)
28
- model.eval() # Set to evaluation mode (faster inference)
29
  self.models[dim] = model
30
- logger.info(f"✅ Loaded model for dimension {dim}")
31
  except Exception as e:
32
  logger.error(f"❌ Failed to load {dim}-dim model: {e}")
33
 
34
- def generate_embedding(self, text: str | list[str], dimension: int):
35
- """Generates embeddings using the specific model for the requested dimension."""
36
  if dimension not in self.models:
37
- raise ValueError(f"Dimension {dimension} not supported. Available: {list(self.models.keys())}")
38
 
39
- # --- OPTIMIZATION FIX ---
40
- # show_progress_bar=False prevents the logs you saw
41
- # batch_size=32 ensures efficient processing for lists
42
  return self.models[dimension].encode(
43
  text,
44
  normalize_embeddings=True,
45
  convert_to_numpy=True,
46
- show_progress_bar=False, # <--- THIS STOPS THE LOG SPAM
47
  batch_size=32
48
  ).tolist()
 
1
  import logging
 
2
  import torch
3
+ from sentence_transformers import SentenceTransformer
4
 
5
  logger = logging.getLogger("EmbedService")
6
 
7
  class MultiEmbeddingService:
8
  def __init__(self):
9
  self.models = {}
 
10
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
 
12
  self.model_map = {
13
  384: "./models/bge-384",
14
  768: "./models/bge-768",
 
16
  }
17
 
18
  def load_all_models(self):
19
+ """Loads all defined models into memory."""
20
  logger.info(f"🚀 Acceleration Device: {self.device.upper()}")
21
 
22
  for dim, path in self.model_map.items():
23
  try:
24
+ logger.info(f"Loading {dim}-dimension model...")
25
  model = SentenceTransformer(path, device=self.device)
26
+ model.eval()
27
  self.models[dim] = model
 
28
  except Exception as e:
29
  logger.error(f"❌ Failed to load {dim}-dim model: {e}")
30
 
31
+ def generate_embedding(self, text, dimension):
 
32
  if dimension not in self.models:
33
+ raise ValueError(f"Dimension {dimension} not supported.")
34
 
35
+ # show_progress_bar=False stops the spam
 
 
36
  return self.models[dimension].encode(
37
  text,
38
  normalize_embeddings=True,
39
  convert_to_numpy=True,
40
+ show_progress_bar=False,
41
  batch_size=32
42
  ).tolist()
requirements.txt CHANGED
@@ -10,4 +10,7 @@ numpy==1.26.4
10
 
11
  # Production dependencies
12
  python-multipart==0.0.20
13
- aiofiles==24.1.0
 
 
 
 
10
 
11
  # Production dependencies
12
  python-multipart==0.0.20
13
+ aiofiles==24.1.0
14
+
15
+ # Vector database dependencies
16
+ faiss-cpu==1.9.0.post1
vector_store.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+ import pickle
4
+ import os
5
+ import time
6
+ import logging
7
+ import random
8
+
9
+ logger = logging.getLogger("SmartStore")
10
+
11
+ class SmartVectorStore:
12
+ def __init__(self, storage_path="./storage", ttl_hours=24):
13
+ self.storage_path = storage_path
14
+ self.ttl_seconds = ttl_hours * 3600
15
+ os.makedirs(storage_path, exist_ok=True)
16
+
17
+ self.indices = {}
18
+ self.metadata = {} # Maps ID -> { "text": str, "created_at": float }
19
+
20
+ # Initialize indexes for all dimensions
21
+ for dim in [384, 768, 1024]:
22
+ self._load_or_create_index(dim)
23
+
24
+ def _load_or_create_index(self, dim):
25
+ index_file = os.path.join(self.storage_path, f"index_{dim}.faiss")
26
+ meta_file = os.path.join(self.storage_path, f"meta_{dim}.pkl")
27
+
28
+ if os.path.exists(index_file) and os.path.exists(meta_file):
29
+ try:
30
+ self.indices[dim] = faiss.read_index(index_file)
31
+ with open(meta_file, "rb") as f:
32
+ self.metadata[dim] = pickle.load(f)
33
+ logger.info(f"📂 Loaded DB for dim {dim} with {self.indices[dim].ntotal} items.")
34
+ except Exception:
35
+ logger.warning(f"⚠️ Corrupt DB for {dim}, creating new.")
36
+ self._create_new_index(dim)
37
+ else:
38
+ self._create_new_index(dim)
39
+
40
+ def _create_new_index(self, dim):
41
+ # IndexIDMap lets us assign our own IDs
42
+ self.indices[dim] = faiss.IndexIDMap(faiss.IndexFlatL2(dim))
43
+ self.metadata[dim] = {}
44
+
45
+ def add(self, text: str, vector: list[float], dim: int):
46
+ """Adds text, assigns a unique ID, and saves timestamp."""
47
+
48
+ # Generate Unique ID (Time based + Random)
49
+ unique_id = int(time.time() * 1000) + random.randint(0, 999)
50
+
51
+ vector_np = np.array([vector], dtype=np.float32)
52
+ id_np = np.array([unique_id], dtype=np.int64)
53
+
54
+ # Add to FAISS
55
+ self.indices[dim].add_with_ids(vector_np, id_np)
56
+
57
+ # Add to Metadata
58
+ self.metadata[dim][unique_id] = {
59
+ "text": text,
60
+ "created_at": time.time()
61
+ }
62
+
63
+ # Save to disk
64
+ self._save(dim)
65
+ return unique_id
66
+
67
+ def search(self, vector: list[float], dim: int):
68
+ """Finds closest text by vector."""
69
+ if self.indices[dim].ntotal == 0:
70
+ return None
71
+
72
+ vector_np = np.array([vector], dtype=np.float32)
73
+ D, I = self.indices[dim].search(vector_np, 1)
74
+
75
+ found_id = I[0][0]
76
+ distance = D[0][0] # 0.0 is exact match
77
+
78
+ if found_id != -1 and distance < 1e-4:
79
+ if found_id in self.metadata[dim]:
80
+ return self.metadata[dim][found_id]
81
+
82
+ return None
83
+
84
+ def get_by_id(self, unique_id: int, dim: int):
85
+ """Direct lookup by ID."""
86
+ return self.metadata[dim].get(unique_id)
87
+
88
+ def prune_expired(self):
89
+ """Deletes items older than 24 hours."""
90
+ current_time = time.time()
91
+
92
+ for dim in self.indices:
93
+ ids_to_remove = []
94
+
95
+ for uid, data in list(self.metadata[dim].items()):
96
+ age = current_time - data["created_at"]
97
+ if age > self.ttl_seconds:
98
+ ids_to_remove.append(uid)
99
+
100
+ if ids_to_remove:
101
+ logger.info(f"🧹 Purging {len(ids_to_remove)} expired items from Dim {dim}...")
102
+
103
+ # Remove from Metadata
104
+ for uid in ids_to_remove:
105
+ del self.metadata[dim][uid]
106
+
107
+ # Remove from FAISS
108
+ ids_np = np.array(ids_to_remove, dtype=np.int64)
109
+ self.indices[dim].remove_ids(ids_np)
110
+
111
+ self._save(dim)
112
+
113
+ def _save(self, dim):
114
+ faiss.write_index(self.indices[dim], os.path.join(self.storage_path, f"index_{dim}.faiss"))
115
+ with open(os.path.join(self.storage_path, f"meta_{dim}.pkl"), "wb") as f:
116
+ pickle.dump(self.metadata[dim], f)