dejanseo commited on
Commit
dc71e16
·
verified ·
1 Parent(s): 896fe0d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +12 -7
src/streamlit_app.py CHANGED
@@ -60,9 +60,15 @@ if st.button("Calculate Similarities"):
60
  st.stop()
61
 
62
  try:
63
- st.info("Generating embeddings with MRL=256, ubinary")
64
- primary_embedding = model.encode(primary_keyword, convert_to_tensor=False, normalize_embeddings=True, device="cpu", mrl=256, ubinary=True)
65
- keyword_embeddings = model.encode(keyword_list, convert_to_tensor=False, normalize_embeddings=True, device="cpu", mrl=256, ubinary=True)
 
 
 
 
 
 
66
  except Exception as e:
67
  st.error(f"Embedding failed: {e}")
68
  st.stop()
@@ -90,8 +96,7 @@ if st.button("Calculate Similarities"):
90
  # 3D PCA Plot
91
  st.subheader("3D PCA of Embeddings")
92
  pca = PCA(n_components=3)
93
- all_embeddings = np.vstack([primary_embedding] + keyword_embeddings)
94
- pca_result = pca.fit_transform(all_embeddings)
95
 
96
  pca_df = pd.DataFrame(pca_result, columns=["PC1", "PC2", "PC3"])
97
  pca_df["Label"] = ["Primary"] + keyword_list
@@ -99,8 +104,8 @@ if st.button("Calculate Similarities"):
99
  st.plotly_chart(fig, use_container_width=True)
100
 
101
  with st.expander("🔧 Technical Details (click to expand)"):
102
- st.write("Primary Embedding:", primary_embedding)
103
- st.write("Keyword Embeddings:", keyword_embeddings)
104
 
105
  # Footer
106
  st.markdown("---")
 
60
  st.stop()
61
 
62
  try:
63
+ st.info("Generating embeddings...")
64
+ all_texts = [primary_keyword] + keyword_list
65
+ embeddings = model.encode(all_texts, normalize_embeddings=True)
66
+
67
+ # Apply Matryoshka Representation Learning: slice to 256 dims
68
+ mrl_embeddings = embeddings[:, :256]
69
+
70
+ primary_embedding = mrl_embeddings[0]
71
+ keyword_embeddings = mrl_embeddings[1:]
72
  except Exception as e:
73
  st.error(f"Embedding failed: {e}")
74
  st.stop()
 
96
  # 3D PCA Plot
97
  st.subheader("3D PCA of Embeddings")
98
  pca = PCA(n_components=3)
99
+ pca_result = pca.fit_transform(mrl_embeddings)
 
100
 
101
  pca_df = pd.DataFrame(pca_result, columns=["PC1", "PC2", "PC3"])
102
  pca_df["Label"] = ["Primary"] + keyword_list
 
104
  st.plotly_chart(fig, use_container_width=True)
105
 
106
  with st.expander("🔧 Technical Details (click to expand)"):
107
+ st.write("Primary Embedding:", primary_embedding.tolist())
108
+ st.write("Keyword Embeddings:", keyword_embeddings.tolist())
109
 
110
  # Footer
111
  st.markdown("---")