Spaces:
Sleeping
Sleeping
Changed embedding pulls
Browse files- src/api/modern_bert_utils.py +12 -7
src/api/modern_bert_utils.py
CHANGED
|
@@ -97,19 +97,24 @@ def _load_posts_and_index():
|
|
| 97 |
_cached_posts = session.query(Whole_Blogs).all()
|
| 98 |
|
| 99 |
# Load precomputed embeddings
|
| 100 |
-
# If EMBEDDINGS_PATH starts with "s3://", parse it
|
| 101 |
if EMBEDDINGS_PATH.startswith("s3://"):
|
| 102 |
-
|
| 103 |
-
# Remove "s3://"
|
| 104 |
-
path_without_s3 = EMBEDDINGS_PATH[5:]
|
| 105 |
bucket_name, key = path_without_s3.split("/", 1)
|
| 106 |
data = load_embeddings_from_s3(bucket_name, key)
|
| 107 |
else:
|
| 108 |
-
# fallback to local file
|
| 109 |
data = torch.load(EMBEDDINGS_PATH, weights_only=True)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
# Build FAISS index
|
| 115 |
_index = faiss.IndexFlatL2(_embeddings.shape[1])
|
|
|
|
| 97 |
_cached_posts = session.query(Whole_Blogs).all()
|
| 98 |
|
| 99 |
# Load precomputed embeddings
|
|
|
|
| 100 |
if EMBEDDINGS_PATH.startswith("s3://"):
|
| 101 |
+
path_without_s3 = EMBEDDINGS_PATH[5:]
|
|
|
|
|
|
|
| 102 |
bucket_name, key = path_without_s3.split("/", 1)
|
| 103 |
data = load_embeddings_from_s3(bucket_name, key)
|
| 104 |
else:
|
|
|
|
| 105 |
data = torch.load(EMBEDDINGS_PATH, weights_only=True)
|
| 106 |
|
| 107 |
+
# Convert embeddings dict to tensor array
|
| 108 |
+
# Data is {blog_id: embedding_tensor, ...}
|
| 109 |
+
# We need to convert to a stacked tensor aligned with _cached_posts order
|
| 110 |
+
embedding_list = []
|
| 111 |
+
for post in _cached_posts:
|
| 112 |
+
if post.id in data:
|
| 113 |
+
embedding_list.append(data[post.id])
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(f"Missing embedding for blog post ID {post.id}")
|
| 116 |
+
|
| 117 |
+
_embeddings = torch.stack(embedding_list)
|
| 118 |
|
| 119 |
# Build FAISS index
|
| 120 |
_index = faiss.IndexFlatL2(_embeddings.shape[1])
|