Spaces:
Build error
Build error
Commit ·
b2ee56b
1
Parent(s): c27bf27
Update model version.
Browse files- pages/predict.py +8 -3
- requirements.txt +1 -2
pages/predict.py
CHANGED
|
@@ -36,16 +36,21 @@ st.markdown(f"**Query:** {st.session_state.query['source_node'].replace('_', ' '
|
|
| 36 |
|
| 37 |
@st.cache_data(show_spinner = 'Downloading AI model...')
|
| 38 |
def get_embeddings():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
# Get paths to embeddings, relation weights, and edge types
|
| 40 |
# with st.spinner('Downloading AI model...'):
|
| 41 |
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
| 42 |
-
filename="
|
| 43 |
token=st.secrets["HF_TOKEN"])
|
| 44 |
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
| 45 |
-
filename="
|
| 46 |
token=st.secrets["HF_TOKEN"])
|
| 47 |
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
| 48 |
-
filename="
|
| 49 |
token=st.secrets["HF_TOKEN"])
|
| 50 |
return embed_path, relation_weights_path, edge_types_path
|
| 51 |
|
|
|
|
| 36 |
|
| 37 |
@st.cache_data(show_spinner = 'Downloading AI model...')
|
| 38 |
def get_embeddings():
|
| 39 |
+
# Get checkpoint name
|
| 40 |
+
# best_ckpt = "2024_05_22_11_59_43_epoch=18-step=22912"
|
| 41 |
+
best_ckpt = "2024_05_15_13_05_33_epoch=2-step=40383"
|
| 42 |
+
# best_ckpt = "2024_03_29_04_12_52_epoch=3-step=54291"
|
| 43 |
+
|
| 44 |
# Get paths to embeddings, relation weights, and edge types
|
| 45 |
# with st.spinner('Downloading AI model...'):
|
| 46 |
embed_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
| 47 |
+
filename=(best_ckpt + "-thresh=4000_embeddings.pt"),
|
| 48 |
token=st.secrets["HF_TOKEN"])
|
| 49 |
relation_weights_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
| 50 |
+
filename=(best_ckpt + "_relation_weights.pt"),
|
| 51 |
token=st.secrets["HF_TOKEN"])
|
| 52 |
edge_types_path = hf_hub_download(repo_id="ayushnoori/galaxy",
|
| 53 |
+
filename=(best_ckpt + "_edge_types.pt"),
|
| 54 |
token=st.secrets["HF_TOKEN"])
|
| 55 |
return embed_path, relation_weights_path, edge_types_path
|
| 56 |
|
requirements.txt
CHANGED
|
@@ -8,5 +8,4 @@ torch
|
|
| 8 |
altair<5
|
| 9 |
gspread
|
| 10 |
oauth2client
|
| 11 |
-
huggingface_hub
|
| 12 |
-
matplotlib
|
|
|
|
| 8 |
altair<5
|
| 9 |
gspread
|
| 10 |
oauth2client
|
| 11 |
+
huggingface_hub
|
|
|