import os import torch import torch.nn.functional as F import pickle from tqdm import tqdm import random import numpy as np import matplotlib.pyplot as plt # =============================================================== # --- CONFIGURATION: SET YOUR ANALYSIS PARAMETERS HERE --- # =============================================================== # --- Task: Analyze TEXT Embeddings --- # Set POOL_MODE to True because text embeddings are variable-length. INPUT_FILE_PATH = "/media/RTCIN7TBDriveA/Interns/RDT2/gte3kor/complex_image_search-main/results/coreset/embeddings/text_embeddings.pkl" OUTPUT_PLOT_PATH = "text_distance_histogram.png" POOL_MODE = True # <<< IMPORTANT: True for Text PLOT_TITLE = "Distance Distribution of Text Embeddings" # --- Task: Analyze IMAGE Embeddings --- # Uncomment the block below and comment out the Text block above. # Set POOL_MODE to False because image embeddings are already fixed-size. # INPUT_FILE_PATH = "/media/RTCIN7TBDriveA/Interns/RDT2/gte3kor/complex_image_search-main/results/coreset/embeddings/image_embeddings.pkl" # OUTPUT_PLOT_PATH = "image_distance_histogram.png" # POOL_MODE = False # <<< IMPORTANT: False for Images # PLOT_TITLE = "Distance Distribution of Image Embeddings" # --- General Settings --- SAMPLE_SIZE = 2000 # Number of embeddings to sample for the analysis. DEVICE = 'cuda:2' # Device to use for computation ('cuda' or 'cpu'). # =============================================================== def load_embeddings(file_path): """Loads a pickle file and returns its content.""" print(f"Loading embeddings from '{file_path}'...") if not os.path.exists(file_path): print(f"Error: File not found at '{file_path}'") return None with open(file_path, 'rb') as f: return pickle.load(f) def main(): """Main function to perform embedding analysis.""" device = torch.device(DEVICE if torch.cuda.is_available() and DEVICE == 'cuda' else 'cpu') print("=" * 60) print("--- Running Embedding Distance Diagnosis ---") print(f" Input File: {INPUT_FILE_PATH}") print(f" Output Plot: {OUTPUT_PLOT_PATH}") print(f" Pooling Mode: {'Enabled' if POOL_MODE else 'Disabled'}") print(f" Sample Size: {SAMPLE_SIZE}") print(f" Device: {device}") print("=" * 60) # 1. Load the embeddings original_embeddings_dict = load_embeddings(INPUT_FILE_PATH) if not original_embeddings_dict: print("No embeddings loaded. Exiting.") return # 2. Pre-process embeddings based on the POOL_MODE print("Pre-processing embeddings...") processed_embeddings = {} for key, emb in tqdm(original_embeddings_dict.items(), desc="Processing Embeddings"): if isinstance(emb, np.ndarray): emb_tensor = torch.from_numpy(emb.flatten()).float() else: emb_tensor = emb.float() # Conditional Pooling if POOL_MODE and emb_tensor.dim() > 1: emb_tensor = emb_tensor.mean(dim=0) # Always L2-normalize for fair comparison processed_embeddings[key] = F.normalize(emb_tensor, p=2, dim=0) # 3. Take a random sample if len(processed_embeddings) < SAMPLE_SIZE: print(f"Warning: Dataset size ({len(processed_embeddings)}) is smaller than sample size. Using all data.") keys_to_sample = list(processed_embeddings.keys()) else: keys_to_sample = random.sample(list(processed_embeddings.keys()), SAMPLE_SIZE) sample_tensor = torch.stack([processed_embeddings[key] for key in keys_to_sample]).to(device) # 4. Calculate pairwise distances print(f"\nCalculating pairwise distances for {len(keys_to_sample)} samples...") distances = torch.cdist(sample_tensor, sample_tensor) distances_flat = distances[torch.triu(torch.ones_like(distances), diagonal=1).bool()].cpu().numpy() # 5. Analyze and print statistics min_dist = np.min(distances_flat) max_dist = np.max(distances_flat) mean_dist = np.mean(distances_flat) median_dist = np.median(distances_flat) print("\n--- DISTANCE ANALYSIS (from sample) ---") print(f"Minimum Distance: {min_dist:.4f}") print(f"Maximum Distance: {max_dist:.4f}") print(f"Mean Distance: {mean_dist:.4f}") print(f"Median Distance: {median_dist:.4f}") print("---------------------------------------") # 6. Plot and save the histogram print("Generating distance histogram...") plt.figure(figsize=(10, 6)) plt.hist(distances_flat, bins=100, color='skyblue', edgecolor='black') plt.title(PLOT_TITLE) plt.xlabel('L2 Distance (after normalization)') plt.ylabel('Frequency') plt.axvline(mean_dist, color='r', linestyle='dashed', linewidth=2, label=f'Mean: {mean_dist:.2f}') plt.axvline(median_dist, color='g', linestyle='dashed', linewidth=2, label=f'Median: {median_dist:.2f}') plt.legend() output_dir = os.path.dirname(OUTPUT_PLOT_PATH) if output_dir: os.makedirs(output_dir, exist_ok=True) plt.savefig(OUTPUT_PLOT_PATH) print(f"✅ Histogram saved to '{OUTPUT_PLOT_PATH}'") if __name__ == '__main__': main()