Ojochegbeng commited on
Commit
020892f
·
verified ·
1 Parent(s): 9e9c055

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -35
app.py CHANGED
@@ -65,7 +65,7 @@ def load_model():
65
  return True
66
  except Exception as fallback_error:
67
  logger.error(f"Fallback model loading also failed: {str(fallback_error)}")
68
- return False
69
 
70
  def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
71
  """Generate embeddings for input text(s) using Qwen3-Embedding-0.6B model"""
@@ -129,18 +129,30 @@ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List
129
 
130
  except Exception as e:
131
  logger.warning(f"Error generating embedding for text: {str(e)}")
132
- # Return zero vector as last resort
133
- embeddings.append([0.0] * 1024) # Qwen3-Embedding-0.6B has 1024 dimensions
 
 
 
 
 
134
 
135
  return embeddings[0] if single_text else embeddings
136
 
137
  except Exception as e:
138
  logger.error(f"Error in generate_embeddings: {str(e)}")
139
- # Return zero vectors as fallback
 
 
 
 
 
 
 
140
  if single_text:
141
- return [0.0] * 1024
142
  else:
143
- return [[0.0] * 1024] * len(texts)
144
 
145
  def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float:
146
  """Compute cosine similarity between two embeddings"""
@@ -278,29 +290,61 @@ async def health():
278
  async def predict(data: dict):
279
  """Main prediction endpoint for embeddings"""
280
  try:
