Soumik Bose commited on
Commit
9ab4c8b
·
1 Parent(s): 0ba7ee8
Files changed (1) hide show
  1. main.py +114 -156
main.py CHANGED
@@ -1,13 +1,17 @@
1
- from fastapi import FastAPI, HTTPException, Security, Depends, Header
2
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from pydantic import BaseModel, Field
5
- from typing import List, Union, Optional
6
  import os
7
  import logging
8
  import asyncio
9
- from concurrent.futures import ThreadPoolExecutor
10
  import multiprocessing
 
 
 
 
 
 
 
 
 
 
11
  from model_service import LocalEmbeddingService
12
 
13
  # ============================================================================
@@ -15,39 +19,83 @@ from model_service import LocalEmbeddingService
15
  # ============================================================================
16
  logging.basicConfig(
17
  level=logging.INFO,
18
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
19
- handlers=[
20
- logging.StreamHandler()
21
- ]
22
  )
23
- logger = logging.getLogger(__name__)
24
 
25
  # ============================================================================
26
- # CONFIGURATION
27
  # ============================================================================
28
  LOCAL_MODEL_PATH = os.getenv('MODEL_PATH', './models/bge-base-en-v1.5')
29
- AUTH_TOKEN = os.getenv('AUTH_TOKEN', None) # Set via environment variable
30
  ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS', '*').split(',')
31
 
