Spaces:
Running
Running
| import os | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.decomposition import PCA | |
| from sklearn.metrics import silhouette_score | |
| from torch.utils.data import DistributedSampler | |
| from torchvision.datasets import ImageFolder | |
| PLOTS_DIR = 'plots' | |
| os.makedirs(PLOTS_DIR, exist_ok=True) | |
| # 1. Reconstruct dataset file names matching the DistributedSampler order | |
| print("Reconstructing dataset file list and sampler indices...") | |
| dataset = ImageFolder('data/images/val') | |
| sampler = DistributedSampler(dataset, num_replicas=1, rank=0) | |
| indices = list(sampler) | |
| # Get the list of file names in the exact sampler order | |
| sampler_filenames = [dataset.imgs[idx][0] for idx in indices] | |
| # Convert absolute/relative paths to match the parquet file_name column (e.g. val/class/file.jpg) | |
| # dataset.imgs paths look like: 'data/images/val/03368_.../file.jpg' | |
| cleaned_filenames = [] | |
| for p in sampler_filenames: | |
| parts = p.split(os.sep) | |
| # Find where 'val' starts | |
| if 'val' in parts: | |
| val_idx = parts.index('val') | |
| cleaned_filenames.append('/'.join(parts[val_idx:])) | |
| else: | |
| cleaned_filenames.append(p) | |
| # 2. Load metadata and align | |
| print("Loading and aligning metadata...") | |
| df = pd.read_parquet('metadata/inat_world_model_master.parquet') | |
| df_val = df[df['split'] == 'val'].copy() | |
| df_val.set_index('file_name', inplace=True) | |
| # Reorder metadata to match the sampler order | |
| aligned_meta = [] | |
| for fname in cleaned_filenames: | |
| if fname in df_val.index: | |
| # If there are duplicates, take the first one | |
| row = df_val.loc[fname] | |
| if isinstance(row, pd.DataFrame): | |
| row = row.iloc[0] | |
| aligned_meta.append(row.to_dict()) | |
| else: | |
| # Fallback empty dict | |
| aligned_meta.append({}) | |
| df_aligned = pd.DataFrame(aligned_meta) | |
| # Drop rows where we have missing lifestyle/trophic level metadata | |
| valid_mask = df_aligned['Primary.Lifestyle'].notna() & df_aligned['Trophic.Level'].notna() | |
| print(f"Aligned metadata rows: {len(df_aligned)}, Valid rows for analysis: {valid_mask.sum()}") | |
| # 3. Load embeddings | |
| models_config = { | |
| 'DINOv3': { | |
| 'path': 'outputs/vit_large_patch16_dinov3/macro_val_30percent_rank_0.pt', | |
| 'safe_name': 'vit_large_patch16_dinov3' | |
| }, | |
| 'SigLIP2': { | |
| 'path': 'outputs/google_siglip2_so400m_patch14_384/macro_val_30percent_rank_0.pt', | |
| 'safe_name': 'google_siglip2_so400m_patch14_384' | |
| }, | |
| 'BioCLIP2': { | |
| 'path': 'outputs/hf_hub:imageomics_bioclip_2.5_vith14/macro_val_30percent_rank_0.pt', | |
| 'safe_name': 'hf_hub:imageomics_bioclip_2.5_vith14' | |
| } | |
| } | |
| embeddings_data = {} | |
| for name, cfg in models_config.items(): | |
| if os.path.exists(cfg['path']): | |
| print(f"Loading {name} embeddings from {cfg['path']}...") | |
| data = torch.load(cfg['path'], map_location='cpu') | |
| embeddings = data['embeddings'].numpy() | |
| embeddings_data[name] = embeddings | |
| else: | |
| print(f"Warning: {cfg['path']} does not exist!") | |
| # Subset metadata to match the length of the 30% macro files | |
| # The 30% macro file contains exactly the first 30% of the validation set batches | |
| n_samples = None | |
| for name, embs in embeddings_data.items(): | |
| n_samples = len(embs) | |
| break | |
| if n_samples is None: | |
| print("Error: No embeddings loaded. Exiting.") | |
| exit(1) | |
| print(f"Analyzing first {n_samples} samples...") | |
| df_subset = df_aligned.iloc[:n_samples].copy() | |
| valid_mask_sub = valid_mask.iloc[:n_samples].values | |
| # Filter out rows with invalid metadata from both embeddings and metadata | |
| df_filtered = df_subset[valid_mask_sub].copy() | |
| # Apply PCA and calculate Silhouette Scores | |
| results = {} | |
| for name, embs in embeddings_data.items(): | |
| embs_filtered = embs[valid_mask_sub] | |
| # 2D PCA Projection | |
| print(f"Running PCA for {name}...") | |
| pca = PCA(n_components=2, random_state=42) | |
| coords = pca.fit_transform(embs_filtered) | |
| # Calculate Silhouette Scores on the full-dimensional embeddings | |
| print(f"Calculating Silhouette Scores for {name}...") | |
| sil_lifestyle = silhouette_score(embs_filtered, df_filtered['Primary.Lifestyle'].astype(str)) | |
| sil_trophic = silhouette_score(embs_filtered, df_filtered['Trophic.Level'].astype(str)) | |
| results[name] = { | |
| 'coords': coords, | |
| 'sil_lifestyle': sil_lifestyle, | |
| 'sil_trophic': sil_trophic | |
| } | |
| # 4. Generate the 2x3 Plot Grid | |
| sns.set_theme(style="whitegrid") | |
| fig, axes = plt.subplots(2, 3, figsize=(18, 11)) | |
| for col_idx, model_name in enumerate(results.keys()): | |
| res = results[model_name] | |
| coords = res['coords'] | |
| # Row 1: Primary Lifestyle | |
| ax1 = axes[0, col_idx] | |
| sns.scatterplot( | |
| x=coords[:, 0], y=coords[:, 1], | |
| hue=df_filtered['Primary.Lifestyle'], | |
| palette='Set1', alpha=0.5, s=8, ax=ax1, legend=(col_idx == 2) | |
| ) | |
| ax1.set_title(f"{model_name}\nLifestyle Silhouette: {res['sil_lifestyle']:.3f}") | |
| if col_idx == 2: | |
| ax1.legend(title='Primary Lifestyle', bbox_to_anchor=(1.05, 1), loc='upper left') | |
| ax1.set_xlabel("PCA 1") | |
| ax1.set_ylabel("PCA 2") | |
| # Row 2: Trophic Level | |
| ax2 = axes[1, col_idx] | |
| sns.scatterplot( | |
| x=coords[:, 0], y=coords[:, 1], | |
| hue=df_filtered['Trophic.Level'], | |
| palette='Dark2', alpha=0.5, s=8, ax=ax2, legend=(col_idx == 2) | |
| ) | |
| ax2.set_title(f"Trophic Level Silhouette: {res['sil_trophic']:.3f}") | |
| if col_idx == 2: | |
| ax2.legend(title='Trophic Level', bbox_to_anchor=(1.05, 1), loc='upper left') | |
| ax2.set_xlabel("PCA 1") | |
| ax2.set_ylabel("PCA 2") | |
| plt.suptitle("Model Feature Space Representation vs. Ecological Metadata", y=0.98, fontsize=16, fontweight='bold') | |
| plt.tight_layout() | |
| fig.savefig(os.path.join(PLOTS_DIR, 'feature_representation_comparison.png'), bbox_inches='tight', dpi=200) | |
| plt.close(fig) | |
| print("\n=== SILHOUETTE SCORE SUMMARY ===") | |
| for name, res in results.items(): | |
| print(f"{name}:") | |
| print(f" Primary Lifestyle Silhouette Score: {res['sil_lifestyle']:.4f}") | |
| print(f" Trophic Level Silhouette Score: {res['sil_trophic']:.4f}") | |
| print("\nPlots generated successfully at plots/feature_representation_comparison.png") | |