Anisha Bhatnagar
commited on
Commit
Β·
f4d3b67
1
Parent(s):
051d45c
removed unused code
Browse files- app.py +24 -24
- utils/visualizations.py +1 -65
app.py
CHANGED
|
@@ -55,7 +55,7 @@ def validate_ground_truth(gt1, gt2, gt3):
|
|
| 55 |
return index, f"Candidate {index+1} is marked as the ground truth author."
|
| 56 |
|
| 57 |
|
| 58 |
-
def app(share=False
|
| 59 |
instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
|
| 60 |
|
| 61 |
interp = load_interp_space(cfg)
|
|
@@ -392,12 +392,12 @@ def app(share=False, use_cluster_feats=False):
|
|
| 392 |
llm_style_feats_analysis = gr.State()
|
| 393 |
visible_zoomed_authors = gr.State()
|
| 394 |
|
| 395 |
-
if use_cluster_feats:
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
else:
|
| 400 |
-
|
| 401 |
|
| 402 |
with gr.Row():
|
| 403 |
# ββ LLM Features Column ββββββββββββββββββββββββββββββββββ
|
|
@@ -448,21 +448,21 @@ def app(share=False, use_cluster_feats=False):
|
|
| 448 |
)
|
| 449 |
|
| 450 |
# Populate feature list based on selection.
|
| 451 |
-
if use_cluster_feats:
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
else:
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
|
| 467 |
|
| 468 |
# ββ Show combined featureβspan highlights ββ
|
|
@@ -543,6 +543,6 @@ def app(share=False, use_cluster_feats=False):
|
|
| 543 |
|
| 544 |
if __name__ == "__main__":
|
| 545 |
parser = argparse.ArgumentParser()
|
| 546 |
-
parser.add_argument("--use_cluster_feats", action="store_true", help="Use cluster-based selection for features")
|
| 547 |
args = parser.parse_args()
|
| 548 |
-
app(share=True
|
|
|
|
| 55 |
return index, f"Candidate {index+1} is marked as the ground truth author."
|
| 56 |
|
| 57 |
|
| 58 |
+
def app(share=False):#, use_cluster_feats=False):
|
| 59 |
instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
|
| 60 |
|
| 61 |
interp = load_interp_space(cfg)
|
|
|
|
| 392 |
llm_style_feats_analysis = gr.State()
|
| 393 |
visible_zoomed_authors = gr.State()
|
| 394 |
|
| 395 |
+
# if use_cluster_feats:
|
| 396 |
+
# # ββ Dynamic Cluster Choice dropdown ββββββββββββββββββββββββββββββββββ
|
| 397 |
+
# gr.HTML(instruction_callout("Choose a cluster from the dropdown below to inspect whether its features appear in the mystery authorβs text."))
|
| 398 |
+
# cluster_dropdown.visible = True
|
| 399 |
+
# else:
|
| 400 |
+
gr.HTML(instruction_callout("Zoom in on the plot to select a set of background authors and see the presence of the top features from this set in candidate and mystery authors."))
|
| 401 |
|
| 402 |
with gr.Row():
|
| 403 |
# ββ LLM Features Column ββββββββββββββββββββββββββββββββββ
|
|
|
|
| 448 |
)
|
| 449 |
|
| 450 |
# Populate feature list based on selection.
|
| 451 |
+
# if use_cluster_feats:
|
| 452 |
+
# # Use cluster-based flow
|
| 453 |
+
# cluster_dropdown.change(
|
| 454 |
+
# fn=on_cluster_change,
|
| 455 |
+
# inputs=[cluster_dropdown, style_map_state],
|
| 456 |
+
# outputs=[features_rb, gram2vec_rb , feature_list_state]
|
| 457 |
+
# #adding feature_list_state to persisit all llm features in the app state
|
| 458 |
+
# )
|
| 459 |
+
# else:
|
| 460 |
+
|
| 461 |
+
axis_ranges.change(
|
| 462 |
+
fn=handle_zoom_with_retries,
|
| 463 |
+
inputs=[axis_ranges, bg_proj_state, bg_lbls_state, bg_authors_df, task_authors_embeddings_df],
|
| 464 |
+
outputs=[features_rb, gram2vec_rb , llm_style_feats_analysis, feature_list_state, visible_zoomed_authors]
|
| 465 |
+
)
|
| 466 |
|
| 467 |
|
| 468 |
# ββ Show combined featureβspan highlights ββ
|
|
|
|
| 543 |
|
| 544 |
if __name__ == "__main__":
|
| 545 |
parser = argparse.ArgumentParser()
|
| 546 |
+
# parser.add_argument("--use_cluster_feats", action="store_true", help="Use cluster-based selection for features")
|
| 547 |
args = parser.parse_args()
|
| 548 |
+
app(share=True)#, use_cluster_feats=args.use_cluster_feats)
|
utils/visualizations.py
CHANGED
|
@@ -411,33 +411,8 @@ def visualize_clusters_plotly(iid, cfg, instances, model_radio, custom_model_inp
|
|
| 411 |
# split
|
| 412 |
q_proj = proj[0]
|
| 413 |
c_proj = proj[1:4]
|
| 414 |
-
#bg_proj = proj[4:4+len(bg_lbls)]
|
| 415 |
bg_proj = proj
|
| 416 |
|
| 417 |
-
# cent_proj = proj[4+len(bg_lbls):]
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
# find nearest centroid
|
| 421 |
-
# dists = np.linalg.norm(cent_proj - q_proj, axis=1)
|
| 422 |
-
# idx = int(np.argmin(dists))
|
| 423 |
-
# cluster_label_query = cent_lbl[idx]
|
| 424 |
-
# features of the nearest centroid to display
|
| 425 |
-
# feature_list = style_names[cluster_label_query]
|
| 426 |
-
|
| 427 |
-
# cluster_labels_per_candidate = [
|
| 428 |
-
# cent_lbl[int(np.argmin(np.linalg.norm(cent_proj - c_proj[i], axis=1)))]
|
| 429 |
-
# for i in range(c_proj.shape[0])
|
| 430 |
-
# ]
|
| 431 |
-
|
| 432 |
-
# prepare colorscale
|
| 433 |
-
# n_cent = len(cent_lbl)
|
| 434 |
-
# cent_colors = sample_colorscale("algae", [i/(n_cent-1) for i in range(n_cent)])
|
| 435 |
-
# map each cluster label to its color
|
| 436 |
-
# color_map = { label: cent_colors[i] for i, label in enumerate(cent_lbl) }
|
| 437 |
-
|
| 438 |
-
# uncomment the following line to show background authors
|
| 439 |
-
## background author colors pulled from their cluster label
|
| 440 |
-
# bg_colors = [ color_map[label] for label in bg_lbls ]
|
| 441 |
|
| 442 |
# 2) build Plotly figure
|
| 443 |
fig = go.Figure()
|
|
@@ -450,13 +425,6 @@ def visualize_clusters_plotly(iid, cfg, instances, model_radio, custom_model_inp
|
|
| 450 |
# Enable zoom events
|
| 451 |
dragmode='zoom'
|
| 452 |
)
|
| 453 |
-
|
| 454 |
-
# fig.update_layout(
|
| 455 |
-
# template='plotly_white',
|
| 456 |
-
# margin=dict(l=40,r=40,t=60,b=40),
|
| 457 |
-
# autosize=True,
|
| 458 |
-
# hovermode='closest')
|
| 459 |
-
|
| 460 |
|
| 461 |
# uncomment the following line to show background authors
|
| 462 |
## background authors (light grey dots)
|
|
@@ -468,20 +436,6 @@ def visualize_clusters_plotly(iid, cfg, instances, model_radio, custom_model_inp
|
|
| 468 |
hoverinfo='skip'
|
| 469 |
))
|
| 470 |
|
| 471 |
-
# centroids (rainbow colors + hovertext of your top-k features)
|
| 472 |
-
# hover_texts = [
|
| 473 |
-
# f"Cluster {lbl}<br>" + "<br>".join(style_names[lbl])
|
| 474 |
-
# for lbl in cent_lbl
|
| 475 |
-
# ]
|
| 476 |
-
# fig.add_trace(go.Scattergl(
|
| 477 |
-
# x=cent_proj[:,0], y=cent_proj[:,1],
|
| 478 |
-
# mode='markers',
|
| 479 |
-
# marker=dict(symbol='triangle-up', size=10, color="#d3d3d3"),#color=cent_colors
|
| 480 |
-
# name='Cluster centroids',
|
| 481 |
-
# hovertext=hover_texts,
|
| 482 |
-
# hoverinfo='text'
|
| 483 |
-
# ))
|
| 484 |
-
|
| 485 |
# three candidates
|
| 486 |
marker_syms = ['diamond','pentagon','x']
|
| 487 |
for i in range(3):
|
|
@@ -557,25 +511,7 @@ def visualize_clusters_plotly(iid, cfg, instances, model_radio, custom_model_inp
|
|
| 557 |
)
|
| 558 |
|
| 559 |
print('Done processing....')
|
| 560 |
-
|
| 561 |
-
# all_clusters = sorted(style_names.keys())
|
| 562 |
-
# --- build display names for the dropdown ---
|
| 563 |
-
# sorted_labels = sorted([int(lbl) for lbl in cent_lbl])
|
| 564 |
-
# display_clusters = []
|
| 565 |
-
# for lbl in sorted_labels:
|
| 566 |
-
# name = f"Cluster {lbl}"
|
| 567 |
-
# if lbl == cluster_label_query:
|
| 568 |
-
# name += " (closest to mystery author)"
|
| 569 |
-
# matching_indices = [i + 1 for i, val in enumerate(cluster_labels_per_candidate) if int(val) == lbl]
|
| 570 |
-
# if matching_indices:
|
| 571 |
-
# if len(matching_indices) == 1:
|
| 572 |
-
# name += f" (closest to Candidate {matching_indices[0]} author)"
|
| 573 |
-
# else:
|
| 574 |
-
# candidate_str = ", ".join(f"Candidate {i}" for i in matching_indices)
|
| 575 |
-
# name += f" (closest to {candidate_str} authors)"
|
| 576 |
-
# display_clusters.append(name)
|
| 577 |
-
# print(f"All clusters: {all_clusters}")
|
| 578 |
-
# return: figure, dropdown payload, full style_map
|
| 579 |
return (
|
| 580 |
fig,
|
| 581 |
# update(choices=display_clusters, value=display_clusters[cluster_label_query]),
|
|
|
|
| 411 |
# split
|
| 412 |
q_proj = proj[0]
|
| 413 |
c_proj = proj[1:4]
|
|
|
|
| 414 |
bg_proj = proj
|
| 415 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
# 2) build Plotly figure
|
| 418 |
fig = go.Figure()
|
|
|
|
| 425 |
# Enable zoom events
|
| 426 |
dragmode='zoom'
|
| 427 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
# uncomment the following line to show background authors
|
| 430 |
## background authors (light grey dots)
|
|
|
|
| 436 |
hoverinfo='skip'
|
| 437 |
))
|
| 438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
# three candidates
|
| 440 |
marker_syms = ['diamond','pentagon','x']
|
| 441 |
for i in range(3):
|
|
|
|
| 511 |
)
|
| 512 |
|
| 513 |
print('Done processing....')
|
| 514 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
return (
|
| 516 |
fig,
|
| 517 |
# update(choices=display_clusters, value=display_clusters[cluster_label_query]),
|