Soumik Bose commited on
Commit
3b07301
·
1 Parent(s): 16530ae
Files changed (4) hide show
  1. Dockerfile +16 -14
  2. main.py +36 -92
  3. requirements.txt +1 -4
  4. vector_store.py +0 -116
Dockerfile CHANGED
@@ -1,41 +1,43 @@
1
- # Use Python 3.11 Slim
2
  FROM python:3.11-slim
3
 
4
  # Set environment variables
5
  ENV PYTHONDONTWRITEBYTECODE=1 \
6
  PYTHONUNBUFFERED=1 \
7
  PYTHONIOENCODING=UTF-8 \
8
- HF_HOME=/app/cache
9
 
10
- # Install basic tools
11
- RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
 
 
12
 
13
- # Create user to avoid permission issues
14
- RUN useradd -m -u 1000 user
15
  WORKDIR /app
16
 
17
- # --- LAYER 1: Install Dependencies ---
18
  COPY --chown=user:user requirements.txt .
19
  RUN pip install --no-cache-dir -r requirements.txt
20
 
21
  # --- LAYER 2: Download Models (Cached) ---
22
- # This ensures models are baked into the image and load instantly
 
 
 
23
  RUN python3 -c "from huggingface_hub import snapshot_download; \
24
  snapshot_download(repo_id='BAAI/bge-small-en-v1.5', local_dir='./models/bge-384'); \
25
  snapshot_download(repo_id='BAAI/bge-base-en-v1.5', local_dir='./models/bge-768'); \
26
  snapshot_download(repo_id='BAAI/bge-large-en-v1.5', local_dir='./models/bge-1024')"
27
 
28
- # --- LAYER 3: App Code ---
29
  COPY --chown=user:user . .
30
 
31
- # Create storage directory for the database and set permissions
32
- RUN mkdir -p /app/storage && chown -R user:user /app/storage && chmod 777 /app/storage
33
- RUN mkdir -p $HF_HOME && chown -R user:user $HF_HOME
34
 
35
- # Switch to non-root user
36
  USER user
37
 
38
- # Expose Port
39
  EXPOSE 7860
40
 
41
  # Start script
 
1
+ # Use the official Python 3.11 slim image
2
  FROM python:3.11-slim
3
 
4
  # Set environment variables
5
  ENV PYTHONDONTWRITEBYTECODE=1 \
6
  PYTHONUNBUFFERED=1 \
7
  PYTHONIOENCODING=UTF-8 \
8
+ HF_HOME=/app/cache
9
 
