| import os, pickle |
| import torch |
| import numpy as np |
| import pandas as pd |
| from tqdm import tqdm |
| from copy import deepcopy |
| import matplotlib.pyplot as plt |
| np.random.seed(1) |
|
|
| def cosine_similarity(a, b): |
| dot_product = np.dot(a, b.T) |
| norm_a = np.linalg.norm(a) |
| norm_b = np.linalg.norm(b, axis=1) |
| return dot_product / (norm_a * norm_b) |
|
|
| def optimized_cosine_matrix(query_features, gallery_features, chunk_size=50000, device='cuda'): |
| """ Revised cosine similarity matrix computation """ |
| |
| query = torch.as_tensor(query_features, dtype=torch.float32) |
| gallery = torch.as_tensor(gallery_features, dtype=torch.float32) |
| |
| |
| query = query.to(device) |
| gallery = gallery.to(device) |
| |
| |
| def safe_normalize(x): |
| norm = torch.norm(x, p=2, dim=1, keepdim=True) |
| return x / torch.where(norm == 0, torch.ones_like(norm), norm) |
| |
| query_norm = safe_normalize(query) |
| gallery_norm = safe_normalize(gallery) |
| |
| |
| dist_mat = [] |
| with torch.no_grad(), torch.amp.autocast('cuda'): |
| for i in range(0, gallery_norm.size(0), chunk_size): |
| chunk = gallery_norm[i:i+chunk_size] |
| sim = torch.mm(query_norm, chunk.T) |
| dist_mat.append(sim.cpu().to(torch.float32)) |
| |
| return torch.cat(dist_mat, dim=1) |
|
|
| def retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10): |
| """ |
| Optimized CMC and AP computation for very large-scale settings (supports 14 × 800k matrices) |
| |
| Args: |
| dist_mat (Tensor): Distance matrix (num_query, num_gallery) |
| labels_query (Tensor): Query labels (num_query,) |
| labels_gallery (Tensor): Gallery labels (num_gallery,) |
| dist_type (str): Distance type ["cosine"|"l2"] |
| rank_max (int): Maximum ranking depth |
| |
| Returns: |
| cmc (Tensor): CMC curve (rank_max,) |
| ap_all (np.ndarray): AP values for each query (num_valid_query,) |
| """ |
| labels_query = torch.tensor(labels_query) |
| labels_gallery = torch.tensor(labels_gallery) |
| |
| |
| num_query, num_gallery = dist_mat.shape |
| rank_max = min(rank_max, num_gallery) |
| device = dist_mat.device |
| |
| if dist_type == "l2": |
| sorted_indices = torch.argsort(dist_mat, dim=1, descending=False) |
| else: |
| sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) |
| |
| |
| sorted_labels = labels_gallery[sorted_indices] |
| |
| |
| matches = (sorted_labels == labels_query.view(-1, 1)).long() |
| |
| |
| valid_mask = matches.sum(dim=1) > 0 |
| valid_matches = matches[valid_mask] |
| if valid_matches.size(0) == 0: |
| exit() |
| |
| |
| cmc = valid_matches.cumsum(dim=1) |
| cmc = (cmc > 0).float() |
| cmc_final = cmc[:, :rank_max].mean(dim=0) |
| |
| |
| cum_correct = valid_matches.cumsum(dim=1) |
| positions = torch.arange(1, num_gallery+1, device=device).view(1, -1) |
| precisions = cum_correct / positions |
| |
| ap_values = (precisions * valid_matches).sum(dim=1) / valid_matches.sum(dim=1).clamp(min=1e-6) |
| |
| return cmc_final.cpu(), ap_values.cpu().numpy() |
|
|
| def retrieval_real(dist_mat, labels_query, labels_gallery, dist_type="cosine", reject_real_ratio=0.001, paths=None): |
| |
| labels_query = torch.tensor(labels_query) |
| labels_gallery = torch.tensor(labels_gallery) |
| |
| |
| num_query, num_gallery = dist_mat.shape |
| device = dist_mat.device |
| |
| indexs = torch.tensor([i for i in range(len(labels_gallery))]) |
| |
| if dist_type == "l2": |
| sorted_indices = torch.argsort(dist_mat, dim=1, descending=False) |
| else: |
| sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) |
| |
| |
| sorted_labels = labels_gallery[sorted_indices] |
| sorted_indexs = indexs[sorted_indices] |
| |
| |
| matches_fake = sorted_labels == labels_query.view(-1, 1) |
| matches_real = sorted_labels == torch.zeros(len(labels_query)).view(-1, 1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| real_num = torch.sum(torch.eq(labels_gallery, 0)) |
| cum_real = matches_real.cumsum(dim=1) |
| reject_k = cum_real == (real_num * reject_real_ratio).view(-1, 1).int() |
| first_occurrence = torch.argmax(reject_k.long(), dim=1) |
| |
| |
| reject_real_union = set() |
| for i in range(len(labels_query)): |
| cut_pos = first_occurrence[i]+1 |
| cut_labels = sorted_labels[i][:cut_pos] |
| cut_indexs = sorted_indexs[i][:cut_pos] |
| reject_real_union.update(cut_indexs[cut_labels == 0]) |
| |
| |
| |
| |
| |
| |
| cum_fake = matches_fake.cumsum(dim=1) |
| recall_nums = cum_fake[torch.arange(cum_fake.size(0)), first_occurrence] |
| |
| |
| |
| |
| return (recall_nums / cum_fake[:, -1]).numpy() |
|
|
| def merge_pkls(pkl_algo): |
| features_all, paths_all = [], [] |
| |
| for pkl in pkl_algo: |
| |
| |
| with open(pkl, 'rb') as f: |
| data = pickle.load(f) |
| features_all.append(data['features']) |
| paths_all.append(data['paths']) |
|
|
| |
| if sum([len(i) for i in features_all]) < 10: |
| print("Padding:", pkl_algo) |
| features_all = features_all + features_all |
| paths_all = paths_all + paths_all |
| |
| features_all = np.concatenate(features_all, axis=0) |
| paths_all = [element for sublist in paths_all for element in sublist] |
|
|
| |
| |
| |
| |
| |
| |
| return features_all, paths_all |
|
|
| def group_average(arr, x, discard_remainder=True): |
| """ |
| Group the array into chunks of x elements and compute the average of each group. |
| |
| Args: |
| arr (np.ndarray): Input NumPy array |
| x (int): Number of elements in each group |
| discard_remainder (bool): Whether to discard the remaining elements when fewer than x |
| |
| Returns: |
| np.ndarray: Array of group averages |
| """ |
| if not isinstance(arr, np.ndarray): |
| arr = np.array(arr) |
| |
| if arr.size == 0: |
| return np.array([]) |
| |
| if x <= 0: |
| raise ValueError("Group size x must be a positive integer") |
| |
| if x > arr.size: |
| if discard_remainder: |
| return np.array([]) |
| else: |
| return np.array([arr.mean()]) |
| |
| |
| n_groups = arr.size // x |
| full_groups = arr[:n_groups * x].reshape(-1, x) |
| averages = full_groups.mean(axis=1) |
| |
| |
| if not discard_remainder and arr.size % x != 0: |
| remainder = arr[n_groups * x:] |
| avg_remainder = remainder.mean() |
| averages = np.append(averages, avg_remainder) |
| |
| return averages |
|
|
| def calculated_final_result(all_pkl): |
| ''' |
| Directly compute the final metrics from the pkl dictionary: global metrics + local metrics under 0.0001 false rejection |
| ''' |
| result_array = [] |
| |
| eval_modes = [2, 3] |
| for eval_mode in eval_modes: |
| |
| NUM_PER_CLASS = 10 |
|
|
| features_lists, paths_lists = [], [] |
| query_lists, labels_query = [], [] |
| labels_gallery = [] |
|
|
| query_lenth = [] |
| |
| for idx, pkl_key in tqdm(enumerate(all_pkl.keys())): |
| features_all, paths_all = merge_pkls(all_pkl[pkl_key]) |
| query_lenth.append(len(features_all)) |
| |
| if eval_mode == 1: |
| |
| sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) |
| query_lists.extend(deepcopy(features_all[sel_idxs])) |
| |
| labels_query.extend([idx for _ in range(NUM_PER_CLASS)]) |
| features_all = np.delete(features_all, sel_idxs, axis=0) |
| paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] |
| |
| elif eval_mode == 2: |
| |
| sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) |
| query_lists.append(np.mean(deepcopy(features_all[sel_idxs]), axis=0)) |
| labels_query.append(idx) |
| features_all = np.delete(features_all, sel_idxs, axis=0) |
| paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] |
| |
| elif eval_mode == 3: |
| |
| labels_query.append(idx) |
| query_lists.append(np.mean(np.array(features_all), axis=0)) |
|
|
| |
| labels_gallery.extend([idx for _ in range(len(paths_all))]) |
| |
| features_lists.append(features_all) |
| paths_lists.append(paths_all) |
|
|
| if eval_mode in [2, 3]: |
| NUM_PER_CLASS = 1 |
| |
| |
| features = np.concatenate(features_lists, axis=0) |
| paths = [element for sublist in paths_lists for element in sublist] |
| print('Len(gallery):', len(features)) |
|
|
| |
| dist_mat = optimized_cosine_matrix(np.array(query_lists), features) |
| print('Shape(matrix):', dist_mat.shape) |
|
|
| |
| cmc_all, ap_all = retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10) |
| print("Global metric AP:") |
| result_array.append(group_average(ap_all, NUM_PER_CLASS)) |
|
|
| |
| |
| |
| |
|
|
| |
| recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.0001) |
| print("Local metric: recall under 0.0001 false rejection:") |
| result_array.append(group_average(recall, NUM_PER_CLASS)) |
| print() |
|
|
| |
| if len(eval_modes) == 2: |
| rearranged = np.array(result_array)[[i + j * 2 for i in range(2) for j in range(len(result_array) // 2)]] |
| elif len(eval_modes) == 3: |
| rearranged = np.array(result_array)[[0,2,4,1,3,5]] |
| |
| result_array = np.transpose(rearranged) |
| |
| df = pd.DataFrame({ |
| 'Algorithm': list(all_pkl.keys()), |
| **{f'Value_{i+1}': result_array[:, i] for i in range(result_array.shape[1])} |
| }) |
| |
| return df, query_lenth |
|
|
| |
| def calculated_final_result_multi_query_(all_pkl, upper_len=10): |
| ''' |
| Directly compute the final metrics from the pkl dictionary: global metrics + local metrics under 0.0001 false rejection |
| ''' |
| result_array = [] |
| |
| for num_query in range(1, upper_len + 1): |
| eval_modes = [2] |
| for eval_mode in eval_modes: |
| |
| NUM_PER_CLASS = num_query |
|
|
| features_lists, paths_lists = [], [] |
| query_lists, labels_query = [], [] |
| labels_gallery = [] |
|
|
| |
| for idx, pkl_key in tqdm(enumerate(all_pkl.keys())): |
| features_all, paths_all = merge_pkls(all_pkl[pkl_key]) |
| |
| if eval_mode == 1: |
| |
| sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) |
| query_lists.extend(deepcopy(features_all[sel_idxs])) |
| |
| labels_query.extend([idx for _ in range(NUM_PER_CLASS)]) |
| features_all = np.delete(features_all, sel_idxs, axis=0) |
| paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] |
| |
| elif eval_mode == 2: |
| |
| sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) |
| query_lists.append(np.mean(deepcopy(features_all[sel_idxs]), axis=0)) |
| labels_query.append(idx) |
| features_all = np.delete(features_all, sel_idxs, axis=0) |
| print("Selected videos:", [paths_all[i] for i in sel_idxs]) |
| paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] |
| |
| elif eval_mode == 3: |
| |
| labels_query.append(idx) |
| query_lists.append(np.mean(np.array(features_all), axis=0)) |
|
|
| |
| labels_gallery.extend([idx for _ in range(len(paths_all))]) |
| |
| features_lists.append(features_all) |
| paths_lists.append(paths_all) |
|
|
| if eval_mode in [2, 3]: |
| NUM_PER_CLASS = 1 |
| |
| |
| features = np.concatenate(features_lists, axis=0) |
| paths = [element for sublist in paths_lists for element in sublist] |
| print('Len(gallery):', len(features)) |
|
|
| |
| dist_mat = optimized_cosine_matrix(np.array(query_lists), features) |
| print('Shape(matrix):', dist_mat.shape) |
|
|
| |
| cmc_all, ap_all = retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10) |
| print("Global metric AP:") |
| result_array.append(group_average(ap_all, NUM_PER_CLASS)) |
|
|
| |
| |
| |
| |
|
|
| |
| recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.0001) |
| print("Local metric: recall under 0.0001 false rejection:") |
| result_array.append(group_average(recall, NUM_PER_CLASS)) |
| print() |
|
|
| |
| rearranged = np.array(result_array)[1::2] |
| result_array = np.transpose(rearranged) |
| |
| df = pd.DataFrame({ |
| 'Algorithm': list(all_pkl.keys()), |
| **{f'Value_{i+1}': result_array[:, i] for i in range(result_array.shape[1])} |
| }) |
| |
| return df |
|
|
| |
| def calculated_final_result_multi_query(all_pkl, upper_len=10): |
| ''' |
| Directly compute the final metrics from the pkl dictionary: global metrics + local metrics under 0.0001 false rejection |
| ''' |
| result_array = [] |
| |
| for num_query in range(1, upper_len + 1): |
| eval_modes = [2] |
| for eval_mode in eval_modes: |
| |
| NUM_PER_CLASS = num_query |
|
|
| features_lists, paths_lists = [], [] |
| query_lists, labels_query = [], [] |
| labels_gallery = [] |
|
|
| |
| for idx, pkl_key in tqdm(enumerate(all_pkl.keys())): |
| features_all, paths_all = merge_pkls(all_pkl[pkl_key]) |
| |
| if eval_mode == 1: |
| |
| sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) |
| query_lists.extend(deepcopy(features_all[sel_idxs])) |
| |
| labels_query.extend([idx for _ in range(NUM_PER_CLASS)]) |
| features_all = np.delete(features_all, sel_idxs, axis=0) |
| paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] |
| |
| elif eval_mode == 2: |
| |
| sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False) |
| query_lists.append(np.mean(deepcopy(features_all[sel_idxs]), axis=0)) |
| labels_query.append(idx) |
| features_all = np.delete(features_all, sel_idxs, axis=0) |
| print("Selected videos:", [paths_all[i] for i in sel_idxs]) |
| paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs] |
| |
| elif eval_mode == 3: |
| |
| labels_query.append(idx) |
| query_lists.append(np.mean(np.array(features_all), axis=0)) |
|
|
| |
| labels_gallery.extend([idx for _ in range(len(paths_all))]) |
| |
| features_lists.append(features_all) |
| paths_lists.append(paths_all) |
|
|
| if eval_mode in [2, 3]: |
| NUM_PER_CLASS = 1 |
| |
| |
| features = np.concatenate(features_lists, axis=0) |
| paths = [element for sublist in paths_lists for element in sublist] |
| print('Len(gallery):', len(features)) |
|
|
| |
| dist_mat = optimized_cosine_matrix(np.array(query_lists), features) |
| print('Shape(matrix):', dist_mat.shape) |
|
|
| |
| cmc_all, ap_all = retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10) |
| print("Global metric AP:") |
| result_array.append(group_average(ap_all, NUM_PER_CLASS)) |
|
|
| |
| |
| |
| |
|
|
| |
| recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.00001) |
| print("Local metric: recall under 0.0001 false rejection:") |
| result_array.append(group_average(recall, NUM_PER_CLASS)) |
| print() |
|
|
| |
| rearranged = np.array(result_array)[1::2] |
| result_array = np.transpose(rearranged) |
| |
| df = pd.DataFrame({ |
| 'Algorithm': list(all_pkl.keys()), |
| **{f'Value_{i+1}': result_array[:, i] for i in range(result_array.shape[1])} |
| }) |
| |
| return df |
| """ |
| def retrieval_real_p4(dist_mat, labels_query, labels_gallery, dist_type="cosine", reject_real_ratio=0.001, paths=None,cls2real=None): |
| #Exclude real classes from this computation |
| labels_query = torch.tensor(labels_query) |
| labels_gallery = torch.tensor(labels_gallery) |
| |
| # Basic parameter validation |
| num_query, num_gallery = dist_mat.shape |
| device = dist_mat.device |
| #Build real labels |
| if cls2real is None: |
| raise ValueError("The `cls2real` mapping must be provided (length = num_classes, with each class mapped to its corresponding real label).") |
| cls2real = cls2real.to(device) |
| query_real_labels = cls2real[labels_query] |
| #Filter out real classes |
| is_fake_query = (labels_query != query_real_labels) |
| fake_query_mask = is_fake_query |
| fake_labels_query = labels_query[fake_query_mask] |
| # n,fake_classes |
| fake_query_real_labels = query_real_labels[fake_query_mask] |
| dist_mat = dist_mat[fake_query_mask] |
| # Build the ordered index list |
| indexs = torch.tensor([i for i in range(len(labels_gallery))]) |
| # dist_mat : num_classes,n |
| if dist_type == "l2": |
| sorted_indices = torch.argsort(dist_mat, dim=1, descending=False) |
| else: # cosine |
| sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) # for cosine distance, larger means closer, so sort in descending order num,n indices in ascending order |
| |
| # Build the sorted label matrix in batch (num_query, num_gallery) |
| sorted_labels = labels_gallery[sorted_indices] # Get retrieval result labels ordered by similarity |
| sorted_indexs = indexs[sorted_indices] # Get index list ordered by similarity |
| |
| # Build the match matrix (num_query, num_gallery), Find indices of all FakeX and Real entries in each row |
| matches_fake = sorted_labels == fake_labels_query.view(-1, 1) # Set positions with correct labels in the gallery to 1 |
| matches_real = sorted_labels == fake_query_real_labels.view(-1, 1) #here being equal to 0 means |
| |
| # Plot similarity curves for each algorithm class |
| # for i in range(num_query): |
| # if i == 0: continue # Skip real retrieval results |
| # tensor_fake = dist_mat[i][labels_gallery == i] |
| # tensor_real = dist_mat[i][labels_gallery == 0] |
| # print(len(tensor_fake), len(tensor_real)) |
| # plt.figure(figsize=(12, 4)) |
| # plt.plot(torch.concat([tensor_fake, tensor_real], axis=0)) |
| # # tensor_sort_fake, _ = torch.sort(tensor_fake, descending=True) |
| # # tensor_sort_real, _ = torch.sort(tensor_real, descending=True) |
| # # plt.plot(torch.concat([tensor_sort_fake, tensor_sort_real], axis=0)) |
| # plt.grid(True) |
| # plt.savefig(f'{i}_unsort.png') |
| # plt.close() |
| |
| # Find the position where the 0.001 false rejection threshold is reached |
| real_num = (labels_gallery.view(1, -1) == fake_query_real_labels.view(-1, 1)).sum(dim=1) # Number of real samples in each row |
| cum_real = matches_real.cumsum(dim=1) |
| reject_k = cum_real == (real_num * reject_real_ratio).view(-1, 1).int() # Apply the 0.001 false rejection constraint |
| first_occurrence = torch.argmax(reject_k.long(), dim=1) # Position where false rejection is tolerated |
| |
| # Find sequence positions where false rejection occurs (handled per fake class) |
| reject_real_union = set() |
| for i in range(len(labels_query)-4): # skip real retrieval |
| cut_pos = first_occurrence[i]+1 # Include the final false rejection position |
| cut_labels = sorted_labels[i][:cut_pos] # All labels before the false rejection cutoff |
| cut_indexs = sorted_indexs[i][:cut_pos] # All indices before the false rejection cutoff |
| reject_real_union.update(cut_indexs[cut_labels == query_real_labels[i]]) # Update the set |
| # print(i, len(reject_real_union), cut_indexs[cut_labels == 0]) |
| # print("Class label", i.numpy(), "descending similarity under false rejection:", dist_mat[i][cut_indexs[cut_labels == 0]]) |
| # print("cut_pos", i, cut_pos) |
| # print(f"Total false rejection after union across all fake classes at {reject_real_ratio}:", len(reject_real_union) / real_num) |
| |
| # Compute recall based on this position |
| cum_fake = matches_fake.cumsum(dim=1) |
| recall_nums = cum_fake[torch.arange(cum_fake.size(0)), first_occurrence] |
| |
| # print("recall_nums", recall_nums) |
| # print("cum_fake", cum_fake[:, -1]) |
| |
| return (recall_nums / cum_fake[:, -1]).numpy() |
| |
| |
| """ |
| def retrieval_real_p4(dist_mat, labels_query, labels_gallery, dist_type="cosine", reject_real_ratio=0.001, paths=None,cls2real=None): |
| |
| labels_query = torch.tensor(labels_query) |
| labels_gallery = torch.tensor(labels_gallery) |
| |
| |
| num_query, num_gallery = dist_mat.shape |
| device = dist_mat.device |
| |
| if cls2real is None: |
| raise ValueError("The `cls2real` mapping must be provided") |
| cls2real = cls2real.to(device) |
| query_real_labels = cls2real[labels_query] |
| |
| indexs = torch.tensor([i for i in range(len(labels_gallery))]) |
| |
| if dist_type == "l2": |
| sorted_indices = torch.argsort(dist_mat, dim=1, descending=False) |
| else: |
| sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) |
| |
| |
| sorted_labels = labels_gallery[sorted_indices] |
| sorted_indexs = indexs[sorted_indices] |
| |
| |
| matches_fake = sorted_labels == labels_query.view(-1, 1) |
| matches_real = sorted_labels == query_real_labels.view(-1, 1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| real_num = (labels_gallery.view(1, -1) == query_real_labels.view(-1, 1)).sum(dim=1) |
| cum_real = matches_real.cumsum(dim=1) |
| reject_k = cum_real == (torch.clamp_min((real_num * reject_real_ratio).int(), 1)).view(-1, 1).int() |
| first_occurrence = torch.argmax(reject_k.long(), dim=1) |
| |
| |
| reject_real_union = set() |
| for i in range(len(labels_query)): |
| cut_pos = first_occurrence[i]+1 |
| cut_labels = sorted_labels[i][:cut_pos] |
| cut_indexs = sorted_indexs[i][:cut_pos] |
| reject_real_union.update(cut_indexs[cut_labels == query_real_labels[i]]) |
| |
| |
| |
| |
| |
| |
| cum_fake = matches_fake.cumsum(dim=1) |
| recall_nums = cum_fake[torch.arange(cum_fake.size(0)), first_occurrence] |
| |
| |
| |
| |
| return (recall_nums / cum_fake[:, -1]).numpy() |
|
|
|
|