Ojochegbeng commited on
Commit
68092ea
·
verified ·
1 Parent(s): 020892f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -113,7 +113,7 @@ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List
113
  embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy()
114
  else:
115
  # Simple mean pooling without attention mask
116
- embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
117
  else:
118
  # Fallback to pooled output if available
119
  embedding = outputs.pooler_output.squeeze().cpu().numpy()
@@ -123,9 +123,9 @@ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List
123
  elif model and hasattr(model, 'encode'):
124
  # Method 2: Using sentence transformer fallback
125
  embedding = model.encode(text)
126
- embeddings.append(embedding.tolist())
127
- else:
128
- raise Exception("No model available")
129
 
130
  except Exception as e:
131
  logger.warning(f"Error generating embedding for text: {str(e)}")
@@ -309,8 +309,12 @@ async def predict(data: dict):
309
  # Normalize embeddings if requested
310
  if normalize:
311
  import numpy as np
312
- embeddings = [emb / np.linalg.norm(emb) for emb in embeddings]
313
- logger.info("Embeddings normalized")
 
 
 
 
314
 
315
  return {
316
  "embeddings": embeddings,
 
113
  embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy()
114
  else:
115
  # Simple mean pooling without attention mask
116
+ embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
117
  else:
118
  # Fallback to pooled output if available
119
  embedding = outputs.pooler_output.squeeze().cpu().numpy()
 
123
  elif model and hasattr(model, 'encode'):
124
  # Method 2: Using sentence transformer fallback
125
  embedding = model.encode(text)
126
+ embeddings.append(embedding.tolist())
127
+ else:
128
+ raise Exception("No model available")
129
 
130
  except Exception as e:
131
  logger.warning(f"Error generating embedding for text: {str(e)}")
 
309
  # Normalize embeddings if requested
310
  if normalize:
311
  import numpy as np
312
+ try:
313
+ embeddings = [emb / np.linalg.norm(emb) for emb in embeddings]
314
+ logger.info("Embeddings normalized")
315
+ except Exception as norm_error:
316
+ logger.warning(f"Normalization failed: {str(norm_error)}, returning unnormalized embeddings")
317
+ # Continue with unnormalized embeddings
318
 
319
  return {
320
  "embeddings": embeddings,