281
- if "data" not in data:
282
- raise HTTPException(status_code=400, detail="Missing 'data' field in request")
283
-
284
- input_data = data["data"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
- # Handle single text or batch texts
287
- if isinstance(input_data, str):
288
- # Single text
289
- embeddings = generate_embeddings(input_data)
290
- return {"data": [embeddings]}
291
- elif isinstance(input_data, list):
292
- if len(input_data) > 0 and isinstance(input_data[0], str):
293
- # Single text in list
294
- embeddings = generate_embeddings(input_data[0])
295
- return {"data": [embeddings]}
296
- elif len(input_data) > 0 and isinstance(input_data[0], list):
297
- # Batch texts
298
- embeddings = generate_embeddings(input_data[0])
299
  return {"data": [embeddings]}
 
 
 
 
 
 
 
 
 
 
 
300
  else:
301
- raise HTTPException(status_code=400, detail="Invalid data format")
302
  else:
303
- raise HTTPException(status_code=400, detail="Invalid data type")
304
 
305
  except Exception as e:
306
  logger.error(f"Error in predict endpoint: {str(e)}")
@@ -308,19 +352,42 @@ async def predict(data: dict):
308
 
309
  @app.post("/api/similarity")
310
  async def similarity(data: dict):
311
- """Compute similarity between two embeddings"""
312
  try:
313
- if "embedding1" not in data or "embedding2" not in data:
314
- raise HTTPException(status_code=400, detail="Missing embedding1 or embedding2 field")
315
-
316
- emb1 = data["embedding1"]
317
- emb2 = data["embedding2"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
- if not isinstance(emb1, list) or not isinstance(emb2, list):
320
- raise HTTPException(status_code=400, detail="Embeddings must be lists")
 
 
 
 
 
 
 
 
321
 
322
- sim = compute_similarity(emb1, emb2)
323
- return {"similarity": sim}
324
 
325
  except Exception as e:
326
  logger.error(f"Error in similarity endpoint: {str(e)}")
 
65
  return True
66
  except Exception as fallback_error:
67
  logger.error(f"Fallback model loading also failed: {str(fallback_error)}")
68
+ return False
69
 
70
  def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
71
  """Generate embeddings for input text(s) using Qwen3-Embedding-0.6B model"""
 
129
 
130
  except Exception as e:
131
  logger.warning(f"Error generating embedding for text: {str(e)}")
132
+ # Return zero vector as last resort - use correct dimension based on model type
133
+ if hasattr(model, 'config') and hasattr(model.config, 'hidden_size'):
134
+ # Qwen3 model dimension
135
+ embeddings.append([0.0] * model.config.hidden_size)
136
+ else:
137
+ # Fallback model dimension (384 for all-MiniLM-L6-v2)
138
+ embeddings.append([0.0] * 384)
139
 
140
  return embeddings[0] if single_text else embeddings
141
 
142
  except Exception as e:
143
  logger.error(f"Error in generate_embeddings: {str(e)}")
144
+ # Return zero vectors as fallback - use correct dimension
145
+ if hasattr(model, 'config') and hasattr(model.config, 'hidden_size'):
146
+ # Qwen3 model dimension
147
+ fallback_dim = model.config.hidden_size
148
+ else:
149
+ # Fallback model dimension (384 for all-MiniLM-L6-v2)
150
+ fallback_dim = 384
151
+
152
  if single_text:
153
+ return [0.0] * fallback_dim
154
  else:
155
+ return [[0.0] * fallback_dim] * len(texts)
156
 
157
  def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float:
158
  """Compute cosine similarity between two embeddings"""
 
290
  async def predict(data: dict):
291
  """Main prediction endpoint for embeddings"""
292
  try:
293
+ # Check for new format first (texts parameter)
294
+ if "texts" in data:
295
+ texts = data["texts"]
296
+ normalize = data.get("normalize", True)
297
+
298
+ if not isinstance(texts, list):
299
+ raise HTTPException(status_code=400, detail="'texts' must be a list")
300
+
301
+ if len(texts) == 0:
302
+ raise HTTPException(status_code=400, detail="'texts' list cannot be empty")
303
+
304
+ # Generate embeddings
305
+ logger.info(f"Generating embeddings for {len(texts)} texts")
306
+ embeddings = generate_embeddings(texts)
307
+ logger.info(f"Generated {len(embeddings)} embeddings with dimension {len(embeddings[0]) if embeddings else 0}")
308
+
309
+ # Normalize embeddings if requested
310
+ if normalize:
311
+ import numpy as np
312
+ embeddings = [emb / np.linalg.norm(emb) for emb in embeddings]
313
+ logger.info("Embeddings normalized")
314
+
315
+ return {
316
+ "embeddings": embeddings,
317
+ "model": MODEL_NAME,
318
+ "usage": {
319
+ "prompt_tokens": sum(len(text.split()) for text in texts),
320
+ "total_tokens": sum(len(text.split()) for text in texts)
321
+ }
322
+ }
323
 
324
+ # Fallback to old format for backward compatibility
325
+ elif "data" in data:
326
+ input_data = data["data"]
327
+
328
+ # Handle single text or batch texts
329
+ if isinstance(input_data, str):
330
+ # Single text
331
+ embeddings = generate_embeddings(input_data)
 
 
 
 
 
332
  return {"data": [embeddings]}
333
+ elif isinstance(input_data, list):
334
+ if len(input_data) > 0 and isinstance(input_data[0], str):
335
+ # Single text in list
336
+ embeddings = generate_embeddings(input_data[0])
337
+ return {"data": [embeddings]}
338
+ elif len(input_data) > 0 and isinstance(input_data[0], list):
339
+ # Batch texts
340
+ embeddings = generate_embeddings(input_data[0])
341
+ return {"data": [embeddings]}
342
+ else:
343
+ raise HTTPException(status_code=400, detail="Invalid data format")
344
  else:
345
+ raise HTTPException(status_code=400, detail="Invalid data type")
346
  else:
347
+ raise HTTPException(status_code=400, detail="Missing 'texts' or 'data' field in request")
348
 
349
  except Exception as e:
350
  logger.error(f"Error in predict endpoint: {str(e)}")
 
352
 
353
  @app.post("/api/similarity")
354
  async def similarity(data: dict):
355
+ """Compute similarity between two texts or embeddings"""
356
  try:
357
+ # Check for new format first (text1, text2 parameters)
358
+ if "text1" in data and "text2" in data:
359
+ text1 = data["text1"]
360
+ text2 = data["text2"]
361
+
362
+ if not isinstance(text1, str) or not isinstance(text2, str):
363
+ raise HTTPException(status_code=400, detail="text1 and text2 must be strings")
364
+
365
+ # Generate embeddings for both texts
366
+ emb1 = generate_embeddings(text1)
367
+ emb2 = generate_embeddings(text2)
368
+
369
+ # Compute similarity
370
+ sim = compute_similarity(emb1, emb2)
371
+ return {
372
+ "similarity": sim,
373
+ "model": MODEL_NAME,
374
+ "text1": text1,
375
+ "text2": text2
376
+ }
377
 
378
+ # Fallback to old format (embedding1, embedding2 parameters)
379
+ elif "embedding1" in data and "embedding2" in data:
380
+ emb1 = data["embedding1"]
381
+ emb2 = data["embedding2"]
382
+
383
+ if not isinstance(emb1, list) or not isinstance(emb2, list):
384
+ raise HTTPException(status_code=400, detail="Embeddings must be lists")
385
+
386
+ sim = compute_similarity(emb1, emb2)
387
+ return {"similarity": sim}
388
 
389
+ else:
390
+ raise HTTPException(status_code=400, detail="Missing 'text1' and 'text2' or 'embedding1' and 'embedding2' fields")
391
 
392
  except Exception as e:
393
  logger.error(f"Error in similarity endpoint: {str(e)}")