Soumik Bose commited on
Commit
58f4a9c
·
1 Parent(s): 08a63bd
Files changed (3) hide show
  1. Dockerfile +9 -5
  2. main.py +84 -107
  3. model_service.py +34 -40
Dockerfile CHANGED
@@ -5,7 +5,6 @@ FROM python:3.11-slim
5
  ENV PYTHONDONTWRITEBYTECODE=1 \
6
  PYTHONUNBUFFERED=1 \
7
  PYTHONIOENCODING=UTF-8 \
8
- # Set HF_HOME to a writable directory
9
  HF_HOME=/app/cache \
10
  TRANSFORMERS_CACHE=/app/cache
11
 
@@ -20,10 +19,15 @@ WORKDIR /app
20
  COPY --chown=user:user requirements.txt .
21
  RUN pip install --no-cache-dir -r requirements.txt
22
 
23
- # --- LAYER 2: Model Download (Cached) ---
24
- # Instead of copying local files, we download the model during the build.
25
- # This layer will be CACHED and won't run again unless you change this line.
26
- RUN python3 -c "from huggingface_hub import snapshot_download; snapshot_download(repo_id='BAAI/bge-base-en-v1.5', local_dir='./models/bge-base-en-v1.5')"
 
 
 
 
 
27
 
28
  # --- LAYER 3: Application Code ---
29
  COPY --chown=user:user . .
 
5
  ENV PYTHONDONTWRITEBYTECODE=1 \
6
  PYTHONUNBUFFERED=1 \
7
  PYTHONIOENCODING=UTF-8 \
 
8
  HF_HOME=/app/cache \
9
  TRANSFORMERS_CACHE=/app/cache
10
 
 
19
  COPY --chown=user:user requirements.txt .
20
  RUN pip install --no-cache-dir -r requirements.txt
21
 
22
+ # --- LAYER 2: Download Models (Cached) ---
23
+ # We download models for 384, 768, and 1024 dimensions.
24
+ # 384 dim: BAAI/bge-small-en-v1.5
25
+ # 768 dim: BAAI/bge-base-en-v1.5
26
+ # 1024 dim: BAAI/bge-large-en-v1.5
27
+ RUN python3 -c "from huggingface_hub import snapshot_download; \
28
+ snapshot_download(repo_id='BAAI/bge-small-en-v1.5', local_dir='./models/bge-384'); \
29
+ snapshot_download(repo_id='BAAI/bge-base-en-v1.5', local_dir='./models/bge-768'); \
30
+ snapshot_download(repo_id='BAAI/bge-large-en-v1.5', local_dir='./models/bge-1024')"
31
 
32
  # --- LAYER 3: Application Code ---
33
  COPY --chown=user:user . .
main.py CHANGED
@@ -11,11 +11,11 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel, Field
13
 
14
- # Ensure this module exists in your project
15
- from model_service import LocalEmbeddingService
16
 
17
  # ============================================================================
18
- # LOGGING CONFIGURATION
19
  # ============================================================================
