shunliwang
update
8bc3305
import os, pickle
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from copy import deepcopy
import matplotlib.pyplot as plt
np.random.seed(1)
def cosine_similarity(a, b):
dot_product = np.dot(a, b.T)
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b, axis=1)
return dot_product / (norm_a * norm_b)
def optimized_cosine_matrix(query_features, gallery_features, chunk_size=50000, device='cuda'):
""" Revised cosine similarity matrix computation """
# Ensure the inputs use floating point types (important!)
query = torch.as_tensor(query_features, dtype=torch.float32)
gallery = torch.as_tensor(gallery_features, dtype=torch.float32)
# Move tensors to the target device synchronously to avoid asynchronous errors
query = query.to(device)
gallery = gallery.to(device)
# Safely normalize inputs to avoid zero vectors
def safe_normalize(x):
norm = torch.norm(x, p=2, dim=1, keepdim=True)
return x / torch.where(norm == 0, torch.ones_like(norm), norm)
query_norm = safe_normalize(query)
gallery_norm = safe_normalize(gallery)
# Compute in chunks to optimize GPU memory usage
dist_mat = []
with torch.no_grad(), torch.amp.autocast('cuda'):
for i in range(0, gallery_norm.size(0), chunk_size):
chunk = gallery_norm[i:i+chunk_size]
sim = torch.mm(query_norm, chunk.T) # (n_query, chunk_size)
dist_mat.append(sim.cpu().to(torch.float32)) # Preserve precision
return torch.cat(dist_mat, dim=1)
def retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10):
"""
Optimized CMC and AP computation for very large-scale settings (supports 14 × 800k matrices)
Args:
dist_mat (Tensor): Distance matrix (num_query, num_gallery)
labels_query (Tensor): Query labels (num_query,)
labels_gallery (Tensor): Gallery labels (num_gallery,)
dist_type (str): Distance type ["cosine"|"l2"]
rank_max (int): Maximum ranking depth
Returns:
cmc (Tensor): CMC curve (rank_max,)
ap_all (np.ndarray): AP values for each query (num_valid_query,)
"""
labels_query = torch.tensor(labels_query)
labels_gallery = torch.tensor(labels_gallery)
# Basic parameter validation
num_query, num_gallery = dist_mat.shape
rank_max = min(rank_max, num_gallery)
device = dist_mat.device
if dist_type == "l2":
sorted_indices = torch.argsort(dist_mat, dim=1, descending=False)
else: # cosine
sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) # for cosine distance, larger means closer, so sort in descending order
# Build the sorted label matrix in batch (num_query, num_gallery)
sorted_labels = labels_gallery[sorted_indices] # Get retrieval result labels ordered by distance
# Build the match matrix (num_query, num_gallery)
matches = (sorted_labels == labels_query.view(-1, 1)).long() # Set positions with correct labels in the gallery to 1
# Filter out invalid queries with no matches
valid_mask = matches.sum(dim=1) > 0
valid_matches = matches[valid_mask]
if valid_matches.size(0) == 0:
exit()
# Vectorized CMC computation
cmc = valid_matches.cumsum(dim=1)
cmc = (cmc > 0).float() # Binarize the result
cmc_final = cmc[:, :rank_max].mean(dim=0) # (rank_max,)
# Vectorized AP computation without Python loops
cum_correct = valid_matches.cumsum(dim=1) # Cumulative number of matches [q_num, g_num]
positions = torch.arange(1, num_gallery+1, device=device).view(1, -1) # [1, g_num]
precisions = cum_correct / positions # Compute precision at every position
# First collect precision values at all matched positions, then sum them and divide by the total number of matches
ap_values = (precisions * valid_matches).sum(dim=1) / valid_matches.sum(dim=1).clamp(min=1e-6)
return cmc_final.cpu(), ap_values.cpu().numpy()
def retrieval_real(dist_mat, labels_query, labels_gallery, dist_type="cosine", reject_real_ratio=0.001, paths=None):
labels_query = torch.tensor(labels_query)
labels_gallery = torch.tensor(labels_gallery)
# Basic parameter validation
num_query, num_gallery = dist_mat.shape
device = dist_mat.device
# Build the ordered index list
indexs = torch.tensor([i for i in range(len(labels_gallery))])
if dist_type == "l2":
sorted_indices = torch.argsort(dist_mat, dim=1, descending=False)
else: # cosine
sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) # for cosine distance, larger means closer, so sort in descending order
# Build the sorted label matrix in batch (num_query, num_gallery)
sorted_labels = labels_gallery[sorted_indices] # Get retrieval result labels ordered by similarity
sorted_indexs = indexs[sorted_indices] # Get index list ordered by similarity
# Build the match matrix (num_query, num_gallery), Find indices of all FakeX and Real entries in each row
matches_fake = sorted_labels == labels_query.view(-1, 1) # Set positions with correct labels in the gallery to 1
matches_real = sorted_labels == torch.zeros(len(labels_query)).view(-1, 1)
# Plot similarity curves for each algorithm class
# for i in range(num_query):
# if i == 0: continue # Skip real retrieval results
# tensor_fake = dist_mat[i][labels_gallery == i]
# tensor_real = dist_mat[i][labels_gallery == 0]
# print(len(tensor_fake), len(tensor_real))
# plt.figure(figsize=(12, 4))
# plt.plot(torch.concat([tensor_fake, tensor_real], axis=0))
# # tensor_sort_fake, _ = torch.sort(tensor_fake, descending=True)
# # tensor_sort_real, _ = torch.sort(tensor_real, descending=True)
# # plt.plot(torch.concat([tensor_sort_fake, tensor_sort_real], axis=0))
# plt.grid(True)
# plt.savefig(f'{i}_unsort.png')
# plt.close()
# Find the position where the 0.001 false rejection threshold is reached
real_num = torch.sum(torch.eq(labels_gallery, 0)) # Number of real samples in each row
cum_real = matches_real.cumsum(dim=1)
reject_k = cum_real == (real_num * reject_real_ratio).view(-1, 1).int() # Apply the 0.001 false rejection constraint
first_occurrence = torch.argmax(reject_k.long(), dim=1) # Position where false rejection is tolerated
# Find sequence positions where false rejection occurs (handled per fake class)
reject_real_union = set()
for i in range(len(labels_query)): # skip real retrieval
cut_pos = first_occurrence[i]+1 # Include the final false rejection position
cut_labels = sorted_labels[i][:cut_pos] # All labels before the false rejection cutoff
cut_indexs = sorted_indexs[i][:cut_pos] # All indices before the false rejection cutoff
reject_real_union.update(cut_indexs[cut_labels == 0]) # Update the set
# print(i, len(reject_real_union), cut_indexs[cut_labels == 0])
# print("Class label", i.numpy(), "descending similarity under false rejection:", dist_mat[i][cut_indexs[cut_labels == 0]])
# print("cut_pos", i, cut_pos)
# print(f"Total false rejection after union across all fake classes at {reject_real_ratio}:", len(reject_real_union) / real_num)
# Compute recall based on this position
cum_fake = matches_fake.cumsum(dim=1)
recall_nums = cum_fake[torch.arange(cum_fake.size(0)), first_occurrence]
# print("recall_nums", recall_nums)
# print("cum_fake", cum_fake[:, -1])
return (recall_nums / cum_fake[:, -1]).numpy()
def merge_pkls(pkl_algo):
features_all, paths_all = [], []
# filter_flag = False
for pkl in pkl_algo:
# if '20240412_weixinkaiping_norm_lmdb.pkl' in pkl:
# filter_flag = True
with open(pkl, 'rb') as f:
data = pickle.load(f) # 'features' 'paths'
features_all.append(data['features'])
paths_all.append(data['paths'])
# Apply padding if some algorithms have fewer than 10 samples
if sum([len(i) for i in features_all]) < 10:
print("Padding:", pkl_algo)
features_all = features_all + features_all
paths_all = paths_all + paths_all
features_all = np.concatenate(features_all, axis=0)
paths_all = [element for sublist in paths_all for element in sublist]
# Apply blacklist filtering to LMDB information for norm data
# if filter_flag:
# mask = np.array([vid not in all_drop_vids for vid in paths_all])
# features_all = features_all[mask]
# paths_all = [img for img, keep in zip(paths_all, mask) if keep]
return features_all, paths_all
def group_average(arr, x, discard_remainder=True):
"""
Group the array into chunks of x elements and compute the average of each group.
Args:
arr (np.ndarray): Input NumPy array
x (int): Number of elements in each group
discard_remainder (bool): Whether to discard the remaining elements when fewer than x
Returns:
np.ndarray: Array of group averages
"""
if not isinstance(arr, np.ndarray):
arr = np.array(arr)
if arr.size == 0:
return np.array([])
if x <= 0:
raise ValueError("Group size x must be a positive integer")
if x > arr.size:
if discard_remainder:
return np.array([])
else:
return np.array([arr.mean()])
# Compute the number of complete groups
n_groups = arr.size // x
full_groups = arr[:n_groups * x].reshape(-1, x)
averages = full_groups.mean(axis=1)
# Handle the remaining elements
if not discard_remainder and arr.size % x != 0:
remainder = arr[n_groups * x:]
avg_remainder = remainder.mean()
averages = np.append(averages, avg_remainder)
return averages
def calculated_final_result(all_pkl):
'''
Directly compute the final metrics from the pkl dictionary: global metrics + local metrics under 0.0001 false rejection
'''
result_array = []
eval_modes = [2, 3]
for eval_mode in eval_modes: # range(1, 4): # 1:'10Q' 2:'10Q-Avg' 3:'All-Mean'
#### Settings
NUM_PER_CLASS = 10
features_lists, paths_lists = [], [] # Store features from each pkl
query_lists, labels_query = [], [] # Store the initial video features
labels_gallery = []
query_lenth = []
#### S1.Load all pkl files and build the query set
for idx, pkl_key in tqdm(enumerate(all_pkl.keys())):
features_all, paths_all = merge_pkls(all_pkl[pkl_key])
query_lenth.append(len(features_all))
if eval_mode == 1:
# Method 1: randomly select X samples as queries
sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False)
query_lists.extend(deepcopy(features_all[sel_idxs])) # Use deep copy here
# labels_query.extend([0 for _ in range(NUM_PER_CLASS)]) # Set to 0 to perform reverse recall for Real
labels_query.extend([idx for _ in range(NUM_PER_CLASS)])
features_all = np.delete(features_all, sel_idxs, axis=0)
paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs]
elif eval_mode == 2:
# Method 2: use the center of X queries within a class as the query
sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False)
query_lists.append(np.mean(deepcopy(features_all[sel_idxs]), axis=0)) # Use deep copy here
labels_query.append(idx)
features_all = np.delete(features_all, sel_idxs, axis=0)
paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs]
elif eval_mode == 3:
# Method 3: use the class centroid as the query (cheating)
labels_query.append(idx)
query_lists.append(np.mean(np.array(features_all), axis=0))
# Generate ground-truth labels for the gallery set
labels_gallery.extend([idx for _ in range(len(paths_all))])
# Save
features_lists.append(features_all)
paths_lists.append(paths_all)
if eval_mode in [2, 3]:
NUM_PER_CLASS = 1
# Merge into one large feature list
features = np.concatenate(features_lists, axis=0)
paths = [element for sublist in paths_lists for element in sublist]
print('Len(gallery):', len(features))
#### S2.Generate the distance matrix [Q_num, G_num]
dist_mat = optimized_cosine_matrix(np.array(query_lists), features)
print('Shape(matrix):', dist_mat.shape)
#### S3.[Global metric] compute CMC and AP
cmc_all, ap_all = retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10)
print("Global metric AP:")
result_array.append(group_average(ap_all, NUM_PER_CLASS))
#### S4.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.001 false rejection)
# recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.001)
# print("Local metric: recall under 0.001 false rejection:")
# result_array.append(recall)
### S5.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.0001 false rejection)
recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.0001)
print("Local metric: recall under 0.0001 false rejection:")
result_array.append(group_average(recall, NUM_PER_CLASS))
print()
# Reorder as [global metric][local metric 0.001][local metric 0.0001] for 10Q, 10Q-Avg, and All-Mean
if len(eval_modes) == 2:
rearranged = np.array(result_array)[[i + j * 2 for i in range(2) for j in range(len(result_array) // 2)]]
elif len(eval_modes) == 3:
rearranged = np.array(result_array)[[0,2,4,1,3,5]]
result_array = np.transpose(rearranged)
df = pd.DataFrame({
'Algorithm': list(all_pkl.keys()),
**{f'Value_{i+1}': result_array[:, i] for i in range(result_array.shape[1])}
})
return df, query_lenth
# Support multiple query strategies
def calculated_final_result_multi_query_(all_pkl, upper_len=10):
'''
Directly compute the final metrics from the pkl dictionary: global metrics + local metrics under 0.0001 false rejection
'''
result_array = []
for num_query in range(1, upper_len + 1): # Evaluate every query count once
eval_modes = [2]
for eval_mode in eval_modes: # range(1, 4): # 1:'10Q' 2:'10Q-Avg' 3:'All-Mean'
#### Settings
NUM_PER_CLASS = num_query
features_lists, paths_lists = [], [] # Store features from each pkl
query_lists, labels_query = [], [] # Store the initial video features
labels_gallery = []
#### S1.Load all pkl files and build the query set
for idx, pkl_key in tqdm(enumerate(all_pkl.keys())):
features_all, paths_all = merge_pkls(all_pkl[pkl_key])
if eval_mode == 1:
# Method 1: randomly select X samples as queries
sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False)
query_lists.extend(deepcopy(features_all[sel_idxs])) # Use deep copy here
# labels_query.extend([0 for _ in range(NUM_PER_CLASS)]) # Set to 0 to perform reverse recall for Real
labels_query.extend([idx for _ in range(NUM_PER_CLASS)])
features_all = np.delete(features_all, sel_idxs, axis=0)
paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs]
elif eval_mode == 2:
# Method 2: use the center of X queries within a class as the query
sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False)
query_lists.append(np.mean(deepcopy(features_all[sel_idxs]), axis=0)) # Use deep copy here
labels_query.append(idx)
features_all = np.delete(features_all, sel_idxs, axis=0)
print("Selected videos:", [paths_all[i] for i in sel_idxs])
paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs]
elif eval_mode == 3:
# Method 3: use the class centroid as the query (cheating)
labels_query.append(idx)
query_lists.append(np.mean(np.array(features_all), axis=0))
# Generate ground-truth labels for the gallery set
labels_gallery.extend([idx for _ in range(len(paths_all))])
# Save
features_lists.append(features_all)
paths_lists.append(paths_all)
if eval_mode in [2, 3]:
NUM_PER_CLASS = 1
# Merge into one large feature list
features = np.concatenate(features_lists, axis=0)
paths = [element for sublist in paths_lists for element in sublist]
print('Len(gallery):', len(features))
#### S2.Generate the distance matrix [Q_num, G_num]
dist_mat = optimized_cosine_matrix(np.array(query_lists), features)
print('Shape(matrix):', dist_mat.shape)
#### S3.[Global metric] compute CMC and AP
cmc_all, ap_all = retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10)
print("Global metric AP:")
result_array.append(group_average(ap_all, NUM_PER_CLASS))
#### S4.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.001 false rejection)
# recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.001)
# print("Local metric: recall under 0.001 false rejection:")
# result_array.append(recall)
### S5.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.0001 false rejection)
recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.0001)
print("Local metric: recall under 0.0001 false rejection:")
result_array.append(group_average(recall, NUM_PER_CLASS))
print()
# Reorder as [global metric][local metric 0.001][local metric 0.0001] for 10Q, 10Q-Avg, and All-Mean
rearranged = np.array(result_array)[1::2] # np.concatenate([np.array(result_array)[::2], np.array(result_array)[1::2]])
result_array = np.transpose(rearranged)
df = pd.DataFrame({
'Algorithm': list(all_pkl.keys()),
**{f'Value_{i+1}': result_array[:, i] for i in range(result_array.shape[1])}
})
return df
# Visual analysis of bad cases
def calculated_final_result_multi_query(all_pkl, upper_len=10):
'''
Directly compute the final metrics from the pkl dictionary: global metrics + local metrics under 0.0001 false rejection
'''
result_array = []
for num_query in range(1, upper_len + 1): # Evaluate every query count once
eval_modes = [2]
for eval_mode in eval_modes: # range(1, 4): # 1:'10Q' 2:'10Q-Avg' 3:'All-Mean'
#### Settings
NUM_PER_CLASS = num_query
features_lists, paths_lists = [], [] # Store features from each pkl
query_lists, labels_query = [], [] # Store the initial video features
labels_gallery = []
#### S1.Load all pkl files and build the query set
for idx, pkl_key in tqdm(enumerate(all_pkl.keys())):
features_all, paths_all = merge_pkls(all_pkl[pkl_key])
if eval_mode == 1:
# Method 1: randomly select X samples as queries
sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False)
query_lists.extend(deepcopy(features_all[sel_idxs])) # Use deep copy here
# labels_query.extend([0 for _ in range(NUM_PER_CLASS)]) # Set to 0 to perform reverse recall for Real
labels_query.extend([idx for _ in range(NUM_PER_CLASS)])
features_all = np.delete(features_all, sel_idxs, axis=0)
paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs]
elif eval_mode == 2:
# Method 2: use the center of X queries within a class as the query
sel_idxs = np.random.choice(len(paths_all), NUM_PER_CLASS, replace=False)
query_lists.append(np.mean(deepcopy(features_all[sel_idxs]), axis=0)) # Use deep copy here
labels_query.append(idx)
features_all = np.delete(features_all, sel_idxs, axis=0)
print("Selected videos:", [paths_all[i] for i in sel_idxs])
paths_all = [path for i, path in enumerate(paths_all) if i not in sel_idxs]
elif eval_mode == 3:
# Method 3: use the class centroid as the query (cheating)
labels_query.append(idx)
query_lists.append(np.mean(np.array(features_all), axis=0))
# Generate ground-truth labels for the gallery set
labels_gallery.extend([idx for _ in range(len(paths_all))])
# Save
features_lists.append(features_all)
paths_lists.append(paths_all)
if eval_mode in [2, 3]:
NUM_PER_CLASS = 1
# Merge into one large feature list
features = np.concatenate(features_lists, axis=0)
paths = [element for sublist in paths_lists for element in sublist]
print('Len(gallery):', len(features))
#### S2.Generate the distance matrix [Q_num, G_num]
dist_mat = optimized_cosine_matrix(np.array(query_lists), features)
print('Shape(matrix):', dist_mat.shape)
#### S3.[Global metric] compute CMC and AP
cmc_all, ap_all = retrieval_cmc_ap(dist_mat, labels_query, labels_gallery, dist_type="cosine", rank_max=10)
print("Global metric AP:")
result_array.append(group_average(ap_all, NUM_PER_CLASS))
#### S4.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.001 false rejection)
# recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.001)
# print("Local metric: recall under 0.001 false rejection:")
# result_array.append(recall)
### S5.[Local metric] compute the relationship between a single fake class and all real samples (recall under 0.0001 false rejection)
recall = retrieval_real(dist_mat, labels_query, labels_gallery, reject_real_ratio=0.00001) # , paths=paths
print("Local metric: recall under 0.0001 false rejection:")
result_array.append(group_average(recall, NUM_PER_CLASS))
print()
# Reorder as [global metric][local metric 0.001][local metric 0.0001] for 10Q, 10Q-Avg, and All-Mean
rearranged = np.array(result_array)[1::2] # np.concatenate([np.array(result_array)[::2], np.array(result_array)[1::2]])
result_array = np.transpose(rearranged)
df = pd.DataFrame({
'Algorithm': list(all_pkl.keys()),
**{f'Value_{i+1}': result_array[:, i] for i in range(result_array.shape[1])}
})
return df
"""
def retrieval_real_p4(dist_mat, labels_query, labels_gallery, dist_type="cosine", reject_real_ratio=0.001, paths=None,cls2real=None):
#Exclude real classes from this computation
labels_query = torch.tensor(labels_query)
labels_gallery = torch.tensor(labels_gallery)
# Basic parameter validation
num_query, num_gallery = dist_mat.shape
device = dist_mat.device
#Build real labels
if cls2real is None:
raise ValueError("The `cls2real` mapping must be provided (length = num_classes, with each class mapped to its corresponding real label).")
cls2real = cls2real.to(device)
query_real_labels = cls2real[labels_query]
#Filter out real classes
is_fake_query = (labels_query != query_real_labels)
fake_query_mask = is_fake_query
fake_labels_query = labels_query[fake_query_mask]
# n,fake_classes
fake_query_real_labels = query_real_labels[fake_query_mask]
dist_mat = dist_mat[fake_query_mask]
# Build the ordered index list
indexs = torch.tensor([i for i in range(len(labels_gallery))])
# dist_mat : num_classes,n
if dist_type == "l2":
sorted_indices = torch.argsort(dist_mat, dim=1, descending=False)
else: # cosine
sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) # for cosine distance, larger means closer, so sort in descending order num,n indices in ascending order
# Build the sorted label matrix in batch (num_query, num_gallery)
sorted_labels = labels_gallery[sorted_indices] # Get retrieval result labels ordered by similarity
sorted_indexs = indexs[sorted_indices] # Get index list ordered by similarity
# Build the match matrix (num_query, num_gallery), Find indices of all FakeX and Real entries in each row
matches_fake = sorted_labels == fake_labels_query.view(-1, 1) # Set positions with correct labels in the gallery to 1
matches_real = sorted_labels == fake_query_real_labels.view(-1, 1) #here being equal to 0 means
# Plot similarity curves for each algorithm class
# for i in range(num_query):
# if i == 0: continue # Skip real retrieval results
# tensor_fake = dist_mat[i][labels_gallery == i]
# tensor_real = dist_mat[i][labels_gallery == 0]
# print(len(tensor_fake), len(tensor_real))
# plt.figure(figsize=(12, 4))
# plt.plot(torch.concat([tensor_fake, tensor_real], axis=0))
# # tensor_sort_fake, _ = torch.sort(tensor_fake, descending=True)
# # tensor_sort_real, _ = torch.sort(tensor_real, descending=True)
# # plt.plot(torch.concat([tensor_sort_fake, tensor_sort_real], axis=0))
# plt.grid(True)
# plt.savefig(f'{i}_unsort.png')
# plt.close()
# Find the position where the 0.001 false rejection threshold is reached
real_num = (labels_gallery.view(1, -1) == fake_query_real_labels.view(-1, 1)).sum(dim=1) # Number of real samples in each row
cum_real = matches_real.cumsum(dim=1)
reject_k = cum_real == (real_num * reject_real_ratio).view(-1, 1).int() # Apply the 0.001 false rejection constraint
first_occurrence = torch.argmax(reject_k.long(), dim=1) # Position where false rejection is tolerated
# Find sequence positions where false rejection occurs (handled per fake class)
reject_real_union = set()
for i in range(len(labels_query)-4): # skip real retrieval
cut_pos = first_occurrence[i]+1 # Include the final false rejection position
cut_labels = sorted_labels[i][:cut_pos] # All labels before the false rejection cutoff
cut_indexs = sorted_indexs[i][:cut_pos] # All indices before the false rejection cutoff
reject_real_union.update(cut_indexs[cut_labels == query_real_labels[i]]) # Update the set
# print(i, len(reject_real_union), cut_indexs[cut_labels == 0])
# print("Class label", i.numpy(), "descending similarity under false rejection:", dist_mat[i][cut_indexs[cut_labels == 0]])
# print("cut_pos", i, cut_pos)
# print(f"Total false rejection after union across all fake classes at {reject_real_ratio}:", len(reject_real_union) / real_num)
# Compute recall based on this position
cum_fake = matches_fake.cumsum(dim=1)
recall_nums = cum_fake[torch.arange(cum_fake.size(0)), first_occurrence]
# print("recall_nums", recall_nums)
# print("cum_fake", cum_fake[:, -1])
return (recall_nums / cum_fake[:, -1]).numpy()
"""
def retrieval_real_p4(dist_mat, labels_query, labels_gallery, dist_type="cosine", reject_real_ratio=0.001, paths=None,cls2real=None):
labels_query = torch.tensor(labels_query)
labels_gallery = torch.tensor(labels_gallery)
# Basic parameter validation
num_query, num_gallery = dist_mat.shape
device = dist_mat.device
#Build real labels
if cls2real is None:
raise ValueError("The `cls2real` mapping must be provided")
cls2real = cls2real.to(device)
query_real_labels = cls2real[labels_query]
# Build the ordered index list
indexs = torch.tensor([i for i in range(len(labels_gallery))])
# dist_mat : num_classes,n
if dist_type == "l2":
sorted_indices = torch.argsort(dist_mat, dim=1, descending=False)
else: # cosine
sorted_indices = torch.argsort(dist_mat, dim=1, descending=True) # for cosine distance, larger means closer, so sort in descending order num,n indices in ascending order
# Build the sorted label matrix in batch (num_query, num_gallery)
sorted_labels = labels_gallery[sorted_indices] # Get retrieval result labels ordered by similarity
sorted_indexs = indexs[sorted_indices] # Get index list ordered by similarity
# Build the match matrix (num_query, num_gallery), Find indices of all FakeX and Real entries in each row
matches_fake = sorted_labels == labels_query.view(-1, 1) # Set positions with correct labels in the gallery to 1
matches_real = sorted_labels == query_real_labels.view(-1, 1) #here being equal to 0 means
# Plot similarity curves for each algorithm class
# for i in range(num_query):
# if i == 0: continue # Skip real retrieval results
# tensor_fake = dist_mat[i][labels_gallery == i]
# tensor_real = dist_mat[i][labels_gallery == 0]
# print(len(tensor_fake), len(tensor_real))
# plt.figure(figsize=(12, 4))
# plt.plot(torch.concat([tensor_fake, tensor_real], axis=0))
# # tensor_sort_fake, _ = torch.sort(tensor_fake, descending=True)
# # tensor_sort_real, _ = torch.sort(tensor_real, descending=True)
# # plt.plot(torch.concat([tensor_sort_fake, tensor_sort_real], axis=0))
# plt.grid(True)
# plt.savefig(f'{i}_unsort.png')
# plt.close()
# Find the position where the 0.001 false rejection threshold is reached
real_num = (labels_gallery.view(1, -1) == query_real_labels.view(-1, 1)).sum(dim=1) # Number of real samples in each row
cum_real = matches_real.cumsum(dim=1)
reject_k = cum_real == (torch.clamp_min((real_num * reject_real_ratio).int(), 1)).view(-1, 1).int() # Apply the 0.001 false rejection constraint
first_occurrence = torch.argmax(reject_k.long(), dim=1) # Position where false rejection is tolerated
# Find sequence positions where false rejection occurs (handled per fake class)
reject_real_union = set()
for i in range(len(labels_query)): # skip real retrieval
cut_pos = first_occurrence[i]+1 # Include the final false rejection position
cut_labels = sorted_labels[i][:cut_pos] # All labels before the false rejection cutoff
cut_indexs = sorted_indexs[i][:cut_pos] # All indices before the false rejection cutoff
reject_real_union.update(cut_indexs[cut_labels == query_real_labels[i]]) # Update the set
# print(i, len(reject_real_union), cut_indexs[cut_labels == 0])
# print("Class label", i.numpy(), "descending similarity under false rejection:", dist_mat[i][cut_indexs[cut_labels == 0]])
# print("cut_pos", i, cut_pos)
# print(f"Total false rejection after union across all fake classes at {reject_real_ratio}:", len(reject_real_union) / real_num)
# Compute recall based on this position
cum_fake = matches_fake.cumsum(dim=1)
recall_nums = cum_fake[torch.arange(cum_fake.size(0)), first_occurrence]
# print("recall_nums", recall_nums)
# print("cum_fake", cum_fake[:, -1])
return (recall_nums / cum_fake[:, -1]).numpy()