Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -113,7 +113,7 @@ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List
|
|
| 113 |
embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy()
|
| 114 |
else:
|
| 115 |
# Simple mean pooling without attention mask
|
| 116 |
-
|
| 117 |
else:
|
| 118 |
# Fallback to pooled output if available
|
| 119 |
embedding = outputs.pooler_output.squeeze().cpu().numpy()
|
|
@@ -123,9 +123,9 @@ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List
|
|
| 123 |
elif model and hasattr(model, 'encode'):
|
| 124 |
# Method 2: Using sentence transformer fallback
|
| 125 |
embedding = model.encode(text)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
|
| 130 |
except Exception as e:
|
| 131 |
logger.warning(f"Error generating embedding for text: {str(e)}")
|
|
@@ -309,8 +309,12 @@ async def predict(data: dict):
|
|
| 309 |
# Normalize embeddings if requested
|
| 310 |
if normalize:
|
| 311 |
import numpy as np
|
| 312 |
-
|
| 313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
return {
|
| 316 |
"embeddings": embeddings,
|
|
|
|
| 113 |
embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy()
|
| 114 |
else:
|
| 115 |
# Simple mean pooling without attention mask
|
| 116 |
+
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
|
| 117 |
else:
|
| 118 |
# Fallback to pooled output if available
|
| 119 |
embedding = outputs.pooler_output.squeeze().cpu().numpy()
|
|
|
|
| 123 |
elif model and hasattr(model, 'encode'):
|
| 124 |
# Method 2: Using sentence transformer fallback
|
| 125 |
embedding = model.encode(text)
|
| 126 |
+
embeddings.append(embedding.tolist())
|
| 127 |
+
else:
|
| 128 |
+
raise Exception("No model available")
|
| 129 |
|
| 130 |
except Exception as e:
|
| 131 |
logger.warning(f"Error generating embedding for text: {str(e)}")
|
|
|
|
| 309 |
# Normalize embeddings if requested
|
| 310 |
if normalize:
|
| 311 |
import numpy as np
|
| 312 |
+
try:
|
| 313 |
+
embeddings = [emb / np.linalg.norm(emb) for emb in embeddings]
|
| 314 |
+
logger.info("Embeddings normalized")
|
| 315 |
+
except Exception as norm_error:
|
| 316 |
+
logger.warning(f"Normalization failed: {str(norm_error)}, returning unnormalized embeddings")
|
| 317 |
+
# Continue with unnormalized embeddings
|
| 318 |
|
| 319 |
return {
|
| 320 |
"embeddings": embeddings,
|