20
  logging.basicConfig(
21
  level=logging.INFO,
@@ -24,76 +24,60 @@ logging.basicConfig(
24
  logger = logging.getLogger("EmbedAPI")
25
 
26
  # ============================================================================
27
- # CONFIGURATION & STATE
28
  # ============================================================================
29
- LOCAL_MODEL_PATH = os.getenv('MODEL_PATH', './models/bge-base-en-v1.5')
30
  AUTH_TOKEN = os.getenv('AUTH_TOKEN', None)
31
  ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS', '*').split(',')
32
 
33
- # Global resource container
34
  ml_context = {
35
  "service": None,
36
  "executor": None
37
  }
38
 
39
  # ============================================================================
40
- # LIFESPAN MANAGER (Replaces deprecated startup/shutdown events)
41
  # ============================================================================
42
  @asynccontextmanager
43
  async def lifespan(app: FastAPI):
44
- """
45
- Manages the application lifecycle.
46
- Initializes the model and thread pool on startup, and cleans them up on shutdown.
47
- """
48
- # --- Startup Phase ---
49
- logger.info("Initializing BGE Embedding Service...")
50
-
51
- # 1. Setup Thread Pool for CPU-bound inference
 
 
 
 
52
  try:
53
- cpu_count = multiprocessing.cpu_count()
54
- max_workers = cpu_count * 2
55
- executor = ThreadPoolExecutor(max_workers=max_workers)
56
- ml_context["executor"] = executor
57
- logger.info(f"Thread pool initialized with {max_workers} workers.")
58
- except Exception as e:
59
- logger.error(f"Failed to initialize thread pool: {e}")
60
- raise e
61
-
62
- # 2. Load ML Model
63
- try:
64
- logger.info(f"Loading model from: {LOCAL_MODEL_PATH}")
65
- service = LocalEmbeddingService(LOCAL_MODEL_PATH)
66
  ml_context["service"] = service
67
- logger.info(f"Model loaded successfully. Dimension: {service.embedding_dim}")
68
  except Exception as e:
69
- logger.critical(f"Critical error loading model: {e}", exc_info=True)
70
  raise e
71
 
72
- # 3. Log Auth Status
73
  if AUTH_TOKEN:
74
- logger.info("Authentication enabled (Bearer token required).")
75
- else:
76
- logger.warning("Authentication disabled (no AUTH_TOKEN set).")
77
-
78
- yield # Application runs here
79
-
80
- # --- Shutdown Phase ---
81
- logger.info("Shutting down service...")
82
  if ml_context["executor"]:
83
  ml_context["executor"].shutdown(wait=True)
84
  ml_context.clear()
85
- logger.info("Shutdown complete.")
86
 
87
  # ============================================================================
88
- # APP INITIALIZATION
89
  # ============================================================================
90
  app = FastAPI(
91
- title="BGE Embedding API",
92
- description="Production-grade embedding inference API.",
93
- version="2.0.0",
94
- lifespan=lifespan,
95
- docs_url="/docs",
96
- redoc_url="/redoc"
97
  )
98
 
99
  app.add_middleware(
@@ -104,88 +88,58 @@ app.add_middleware(
104
  allow_headers=["*"],
105
  )
106
 
107
- # ============================================================================
108
- # SECURITY
109
- # ============================================================================
110
  security = HTTPBearer(auto_error=False)
111
 
112
  async def verify_token(credentials: Optional[HTTPAuthorizationCredentials] = Security(security)):
113
- """Dependency to verify Bearer token if configured."""
114
  if not AUTH_TOKEN:
115
  return True
116
-
117
- if not credentials:
118
- raise HTTPException(
119
- status_code=401,
120
- detail="Authentication required",
121
- headers={"WWW-Authenticate": "Bearer"},
122
- )
123
-
124
- if credentials.credentials != AUTH_TOKEN:
125
- raise HTTPException(
126
- status_code=401,
127
- detail="Invalid authentication token",
128
- headers={"WWW-Authenticate": "Bearer"},
129
- )
130
  return True
131
 
132
  # ============================================================================
133
- # DATA MODELS (Pydantic V2)
134
  # ============================================================================
135
  class EmbedRequest(BaseModel):
136
- text: Union[str, List[str]] = Field(
137
- ...,
138
- description="Single text string or list of texts to embed"
139
- )
140
 
141
  model_config = {
142
  "json_schema_extra": {
143
  "example": {
144
- "text": ["First sentence to embed.", "Second sentence to embed."]
 
145
  }
146
  }
147
  }
148
 
149
  class EmbedResponse(BaseModel):
150
- embeddings: Union[List[float], List[List[float]]] = Field(..., description="Generated vector(s)")
151
- dimension: int = Field(..., description="Embedding dimension")
152
- count: int = Field(..., description="Number of texts processed")
 
 
 
153
 
154
  # ============================================================================
155
  # ENDPOINTS
156
  # ============================================================================
157
 
158
- @app.get("/")
159
- async def root():
160
- """API Metadata."""
161
- return {
162
- "service": "BGE Embedding API",
163
- "status": "running",
164
- "version": "2.0.0",
165
- "authentication": "enabled" if AUTH_TOKEN else "disabled"
166
- }
167
-
168
  @app.get("/health")
169
  async def health_check():
170
- """Liveness probe to ensure model is loaded."""
171
- if not ml_context["service"]:
172
  raise HTTPException(status_code=503, detail="Service not ready")
173
-
174
  return {
175
  "status": "healthy",
176
- "dimension": ml_context["service"].embedding_dim
177
  }
178
 
179
- @app.get("/ping")
180
- async def ping():
181
- """Simple keep-alive endpoint."""
182
- return {"status": "ok", "message": "pong"}
183
-
184
  @app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(verify_token)])
185
  async def create_embeddings(request: EmbedRequest):
186
  """
187
- Generate embeddings.
188
- Runs inference in a separate thread pool to prevent blocking the async event loop.
189
  """
190
  service = ml_context.get("service")
191
  executor = ml_context.get("executor")
@@ -193,30 +147,53 @@ async def create_embeddings(request: EmbedRequest):
193
  if not service or not executor:
194
  raise HTTPException(status_code=503, detail="Service unavailable")
195
 
 
 
 
 
 
 
196
  try:
197
- # Determine if input is single string or list
198
- is_single = isinstance(request.text, str)
199
- count = 1 if is_single else len(request.text)
200
 
201
- # Execute blocking model code in the thread pool
202
  loop = asyncio.get_running_loop()
203
  embeddings = await loop.run_in_executor(
204
  executor,
205
  service.generate_embedding,
206
- request.text
 
207
  )
208
 
209
  return EmbedResponse(
210
  embeddings=embeddings,
211
- dimension=service.embedding_dim,
212
  count=count
213
  )
214
 
215
  except Exception as e:
216
- logger.error(f"Inference failed: {e}", exc_info=True)
217
- raise HTTPException(status_code=500, detail="Internal processing error")
218
 
219
- @app.post("/embeddings", response_model=EmbedResponse, dependencies=[Depends(verify_token)])
220
- async def create_embeddings_alias(request: EmbedRequest):
221
- """Alias for /embed endpoint."""
222
- return await create_embeddings(request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel, Field
13
 
14
+ # Import the new MultiEmbeddingService
15
+ from model_service import MultiEmbeddingService
16
 
17
  # ============================================================================
18
+ # LOGGING
19
  # ============================================================================
20
  logging.basicConfig(
21
  level=logging.INFO,
 
24
  logger = logging.getLogger("EmbedAPI")
25
 
26
  # ============================================================================
27
+ # CONFIGURATION
28
  # ============================================================================
 
29
  AUTH_TOKEN = os.getenv('AUTH_TOKEN', None)
30
  ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS', '*').split(',')
31
 
32
+ # Global context container
33
  ml_context = {
34
  "service": None,
35
  "executor": None
36
  }
37
 
38
  # ============================================================================
39
+ # LIFESPAN MANAGER
40
  # ============================================================================
41
  @asynccontextmanager
42
  async def lifespan(app: FastAPI):
43
+ """Lifecycle manager: Loads models and thread pool."""
44
+ # --- Startup ---
45
+ logger.info("Initializing Multi-Dimensional Embedding Service...")
46
+
47
+ # 1. Thread Pool
48
+ cpu_count = multiprocessing.cpu_count()
49
+ max_workers = cpu_count * 2
50
+ executor = ThreadPoolExecutor(max_workers=max_workers)
51
+ ml_context["executor"] = executor
52
+ logger.info(f"Thread pool ready: {max_workers} workers")
53
+
54
+ # 2. Load Models
55
  try:
56
+ service = MultiEmbeddingService()
57
+ service.load_all_models() # Loads 384, 768, 1024 models
 
 
 
 
 
 
 
 
 
 
 
58
  ml_context["service"] = service
 
59
  except Exception as e:
60
+ logger.critical(f"Critical error loading models: {e}", exc_info=True)
61
  raise e
62
 
 
63
  if AUTH_TOKEN:
64
+ logger.info("🔒 Auth enabled.")
65
+
66
+ yield
67
+
68
+ # --- Shutdown ---
69
+ logger.info("Shutting down...")
 
 
70
  if ml_context["executor"]:
71
  ml_context["executor"].shutdown(wait=True)
72
  ml_context.clear()
 
73
 
74
  # ============================================================================
75
+ # APP SETUP
76
  # ============================================================================
77
  app = FastAPI(
78
+ title="Multi-Dim Embedding API",
79
+ version="3.0.0",
80
+ lifespan=lifespan
 
 
 
81
  )
82
 
83
  app.add_middleware(
 
88
  allow_headers=["*"],
89
  )
90
 
 
 
 
91
  security = HTTPBearer(auto_error=False)
92
 
93
  async def verify_token(credentials: Optional[HTTPAuthorizationCredentials] = Security(security)):
 
94
  if not AUTH_TOKEN:
95
  return True
96
+ if not credentials or credentials.credentials != AUTH_TOKEN:
97
+ raise HTTPException(status_code=401, detail="Invalid token")
 
 
 
 
 
 
 
 
 
 
 
 
98
  return True
99
 
100
  # ============================================================================
101
+ # MODELS
102
  # ============================================================================
103
  class EmbedRequest(BaseModel):
104
+ data: Union[str, List[str]] = Field(..., description="Text string or list of strings")
105
+ dimension: int = Field(768, description="Target dimension (384, 768, or 1024)")
 
 
106
 
107
  model_config = {
108
  "json_schema_extra": {
109
  "example": {
110
+ "data": ["Hello world", "Machine learning is great"],
111
+ "dimension": 768
112
  }
113
  }
114
  }
115
 
116
  class EmbedResponse(BaseModel):
117
+ embeddings: Union[List[float], List[List[float]]] = Field(...)
118
+ dimension: int
119
+ count: int
120
+
121
+ class DeEmbedRequest(BaseModel):
122
+ vector: List[float] = Field(..., description="The embedding vector to decode")
123
 
124
  # ============================================================================
125
  # ENDPOINTS
126
  # ============================================================================
127
 
 
 
 
 
 
 
 
 
 
 
128
  @app.get("/health")
129
  async def health_check():
130
+ service = ml_context.get("service")
131
+ if not service:
132
  raise HTTPException(status_code=503, detail="Service not ready")
 
133
  return {
134
  "status": "healthy",
135
+ "loaded_dimensions": list(service.models.keys())
136
  }
137
 
 
 
 
 
 
138
  @app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(verify_token)])
139
  async def create_embeddings(request: EmbedRequest):
140
  """
141
+ Generate embeddings for specific dimensions.
142
+ Supported dimensions: 384, 768, 1024.
143
  """
144
  service = ml_context.get("service")
145
  executor = ml_context.get("executor")
 
147
  if not service or not executor:
148
  raise HTTPException(status_code=503, detail="Service unavailable")
149
 
150
+ if request.dimension not in service.models:
151
+ raise HTTPException(
152
+ status_code=400,
153
+ detail=f"Dimension {request.dimension} not supported. Use 384, 768, or 1024."
154
+ )
155
+
156
  try:
157
+ is_single = isinstance(request.data, str)
158
+ count = 1 if is_single else len(request.data)
 
159
 
 
160
  loop = asyncio.get_running_loop()
161
  embeddings = await loop.run_in_executor(
162
  executor,
163
  service.generate_embedding,
164
+ request.data,
165
+ request.dimension
166
  )
167
 
168
  return EmbedResponse(
169
  embeddings=embeddings,
170
+ dimension=request.dimension,
171
  count=count
172
  )
173
 
174
  except Exception as e:
175
+ logger.error(f"Inference error: {e}")
176
+ raise HTTPException(status_code=500, detail=str(e))
177
 
178
+ @app.post("/deembed", dependencies=[Depends(verify_token)])
179
+ async def de_embed_vector(request: DeEmbedRequest):
180
+ """
181
+ Experimental: Reverse vector to text.
182
+
183
+ NOTE: Mathematically, standard embedding models (BERT, BGE) are NOT reversible
184
+ because they are lossy compression algorithms.
185
+
186
+ To retrieve text from a vector, you must use a Vector Database (retrieval),
187
+ not a direct model inversion.
188
+ """
189
+ # In a real scenario, this would look like:
190
+ # result = vector_db.search(vector=request.vector, top_k=1)
191
+ # return {"text": result.text}
192
+
193
+ raise HTTPException(
194
+ status_code=501,
195
+ detail=(
196
+ "De-embedding (Vector-to-Text) is not possible with standalone embedding models. "
197
+ "This endpoint requires a connected Vector Database to perform a similarity search."
198
+ )
199
+ )
model_service.py CHANGED
@@ -1,47 +1,41 @@
1
- import os
2
- from typing import List, Union
3
  from sentence_transformers import SentenceTransformer
 
4
 
5
- class LocalEmbeddingService:
6
- """Service for generating embeddings using a locally stored model."""
7
-
8
- def __init__(self, model_folder: str):
9
- """
10
- Initialize the service by loading the model from a local path.
11
 
12
- Args:
13
- model_folder: Path to the folder containing the saved model
14
- """
15
- if not os.path.exists(model_folder):
16
- raise FileNotFoundError(
17
- f"Model folder not found at: {model_folder}. "
18
- "Please run download_model.py first."
19
- )
20
-
21
- print(f"Loading model from {model_folder}...")
22
- self.model = SentenceTransformer(model_folder)
23
- self.embedding_dim = self.model.get_sentence_embedding_dimension()
24
- print(f"✅ Model loaded successfully. Dimension: {self.embedding_dim}")
 
 
 
 
 
 
25
 
26
- def generate_embedding(self, text: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
27
- """
28
- Generate embeddings for the given text(s).
 
29
 
30
- Args:
31
- text: A single string or list of strings to embed
32
-
33
- Returns:
34
- A single embedding (list of floats) or list of embeddings
35
- """
36
- # Encode the text with normalization for cosine similarity
37
- embeddings = self.model.encode(
38
  text,
39
  normalize_embeddings=True,
40
- convert_to_tensor=False
41
- )
42
-
43
- # Convert to list for JSON serialization
44
- if isinstance(text, str):
45
- return embeddings.tolist()
46
-
47
- return embeddings.tolist()
 
1
+ import logging
 
2
  from sentence_transformers import SentenceTransformer
3
+ import torch
4
 
5
+ logger = logging.getLogger("EmbedService")
6
+
7
+ class MultiEmbeddingService:
8
+ def __init__(self):
9
+ self.models = {}
10
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ # Map dimensions to local folders (downloaded in Dockerfile)
13
+ self.model_map = {
14
+ 384: "./models/bge-384",
15
+ 768: "./models/bge-768",
16
+ 1024: "./models/bge-1024"
17
+ }
18
+
19
+ def load_all_models(self):
20
+ """Loads all defined models into memory."""
21
+ for dim, path in self.model_map.items():
22
+ try:
23
+ logger.info(f"Loading {dim}-dimension model from {path}...")
24
+ model = SentenceTransformer(path, device=self.device)
25
+ model.eval() # Set to evaluation mode
26
+ self.models[dim] = model
27
+ logger.info(f"✅ Loaded model for dimension {dim}")
28
+ except Exception as e:
29
+ logger.error(f"❌ Failed to load {dim}-dim model: {e}")
30
+ # We don't raise here, so partial failures don't crash the whole app
31
 
32
+ def generate_embedding(self, text: str | list[str], dimension: int):
33
+ """Generates embeddings using the specific model for the requested dimension."""
34
+ if dimension not in self.models:
35
+ raise ValueError(f"Dimension {dimension} not supported. Available: {list(self.models.keys())}")
36
 
37
+ return self.models[dimension].encode(
 
 
 
 
 
 
 
38
  text,
39
  normalize_embeddings=True,
40
+ convert_to_numpy=True
41
+ ).tolist()