yuanjunchai
commited on
Commit
·
05fff77
1
Parent(s):
170c474
add application files
Browse files
app.py
CHANGED
|
@@ -24,8 +24,7 @@ import os
|
|
| 24 |
import gdown
|
| 25 |
from sentence_transformers import SentenceTransformer
|
| 26 |
import matplotlib.pyplot as plt
|
| 27 |
-
import
|
| 28 |
-
|
| 29 |
|
| 30 |
|
| 31 |
### Some predefined utility functions for you to load the text embeddings
|
|
@@ -53,20 +52,24 @@ def get_model_id_gdrive(model_type):
|
|
| 53 |
|
| 54 |
|
| 55 |
def download_glove_embeddings_gdrive(model_type):
|
| 56 |
-
# Get glove embeddings from google drive
|
| 57 |
-
word_index_id, embeddings_id = get_model_id_gdrive(model_type)
|
| 58 |
-
|
| 59 |
# Use gdown to get files from google drive
|
| 60 |
embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
|
| 61 |
word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
|
| 71 |
|
| 72 |
# @st.cache_data()
|
|
|
|
| 24 |
import gdown
|
| 25 |
from sentence_transformers import SentenceTransformer
|
| 26 |
import matplotlib.pyplot as plt
|
| 27 |
+
from huggingface_hub import hf_hub_download
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
### Some predefined utility functions for you to load the text embeddings
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
def download_glove_embeddings_gdrive(model_type):
|
|
|
|
|
|
|
|
|
|
| 55 |
# Use gdown to get files from google drive
|
| 56 |
embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
|
| 57 |
word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
|
| 58 |
+
# 100d download
|
| 59 |
+
if model_type == "100d":
|
| 60 |
+
hf_hub_download(repo_id='AveMujica/glove-twitter-100d', filename='embeddings_100d_temp.npy')
|
| 61 |
+
hf_hub_download(repo_id='AveMujica/glove-twitter-100d', filename='word_index_dict_100d_temp.pkl')
|
| 62 |
+
else:
|
| 63 |
+
# Get glove embeddings from google drive
|
| 64 |
+
word_index_id, embeddings_id = get_model_id_gdrive(model_type)
|
| 65 |
|
| 66 |
+
# Download word_index pickle file
|
| 67 |
+
print("Downloading word index dictionary....\n")
|
| 68 |
+
gdown.download(id=word_index_id, output=word_index_temp, quiet=False)
|
| 69 |
|
| 70 |
+
# Download embeddings numpy file
|
| 71 |
+
print("Donwloading embedings...\n\n")
|
| 72 |
+
gdown.download(id=embeddings_id, output=embeddings_temp, quiet=False)
|
| 73 |
|
| 74 |
|
| 75 |
# @st.cache_data()
|