event_retrieval / generate_histogram.py
sanskar753's picture
Upload folder using huggingface_hub
02d3a85 verified
import os
import torch
import torch.nn.functional as F
import pickle
from tqdm import tqdm
import random
import numpy as np
import matplotlib.pyplot as plt
# ===============================================================
# --- CONFIGURATION: SET YOUR ANALYSIS PARAMETERS HERE ---
# ===============================================================
# --- Task: Analyze TEXT Embeddings ---
# Set POOL_MODE to True because text embeddings are variable-length.
INPUT_FILE_PATH = "/media/RTCIN7TBDriveA/Interns/RDT2/gte3kor/complex_image_search-main/results/coreset/embeddings/text_embeddings.pkl"
OUTPUT_PLOT_PATH = "text_distance_histogram.png"
POOL_MODE = True # <<< IMPORTANT: True for Text
PLOT_TITLE = "Distance Distribution of Text Embeddings"
# --- Task: Analyze IMAGE Embeddings ---
# Uncomment the block below and comment out the Text block above.
# Set POOL_MODE to False because image embeddings are already fixed-size.
# INPUT_FILE_PATH = "/media/RTCIN7TBDriveA/Interns/RDT2/gte3kor/complex_image_search-main/results/coreset/embeddings/image_embeddings.pkl"
# OUTPUT_PLOT_PATH = "image_distance_histogram.png"
# POOL_MODE = False # <<< IMPORTANT: False for Images
# PLOT_TITLE = "Distance Distribution of Image Embeddings"
# --- General Settings ---
SAMPLE_SIZE = 2000 # Number of embeddings to sample for the analysis.
DEVICE = 'cuda:2' # Device to use for computation ('cuda' or 'cpu').
# ===============================================================
def load_embeddings(file_path):
"""Loads a pickle file and returns its content."""
print(f"Loading embeddings from '{file_path}'...")
if not os.path.exists(file_path):
print(f"Error: File not found at '{file_path}'")
return None
with open(file_path, 'rb') as f:
return pickle.load(f)
def main():
"""Main function to perform embedding analysis."""
device = torch.device(DEVICE if torch.cuda.is_available() and DEVICE == 'cuda' else 'cpu')
print("=" * 60)
print("--- Running Embedding Distance Diagnosis ---")
print(f" Input File: {INPUT_FILE_PATH}")
print(f" Output Plot: {OUTPUT_PLOT_PATH}")
print(f" Pooling Mode: {'Enabled' if POOL_MODE else 'Disabled'}")
print(f" Sample Size: {SAMPLE_SIZE}")
print(f" Device: {device}")
print("=" * 60)
# 1. Load the embeddings
original_embeddings_dict = load_embeddings(INPUT_FILE_PATH)
if not original_embeddings_dict:
print("No embeddings loaded. Exiting.")
return
# 2. Pre-process embeddings based on the POOL_MODE
print("Pre-processing embeddings...")
processed_embeddings = {}
for key, emb in tqdm(original_embeddings_dict.items(), desc="Processing Embeddings"):
if isinstance(emb, np.ndarray):
emb_tensor = torch.from_numpy(emb.flatten()).float()
else:
emb_tensor = emb.float()
# Conditional Pooling
if POOL_MODE and emb_tensor.dim() > 1:
emb_tensor = emb_tensor.mean(dim=0)
# Always L2-normalize for fair comparison
processed_embeddings[key] = F.normalize(emb_tensor, p=2, dim=0)
# 3. Take a random sample
if len(processed_embeddings) < SAMPLE_SIZE:
print(f"Warning: Dataset size ({len(processed_embeddings)}) is smaller than sample size. Using all data.")
keys_to_sample = list(processed_embeddings.keys())
else:
keys_to_sample = random.sample(list(processed_embeddings.keys()), SAMPLE_SIZE)
sample_tensor = torch.stack([processed_embeddings[key] for key in keys_to_sample]).to(device)
# 4. Calculate pairwise distances
print(f"\nCalculating pairwise distances for {len(keys_to_sample)} samples...")
distances = torch.cdist(sample_tensor, sample_tensor)
distances_flat = distances[torch.triu(torch.ones_like(distances), diagonal=1).bool()].cpu().numpy()
# 5. Analyze and print statistics
min_dist = np.min(distances_flat)
max_dist = np.max(distances_flat)
mean_dist = np.mean(distances_flat)
median_dist = np.median(distances_flat)
print("\n--- DISTANCE ANALYSIS (from sample) ---")
print(f"Minimum Distance: {min_dist:.4f}")
print(f"Maximum Distance: {max_dist:.4f}")
print(f"Mean Distance: {mean_dist:.4f}")
print(f"Median Distance: {median_dist:.4f}")
print("---------------------------------------")
# 6. Plot and save the histogram
print("Generating distance histogram...")
plt.figure(figsize=(10, 6))
plt.hist(distances_flat, bins=100, color='skyblue', edgecolor='black')
plt.title(PLOT_TITLE)
plt.xlabel('L2 Distance (after normalization)')
plt.ylabel('Frequency')
plt.axvline(mean_dist, color='r', linestyle='dashed', linewidth=2, label=f'Mean: {mean_dist:.2f}')
plt.axvline(median_dist, color='g', linestyle='dashed', linewidth=2, label=f'Median: {median_dist:.2f}')
plt.legend()
output_dir = os.path.dirname(OUTPUT_PLOT_PATH)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
plt.savefig(OUTPUT_PLOT_PATH)
print(f"✅ Histogram saved to '{OUTPUT_PLOT_PATH}'")
if __name__ == '__main__':
main()