Anisha Bhatnagar commited on
Commit
08f53a7
·
1 Parent(s): 258c7f3

fixed logging; ensured Reddit data files are correctly downloaded

Browse files
Files changed (1) hide show
  1. precompute_caches.py +24 -13
precompute_caches.py CHANGED
@@ -8,18 +8,30 @@ import pandas as pd
8
  from datetime import datetime
9
  import yaml
10
 
11
- # Import your actual modules exactly as app.py does
12
- from utils.visualizations import get_instances, load_interp_space, trigger_precomputed_region, handle_zoom_with_retries
13
- from utils.ui import update_task_display
14
 
15
  def load_config(path="config/config.yaml"):
16
  with open(path, "r") as f:
17
  return yaml.safe_load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def precompute_all_caches(
20
  models_to_test=None,
21
  instances_to_process=None,
22
- config_path="config/config.yaml"
23
  ):
24
  """
25
  Precompute all cache files using the EXACT same methods as app.py.
@@ -34,16 +46,12 @@ def precompute_all_caches(
34
  'AnnaWegmann/Style-Embedding'
35
  ]
36
 
37
- print("=" * 60)
38
  print("CACHE PRECOMPUTATION STARTED")
39
  print(f"Timestamp: {datetime.now()}")
40
  print(f"Models to test: {len(models_to_test)}")
41
  print("=" * 60)
42
-
43
- # Load configuration and instances EXACTLY like app.py
44
- cfg = load_config(config_path)
45
- print(f"Configuration loaded from {config_path}")
46
- print(f"config : \n{cfg}")
47
  instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
48
  # interp = load_interp_space(cfg)
49
  # clustered_authors_df = interp['clustered_authors_df']
@@ -72,7 +80,9 @@ def precompute_all_caches(
72
  for instance_id in tqdm(instances_to_process, desc=f"Processing instances for {model_name.split('/')[-1]}"):
73
  current_combination += 1
74
  try:
75
- print(f"\n[{current_combination}/{total_combinations}] Processing Instance {instance_id}")
 
 
76
 
77
  # STEP 1: Replicate the exact flow from load_button.click()
78
  print(" → Replicating load_button.click() flow...")
@@ -82,7 +92,7 @@ def precompute_all_caches(
82
 
83
  # Call update_task_display EXACTLY like app.py does
84
  task_results = update_task_display(
85
- mode="Predefined HRS Task", # Always use predefined for caching
86
  iid=f"Task {instance_id}",
87
  instances=instances,
88
  background_df=clustered_authors_df,
@@ -137,6 +147,7 @@ def precompute_all_caches(
137
  if precomputed_regions_state:
138
  regions_dict = ast.literal_eval(precomputed_regions_state)
139
  test_regions = list(regions_dict.keys())
 
140
 
141
  for region_name in test_regions:
142
  try:
@@ -194,7 +205,7 @@ from utils.visualizations import visualize_clusters_plotly
194
 
195
  if __name__ == "__main__":
196
  # Test with a small subset first
197
- instances=[i for i in range(10)] # First 20 instances for testing
198
  cache_stats = precompute_all_caches(
199
  models_to_test=[
200
  'AnnaWegmann/Style-Embedding'
 
8
  from datetime import datetime
9
  import yaml
10
 
11
+ CONFIG_PATH="config/config.yaml"
 
 
12
 
13
  def load_config(path="config/config.yaml"):
14
  with open(path, "r") as f:
15
  return yaml.safe_load(f)
16
+
17
+ # Load configuration and instances EXACTLY like app.py
18
+ cfg = load_config(CONFIG_PATH)
19
+ print(f"Configuration loaded from {CONFIG_PATH}")
20
+ print(f"config : \n{cfg}")
21
+
22
+ # Import your actual modules exactly as app.py does
23
+ from utils.file_download import download_file_override
24
+
25
+ download_file_override(cfg.get('background_authors_df_url'), cfg.get('background_authors_df_path'))
26
+ download_file_override(cfg.get('instances_to_explain_url'), cfg.get('instances_to_explain_path'))
27
+ download_file_override(cfg.get('gram2vec_feats_url'), cfg.get('gram2vec_feats_path'))
28
+
29
+ from utils.visualizations import get_instances, trigger_precomputed_region, handle_zoom_with_retries
30
+ from utils.ui import update_task_display
31
 
32
  def precompute_all_caches(
33
  models_to_test=None,
34
  instances_to_process=None,
 
35
  ):
36
  """
37
  Precompute all cache files using the EXACT same methods as app.py.
 
46
  'AnnaWegmann/Style-Embedding'
47
  ]
48
 
49
+ print("\n\n" + "=" * 60)
50
  print("CACHE PRECOMPUTATION STARTED")
51
  print(f"Timestamp: {datetime.now()}")
52
  print(f"Models to test: {len(models_to_test)}")
53
  print("=" * 60)
54
+
 
 
 
 
55
  instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
56
  # interp = load_interp_space(cfg)
57
  # clustered_authors_df = interp['clustered_authors_df']
 
80
  for instance_id in tqdm(instances_to_process, desc=f"Processing instances for {model_name.split('/')[-1]}"):
81
  current_combination += 1
82
  try:
83
+ # print(f"\n\n\n[{current_combination}/{total_combinations}] Processing Instance {instance_id}")
84
+ print(f"\n\n\n\033[1m\033[93m>>> [{current_combination}/{total_combinations}] Processing Instance {instance_id} <<<\033[0m\n")
85
+
86
 
87
  # STEP 1: Replicate the exact flow from load_button.click()
88
  print(" → Replicating load_button.click() flow...")
 
92
 
93
  # Call update_task_display EXACTLY like app.py does
94
  task_results = update_task_display(
95
+ mode="Predefined Reddit Task", # Always use predefined for caching
96
  iid=f"Task {instance_id}",
97
  instances=instances,
98
  background_df=clustered_authors_df,
 
147
  if precomputed_regions_state:
148
  regions_dict = ast.literal_eval(precomputed_regions_state)
149
  test_regions = list(regions_dict.keys())
150
+ print(f" → Found {len(test_regions)} regions to test")
151
 
152
  for region_name in test_regions:
153
  try:
 
205
 
206
  if __name__ == "__main__":
207
  # Test with a small subset first
208
+ instances=[i for i in range(20)] # First 10 instances for testing
209
  cache_stats = precompute_all_caches(
210
  models_to_test=[
211
  'AnnaWegmann/Style-Embedding'