10
+ # Install system dependencies
11
+ RUN apt-get update && apt-get install -y --no-install-recommends curl \
12
+ && rm -rf /var/lib/apt/lists/* \
13
+ && useradd -m -u 1000 user
14
 
 
 
15
  WORKDIR /app
16
 
17
+ # --- LAYER 1: Dependencies ---
18
  COPY --chown=user:user requirements.txt .
19
  RUN pip install --no-cache-dir -r requirements.txt
20
 
21
  # --- LAYER 2: Download Models (Cached) ---
22
+ # We download models for 384, 768, and 1024 dimensions.
23
+ # 384 dim: BAAI/bge-small-en-v1.5
24
+ # 768 dim: BAAI/bge-base-en-v1.5
25
+ # 1024 dim: BAAI/bge-large-en-v1.5
26
  RUN python3 -c "from huggingface_hub import snapshot_download; \
27
  snapshot_download(repo_id='BAAI/bge-small-en-v1.5', local_dir='./models/bge-384'); \
28
  snapshot_download(repo_id='BAAI/bge-base-en-v1.5', local_dir='./models/bge-768'); \
29
  snapshot_download(repo_id='BAAI/bge-large-en-v1.5', local_dir='./models/bge-1024')"
30
 
31
+ # --- LAYER 3: Application Code ---
32
  COPY --chown=user:user . .
33
 
34
+ # Ensure permissions
35
+ RUN mkdir -p $HF_HOME && chown -R user:user /app/cache && chown -R user:user /app/models
 
36
 
37
+ # Switch user
38
  USER user
39
 
40
+ # Expose port
41
  EXPOSE 7860
42
 
43
  # Start script
main.py CHANGED
@@ -11,9 +11,8 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel, Field
13
 
14
- # Import Custom Modules
15
  from model_service import MultiEmbeddingService
16
- from vector_store import SmartVectorStore # <--- MUST HAVE THIS FILE
17
 
18
  # ============================================================================
19
  # LOGGING
@@ -33,57 +32,34 @@ ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS', '*').split(',')
33
  # Global context container
34
  ml_context = {
35
  "service": None,
36
- "executor": None,
37
- "store": None
38
  }
39
 
40
- # ============================================================================
41
- # BACKGROUND TASKS
42
- # ============================================================================
43
- async def background_cleanup_task():
44
- """Runs continuously to clean up data older than 24 hours."""
45
- while True:
46
- await asyncio.sleep(3600) # Sleep for 1 hour
47
- if ml_context["store"]:
48
- logger.info("⏰ Running scheduled storage cleanup...")
49
- ml_context["store"].prune_expired()
50
-
51
  # ============================================================================
52
  # LIFESPAN MANAGER
53
  # ============================================================================
54
  @asynccontextmanager
55
  async def lifespan(app: FastAPI):
56
- """Lifecycle manager: Loads models, DB, and thread pool."""
57
- logger.info("🚀 Initializing Multi-Dimensional Embedding Service...")
 
58
 
59
  # 1. Thread Pool
60
  cpu_count = multiprocessing.cpu_count()
61
  max_workers = cpu_count * 2
62
  executor = ThreadPoolExecutor(max_workers=max_workers)
63
  ml_context["executor"] = executor
 
64
 
65
  # 2. Load Models
66
  try:
67
  service = MultiEmbeddingService()
68
- service.load_all_models()
69
  ml_context["service"] = service
70
  except Exception as e:
71
- logger.critical(f"Critical error loading models: {e}", exc_info=True)
72
  raise e
73
 
74
- # 3. Load Vector Store (Database)
75
- try:
76
- # ttl_hours=24 ensures data is deleted after 24 hours
77
- store = SmartVectorStore(storage_path="./storage", ttl_hours=24)
78
- ml_context["store"] = store
79
- logger.info("✅ Vector Store loaded with 24h retention policy.")
80
- except Exception as e:
81
- logger.critical(f"❌ Critical error loading Vector Store: {e}")
82
- raise e
83
-
84
- # 4. Start Cleanup Task
85
- cleanup_task = asyncio.create_task(background_cleanup_task())
86
-
87
  if AUTH_TOKEN:
88
  logger.info("🔒 Auth enabled.")
89
 
@@ -91,7 +67,6 @@ async def lifespan(app: FastAPI):
91
 
92
  # --- Shutdown ---
93
  logger.info("Shutting down...")
94
- cleanup_task.cancel()
95
  if ml_context["executor"]:
96
  ml_context["executor"].shutdown(wait=True)
97
  ml_context.clear()
@@ -100,8 +75,8 @@ async def lifespan(app: FastAPI):
100
  # APP SETUP
101
  # ============================================================================
102
  app = FastAPI(
103
- title="Multi-Dim Embedding & Retrieval API",
104
- version="3.1.0",
105
  lifespan=lifespan
106
  )
107
 
@@ -123,7 +98,7 @@ async def verify_token(credentials: Optional[HTTPAuthorizationCredentials] = Sec
123
  return True
124
 
125
  # ============================================================================
126
- # Pydantic MODELS
127
  # ============================================================================
128
  class EmbedRequest(BaseModel):
129
  data: Union[str, List[str]] = Field(..., description="Text string or list of strings")
@@ -139,14 +114,12 @@ class EmbedRequest(BaseModel):
139
  }
140
 
141
  class EmbedResponse(BaseModel):
142
- id: Union[int, List[int]] = Field(..., description="Unique ID(s) for retrieval")
143
  embeddings: Union[List[float], List[List[float]]] = Field(...)
144
  dimension: int
145
  count: int
146
 
147
  class DeEmbedRequest(BaseModel):
148
  vector: List[float] = Field(..., description="The embedding vector to decode")
149
- dimension: int = Field(768, description="The dimension of the vector")
150
 
151
  # ============================================================================
152
  # ENDPOINTS
@@ -165,91 +138,62 @@ async def health_check():
165
  @app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(verify_token)])
166
  async def create_embeddings(request: EmbedRequest):
167
  """
168
- Generate embeddings AND store them for later retrieval.
169
- Accepts Single String OR Array of Strings.
170
  """
171
  service = ml_context.get("service")
172
  executor = ml_context.get("executor")
173
- store = ml_context.get("store")
174
 
175
  if not service or not executor:
176
  raise HTTPException(status_code=503, detail="Service unavailable")
177
 
178
- # Validate Dimension
179
  if request.dimension not in service.models:
180
- raise HTTPException(status_code=400, detail=f"Dimension {request.dimension} not supported.")
 
 
 
181
 
182
  try:
183
- # 1. Normalize Input
184
  is_single = isinstance(request.data, str)
185
- inputs = [request.data] if is_single else request.data
186
- count = len(inputs)
187
 
188
- # 2. Generate Embeddings (CPU Thread Pool)
189
  loop = asyncio.get_running_loop()
190
  embeddings = await loop.run_in_executor(
191
  executor,
192
  service.generate_embedding,
193
- inputs, # Pass list for batch processing
194
  request.dimension
195
  )
196
 
197
- # 3. Store in Vector DB (Get Unique IDs)
198
- stored_ids = []
199
-
200
- # If batch processing, embeddings is a list of lists.
201
- # If single, it might be a list of floats, so we wrap it to iterate consistently.
202
- vectors_to_process = [embeddings] if is_single else embeddings
203
-
204
- for text, vec in zip(inputs, vectors_to_process):
205
- # store.add generates a UNIQUE ID (does not overwrite old data)
206
- new_id = store.add(text, vec, request.dimension)
207
- stored_ids.append(new_id)
208
-
209
- # 4. Return Response
210
  return EmbedResponse(
211
- id=stored_ids[0] if is_single else stored_ids,
212
  embeddings=embeddings,
213
  dimension=request.dimension,
214
  count=count
215
  )
216
 
217
  except Exception as e:
218
- logger.error(f"Inference error: {e}", exc_info=True)
219
  raise HTTPException(status_code=500, detail=str(e))
220
 
221
  @app.post("/deembed", dependencies=[Depends(verify_token)])
222
  async def de_embed_vector(request: DeEmbedRequest):
223
  """
224
- Retrieve original text using the vector.
225
- Lookups are done via FAISS (Exact/Nearest Neighbor).
 
 
 
 
 
226
  """
227
- store = ml_context.get("store")
 
 
228
 
229
- if not store:
230
- raise HTTPException(status_code=503, detail="Store unavailable")
231
-
232
- # Search in the store
233
- result = store.search(request.vector, request.dimension)
234
-
235
- if result:
236
- return {
237
- "found": True,
238
- "text": result["text"],
239
- "created_at_timestamp": result["created_at"],
240
- "note": "Data expires 24h after creation."
241
- }
242
- else:
243
- raise HTTPException(
244
- status_code=404,
245
- detail="Vector not found. It may have expired (24h limit) or was never stored."
246
  )
247
-
248
- @app.get("/check_id/{dimension}/{uid}")
249
- async def check_by_id(dimension: int, uid: int):
250
- """Debug endpoint: Check if an ID exists without the vector."""
251
- store = ml_context.get("store")
252
- data = store.get_by_id(uid, dimension)
253
- if data:
254
- return data
255
- raise HTTPException(status_code=404, detail="ID not found")
 
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel, Field
13
 
14
+ # Import the new MultiEmbeddingService
15
  from model_service import MultiEmbeddingService
 
16
 
17
  # ============================================================================
18
  # LOGGING
 
32
  # Global context container
33
  ml_context = {
34
  "service": None,
35
+ "executor": None
 
36
  }
37
 
 
 
 
 
 
 
 
 
 
 
 
38
  # ============================================================================
39
  # LIFESPAN MANAGER
40
  # ============================================================================
41
  @asynccontextmanager
42
  async def lifespan(app: FastAPI):
43
+ """Lifecycle manager: Loads models and thread pool."""
44
+ # --- Startup ---
45
+ logger.info("Initializing Multi-Dimensional Embedding Service...")
46
 
47
  # 1. Thread Pool
48
  cpu_count = multiprocessing.cpu_count()
49
  max_workers = cpu_count * 2
50
  executor = ThreadPoolExecutor(max_workers=max_workers)
51
  ml_context["executor"] = executor
52
+ logger.info(f"Thread pool ready: {max_workers} workers")
53
 
54
  # 2. Load Models
55
  try:
56
  service = MultiEmbeddingService()
57
+ service.load_all_models() # Loads 384, 768, 1024 models
58
  ml_context["service"] = service
59
  except Exception as e:
60
+ logger.critical(f"Critical error loading models: {e}", exc_info=True)
61
  raise e
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if AUTH_TOKEN:
64
  logger.info("🔒 Auth enabled.")
65
 
 
67
 
68
  # --- Shutdown ---
69
  logger.info("Shutting down...")
 
70
  if ml_context["executor"]:
71
  ml_context["executor"].shutdown(wait=True)
72
  ml_context.clear()
 
75
  # APP SETUP
76
  # ============================================================================
77
  app = FastAPI(
78
+ title="Multi-Dim Embedding API",
79
+ version="3.0.0",
80
  lifespan=lifespan
81
  )
82
 
 
98
  return True
99
 
100
  # ============================================================================
101
+ # MODELS
102
  # ============================================================================
103
  class EmbedRequest(BaseModel):
104
  data: Union[str, List[str]] = Field(..., description="Text string or list of strings")
 
114
  }
115
 
116
  class EmbedResponse(BaseModel):
 
117
  embeddings: Union[List[float], List[List[float]]] = Field(...)
118
  dimension: int
119
  count: int
120
 
121
  class DeEmbedRequest(BaseModel):
122
  vector: List[float] = Field(..., description="The embedding vector to decode")
 
123
 
124
  # ============================================================================
125
  # ENDPOINTS
 
138
  @app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(verify_token)])
139
  async def create_embeddings(request: EmbedRequest):
140
  """
141
+ Generate embeddings for specific dimensions.
142
+ Supported dimensions: 384, 768, 1024.
143
  """
144
  service = ml_context.get("service")
145
  executor = ml_context.get("executor")
 
146
 
147
  if not service or not executor:
148
  raise HTTPException(status_code=503, detail="Service unavailable")
149
 
 
150
  if request.dimension not in service.models:
151
+ raise HTTPException(
152
+ status_code=400,
153
+ detail=f"Dimension {request.dimension} not supported. Use 384, 768, or 1024."
154
+ )
155
 
156
  try:
 
157
  is_single = isinstance(request.data, str)
158
+ count = 1 if is_single else len(request.data)
 
159
 
 
160
  loop = asyncio.get_running_loop()
161
  embeddings = await loop.run_in_executor(
162
  executor,
163
  service.generate_embedding,
164
+ request.data,
165
  request.dimension
166
  )
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  return EmbedResponse(
 
169
  embeddings=embeddings,
170
  dimension=request.dimension,
171
  count=count
172
  )
173
 
174
  except Exception as e:
175
+ logger.error(f"Inference error: {e}")
176
  raise HTTPException(status_code=500, detail=str(e))
177
 
178
  @app.post("/deembed", dependencies=[Depends(verify_token)])
179
  async def de_embed_vector(request: DeEmbedRequest):
180
  """
181
+ Experimental: Reverse vector to text.
182
+
183
+ NOTE: Mathematically, standard embedding models (BERT, BGE) are NOT reversible
184
+ because they are lossy compression algorithms.
185
+
186
+ To retrieve text from a vector, you must use a Vector Database (retrieval),
187
+ not a direct model inversion.
188
  """
189
+ # In a real scenario, this would look like:
190
+ # result = vector_db.search(vector=request.vector, top_k=1)
191
+ # return {"text": result.text}
192
 
193
+ raise HTTPException(
194
+ status_code=501,
195
+ detail=(
196
+ "De-embedding (Vector-to-Text) is not possible with standalone embedding models. "
197
+ "This endpoint requires a connected Vector Database to perform a similarity search."
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
199
+ )
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -10,7 +10,4 @@ numpy==1.26.4
10
 
11
  # Production dependencies
12
  python-multipart==0.0.20
13
- aiofiles==24.1.0
14
-
15
- # Vector database dependencies
16
- faiss-cpu==1.9.0.post1
 
10
 
11
  # Production dependencies
12
  python-multipart==0.0.20
13
+ aiofiles==24.1.0
 
 
 
vector_store.py DELETED
@@ -1,116 +0,0 @@
1
- import faiss
2
- import numpy as np
3
- import pickle
4
- import os
5
- import time
6
- import logging
7
- import random
8
-
9
- logger = logging.getLogger("SmartStore")
10
-
11
- class SmartVectorStore:
12
- def __init__(self, storage_path="./storage", ttl_hours=24):
13
- self.storage_path = storage_path
14
- self.ttl_seconds = ttl_hours * 3600
15
- os.makedirs(storage_path, exist_ok=True)
16
-
17
- self.indices = {}
18
- self.metadata = {} # Maps ID -> { "text": str, "created_at": float }
19
-
20
- # Initialize indexes for all dimensions
21
- for dim in [384, 768, 1024]:
22
- self._load_or_create_index(dim)
23
-
24
- def _load_or_create_index(self, dim):
25
- index_file = os.path.join(self.storage_path, f"index_{dim}.faiss")
26
- meta_file = os.path.join(self.storage_path, f"meta_{dim}.pkl")
27
-
28
- if os.path.exists(index_file) and os.path.exists(meta_file):
29
- try:
30
- self.indices[dim] = faiss.read_index(index_file)
31
- with open(meta_file, "rb") as f:
32
- self.metadata[dim] = pickle.load(f)
33
- logger.info(f"📂 Loaded DB for dim {dim} with {self.indices[dim].ntotal} items.")
34
- except Exception:
35
- logger.warning(f"⚠️ Corrupt DB for {dim}, creating new.")
36
- self._create_new_index(dim)
37
- else:
38
- self._create_new_index(dim)
39
-
40
- def _create_new_index(self, dim):
41
- # IndexIDMap lets us assign our own IDs
42
- self.indices[dim] = faiss.IndexIDMap(faiss.IndexFlatL2(dim))
43
- self.metadata[dim] = {}
44
-
45
- def add(self, text: str, vector: list[float], dim: int):
46
- """Adds text, assigns a unique ID, and saves timestamp."""
47
-
48
- # Generate Unique ID (Time based + Random)
49
- unique_id = int(time.time() * 1000) + random.randint(0, 999)
50
-
51
- vector_np = np.array([vector], dtype=np.float32)
52
- id_np = np.array([unique_id], dtype=np.int64)
53
-
54
- # Add to FAISS
55
- self.indices[dim].add_with_ids(vector_np, id_np)
56
-
57
- # Add to Metadata
58
- self.metadata[dim][unique_id] = {
59
- "text": text,
60
- "created_at": time.time()
61
- }
62
-
63
- # Save to disk
64
- self._save(dim)
65
- return unique_id
66
-
67
- def search(self, vector: list[float], dim: int):
68
- """Finds closest text by vector."""
69
- if self.indices[dim].ntotal == 0:
70
- return None
71
-
72
- vector_np = np.array([vector], dtype=np.float32)
73
- D, I = self.indices[dim].search(vector_np, 1)
74
-
75
- found_id = I[0][0]
76
- distance = D[0][0] # 0.0 is exact match
77
-
78
- if found_id != -1 and distance < 1e-4:
79
- if found_id in self.metadata[dim]:
80
- return self.metadata[dim][found_id]
81
-
82
- return None
83
-
84
- def get_by_id(self, unique_id: int, dim: int):
85
- """Direct lookup by ID."""
86
- return self.metadata[dim].get(unique_id)
87
-
88
- def prune_expired(self):
89
- """Deletes items older than 24 hours."""
90
- current_time = time.time()
91
-
92
- for dim in self.indices:
93
- ids_to_remove = []
94
-
95
- for uid, data in list(self.metadata[dim].items()):
96
- age = current_time - data["created_at"]
97
- if age > self.ttl_seconds:
98
- ids_to_remove.append(uid)
99
-
100
- if ids_to_remove:
101
- logger.info(f"🧹 Purging {len(ids_to_remove)} expired items from Dim {dim}...")
102
-
103
- # Remove from Metadata
104
- for uid in ids_to_remove:
105
- del self.metadata[dim][uid]
106
-
107
- # Remove from FAISS
108
- ids_np = np.array(ids_to_remove, dtype=np.int64)
109
- self.indices[dim].remove_ids(ids_np)
110
-
111
- self._save(dim)
112
-
113
- def _save(self, dim):
114
- faiss.write_index(self.indices[dim], os.path.join(self.storage_path, f"index_{dim}.faiss"))
115
- with open(os.path.join(self.storage_path, f"meta_{dim}.pkl"), "wb") as f:
116
- pickle.dump(self.metadata[dim], f)