Anisha Bhatnagar
reducing number of precomputed regions; updating cache; data url; g2v relaxed filtering (peter)
a1e49f6
| import ast | |
| import os | |
| import json | |
| import pickle | |
| import numpy as np | |
| from tqdm import tqdm | |
| import pandas as pd | |
| from datetime import datetime | |
| import yaml | |
| # Import your actual modules exactly as app.py does | |
| from utils.visualizations import get_instances, load_interp_space, trigger_precomputed_region, handle_zoom_with_retries | |
| from utils.ui import update_task_display | |
| def load_config(path="config/config.yaml"): | |
| with open(path, "r") as f: | |
| return yaml.safe_load(f) | |
| def precompute_all_caches( | |
| models_to_test=None, | |
| instances_to_process=None, | |
| config_path="config/config.yaml" | |
| ): | |
| """ | |
| Precompute all cache files using the EXACT same methods as app.py. | |
| This follows the exact flow: load_task β update_task_display β run_visualization | |
| """ | |
| if models_to_test is None: | |
| models_to_test = [ | |
| 'gabrielloiseau/LUAR-MUD-sentence-transformers', | |
| 'gabrielloiseau/LUAR-CRUD-sentence-transformers', | |
| 'miladalsh/light-luar', | |
| 'AnnaWegmann/Style-Embedding' | |
| ] | |
| print("=" * 60) | |
| print("CACHE PRECOMPUTATION STARTED") | |
| print(f"Timestamp: {datetime.now()}") | |
| print(f"Models to test: {len(models_to_test)}") | |
| print("=" * 60) | |
| # Load configuration and instances EXACTLY like app.py | |
| cfg = load_config(config_path) | |
| print(f"Configuration loaded from {config_path}") | |
| print(f"config : \n{cfg}") | |
| instances, instance_ids = get_instances(cfg['instances_to_explain_path']) | |
| interp = load_interp_space(cfg) | |
| clustered_authors_df = interp['clustered_authors_df'] | |
| if instances_to_process is None: | |
| instances_to_process = instance_ids | |
| print(f"Processing {len(instances_to_process)} instances with {len(models_to_test)} models") | |
| total_combinations = len(models_to_test) * len(instances_to_process) | |
| current_combination = 0 | |
| cache_stats = { | |
| 'embeddings_generated': 0, | |
| 'tsne_computed': 0, | |
| 'regions_computed': 0, | |
| 'errors': [] | |
| } | |
| for model_name in models_to_test: | |
| print(f"\n{'=' * 40}") | |
| print(f"PROCESSING MODEL: {model_name}") | |
| print(f"{'=' * 40}") | |
| for instance_id in tqdm(instances_to_process, desc=f"Processing instances for {model_name.split('/')[-1]}"): | |
| current_combination += 1 | |
| try: | |
| print(f"\n[{current_combination}/{total_combinations}] Processing Instance {instance_id}") | |
| # STEP 1: Replicate the exact flow from load_button.click() | |
| print(" β Replicating load_button.click() flow...") | |
| # Create ground truth (using placeholder since we're caching) | |
| ground_truth_author = None # Will be determined by the instance data | |
| # Call update_task_display EXACTLY like app.py does | |
| task_results = update_task_display( | |
| mode="Predefined HRS Task", # Always use predefined for caching | |
| iid=f"Task {instance_id}", | |
| instances=instances, | |
| background_df=clustered_authors_df, | |
| mystery_file=None, # Not used for predefined | |
| cand1_file=None, # Not used for predefined | |
| cand2_file=None, # Not used for predefined | |
| cand3_file=None, # Not used for predefined | |
| true_author=ground_truth_author, | |
| model_radio=model_name, | |
| custom_model_input="" | |
| ) | |
| # Extract the results exactly like app.py expects | |
| (header_html, mystery_html, c0_html, c1_html, c2_html, | |
| mystery_state, c0_state, c1_state, c2_state, | |
| task_authors_embeddings_df, background_authors_embeddings_df, | |
| predicted_author, ground_truth_author) = task_results | |
| print(f" β Embeddings generated for {len(task_authors_embeddings_df)} task authors") | |
| print(f" β Background embeddings: {len(background_authors_embeddings_df)} authors") | |
| cache_stats['embeddings_generated'] += 1 | |
| # STEP 2: Replicate the exact flow from run_btn.click() | |
| print(" β Replicating run_btn.click() flow...") | |
| # Call visualize_clusters_plotly EXACTLY like app.py does | |
| viz_results = visualize_clusters_plotly( | |
| iid=int(instance_id), | |
| cfg=cfg, | |
| instances=instances, | |
| model_radio=model_name, | |
| custom_model_input="", | |
| task_authors_df=task_authors_embeddings_df, | |
| background_authors_embeddings_df=background_authors_embeddings_df, | |
| pred_idx=predicted_author, | |
| gt_idx=ground_truth_author | |
| ) | |
| # Extract results exactly like app.py expects | |
| (fig, style_names, bg_proj, bg_ids, bg_authors_df, | |
| precomputed_regions_state, precomputed_regions_radio) = viz_results | |
| print(f" β t-SNE projection computed") | |
| print(f" β Precomputed regions generated") | |
| cache_stats['tsne_computed'] += 1 | |
| cache_stats['regions_computed'] += 1 | |
| print(f" β Instance {instance_id} with model {model_name} completed successfully") | |
| print(" β Testing region zoom simulation...") | |
| if precomputed_regions_state: | |
| regions_dict = ast.literal_eval(precomputed_regions_state) | |
| test_regions = list(regions_dict.keys()) | |
| for region_name in test_regions: | |
| try: | |
| print(f" β Testing region: {region_name}") | |
| # Step 3a: Simulate region selection (trigger_precomputed_region) | |
| zoom_payload = trigger_precomputed_region(region_name, regions_dict) | |
| if zoom_payload: # Only proceed if we got a valid zoom payload | |
| # Step 3b: Simulate axis_ranges.change() (handle_zoom_with_retries) | |
| zoom_results = handle_zoom_with_retries( | |
| event_json=zoom_payload, | |
| bg_proj=bg_proj, | |
| bg_lbls=bg_ids, | |
| clustered_authors_df=background_authors_embeddings_df, | |
| task_authors_df=task_authors_embeddings_df | |
| ) | |
| # Extract results like app.py does | |
| (features_rb_update, gram2vec_rb_update, llm_style_feats_analysis, | |
| feature_list_state, visible_zoomed_authors) = zoom_results | |
| print(f" β LLM features cached for region: {region_name}") | |
| except Exception as e: | |
| print(f" β Failed to cache features for region {region_name}: {e}") | |
| # Continue with other regions even if one fails | |
| continue | |
| except Exception as e: | |
| error_msg = f"Error processing instance {instance_id} with model {model_name}: {str(e)}" | |
| print(f" β {error_msg}") | |
| cache_stats['errors'].append(error_msg) | |
| import traceback | |
| traceback.print_exc() | |
| continue | |
| # Print final statistics | |
| print("\n" + "=" * 60) | |
| print("CACHE PRECOMPUTATION COMPLETED") | |
| print("=" * 60) | |
| print(f"Embeddings generated: {cache_stats['embeddings_generated']}") | |
| print(f"t-SNE projections computed: {cache_stats['tsne_computed']}") | |
| print(f"Region sets computed: {cache_stats['regions_computed']}") | |
| print(f"Errors encountered: {len(cache_stats['errors'])}") | |
| if cache_stats['errors']: | |
| print("\nERROR DETAILS:") | |
| for error in cache_stats['errors']: | |
| print(f" - {error}") | |
| return cache_stats | |
| # Import the exact functions your app uses | |
| from utils.visualizations import visualize_clusters_plotly | |
| if __name__ == "__main__": | |
| # Test with a small subset first | |
| instances=[i for i in range(20)] # First 20 instances for testing | |
| cache_stats = precompute_all_caches( | |
| models_to_test=[ | |
| 'gabrielloiseau/LUAR-MUD-sentence-transformers' | |
| ], | |
| instances_to_process=instances | |
| ) | |
| print(f"\nCache precomputation completed with {len(cache_stats['errors'])} errors.") |