nurulajt commited on
Commit
cc89204
·
verified ·
1 Parent(s): b810e9b

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +169 -19
api.py CHANGED
@@ -1,12 +1,14 @@
1
  """
2
  Embedding Inference API
3
- Supports JobBERT v2, Jina AI, and Voyage AI embeddings
 
4
  """
5
 
6
- from fastapi import FastAPI, HTTPException
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel, Field
9
- from typing import List, Optional
10
  from sentence_transformers import SentenceTransformer
11
  import os
12
  import logging
@@ -30,8 +32,18 @@ app.add_middleware(
30
 
31
  MODELS = {}
32
  VOYAGE_API_KEY = os.environ.get('VOYAGE_API_KEY', '')
 
 
 
 
33
  voyage_client = None
34
 
 
 
 
 
 
 
35
  if VOYAGE_API_KEY:
36
  try:
37
  import voyageai
@@ -62,11 +74,52 @@ def load_models():
62
  logger.error(f"Error loading models: {e}")
63
  raise
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  @app.on_event("startup")
66
  async def startup_event():
67
  load_models()
68
 
69
- class EmbeddingRequest(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  texts: List[str] = Field(..., description="List of texts to embed", min_items=1)
71
  model: str = Field(..., description="Model to use: 'jobbertv2', 'jobbertv3', 'jina', or 'voyage'")
72
  task: Optional[str] = Field(None, description="Task type for Jina AI: 'retrieval.query', 'retrieval.passage', 'text-matching', etc.")
@@ -81,7 +134,7 @@ class EmbeddingRequest(BaseModel):
81
  }
82
  }
83
 
84
- class EmbeddingResponse(BaseModel):
85
  embeddings: List[List[float]] = Field(..., description="List of embedding vectors")
86
  model: str = Field(..., description="Model used")
87
  dimension: int = Field(..., description="Embedding dimension")
@@ -91,6 +144,7 @@ class HealthResponse(BaseModel):
91
  status: str
92
  models_loaded: List[str]
93
  voyage_available: bool
 
94
 
95
  @app.get("/", response_model=dict)
96
  async def root():
@@ -100,25 +154,121 @@ async def root():
100
  "version": "1.0.0",
101
  "endpoints": {
102
  "/health": "Health check and available models",
103
- "/embed": "Generate embeddings (POST)",
 
 
104
  "/docs": "API documentation"
105
  }
106
  }
107
 
108
  @app.get("/health", response_model=HealthResponse)
109
  async def health():
110
- """Health check endpoint"""
111
  models_loaded = list(MODELS.keys())
112
  return {
113
  "status": "healthy",
114
  "models_loaded": models_loaded,
115
- "voyage_available": voyage_client is not None
 
116
  }
117
 
