import os, pickle import torch import numpy as np import pandas as pd from tqdm import tqdm from copy import deepcopy import matplotlib.pyplot as plt np.random.seed(1) def cosine_similarity(a, b): dot_product = np.dot(a, b.T) norm_a = np.linalg.norm(a) norm_b = np.linalg.norm(b, axis=1) return dot_product / (norm_a * norm_b) def optimized_cosine_matrix(query_features, gallery_features, chunk_size=50000, device='cuda'): """ Revised cosine similarity matrix computation """ # Ensure the inputs use floating point types (important!) query = torch.as_tensor(query_features, dtype=torch.float32) gallery = torch.as_tensor(gallery_features, dtype=torch.float32) # Move tensors to the target device synchronously to avoid asynchronous errors query = query.to(device) gallery = gallery.to(device) # Safely normalize inputs to avoid zero vectors def safe_normalize(x): norm = torch.norm(x, p=2, dim=1, keepdim=True) return x / torch.where(norm == 0, torch.ones_like(norm), norm) query_norm = safe_normalize(query) gallery_norm = safe_normalize(gallery) # Compute in chunks to optimize GPU memory usage dist_mat = [] with torch.no_grad(), torch.amp.autocast('cuda'): for i in range(0, gallery_norm.size(0), chunk_size): chunk = gallery_norm[i:i+chunk_size] sim = torch.mm(query_norm, chunk.T) # (n_query, chunk_size) dist_mat.append(sim.cpu().to(torch.float32)) # Preserve precision return torch.cat(dist_mat, dim=1) def retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10): """ Optimized CMC and AP computation for very large-scale settings (supports 14 × 800k matrices) Args: dist_mat (Tensor): Distance matrix (num_query, num_gallery) labels_query (Tensor): Query labels (num_query,) labels_gallery (Tensor): Gallery labels (num_gallery,) dist_type (str): Distance type ["cosine"|"l2"] rank_max (int): Maximum ranking depth Returns: cmc (Tensor): CMC curve (rank_max,) ap_all (np.ndarray): AP values for each query (num_valid_query,) """ labels_query = torch.tensor(labels_query) labels_gallery = torch.tensor(labels_gallery) # Basic parameter validation num_query, num_gallery = dist_mat.shape rank_max = min(rank_max, num_gallery) device = dist_mat.device if dist_type == "l2": sorted_indices = torch.argsort(dist_mat, dim=1, descending=False) else: # cosine sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) # for cosine distance, larger means closer, so sort in descending order # Build the sorted label matrix in batch (num_query, num_gallery) sorted_labels = labels_gallery[sorted_indices] # Get retrieval result labels ordered by distance # Build the match matrix (num_query, num_gallery) matches = (sorted_labels == labels_query.view(-1, 1)).long() # Set positions with correct labels in the gallery to 1 # Filter out invalid queries with no matches valid_mask = matches.sum(dim=1) > 0 valid_matches = matches[valid_mask] if valid_matches.size(0) == 0: exit() # Vectorized CMC computation cmc = valid_matches.cumsum(dim=1) cmc = (cmc > 0).float() # Binarize the result cmc_final = cmc[:, :rank_max].mean(dim=0) # (rank_max,) # Vectorized AP computation without Python loops cum_correct = valid_matches.cumsum(dim=1) # Cumulative number of matches [q_num, g_num] positions = torch.arange(1, num_gallery+1, device=device).view(1, -1) # [1, g_num] precisions = cum_correct / positions # Compute precision at every position # First collect precision values at all matched positions, then sum them and divide by the total number of matches ap_values = (precisions * valid_matches).sum(dim=1) / valid_matches.sum(dim=1).clamp(min=1e-6) return cmc_final.cpu(), ap_values.cpu().numpy() def retrieval_real(dist_mat, labels_query, labels_gallery, dist_type="cosine", reject_real_ratio=0.001, paths=None): labels_query = torch.tensor(labels_query) labels_gallery = torch.tensor(labels_gallery) # Basic parameter validation num_query, num_gallery = dist_mat.shape device = dist_mat.device # Build the ordered index list indexs = torch.tensor([i for i in range(len(labels_gallery))]) if dist_type == "l2": sorted_indices = torch.argsort(dist_mat, dim=1, descending=False) else: # cosine sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) # for cosine distance, larger means closer, so sort in descending order # Build the sorted label matrix in batch (num_query, num_gallery) sorted_labels = labels_gallery[sorted_indices] # Get retrieval result labels ordered by similarity sorted_indexs = indexs[sorted_indices] # Get index list ordered by similarity # Build the match matrix (num_query, num_gallery), Find indices of all FakeX and Real entries in each row matches_fake = sorted_labels == labels_query.view(-1, 1) # Set positions with correct labels in the gallery to 1 matches_real = sorted_labels == torch.zeros(len(labels_query)).view(-1, 1) # Plot similarity curves for each algorithm class # for i in range(num_query): # if i == 0: continue # Skip real retrieval results # tensor_fake = dist_mat[i][labels_gallery == i] # tensor_real = dist_mat[i][labels_gallery == 0] # print(len(tensor_fake), len(tensor_real)) # plt.figure(figsize=(12, 4)) # plt.plot(torch.concat([tensor_fake, tensor_real], axis=0)) # # tensor_sort_fake, _ = torch.sort(tensor_fake, descending=True) # # tensor_sort_real, _ = torch.sort(tensor_real, descending=True) # # plt.plot(torch.concat([tensor_sort_fake, tensor_sort_real], axis=0)) # plt.grid(True) # plt.savefig(f'{i}_unsort.png') # plt.close() # Find the position where the 0.001 false rejection threshold is reached real_num = torch.sum(torch.eq(labels_gallery, 0)) # Number of real samples in each row cum_real = matches_real.cumsum(dim=1) reject_k = cum_real == (real_num * reject_real_ratio).view(-1, 1).int() # Apply the 0.001 false rejection constraint first_occurrence = torch.argmax(reject_k.long(), dim=1) # Position where false rejection is tolerated # Find sequence positions where false rejection occurs (handled per fake class) reject_real_union = set() for i in range(len(labels_query)): # skip real retrieval cut_pos = first_occurrence[i]+1 # Include the final false rejection position cut_labels = sorted_labels[i][:cut_pos] # All labels before the false rejection cutoff cut_indexs = sorted_indexs[i][:cut_pos] # All indices before the false rejection cutoff reject_real_union.update(cut_indexs[cut_labels == 0]) # Update the set # print(i, len(reject_real_union), cut_indexs[cut_labels == 0]) # print("Class label", i.numpy(), "descending similarity under false rejection:", dist_mat[i][cut_indexs[cut_labels == 0]]) # print("cut_pos", i, cut_pos) # print(f"Total false rejection after union across all fake classes at {reject_real_ratio}:", len(reject_real_union) / real_num) # Compute recall based on this position cum_fake = matches_fake.cumsum(dim=1) recall_nums = cum_fake[torch.arange(cum_fake.size(0)), first_occurrence] # print("recall_nums", recall_nums) # print("cum_fake", cum_fake[:, -1]) return (recall_nums / cum_fake[:, -1]).numpy() def merge_pkls(pkl_algo): features_all, paths_all = [], [] # filter_flag = False for pkl in pkl_algo: # if '20240412_weixinkaiping_norm_lmdb.pkl' in pkl: # filter_flag = True with open(pkl, 'rb') as f: data = pickle.load(f) # 'features' 'paths' features_all.append(data['features']) paths_all.append(data['paths']) # Apply padding if some algorithms have fewer than 10 samples if sum([len(i) for i in features_all]) < 10: print("Padding:", pkl_algo) features_all = features_all + features_all paths_all = paths_all + paths_all features_all = np.concatenate(features_all, axis=0) paths_all = [element for sublist in paths_all for element in sublist] # Apply blacklist filtering to LMDB information for norm data # if filter_flag: # mask = np.array([vid not in all_drop_vids for vid in paths_all]) # features_all = features_all[mask] # paths_all = [img for img, keep in zip(paths_all, mask) if keep] return features_all, paths_all def group_average(arr, x, discard_remainder=True): """ Group the array into chunks of x elements and compute the average of each group. Args: arr (np.ndarray): Input NumPy array x (int): Number of elements in each group discard_remainder (bool): Whether to discard the remaining elements when fewer than x Returns: np.ndarray: Array of group averages """ if not isinstance(arr, np.ndarray): arr = np.array(arr) if arr.size == 0: return np.array([]) if x <= 0: raise ValueError("Group size x must be a positive integer") if x > arr.size: if discard_remainder: return np.array([]) else: return np.array([arr.mean()]) # Compute the number of complete groups n_groups = arr.size // x full_groups = arr[:n_groups * x].reshape(-1, x) averages = full_groups.mean(axis=1) # Handle the remaining elements if not discard_remainder and arr.size % x != 0: remainder = arr[n_groups * x:] avg_remainder = remainder.mean() averages = np.append(averages, avg_remainder) return averages def calculated_final_result(all_pkl): ''' Directly compute the final metrics from the pkl dictionary: global metrics + local metrics under 0.0001 false rejection ''' result_array = [] eval_modes = [2, 3] for eval_mode in eval_modes: # range(1, 4): # 1:'10Q' 2:'10Q-Avg' 3:'All-Mean' #### Settings NUM_PER_CLASS = 10 features_lists, paths_lists = [], [] # Store features from each pkl query_lists, labels_query = [], [] # Store the initial video features labels_gallery = [] query_lenth = [] #### S1.Load all pkl files and build the query set for idx, pkl_key in tqdm(enumerate(all_pkl.keys())): features_all, paths_all = merge_pkls(all_pkl[pkl_key]) query_lenth.append(len(features_all)) if eval_mode == 1: # Method 1: randomly select X samples as queries sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) query_lists.extend(deepcopy(features_all[sel_idxs])) # Use deep copy here # labels_query.extend([0 for _ in range(NUM_PER_CLASS)]) # Set to 0 to perform reverse recall for Real labels_query.extend([idx for _ in range(NUM_PER_CLASS)]) features_all = np.delete(features_all, sel_idxs, axis=0) paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] elif eval_mode == 2: # Method 2: use the center of X queries within a class as the query sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) query_lists.append(np.mean(deepcopy(features_all[sel_idxs]), axis=0)) # Use deep copy here labels_query.append(idx) features_all = np.delete(features_all, sel_idxs, axis=0) paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] elif eval_mode == 3: # Method 3: use the class centroid as the query (cheating) labels_query.append(idx) query_lists.append(np.mean(np.array(features_all), axis=0)) # Generate ground-truth labels for the gallery set labels_gallery.extend([idx for _ in range(len(paths_all))]) # Save features_lists.append(features_all) paths_lists.append(paths_all) if eval_mode in [2, 3]: NUM_PER_CLASS = 1 # Merge into one large feature list features = np.concatenate(features_lists, axis=0) paths = [element for sublist in paths_lists for element in sublist] print('Len(gallery):', len(features)) #### S2.Generate the distance matrix [Q_num, G_num] dist_mat = optimized_cosine_matrix(np.array(query_lists), features) print('Shape(matrix):', dist_mat.shape) #### S3.[Global metric] compute CMC and AP cmc_all, ap_all = retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10) print("Global metric AP:") result_array.append(group_average(ap_all, NUM_PER_CLASS)) #### S4.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.001 false rejection) # recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.001) # print("Local metric: recall under 0.001 false rejection:") # result_array.append(recall) ### S5.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.0001 false rejection) recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.0001) print("Local metric: recall under 0.0001 false rejection:") result_array.append(group_average(recall, NUM_PER_CLASS)) print() # Reorder as [global metric][local metric 0.001][local metric 0.0001] for 10Q, 10Q-Avg, and All-Mean if len(eval_modes) == 2: rearranged = np.array(result_array)[[i + j * 2 for i in range(2) for j in range(len(result_array) // 2)]] elif len(eval_modes) == 3: rearranged = np.array(result_array)[[0,2,4,1,3,5]] result_array = np.transpose(rearranged) df = pd.DataFrame({ 'Algorithm': list(all_pkl.keys()), **{f'Value_{i+1}': result_array[:, i] for i in range(result_array.shape[1])} }) return df, query_lenth # Support multiple query strategies def calculated_final_result_multi_query_(all_pkl, upper_len=10): ''' Directly compute the final metrics from the pkl dictionary: global metrics + local metrics under 0.0001 false rejection ''' result_array = [] for num_query in range(1, upper_len + 1): # Evaluate every query count once eval_modes = [2] for eval_mode in eval_modes: # range(1, 4): # 1:'10Q' 2:'10Q-Avg' 3:'All-Mean' #### Settings NUM_PER_CLASS = num_query features_lists, paths_lists = [], [] # Store features from each pkl query_lists, labels_query = [], [] # Store the initial video features labels_gallery = [] #### S1.Load all pkl files and build the query set for idx, pkl_key in tqdm(enumerate(all_pkl.keys())): features_all, paths_all = merge_pkls(all_pkl[pkl_key]) if eval_mode == 1: # Method 1: randomly select X samples as queries sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) query_lists.extend(deepcopy(features_all[sel_idxs])) # Use deep copy here # labels_query.extend([0 for _ in range(NUM_PER_CLASS)]) # Set to 0 to perform reverse recall for Real labels_query.extend([idx for _ in range(NUM_PER_CLASS)]) features_all = np.delete(features_all, sel_idxs, axis=0) paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] elif eval_mode == 2: # Method 2: use the center of X queries within a class as the query sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) query_lists.append(np.mean(deepcopy(features_all[sel_idxs]), axis=0)) # Use deep copy here labels_query.append(idx) features_all = np.delete(features_all, sel_idxs, axis=0) print("Selected videos:", [paths_all[i] for i in sel_idxs]) paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] elif eval_mode == 3: # Method 3: use the class centroid as the query (cheating) labels_query.append(idx) query_lists.append(np.mean(np.array(features_all), axis=0)) # Generate ground-truth labels for the gallery set labels_gallery.extend([idx for _ in range(len(paths_all))]) # Save features_lists.append(features_all) paths_lists.append(paths_all) if eval_mode in [2, 3]: NUM_PER_CLASS = 1 # Merge into one large feature list features = np.concatenate(features_lists, axis=0) paths = [element for sublist in paths_lists for element in sublist] print('Len(gallery):', len(features)) #### S2.Generate the distance matrix [Q_num, G_num] dist_mat = optimized_cosine_matrix(np.array(query_lists), features) print('Shape(matrix):', dist_mat.shape) #### S3.[Global metric] compute CMC and AP cmc_all, ap_all = retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10) print("Global metric AP:") result_array.append(group_average(ap_all, NUM_PER_CLASS)) #### S4.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.001 false rejection) # recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.001) # print("Local metric: recall under 0.001 false rejection:") # result_array.append(recall) ### S5.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.0001 false rejection) recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.0001) print("Local metric: recall under 0.0001 false rejection:") result_array.append(group_average(recall, NUM_PER_CLASS)) print() # Reorder as [global metric][local metric 0.001][local metric 0.0001] for 10Q, 10Q-Avg, and All-Mean rearranged = np.array(result_array)[1::2] # np.concatenate([np.array(result_array)[::2], np.array(result_array)[1::2]]) result_array = np.transpose(rearranged) df = pd.DataFrame({ 'Algorithm': list(all_pkl.keys()), **{f'Value_{i+1}': result_array[:, i] for i in range(result_array.shape[1])} }) return df # Visual analysis of bad cases def calculated_final_result_multi_query(all_pkl, upper_len=10): ''' Directly compute the final metrics from the pkl dictionary: global metrics + local metrics under 0.0001 false rejection ''' result_array = [] for num_query in range(1, upper_len + 1): # Evaluate every query count once eval_modes = [2] for eval_mode in eval_modes: # range(1, 4): # 1:'10Q' 2:'10Q-Avg' 3:'All-Mean' #### Settings NUM_PER_CLASS = num_query features_lists, paths_lists = [], [] # Store features from each pkl query_lists, labels_query = [], [] # Store the initial video features labels_gallery = [] #### S1.Load all pkl files and build the query set for idx, pkl_key in tqdm(enumerate(all_pkl.keys())): features_all, paths_all = merge_pkls(all_pkl[pkl_key]) if eval_mode == 1: # Method 1: randomly select X samples as queries sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) query_lists.extend(deepcopy(features_all[sel_idxs])) # Use deep copy here # labels_query.extend([0 for _ in range(NUM_PER_CLASS)]) # Set to 0 to perform reverse recall for Real labels_query.extend([idx for _ in range(NUM_PER_CLASS)]) features_all = np.delete(features_all, sel_idxs, axis=0) paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] elif eval_mode == 2: # Method 2: use the center of X queries within a class as the query sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) query_lists.append(np.mean(deepcopy(features_all[sel_idxs]), axis=0)) # Use deep copy here labels_query.append(idx) features_all = np.delete(features_all, sel_idxs, axis=0) print("Selected videos:", [paths_all[i] for i in sel_idxs]) paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] elif eval_mode == 3: # Method 3: use the class centroid as the query (cheating) labels_query.append(idx) query_lists.append(np.mean(np.array(features_all), axis=0)) # Generate ground-truth labels for the gallery set labels_gallery.extend([idx for _ in range(len(paths_all))]) # Save features_lists.append(features_all) paths_lists.append(paths_all) if eval_mode in [2, 3]: NUM_PER_CLASS = 1 # Merge into one large feature list features = np.concatenate(features_lists, axis=0) paths = [element for sublist in paths_lists for element in sublist] print('Len(gallery):', len(features)) #### S2.Generate the distance matrix [Q_num, G_num] dist_mat = optimized_cosine_matrix(np.array(query_lists), features) print('Shape(matrix):', dist_mat.shape) #### S3.[Global metric] compute CMC and AP cmc_all, ap_all = retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10) print("Global metric AP:") result_array.append(group_average(ap_all, NUM_PER_CLASS)) #### S4.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.001 false rejection) # recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.001) # print("Local metric: recall under 0.001 false rejection:") # result_array.append(recall) ### S5.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.0001 false rejection) recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.00001) # , paths=paths print("Local metric: recall under 0.0001 false rejection:") result_array.append(group_average(recall, NUM_PER_CLASS)) print() # Reorder as [global metric][local metric 0.001][local metric 0.0001] for 10Q, 10Q-Avg, and All-Mean rearranged = np.array(result_array)[1::2] # np.concatenate([np.array(result_array)[::2], np.array(result_array)[1::2]]) result_array = np.transpose(rearranged) df = pd.DataFrame({ 'Algorithm': list(all_pkl.keys()), **{f'Value_{i+1}': result_array[:, i] for i in range(result_array.shape[1])} }) return df """ def retrieval_real_p4(dist_mat, labels_query, labels_gallery, dist_type="cosine", reject_real_ratio=0.001, paths=None,cls2real=None): #Exclude real classes from this computation labels_query = torch.tensor(labels_query) labels_gallery = torch.tensor(labels_gallery) # Basic parameter validation num_query, num_gallery = dist_mat.shape device = dist_mat.device #Build real labels if cls2real is None: raise ValueError("The `cls2real` mapping must be provided (length = num_classes, with each class mapped to its corresponding real label).") cls2real = cls2real.to(device) query_real_labels = cls2real[labels_query] #Filter out real classes is_fake_query = (labels_query != query_real_labels) fake_query_mask = is_fake_query fake_labels_query = labels_query[fake_query_mask] # n,fake_classes fake_query_real_labels = query_real_labels[fake_query_mask] dist_mat = dist_mat[fake_query_mask] # Build the ordered index list indexs = torch.tensor([i for i in range(len(labels_gallery))]) # dist_mat : num_classes,n if dist_type == "l2": sorted_indices = torch.argsort(dist_mat, dim=1, descending=False) else: # cosine sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) # for cosine distance, larger means closer, so sort in descending order num,n indices in ascending order # Build the sorted label matrix in batch (num_query, num_gallery) sorted_labels = labels_gallery[sorted_indices] # Get retrieval result labels ordered by similarity sorted_indexs = indexs[sorted_indices] # Get index list ordered by similarity # Build the match matrix (num_query, num_gallery), Find indices of all FakeX and Real entries in each row matches_fake = sorted_labels == fake_labels_query.view(-1, 1) # Set positions with correct labels in the gallery to 1 matches_real = sorted_labels == fake_query_real_labels.view(-1, 1) #here being equal to 0 means # Plot similarity curves for each algorithm class # for i in range(num_query): # if i == 0: continue # Skip real retrieval results # tensor_fake = dist_mat[i][labels_gallery == i] # tensor_real = dist_mat[i][labels_gallery == 0] # print(len(tensor_fake), len(tensor_real)) # plt.figure(figsize=(12, 4)) # plt.plot(torch.concat([tensor_fake, tensor_real], axis=0)) # # tensor_sort_fake, _ = torch.sort(tensor_fake, descending=True) # # tensor_sort_real, _ = torch.sort(tensor_real, descending=True) # # plt.plot(torch.concat([tensor_sort_fake, tensor_sort_real], axis=0)) # plt.grid(True) # plt.savefig(f'{i}_unsort.png') # plt.close() # Find the position where the 0.001 false rejection threshold is reached real_num = (labels_gallery.view(1, -1) == fake_query_real_labels.view(-1, 1)).sum(dim=1) # Number of real samples in each row cum_real = matches_real.cumsum(dim=1) reject_k = cum_real == (real_num * reject_real_ratio).view(-1, 1).int() # Apply the 0.001 false rejection constraint first_occurrence = torch.argmax(reject_k.long(), dim=1) # Position where false rejection is tolerated # Find sequence positions where false rejection occurs (handled per fake class) reject_real_union = set() for i in range(len(labels_query)-4): # skip real retrieval cut_pos = first_occurrence[i]+1 # Include the final false rejection position cut_labels = sorted_labels[i][:cut_pos] # All labels before the false rejection cutoff cut_indexs = sorted_indexs[i][:cut_pos] # All indices before the false rejection cutoff reject_real_union.update(cut_indexs[cut_labels == query_real_labels[i]]) # Update the set # print(i, len(reject_real_union), cut_indexs[cut_labels == 0]) # print("Class label", i.numpy(), "descending similarity under false rejection:", dist_mat[i][cut_indexs[cut_labels == 0]]) # print("cut_pos", i, cut_pos) # print(f"Total false rejection after union across all fake classes at {reject_real_ratio}:", len(reject_real_union) / real_num) # Compute recall based on this position cum_fake = matches_fake.cumsum(dim=1) recall_nums = cum_fake[torch.arange(cum_fake.size(0)), first_occurrence] # print("recall_nums", recall_nums) # print("cum_fake", cum_fake[:, -1]) return (recall_nums / cum_fake[:, -1]).numpy() """ def retrieval_real_p4(dist_mat, labels_query, labels_gallery, dist_type="cosine", reject_real_ratio=0.001, paths=None,cls2real=None): labels_query = torch.tensor(labels_query) labels_gallery = torch.tensor(labels_gallery) # Basic parameter validation num_query, num_gallery = dist_mat.shape device = dist_mat.device #Build real labels if cls2real is None: raise ValueError("The `cls2real` mapping must be provided") cls2real = cls2real.to(device) query_real_labels = cls2real[labels_query] # Build the ordered index list indexs = torch.tensor([i for i in range(len(labels_gallery))]) # dist_mat : num_classes,n if dist_type == "l2": sorted_indices = torch.argsort(dist_mat, dim=1, descending=False) else: # cosine sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) # for cosine distance, larger means closer, so sort in descending order num,n indices in ascending order # Build the sorted label matrix in batch (num_query, num_gallery) sorted_labels = labels_gallery[sorted_indices] # Get retrieval result labels ordered by similarity sorted_indexs = indexs[sorted_indices] # Get index list ordered by similarity # Build the match matrix (num_query, num_gallery), Find indices of all FakeX and Real entries in each row matches_fake = sorted_labels == labels_query.view(-1, 1) # Set positions with correct labels in the gallery to 1 matches_real = sorted_labels == query_real_labels.view(-1, 1) #here being equal to 0 means # Plot similarity curves for each algorithm class # for i in range(num_query): # if i == 0: continue # Skip real retrieval results # tensor_fake = dist_mat[i][labels_gallery == i] # tensor_real = dist_mat[i][labels_gallery == 0] # print(len(tensor_fake), len(tensor_real)) # plt.figure(figsize=(12, 4)) # plt.plot(torch.concat([tensor_fake, tensor_real], axis=0)) # # tensor_sort_fake, _ = torch.sort(tensor_fake, descending=True) # # tensor_sort_real, _ = torch.sort(tensor_real, descending=True) # # plt.plot(torch.concat([tensor_sort_fake, tensor_sort_real], axis=0)) # plt.grid(True) # plt.savefig(f'{i}_unsort.png') # plt.close() # Find the position where the 0.001 false rejection threshold is reached real_num = (labels_gallery.view(1, -1) == query_real_labels.view(-1, 1)).sum(dim=1) # Number of real samples in each row cum_real = matches_real.cumsum(dim=1) reject_k = cum_real == (torch.clamp_min((real_num * reject_real_ratio).int(), 1)).view(-1, 1).int() # Apply the 0.001 false rejection constraint first_occurrence = torch.argmax(reject_k.long(), dim=1) # Position where false rejection is tolerated # Find sequence positions where false rejection occurs (handled per fake class) reject_real_union = set() for i in range(len(labels_query)): # skip real retrieval cut_pos = first_occurrence[i]+1 # Include the final false rejection position cut_labels = sorted_labels[i][:cut_pos] # All labels before the false rejection cutoff cut_indexs = sorted_indexs[i][:cut_pos] # All indices before the false rejection cutoff reject_real_union.update(cut_indexs[cut_labels == query_real_labels[i]]) # Update the set # print(i, len(reject_real_union), cut_indexs[cut_labels == 0]) # print("Class label", i.numpy(), "descending similarity under false rejection:", dist_mat[i][cut_indexs[cut_labels == 0]]) # print("cut_pos", i, cut_pos) # print(f"Total false rejection after union across all fake classes at {reject_real_ratio}:", len(reject_real_union) / real_num) # Compute recall based on this position cum_fake = matches_fake.cumsum(dim=1) recall_nums = cum_fake[torch.arange(cum_fake.size(0)), first_occurrence] # print("recall_nums", recall_nums) # print("cum_fake", cum_fake[:, -1]) return (recall_nums / cum_fake[:, -1]).numpy()