32
- # Detect CPU cores for optimal workers
33
- CPU_COUNT = multiprocessing.cpu_count()
34
- MAX_WORKERS = CPU_COUNT * 2 # 2x CPU cores for I/O-bound operations
35
- logger.info(f"Detected {CPU_COUNT} CPU cores. Using {MAX_WORKERS} max workers for thread pool.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # ============================================================================
38
- # FASTAPI APP INITIALIZATION
39
  # ============================================================================
40
  app = FastAPI(
41
  title="BGE Embedding API",
42
- description="Production-grade embedding inference API using BAAI/bge-base-en-v1.5",
43
  version="2.0.0",
 
44
  docs_url="/docs",
45
  redoc_url="/redoc"
46
  )
47
 
48
- # ============================================================================
49
- # CORS MIDDLEWARE
50
- # ============================================================================
51
  app.add_middleware(
52
  CORSMiddleware,
53
  allow_origins=ALLOWED_ORIGINS,
@@ -55,7 +103,6 @@ app.add_middleware(
55
  allow_methods=["*"],
56
  allow_headers=["*"],
57
  )
58
- logger.info(f"CORS enabled for origins: {ALLOWED_ORIGINS}")
59
 
60
  # ============================================================================
61
  # SECURITY
@@ -63,101 +110,44 @@ logger.info(f"CORS enabled for origins: {ALLOWED_ORIGINS}")
63
  security = HTTPBearer(auto_error=False)
64
 
65
  async def verify_token(credentials: Optional[HTTPAuthorizationCredentials] = Security(security)):
66
- """Verify Bearer token if AUTH_TOKEN is set."""
67
- if AUTH_TOKEN is None:
68
- # No authentication required
69
  return True
70
-
71
- if credentials is None:
72
- logger.warning("Authentication required but no token provided")
73
  raise HTTPException(
74
  status_code=401,
75
  detail="Authentication required",
76
  headers={"WWW-Authenticate": "Bearer"},
77
  )
78
-
79
  if credentials.credentials != AUTH_TOKEN:
80
- logger.warning(f"Invalid token attempt: {credentials.credentials[:10]}...")
81
  raise HTTPException(
82
  status_code=401,
83
  detail="Invalid authentication token",
84
  headers={"WWW-Authenticate": "Bearer"},
85
  )
86
-
87
  return True
88
 
89
  # ============================================================================
90
- # GLOBAL STATE
91
- # ============================================================================
92
- service = None
93
- executor = None
94
-
95
- @app.on_event("startup")
96
- async def startup_event():
97
- """Load the model on startup and initialize thread pool."""
98
- global service, executor
99
-
100
- try:
101
- logger.info("=" * 60)
102
- logger.info("Starting BGE Embedding Service")
103
- logger.info("=" * 60)
104
-
105
- # Initialize thread pool executor for non-blocking operations
106
- executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
107
- logger.info(f"Thread pool executor initialized with {MAX_WORKERS} workers")
108
-
109
- # Load model
110
- logger.info(f"Loading model from: {LOCAL_MODEL_PATH}")
111
- service = LocalEmbeddingService(LOCAL_MODEL_PATH)
112
- logger.info(f"✅ Model loaded successfully! Dimension: {service.embedding_dim}")
113
-
114
- # Authentication status
115
- if AUTH_TOKEN:
116
- logger.info("🔒 Authentication enabled (Bearer token required)")
117
- else:
118
- logger.warning("⚠️ Authentication disabled (no AUTH_TOKEN set)")
119
-
120
- logger.info("=" * 60)
121
- logger.info("Service ready to accept requests")
122
- logger.info("=" * 60)
123
-
124
- except Exception as e:
125
- logger.error(f"❌ Failed to initialize service: {e}", exc_info=True)
126
- raise
127
-
128
- @app.on_event("shutdown")
129
- async def shutdown_event():
130
- """Cleanup on shutdown."""
131
- global executor
132
- logger.info("Shutting down service...")
133
-
134
- if executor:
135
- executor.shutdown(wait=True)
136
- logger.info("Thread pool executor shut down")
137
-
138
- logger.info("Service shutdown complete")
139
-
140
- # ============================================================================
141
- # REQUEST/RESPONSE MODELS
142
  # ============================================================================
143
  class EmbedRequest(BaseModel):
144
  text: Union[str, List[str]] = Field(
145
- ...,
146
  description="Single text string or list of texts to embed"
147
  )
148
-
149
- class Config:
150
- schema_extra = {
151
  "example": {
152
- "text": "Ginger was also a smart giraffe. She knew what was wrong."
153
  }
154
  }
 
155
 
156
  class EmbedResponse(BaseModel):
157
- embeddings: Union[List[float], List[List[float]]] = Field(
158
- ...,
159
- description="Generated embedding(s)"
160
- )
161
  dimension: int = Field(..., description="Embedding dimension")
162
  count: int = Field(..., description="Number of texts processed")
163
 
@@ -167,98 +157,66 @@ class EmbedResponse(BaseModel):
167
 
168
  @app.get("/")
169
  async def root():
170
- """API information."""
171
  return {
172
- "message": "BGE Embedding API - Production Ready",
173
- "model": "BAAI/bge-base-en-v1.5",
174
- "dimension": 768,
175
  "version": "2.0.0",
176
- "authentication": "enabled" if AUTH_TOKEN else "disabled",
177
- "endpoints": {
178
- "health": "/health",
179
- "ping": "/ping",
180
- "embed": "/embed",
181
- "embeddings": "/embeddings",
182
- "docs": "/docs"
183
- }
184
  }
185
 
186
  @app.get("/health")
187
  async def health_check():
188
- """Check if the service is healthy."""
189
- if service is None:
190
- logger.error("Health check failed: service not initialized")
191
- raise HTTPException(status_code=503, detail="Service not initialized")
192
 
193
  return {
194
  "status": "healthy",
195
- "model_dimension": service.embedding_dim,
196
- "model_path": LOCAL_MODEL_PATH,
197
- "max_workers": MAX_WORKERS,
198
- "cpu_count": CPU_COUNT
199
  }
200
 
201
  @app.get("/ping")
202
  async def ping():
203
- """Simple ping endpoint for keep-alive."""
204
  return {"status": "ok", "message": "pong"}
205
 
206
- @app.post("/embed", response_model=EmbedResponse)
207
- async def create_embeddings(
208
- request: EmbedRequest,
209
- authenticated: bool = Depends(verify_token)
210
- ):
211
  """
212
- Generate embeddings for the provided text(s) - Non-blocking operation.
213
-
214
- - **text**: Single string or list of strings to embed
215
-
216
- Returns normalized 768-dimensional embeddings suitable for cosine similarity.
217
-
218
- Requires Bearer token authentication if AUTH_TOKEN is set.
219
  """
220
- if service is None:
221
- logger.error("Embedding request failed: service not initialized")
222
- raise HTTPException(status_code=503, detail="Service not initialized")
223
-
 
 
224
  try:
225
- # Determine input type and count
226
  is_single = isinstance(request.text, str)
227
  count = 1 if is_single else len(request.text)
228
-
229
- logger.info(f"Processing embedding request for {count} text(s)")
230
-
231
- # Run embedding generation in thread pool (non-blocking)
232
- loop = asyncio.get_event_loop()
233
  embeddings = await loop.run_in_executor(
234
  executor,
235
  service.generate_embedding,
236
  request.text
237
  )
238
-
239
- logger.info(f"✅ Successfully generated {count} embedding(s)")
240
-
241
  return EmbedResponse(
242
  embeddings=embeddings,
243
  dimension=service.embedding_dim,
244
  count=count
245
  )
246
-
247
  except Exception as e:
248
- logger.error(f" Embedding generation failed: {e}", exc_info=True)
249
- raise HTTPException(
250
- status_code=500,
251
- detail=f"Embedding generation failed: {str(e)}"
252
- )
253
 
254
- @app.post("/embeddings", response_model=EmbedResponse)
255
- async def create_embeddings_batch(
256
- request: EmbedRequest,
257
- authenticated: bool = Depends(verify_token)
258
- ):
259
- """
260
- Alias for /embed endpoint - Non-blocking batch embedding generation.
261
-
262
- Requires Bearer token authentication if AUTH_TOKEN is set.
263
- """
264
- return await create_embeddings(request, authenticated)
 
 
 
 
 
 
1
  import os
2
  import logging
3
  import asyncio
 
4
  import multiprocessing
5
+ from contextlib import asynccontextmanager
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ from typing import Union, List, Optional, Any
8
+
9
+ from fastapi import FastAPI, HTTPException, Security, Depends
10
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from pydantic import BaseModel, Field
13
+
14
+ # Ensure this module exists in your project
15
  from model_service import LocalEmbeddingService
16
 
17
  # ============================================================================
 
19
  # ============================================================================
20
  logging.basicConfig(
21
  level=logging.INFO,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
 
 
 
23
  )
24
+ logger = logging.getLogger("EmbedAPI")
25
 
26
  # ============================================================================
27
+ # CONFIGURATION & STATE
28
  # ============================================================================
29
  LOCAL_MODEL_PATH = os.getenv('MODEL_PATH', './models/bge-base-en-v1.5')
30
+ AUTH_TOKEN = os.getenv('AUTH_TOKEN', None)
31
  ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS', '*').split(',')
32
 
33
+ # Global resource container
34
+ ml_context = {
35
+ "service": None,
36
+ "executor": None
37
+ }
38
+
39
+ # ============================================================================
40
+ # LIFESPAN MANAGER (Replaces deprecated startup/shutdown events)
41
+ # ============================================================================
42
+ @asynccontextmanager
43
+ async def lifespan(app: FastAPI):
44
+ """
45
+ Manages the application lifecycle.
46
+ Initializes the model and thread pool on startup, and cleans them up on shutdown.
47
+ """
48
+ # --- Startup Phase ---
49
+ logger.info("Initializing BGE Embedding Service...")
50
+
51
+ # 1. Setup Thread Pool for CPU-bound inference
52
+ try:
53
+ cpu_count = multiprocessing.cpu_count()
54
+ max_workers = cpu_count * 2
55
+ executor = ThreadPoolExecutor(max_workers=max_workers)
56
+ ml_context["executor"] = executor
57
+ logger.info(f"Thread pool initialized with {max_workers} workers.")
58
+ except Exception as e:
59
+ logger.error(f"Failed to initialize thread pool: {e}")
60
+ raise e
61
+
62
+ # 2. Load ML Model
63
+ try:
64
+ logger.info(f"Loading model from: {LOCAL_MODEL_PATH}")
65
+ service = LocalEmbeddingService(LOCAL_MODEL_PATH)
66
+ ml_context["service"] = service
67
+ logger.info(f"Model loaded successfully. Dimension: {service.embedding_dim}")
68
+ except Exception as e:
69
+ logger.critical(f"Critical error loading model: {e}", exc_info=True)
70
+ raise e
71
+
72
+ # 3. Log Auth Status
73
+ if AUTH_TOKEN:
74
+ logger.info("Authentication enabled (Bearer token required).")
75
+ else:
76
+ logger.warning("Authentication disabled (no AUTH_TOKEN set).")
77
+
78
+ yield # Application runs here
79
+
80
+ # --- Shutdown Phase ---
81
+ logger.info("Shutting down service...")
82
+ if ml_context["executor"]:
83
+ ml_context["executor"].shutdown(wait=True)
84
+ ml_context.clear()
85
+ logger.info("Shutdown complete.")
86
 
87
  # ============================================================================
88
+ # APP INITIALIZATION
89
  # ============================================================================
90
  app = FastAPI(
91
  title="BGE Embedding API",
92
+ description="Production-grade embedding inference API.",
93
  version="2.0.0",
94
+ lifespan=lifespan,
95
  docs_url="/docs",
96
  redoc_url="/redoc"
97
  )
98
 
 
 
 
99
  app.add_middleware(
100
  CORSMiddleware,
101
  allow_origins=ALLOWED_ORIGINS,
 
103
  allow_methods=["*"],
104
  allow_headers=["*"],
105
  )
 
106
 
107
  # ============================================================================
108
  # SECURITY
 
110
  security = HTTPBearer(auto_error=False)
111
 
112
  async def verify_token(credentials: Optional[HTTPAuthorizationCredentials] = Security(security)):
113
+ """Dependency to verify Bearer token if configured."""
114
+ if not AUTH_TOKEN:
 
115
  return True
116
+
117
+ if not credentials:
 
118
  raise HTTPException(
119
  status_code=401,
120
  detail="Authentication required",
121
  headers={"WWW-Authenticate": "Bearer"},
122
  )
123
+
124
  if credentials.credentials != AUTH_TOKEN:
 
125
  raise HTTPException(
126
  status_code=401,
127
  detail="Invalid authentication token",
128
  headers={"WWW-Authenticate": "Bearer"},
129
  )
 
130
  return True
131
 
132
  # ============================================================================
133
+ # DATA MODELS (Pydantic V2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  # ============================================================================
135
  class EmbedRequest(BaseModel):
136
  text: Union[str, List[str]] = Field(
137
+ ...,
138
  description="Single text string or list of texts to embed"
139
  )
140
+
141
+ model_config = {
142
+ "json_schema_extra": {
143
  "example": {
144
+ "text": ["First sentence to embed.", "Second sentence to embed."]
145
  }
146
  }
147
+ }
148
 
149
  class EmbedResponse(BaseModel):
150
+ embeddings: Union[List[float], List[List[float]]] = Field(..., description="Generated vector(s)")
 
 
 
151
  dimension: int = Field(..., description="Embedding dimension")
152
  count: int = Field(..., description="Number of texts processed")
153
 
 
157
 
158
  @app.get("/")
159
  async def root():
160
+ """API Metadata."""
161
  return {
162
+ "service": "BGE Embedding API",
163
+ "status": "running",
 
164
  "version": "2.0.0",
165
+ "authentication": "enabled" if AUTH_TOKEN else "disabled"
 
 
 
 
 
 
 
166
  }
167
 
168
  @app.get("/health")
169
  async def health_check():
170
+ """Liveness probe to ensure model is loaded."""
171
+ if not ml_context["service"]:
172
+ raise HTTPException(status_code=503, detail="Service not ready")
 
173
 
174
  return {
175
  "status": "healthy",
176
+ "dimension": ml_context["service"].embedding_dim
 
 
 
177
  }
178
 
179
  @app.get("/ping")
180
  async def ping():
181
+ """Simple keep-alive endpoint."""
182
  return {"status": "ok", "message": "pong"}
183
 
184
+ @app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(verify_token)])
185
+ async def create_embeddings(request: EmbedRequest):
 
 
 