118
- @app.post("/embed", response_model=EmbeddingResponse)
119
- async def create_embeddings(request: EmbeddingRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  """
121
- Generate embeddings for input texts
122
 
123
  **Models:**
124
  - `jobbertv2`: JobBERT-v2 (768-dim, job-specific)
@@ -147,16 +297,16 @@ async def create_embeddings(request: EmbeddingRequest):
147
  )
148
 
149
  try:
150
- input_type = request.input_type or "document"
151
  result = voyage_client.embed(
152
  texts=request.texts,
153
  model="voyage-3",
154
- input_type=input_type
155
  )
156
  embeddings = result.embeddings
157
  dimension = len(embeddings[0]) if embeddings else 0
158
 
159
- return EmbeddingResponse(
160
  embeddings=embeddings,
161
  model="voyage-3",
162
  dimension=dimension,
@@ -167,16 +317,16 @@ async def create_embeddings(request: EmbeddingRequest):
167
 
168
  elif model_name in MODELS:
169
  try:
170
- model = MODELS[model_name]
171
 
172
  if model_name == "jina" and request.task:
173
- embeddings = model.encode(
174
  request.texts,
175
  task=request.task,
176
  convert_to_numpy=True
177
  )
178
  else:
179
- embeddings = model.encode(
180
  request.texts,
181
  convert_to_numpy=True
182
  )
@@ -184,7 +334,7 @@ async def create_embeddings(request: EmbeddingRequest):
184
  embeddings_list = embeddings.tolist()
185
  dimension = len(embeddings_list[0]) if embeddings_list else 0
186
 
187
- return EmbeddingResponse(
188
  embeddings=embeddings_list,
189
  model=model_name,
190
  dimension=dimension,
@@ -200,7 +350,7 @@ async def create_embeddings(request: EmbeddingRequest):
200
  )
201
 
202
  @app.get("/models")
203
- async def list_models():
204
  """List available models and their specifications"""
205
  models_info = {
206
  "jobbertv2": {
 
1
  """
2
  Embedding Inference API
3
+ Supports JobBERT v2/v3, Jina AI, and Voyage AI embeddings
4
+ Compatible with Elasticsearch inference endpoint format
5
  """
6
 
7
+ from fastapi import FastAPI, HTTPException, Query, Security, Depends
8
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel, Field
11
+ from typing import List, Optional, Union
12
  from sentence_transformers import SentenceTransformer
13
  import os
14
  import logging
 
32
 
33
  MODELS = {}
34
  VOYAGE_API_KEY = os.environ.get('VOYAGE_API_KEY', '')
35
+ API_KEY = os.environ.get('API_KEY', '')
36
+ REQUIRE_API_KEY = os.environ.get('REQUIRE_API_KEY', 'false').lower() == 'true'
37
+
38
+ security = HTTPBearer(auto_error=False)
39
  voyage_client = None
40
 
41
+ logger.info(f"API Key authentication: {'ENABLED' if REQUIRE_API_KEY else 'DISABLED'}")
42
+ if API_KEY:
43
+ logger.info(f"✓ API Key configured (length: {len(API_KEY)})")
44
+ else:
45
+ logger.info("ℹ️ No API Key set")
46
+
47
  if VOYAGE_API_KEY:
48
  try:
49
  import voyageai
 
74
  logger.error(f"Error loading models: {e}")
75
  raise
76
 
77
+ async def verify_api_key(credentials: Optional[HTTPAuthorizationCredentials] = Security(security)):
78
+ """Verify API key from Authorization header"""
79
+ if not REQUIRE_API_KEY:
80
+ return True
81
+
82
+ if not API_KEY:
83
+ raise HTTPException(
84
+ status_code=500,
85
+ detail="API key authentication is enabled but no API key is configured on the server"
86
+ )
87
+
88
+ if credentials is None:
89
+ raise HTTPException(
90
+ status_code=401,
91
+ detail="Missing authentication credentials. Use: Authorization: Bearer YOUR_API_KEY"
92
+ )
93
+
94
+ if credentials.credentials != API_KEY:
95
+ raise HTTPException(
96
+ status_code=403,
97
+ detail="Invalid API key"
98
+ )
99
+
100
+ return True
101
+
102
  @app.on_event("startup")
103
  async def startup_event():
104
  load_models()
105
 
106
+ class ElasticsearchInferenceRequest(BaseModel):
107
+ input: Union[str, List[str]] = Field(..., description="Text or list of texts to embed")
108
+
109
+ class Config:
110
+ schema_extra = {
111
+ "example": {
112
+ "input": "Software Engineer"
113
+ }
114
+ }
115
+
116
+ class ElasticsearchInferenceResponse(BaseModel):
117
+ embedding: List[float] = Field(..., description="Embedding vector for single input")
118
+
119
+ class ElasticsearchInferenceBatchResponse(BaseModel):
120
+ embeddings: List[List[float]] = Field(..., description="List of embedding vectors for batch input")
121
+
122
+ class BatchEmbeddingRequest(BaseModel):
123
  texts: List[str] = Field(..., description="List of texts to embed", min_items=1)
124
  model: str = Field(..., description="Model to use: 'jobbertv2', 'jobbertv3', 'jina', or 'voyage'")
125
  task: Optional[str] = Field(None, description="Task type for Jina AI: 'retrieval.query', 'retrieval.passage', 'text-matching', etc.")
 
134
  }
135
  }
136
 
137
+ class BatchEmbeddingResponse(BaseModel):
138
  embeddings: List[List[float]] = Field(..., description="List of embedding vectors")
139
  model: str = Field(..., description="Model used")
140
  dimension: int = Field(..., description="Embedding dimension")
 
144
  status: str
145
  models_loaded: List[str]
146
  voyage_available: bool
147
+ api_key_required: bool
148
 
149
  @app.get("/", response_model=dict)
150
  async def root():
 
154
  "version": "1.0.0",
155
  "endpoints": {
156
  "/health": "Health check and available models",
157
+ "/embed": "Generate embeddings - Elasticsearch compatible (POST)",
158
+ "/embed/batch": "Generate batch embeddings (POST)",
159
+ "/models": "List available models",
160
  "/docs": "API documentation"
161
  }
162
  }
163
 
164
  @app.get("/health", response_model=HealthResponse)
165
  async def health():
166
+ """Health check endpoint (no authentication required)"""
167
  models_loaded = list(MODELS.keys())
168
  return {
169
  "status": "healthy",
170
  "models_loaded": models_loaded,
171
+ "voyage_available": voyage_client is not None,
172
+ "api_key_required": REQUIRE_API_KEY
173
  }
174
 
175
+ @app.post("/embed", response_model=Union[ElasticsearchInferenceResponse, ElasticsearchInferenceBatchResponse])
176
+ async def create_embeddings_elasticsearch(
177
+ request: ElasticsearchInferenceRequest,
178
+ model: str = Query("jobbertv3", description="Model: jobbertv2, jobbertv3, jina, or voyage"),
179
+ task: Optional[str] = Query(None, description="Task for Jina AI: retrieval.query, retrieval.passage, text-matching, etc."),
180
+ input_type: Optional[str] = Query(None, description="Input type for Voyage AI: document or query"),
181
+ authenticated: bool = Depends(verify_api_key)
182
+ ):
183
+ """
184
+ Generate embeddings - Elasticsearch inference endpoint compatible format
185
+
186
+ **Usage:**
187
+ - Single text: `POST /embed?model=jobbertv3` with body `{"input": "Software Engineer"}`
188
+ - Multiple texts: `POST /embed?model=jina` with body `{"input": ["text1", "text2"]}`
189
+
190
+ **Models (via query parameter):**
191
+ - `jobbertv2`: JobBERT-v2 (768-dim, job-specific)
192
+ - `jobbertv3`: JobBERT-v3 (768-dim, job-specific, improved performance) - default
193
+ - `jina`: Jina AI embeddings-v3 (1024-dim, general purpose)
194
+ - `voyage`: Voyage AI (1024-dim, requires API key)
195
+
196
+ **Jina AI Tasks (via query parameter):**
197
+ - `retrieval.query`: For search queries
198
+ - `retrieval.passage`: For documents/passages
199
+ - `text-matching`: For similarity matching (default)
200
+
201
+ **Voyage AI Input Types (via query parameter):**
202
+ - `document`: For documents/passages
203
+ - `query`: For search queries
204
+ """
205
+ model_name = model.lower()
206
+
207
+ # Handle single string or list of strings
208
+ is_single = isinstance(request.input, str)
209
+ texts = [request.input] if is_single else request.input
210
+
211
+ if model_name == "voyage":
212
+ if not voyage_client:
213
+ raise HTTPException(
214
+ status_code=503,
215
+ detail="Voyage AI not available. Set VOYAGE_API_KEY environment variable."
216
+ )
217
+
218
+ try:
219
+ voyage_input_type = input_type or "document"
220
+ result = voyage_client.embed(
221
+ texts=texts,
222
+ model="voyage-3",
223
+ input_type=voyage_input_type
224
+ )
225
+ embeddings = result.embeddings
226
+
227
+ if is_single:
228
+ return ElasticsearchInferenceResponse(embedding=embeddings[0])
229
+ else:
230
+ return ElasticsearchInferenceBatchResponse(embeddings=embeddings)
231
+ except Exception as e:
232
+ raise HTTPException(status_code=500, detail=f"Voyage AI error: {str(e)}")
233
+
234
+ elif model_name in MODELS:
235
+ try:
236
+ selected_model = MODELS[model_name]
237
+
238
+ if model_name == "jina" and task:
239
+ embeddings = selected_model.encode(
240
+ texts,
241
+ task=task,
242
+ convert_to_numpy=True
243
+ )
244
+ else:
245
+ embeddings = selected_model.encode(
246
+ texts,
247
+ convert_to_numpy=True
248
+ )
249
+
250
+ embeddings_list = embeddings.tolist()
251
+
252
+ if is_single:
253
+ return ElasticsearchInferenceResponse(embedding=embeddings_list[0])
254
+ else:
255
+ return ElasticsearchInferenceBatchResponse(embeddings=embeddings_list)
256
+ except Exception as e:
257
+ raise HTTPException(status_code=500, detail=f"Model error: {str(e)}")
258
+
259
+ else:
260
+ raise HTTPException(
261
+ status_code=400,
262
+ detail=f"Invalid model '{model_name}'. Choose from: jobbertv2, jobbertv3, jina, voyage"
263
+ )
264
+
265
+ @app.post("/embed/batch", response_model=BatchEmbeddingResponse)
266
+ async def create_embeddings_batch(
267
+ request: BatchEmbeddingRequest,
268
+ authenticated: bool = Depends(verify_api_key)
269
+ ):
270
  """
271
+ Generate embeddings for multiple texts - Original batch format
272
 
273
  **Models:**
274
  - `jobbertv2`: JobBERT-v2 (768-dim, job-specific)
 
297
  )
