dcorcoran commited on
Commit
79fc3f2
·
1 Parent(s): 1323f77

Changed embedding pulls

Browse files
Files changed (1) hide show
  1. src/api/modern_bert_utils.py +12 -7
src/api/modern_bert_utils.py CHANGED
@@ -97,19 +97,24 @@ def _load_posts_and_index():
97
  _cached_posts = session.query(Whole_Blogs).all()
98
 
99
  # Load precomputed embeddings
100
- # If EMBEDDINGS_PATH starts with "s3://", parse it
101
  if EMBEDDINGS_PATH.startswith("s3://"):
102
- # Example: s3://my-bucket/path/to/travel_blog_embeddings.pt
103
- # Remove "s3://"
104
- path_without_s3 = EMBEDDINGS_PATH[5:]
105
  bucket_name, key = path_without_s3.split("/", 1)
106
  data = load_embeddings_from_s3(bucket_name, key)
107
  else:
108
- # fallback to local file
109
  data = torch.load(EMBEDDINGS_PATH, weights_only=True)
110
 
111
-
112
- _embeddings = data["embeddings"]
 
 
 
 
 
 
 
 
 
113
 
114
  # Build FAISS index
115
  _index = faiss.IndexFlatL2(_embeddings.shape[1])
 
97
  _cached_posts = session.query(Whole_Blogs).all()
98
 
99
  # Load precomputed embeddings
 
100
  if EMBEDDINGS_PATH.startswith("s3://"):
101
+ path_without_s3 = EMBEDDINGS_PATH[5:]
 
 
102
  bucket_name, key = path_without_s3.split("/", 1)
103
  data = load_embeddings_from_s3(bucket_name, key)
104
  else:
 
105
  data = torch.load(EMBEDDINGS_PATH, weights_only=True)
106
 
107
+ # Convert embeddings dict to tensor array
108
+ # Data is {blog_id: embedding_tensor, ...}
109
+ # We need to convert to a stacked tensor aligned with _cached_posts order
110
+ embedding_list = []
111
+ for post in _cached_posts:
112
+ if post.id in data:
113
+ embedding_list.append(data[post.id])
114
+ else:
115
+ raise ValueError(f"Missing embedding for blog post ID {post.id}")
116
+
117
+ _embeddings = torch.stack(embedding_list)
118
 
119
  # Build FAISS index
120
  _index = faiss.IndexFlatL2(_embeddings.shape[1])