186
  """
187
+ Generate embeddings.
188
+ Runs inference in a separate thread pool to prevent blocking the async event loop.
 
 
 
 
 
189
  """
190
+ service = ml_context.get("service")
191
+ executor = ml_context.get("executor")
192
+
193
+ if not service or not executor:
194
+ raise HTTPException(status_code=503, detail="Service unavailable")
195
+
196
  try:
197
+ # Determine if input is single string or list
198
  is_single = isinstance(request.text, str)
199
  count = 1 if is_single else len(request.text)
200
+
201
+ # Execute blocking model code in the thread pool
202
+ loop = asyncio.get_running_loop()
 
 
203
  embeddings = await loop.run_in_executor(
204
  executor,
205
  service.generate_embedding,
206
  request.text
207
  )
208
+
 
 
209
  return EmbedResponse(
210
  embeddings=embeddings,
211
  dimension=service.embedding_dim,
212
  count=count
213
  )
214
+
215
  except Exception as e:
216
+ logger.error(f"Inference failed: {e}", exc_info=True)
217
+ raise HTTPException(status_code=500, detail="Internal processing error")
 
 
 
218
 
219
+ @app.post("/embeddings", response_model=EmbedResponse, dependencies=[Depends(verify_token)])
220
+ async def create_embeddings_alias(request: EmbedRequest):
221
+ """Alias for /embed endpoint."""
222
+ return await create_embeddings(request)