298
 
299
  try:
300
+ voyage_input_type = request.input_type or "document"
301
  result = voyage_client.embed(
302
  texts=request.texts,
303
  model="voyage-3",
304
+ input_type=voyage_input_type
305
  )
306
  embeddings = result.embeddings
307
  dimension = len(embeddings[0]) if embeddings else 0
308
 
309
+ return BatchEmbeddingResponse(
310
  embeddings=embeddings,
311
  model="voyage-3",
312
  dimension=dimension,
 
317
 
318
  elif model_name in MODELS:
319
  try:
320
+ selected_model = MODELS[model_name]
321
 
322
  if model_name == "jina" and request.task:
323
+ embeddings = selected_model.encode(
324
  request.texts,
325
  task=request.task,
326
  convert_to_numpy=True
327
  )
328
  else:
329
+ embeddings = selected_model.encode(
330
  request.texts,
331
  convert_to_numpy=True
332
  )
 
334
  embeddings_list = embeddings.tolist()
335
  dimension = len(embeddings_list[0]) if embeddings_list else 0
336
 
337
+ return BatchEmbeddingResponse(
338
  embeddings=embeddings_list,
339
  model=model_name,
340
  dimension=dimension,
 
350
  )
351
 
352
  @app.get("/models")
353
+ async def list_models(authenticated: bool = Depends(verify_api_key)):
354
  """List available models and their specifications"""
355
  models_info = {
356
  "jobbertv2": {