Anisha Bhatnagar commited on
Commit
6aef776
Β·
1 Parent(s): ab8b2e5

script to precompute cache

Browse files
Files changed (1) hide show
  1. precompute_caches.py +173 -0
precompute_caches.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pickle
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import pandas as pd
7
+ from datetime import datetime
8
+ import yaml
9
+
10
+ # Import your actual modules exactly as app.py does
11
+ from utils.visualizations import get_instances, load_interp_space, compute_tsne_with_cache, compute_precomputed_regions
12
+ from utils.ui import update_task_display, instance_to_df
13
+ from utils.interp_space_utils import cached_generate_style_embedding, compute_g2v_features, compute_predicted_author
14
+
15
+ def load_config(path="config/config.yaml"):
16
+ with open(path, "r") as f:
17
+ return yaml.safe_load(f)
18
+
19
+ def precompute_all_caches(
20
+ models_to_test=None,
21
+ instances_to_process=None,
22
+ config_path="config/config.yaml",
23
+ force_regenerate=False
24
+ ):
25
+ """
26
+ Precompute all cache files using the EXACT same methods as app.py.
27
+ This follows the exact flow: load_task β†’ update_task_display β†’ run_visualization
28
+ """
29
+
30
+ if models_to_test is None:
31
+ models_to_test = [
32
+ 'gabrielloiseau/LUAR-MUD-sentence-transformers',
33
+ 'gabrielloiseau/LUAR-CRUD-sentence-transformers',
34
+ 'miladalsh/light-luar',
35
+ 'AnnaWegmann/Style-Embedding'
36
+ ]
37
+
38
+ print("=" * 60)
39
+ print("CACHE PRECOMPUTATION STARTED")
40
+ print(f"Timestamp: {datetime.now()}")
41
+ print(f"Models to test: {len(models_to_test)}")
42
+ print("=" * 60)
43
+
44
+ # Load configuration and instances EXACTLY like app.py
45
+ cfg = load_config(config_path)
46
+ print(f"Configuration loaded from {config_path}")
47
+ print(f"config : \n{cfg}")
48
+ instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
49
+ interp = load_interp_space(cfg)
50
+ clustered_authors_df = interp['clustered_authors_df']
51
+
52
+ if instances_to_process is None:
53
+ instances_to_process = instance_ids
54
+
55
+ print(f"Processing {len(instances_to_process)} instances with {len(models_to_test)} models")
56
+
57
+ total_combinations = len(models_to_test) * len(instances_to_process)
58
+ current_combination = 0
59
+
60
+ cache_stats = {
61
+ 'embeddings_generated': 0,
62
+ 'tsne_computed': 0,
63
+ 'regions_computed': 0,
64
+ 'errors': []
65
+ }
66
+
67
+ for model_name in models_to_test:
68
+ print(f"\n{'=' * 40}")
69
+ print(f"PROCESSING MODEL: {model_name}")
70
+ print(f"{'=' * 40}")
71
+
72
+ for instance_id in tqdm(instances_to_process, desc=f"Processing instances for {model_name.split('/')[-1]}"):
73
+ current_combination += 1
74
+ try:
75
+ print(f"\n[{current_combination}/{total_combinations}] Processing Instance {instance_id}")
76
+
77
+ # STEP 1: Replicate the exact flow from load_button.click()
78
+ print(" β†’ Replicating load_button.click() flow...")
79
+
80
+ # Create ground truth (using placeholder since we're caching)
81
+ ground_truth_author = None # Will be determined by the instance data
82
+
83
+ # Call update_task_display EXACTLY like app.py does
84
+ task_results = update_task_display(
85
+ mode="Predefined HRS Task", # Always use predefined for caching
86
+ iid=f"Task {instance_id}",
87
+ instances=instances,
88
+ background_df=clustered_authors_df,
89
+ mystery_file=None, # Not used for predefined
90
+ cand1_file=None, # Not used for predefined
91
+ cand2_file=None, # Not used for predefined
92
+ cand3_file=None, # Not used for predefined
93
+ true_author=ground_truth_author,
94
+ model_radio=model_name,
95
+ custom_model_input=""
96
+ )
97
+
98
+ # Extract the results exactly like app.py expects
99
+ (header_html, mystery_html, c0_html, c1_html, c2_html,
100
+ mystery_state, c0_state, c1_state, c2_state,
101
+ task_authors_embeddings_df, background_authors_embeddings_df,
102
+ predicted_author, ground_truth_author) = task_results
103
+
104
+ print(f" βœ“ Embeddings generated for {len(task_authors_embeddings_df)} task authors")
105
+ print(f" βœ“ Background embeddings: {len(background_authors_embeddings_df)} authors")
106
+ cache_stats['embeddings_generated'] += 1
107
+
108
+ # STEP 2: Replicate the exact flow from run_btn.click()
109
+ print(" β†’ Replicating run_btn.click() flow...")
110
+
111
+ # Call visualize_clusters_plotly EXACTLY like app.py does
112
+ viz_results = visualize_clusters_plotly(
113
+ iid=int(instance_id),
114
+ cfg=cfg,
115
+ instances=instances,
116
+ model_radio=model_name,
117
+ custom_model_input="",
118
+ task_authors_df=task_authors_embeddings_df,
119
+ background_authors_embeddings_df=background_authors_embeddings_df,
120
+ pred_idx=predicted_author,
121
+ gt_idx=ground_truth_author
122
+ )
123
+
124
+ # Extract results exactly like app.py expects
125
+ (fig, style_names, bg_proj, bg_ids, bg_authors_df,
126
+ precomputed_regions_state, precomputed_regions_radio) = viz_results
127
+
128
+ print(f" βœ“ t-SNE projection computed")
129
+ print(f" βœ“ Precomputed regions generated")
130
+ cache_stats['tsne_computed'] += 1
131
+ cache_stats['regions_computed'] += 1
132
+
133
+ print(f" βœ“ Instance {instance_id} with model {model_name} completed successfully")
134
+
135
+ except Exception as e:
136
+ error_msg = f"Error processing instance {instance_id} with model {model_name}: {str(e)}"
137
+ print(f" βœ— {error_msg}")
138
+ cache_stats['errors'].append(error_msg)
139
+ import traceback
140
+ traceback.print_exc()
141
+ continue
142
+
143
+ # Print final statistics
144
+ print("\n" + "=" * 60)
145
+ print("CACHE PRECOMPUTATION COMPLETED")
146
+ print("=" * 60)
147
+ print(f"Embeddings generated: {cache_stats['embeddings_generated']}")
148
+ print(f"t-SNE projections computed: {cache_stats['tsne_computed']}")
149
+ print(f"Region sets computed: {cache_stats['regions_computed']}")
150
+ print(f"Errors encountered: {len(cache_stats['errors'])}")
151
+
152
+ if cache_stats['errors']:
153
+ print("\nERROR DETAILS:")
154
+ for error in cache_stats['errors']:
155
+ print(f" - {error}")
156
+
157
+ return cache_stats
158
+
159
+ # Import the exact functions your app uses
160
+ from utils.visualizations import visualize_clusters_plotly
161
+
162
+ if __name__ == "__main__":
163
+ # Test with a small subset first
164
+ instances=[i for i in range(2)] # First 2 instances for testing
165
+ cache_stats = precompute_all_caches(
166
+ models_to_test=[
167
+ 'gabrielloiseau/LUAR-MUD-sentence-transformers'
168
+ ],
169
+ instances_to_process=instances,
170
+ force_regenerate=False
171
+ )
172
+
173
+ print(f"\nCache precomputation completed with {len(cache_stats['errors'])} errors.")