Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| import pickle | |
| from tqdm import tqdm | |
| import random | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # =============================================================== | |
| # --- CONFIGURATION: SET YOUR ANALYSIS PARAMETERS HERE --- | |
| # =============================================================== | |
| # --- Task: Analyze TEXT Embeddings --- | |
| # Set POOL_MODE to True because text embeddings are variable-length. | |
| INPUT_FILE_PATH = "/media/RTCIN7TBDriveA/Interns/RDT2/gte3kor/complex_image_search-main/results/coreset/embeddings/text_embeddings.pkl" | |
| OUTPUT_PLOT_PATH = "text_distance_histogram.png" | |
| POOL_MODE = True # <<< IMPORTANT: True for Text | |
| PLOT_TITLE = "Distance Distribution of Text Embeddings" | |
| # --- Task: Analyze IMAGE Embeddings --- | |
| # Uncomment the block below and comment out the Text block above. | |
| # Set POOL_MODE to False because image embeddings are already fixed-size. | |
| # INPUT_FILE_PATH = "/media/RTCIN7TBDriveA/Interns/RDT2/gte3kor/complex_image_search-main/results/coreset/embeddings/image_embeddings.pkl" | |
| # OUTPUT_PLOT_PATH = "image_distance_histogram.png" | |
| # POOL_MODE = False # <<< IMPORTANT: False for Images | |
| # PLOT_TITLE = "Distance Distribution of Image Embeddings" | |
| # --- General Settings --- | |
| SAMPLE_SIZE = 2000 # Number of embeddings to sample for the analysis. | |
| DEVICE = 'cuda:2' # Device to use for computation ('cuda' or 'cpu'). | |
| # =============================================================== | |
| def load_embeddings(file_path): | |
| """Loads a pickle file and returns its content.""" | |
| print(f"Loading embeddings from '{file_path}'...") | |
| if not os.path.exists(file_path): | |
| print(f"Error: File not found at '{file_path}'") | |
| return None | |
| with open(file_path, 'rb') as f: | |
| return pickle.load(f) | |
| def main(): | |
| """Main function to perform embedding analysis.""" | |
| device = torch.device(DEVICE if torch.cuda.is_available() and DEVICE == 'cuda' else 'cpu') | |
| print("=" * 60) | |
| print("--- Running Embedding Distance Diagnosis ---") | |
| print(f" Input File: {INPUT_FILE_PATH}") | |
| print(f" Output Plot: {OUTPUT_PLOT_PATH}") | |
| print(f" Pooling Mode: {'Enabled' if POOL_MODE else 'Disabled'}") | |
| print(f" Sample Size: {SAMPLE_SIZE}") | |
| print(f" Device: {device}") | |
| print("=" * 60) | |
| # 1. Load the embeddings | |
| original_embeddings_dict = load_embeddings(INPUT_FILE_PATH) | |
| if not original_embeddings_dict: | |
| print("No embeddings loaded. Exiting.") | |
| return | |
| # 2. Pre-process embeddings based on the POOL_MODE | |
| print("Pre-processing embeddings...") | |
| processed_embeddings = {} | |
| for key, emb in tqdm(original_embeddings_dict.items(), desc="Processing Embeddings"): | |
| if isinstance(emb, np.ndarray): | |
| emb_tensor = torch.from_numpy(emb.flatten()).float() | |
| else: | |
| emb_tensor = emb.float() | |
| # Conditional Pooling | |
| if POOL_MODE and emb_tensor.dim() > 1: | |
| emb_tensor = emb_tensor.mean(dim=0) | |
| # Always L2-normalize for fair comparison | |
| processed_embeddings[key] = F.normalize(emb_tensor, p=2, dim=0) | |
| # 3. Take a random sample | |
| if len(processed_embeddings) < SAMPLE_SIZE: | |
| print(f"Warning: Dataset size ({len(processed_embeddings)}) is smaller than sample size. Using all data.") | |
| keys_to_sample = list(processed_embeddings.keys()) | |
| else: | |
| keys_to_sample = random.sample(list(processed_embeddings.keys()), SAMPLE_SIZE) | |
| sample_tensor = torch.stack([processed_embeddings[key] for key in keys_to_sample]).to(device) | |
| # 4. Calculate pairwise distances | |
| print(f"\nCalculating pairwise distances for {len(keys_to_sample)} samples...") | |
| distances = torch.cdist(sample_tensor, sample_tensor) | |
| distances_flat = distances[torch.triu(torch.ones_like(distances), diagonal=1).bool()].cpu().numpy() | |
| # 5. Analyze and print statistics | |
| min_dist = np.min(distances_flat) | |
| max_dist = np.max(distances_flat) | |
| mean_dist = np.mean(distances_flat) | |
| median_dist = np.median(distances_flat) | |
| print("\n--- DISTANCE ANALYSIS (from sample) ---") | |
| print(f"Minimum Distance: {min_dist:.4f}") | |
| print(f"Maximum Distance: {max_dist:.4f}") | |
| print(f"Mean Distance: {mean_dist:.4f}") | |
| print(f"Median Distance: {median_dist:.4f}") | |
| print("---------------------------------------") | |
| # 6. Plot and save the histogram | |
| print("Generating distance histogram...") | |
| plt.figure(figsize=(10, 6)) | |
| plt.hist(distances_flat, bins=100, color='skyblue', edgecolor='black') | |
| plt.title(PLOT_TITLE) | |
| plt.xlabel('L2 Distance (after normalization)') | |
| plt.ylabel('Frequency') | |
| plt.axvline(mean_dist, color='r', linestyle='dashed', linewidth=2, label=f'Mean: {mean_dist:.2f}') | |
| plt.axvline(median_dist, color='g', linestyle='dashed', linewidth=2, label=f'Median: {median_dist:.2f}') | |
| plt.legend() | |
| output_dir = os.path.dirname(OUTPUT_PLOT_PATH) | |
| if output_dir: | |
| os.makedirs(output_dir, exist_ok=True) | |
| plt.savefig(OUTPUT_PLOT_PATH) | |
| print(f"✅ Histogram saved to '{OUTPUT_PLOT_PATH}'") | |
| if __name__ == '__main__': | |
| main() |