Update app.py for Part4
Browse files
app.py
CHANGED
|
@@ -236,42 +236,35 @@ def get_sorted_cosine_similarity(embeddings_metadata):
|
|
| 236 |
|
| 237 |
return sorted_cosine_sim
|
| 238 |
|
| 239 |
-
|
| 240 |
-
def plot_piechart(sorted_cosine_scores_items):
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
|
| 255 |
|
| 256 |
def plot_piechart_helper(sorted_cosine_scores_items):
|
| 257 |
-
sorted_cosine_scores = np.array(
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
for index in range(len(sorted_cosine_scores_items))
|
| 261 |
-
]
|
| 262 |
-
)
|
| 263 |
-
categories = st.session_state.categories.split(" ")
|
| 264 |
-
categories_sorted = [
|
| 265 |
-
categories[sorted_cosine_scores_items[index][0]]
|
| 266 |
-
for index in range(len(sorted_cosine_scores_items))
|
| 267 |
-
]
|
| 268 |
fig, ax = plt.subplots(figsize=(3, 3))
|
| 269 |
my_explode = np.zeros(len(categories_sorted))
|
| 270 |
my_explode[0] = 0.2
|
| 271 |
if len(categories_sorted) == 3:
|
| 272 |
-
my_explode[1] = 0.1
|
| 273 |
elif len(categories_sorted) > 3:
|
| 274 |
my_explode[2] = 0.05
|
|
|
|
| 275 |
ax.pie(
|
| 276 |
sorted_cosine_scores,
|
| 277 |
labels=categories_sorted,
|
|
@@ -314,10 +307,13 @@ def plot_piecharts(sorted_cosine_scores_models):
|
|
| 314 |
|
| 315 |
|
| 316 |
def plot_alatirchart(sorted_cosine_scores_models):
|
|
|
|
|
|
|
| 317 |
models = list(sorted_cosine_scores_models.keys())
|
| 318 |
tabs = st.tabs(models)
|
| 319 |
figs = {}
|
| 320 |
for model in models:
|
|
|
|
| 321 |
figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
|
| 322 |
|
| 323 |
for index in range(len(tabs)):
|
|
@@ -325,12 +321,6 @@ def plot_alatirchart(sorted_cosine_scores_models):
|
|
| 325 |
st.pyplot(figs[models[index]])
|
| 326 |
|
| 327 |
|
| 328 |
-
# 测试
|
| 329 |
-
|
| 330 |
-
import os
|
| 331 |
-
print("Current Working Directory:", os.getcwd())
|
| 332 |
-
|
| 333 |
-
|
| 334 |
### Text Search ###
|
| 335 |
st.sidebar.title("GloVe Twitter")
|
| 336 |
st.sidebar.markdown(
|
|
@@ -343,17 +333,14 @@ Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Gl
|
|
| 343 |
)
|
| 344 |
|
| 345 |
|
| 346 |
-
# 初始化 Session State 变量
|
| 347 |
if 'categories' not in st.session_state:
|
| 348 |
st.session_state['categories'] = "Flowers Colors Cars Weather Food"
|
| 349 |
if 'text_search' not in st.session_state:
|
| 350 |
st.session_state['text_search'] = "Roses are red, trucks are blue, and Seattle is grey right now"
|
| 351 |
|
| 352 |
-
# ... [其余 Streamlit 代码]
|
| 353 |
|
| 354 |
model_type = st.sidebar.selectbox("Choose the model", ("25d", "50d"), index=1)
|
| 355 |
|
| 356 |
-
|
| 357 |
st.title("Search Based Retrieval Demo")
|
| 358 |
st.subheader(
|
| 359 |
"Pass in space separated categories you want this search demo to be about."
|
|
@@ -364,12 +351,11 @@ st.subheader(
|
|
| 364 |
# )
|
| 365 |
|
| 366 |
|
| 367 |
-
#
|
| 368 |
user_categories = st.text_input(
|
| 369 |
label="Categories", value=st.session_state.categories
|
| 370 |
)
|
| 371 |
|
| 372 |
-
# 更新 Session State 变量 - 修改的地方
|
| 373 |
st.session_state.categories = user_categories
|
| 374 |
|
| 375 |
# st.text_input(
|
|
@@ -395,7 +381,6 @@ user_text_search = st.text_input(
|
|
| 395 |
|
| 396 |
)
|
| 397 |
|
| 398 |
-
# 更新 Session State 变量 - 修改的地方
|
| 399 |
st.session_state.text_search = user_text_search
|
| 400 |
# st.session_state.text_search = text_search
|
| 401 |
|
|
@@ -449,25 +434,21 @@ if st.session_state.text_search:
|
|
| 449 |
+ " as per different Embeddings"
|
| 450 |
)
|
| 451 |
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
# plot_alatirchart(
|
| 457 |
-
# {
|
| 458 |
-
# "glove_" + str(model_type): sorted_cosine_sim_glove,
|
| 459 |
-
# "sentence_transformer_384": sorted_cosine_sim_transformer,
|
| 460 |
-
# }
|
| 461 |
-
# )
|
| 462 |
-
# "distilbert_512": sorted_distilbert})
|
| 463 |
-
|
| 464 |
-
# 修改的地方!
|
| 465 |
-
# Display the closest category result for GloVe and Sentence Transformer Embeddings
|
| 466 |
-
st.write(f"The closest category in GloVe embeddings is: {list(sorted_cosine_sim_glove.keys())[0]}")
|
| 467 |
st.write(
|
| 468 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
|
| 470 |
st.write("")
|
| 471 |
st.write(
|
| 472 |
-
"Demo developed by [
|
| 473 |
)
|
|
|
|
| 236 |
|
| 237 |
return sorted_cosine_sim
|
| 238 |
|
| 239 |
+
#
|
| 240 |
+
# def plot_piechart(sorted_cosine_scores_items):
|
| 241 |
+
# sorted_cosine_scores = np.array([
|
| 242 |
+
# sorted_cosine_scores_items[index][1]
|
| 243 |
+
# for index in range(len(sorted_cosine_scores_items))
|
| 244 |
+
# ]
|
| 245 |
+
# )
|
| 246 |
+
# categories = st.session_state.categories.split(" ")
|
| 247 |
+
# categories_sorted = [
|
| 248 |
+
# categories[sorted_cosine_scores_items[index][0]]
|
| 249 |
+
# for index in range(len(sorted_cosine_scores_items))
|
| 250 |
+
# ]
|
| 251 |
+
# fig, ax = plt.subplots()
|
| 252 |
+
# ax.pie(sorted_cosine_scores, labels=categories_sorted, autopct="%1.1f%%")
|
| 253 |
+
# st.pyplot(fig) # Figure
|
| 254 |
|
| 255 |
|
| 256 |
def plot_piechart_helper(sorted_cosine_scores_items):
|
| 257 |
+
sorted_cosine_scores = np.array(list(sorted_cosine_scores_items.values()))
|
| 258 |
+
categories_sorted = list(sorted_cosine_scores_items.keys())
|
| 259 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
fig, ax = plt.subplots(figsize=(3, 3))
|
| 261 |
my_explode = np.zeros(len(categories_sorted))
|
| 262 |
my_explode[0] = 0.2
|
| 263 |
if len(categories_sorted) == 3:
|
| 264 |
+
my_explode[1] = 0.1
|
| 265 |
elif len(categories_sorted) > 3:
|
| 266 |
my_explode[2] = 0.05
|
| 267 |
+
|
| 268 |
ax.pie(
|
| 269 |
sorted_cosine_scores,
|
| 270 |
labels=categories_sorted,
|
|
|
|
| 307 |
|
| 308 |
|
| 309 |
def plot_alatirchart(sorted_cosine_scores_models):
|
| 310 |
+
|
| 311 |
+
|
| 312 |
models = list(sorted_cosine_scores_models.keys())
|
| 313 |
tabs = st.tabs(models)
|
| 314 |
figs = {}
|
| 315 |
for model in models:
|
| 316 |
+
# modified
|
| 317 |
figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
|
| 318 |
|
| 319 |
for index in range(len(tabs)):
|
|
|
|
| 321 |
st.pyplot(figs[models[index]])
|
| 322 |
|
| 323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
### Text Search ###
|
| 325 |
st.sidebar.title("GloVe Twitter")
|
| 326 |
st.sidebar.markdown(
|
|
|
|
| 333 |
)
|
| 334 |
|
| 335 |
|
|
|
|
| 336 |
if 'categories' not in st.session_state:
|
| 337 |
st.session_state['categories'] = "Flowers Colors Cars Weather Food"
|
| 338 |
if 'text_search' not in st.session_state:
|
| 339 |
st.session_state['text_search'] = "Roses are red, trucks are blue, and Seattle is grey right now"
|
| 340 |
|
|
|
|
| 341 |
|
| 342 |
model_type = st.sidebar.selectbox("Choose the model", ("25d", "50d"), index=1)
|
| 343 |
|
|
|
|
| 344 |
st.title("Search Based Retrieval Demo")
|
| 345 |
st.subheader(
|
| 346 |
"Pass in space separated categories you want this search demo to be about."
|
|
|
|
| 351 |
# )
|
| 352 |
|
| 353 |
|
| 354 |
+
# categories of user input
|
| 355 |
user_categories = st.text_input(
|
| 356 |
label="Categories", value=st.session_state.categories
|
| 357 |
)
|
| 358 |
|
|
|
|
| 359 |
st.session_state.categories = user_categories
|
| 360 |
|
| 361 |
# st.text_input(
|
|
|
|
| 381 |
|
| 382 |
)
|
| 383 |
|
|
|
|
| 384 |
st.session_state.text_search = user_text_search
|
| 385 |
# st.session_state.text_search = text_search
|
| 386 |
|
|
|
|
| 434 |
+ " as per different Embeddings"
|
| 435 |
)
|
| 436 |
|
| 437 |
+
print(sorted_cosine_sim_glove)
|
| 438 |
+
print(sorted_cosine_sim_transformer)
|
| 439 |
+
|
| 440 |
+
st.write(f"Closest category using GloVe embeddings : {list(sorted_cosine_sim_glove.keys())[0]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
st.write(
|
| 442 |
+
f"Closest category using Sentence Transformer embeddings : {list(sorted_cosine_sim_transformer.keys())[0]}")
|
| 443 |
+
|
| 444 |
+
plot_alatirchart(
|
| 445 |
+
{
|
| 446 |
+
"glove_" + str(model_type): sorted_cosine_sim_glove,
|
| 447 |
+
"sentence_transformer_384": sorted_cosine_sim_transformer,
|
| 448 |
+
}
|
| 449 |
+
)
|
| 450 |
|
| 451 |
st.write("")
|
| 452 |
st.write(
|
| 453 |
+
"Demo developed by [V50](https://huggingface.co/spaces/ericlkc/V50)"
|
| 454 |
)
|