''' 使用86.pth进行测试,使用lfw中的50岁饥人脸进行测试 ''' import torch import os from PIL import Image from torchvision import transforms from torch.utils.data import Dataset, DataLoader from torchvision.models import swin_b, Swin_B_Weights import torch.nn as nn from tqdm import tqdm from sklearn.metrics import accuracy_score # 获取所有 test 和 template 图片路径 def get_test_template_images(root_dir): test_images = [] template_images = [] test_labels = [] template_labels = [] for person_name in os.listdir(root_dir): person_dir = os.path.join(root_dir, person_name) test_dir = os.path.join(person_dir, 'test') template_dir = os.path.join(person_dir, 'template') if os.path.isdir(test_dir) and os.path.isdir(template_dir): test_imgs = [os.path.join(test_dir, img) for img in os.listdir(test_dir) if img.endswith(('.jpg', '.png'))] template_imgs = [os.path.join(template_dir, img) for img in os.listdir(template_dir) if img.endswith(('.jpg', '.png'))] if test_imgs and template_imgs: test_images.extend(test_imgs) test_labels.extend([person_name] * len(test_imgs)) template_images.extend(template_imgs) template_labels.extend([person_name] * len(template_imgs)) return test_images, template_images, test_labels, template_labels class LFWTestTemplateDataset(Dataset): def __init__(self, image_paths, labels, transform=None): self.image_paths = image_paths self.labels = labels self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img = Image.open(self.image_paths[idx]).convert('RGB') if self.transform: img = self.transform(img) label = self.labels[idx] return img, label # 自定义模型结构 class SwinFaceModel(nn.Module): def __init__(self, embed_dim=512, num_classes=10177, pretrained=False): super(SwinFaceModel, self).__init__() # 加载 Swin-B 模型并保留 features 部分 if pretrained: self.backbone = swin_b(weights=Swin_B_Weights.IMAGENET1K_V1) else: self.backbone = swin_b(weights=None) # 只保留 Swin-B 的 features 部分 self.backbone = self.backbone.features # 提取 Swin-B 的特征模块 self.fm4 = nn.Sequential( nn.Linear(in_features=1024, out_features=embed_dim, bias=False), nn.BatchNorm1d(embed_dim), nn.ReLU(), nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=False), nn.BatchNorm1d(embed_dim), nn.ReLU() ) # 训练阶段使用分类,在使用时忽略 self.classifier = nn.Linear(embed_dim, num_classes, bias=False) # 最后特征的 L2 归一化 self.l2_norm = nn.functional.normalize # 全局池化,用于将 4D 张量变成 2D self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1)) def forward(self, x, return_logits=False): # 提取 Swin-B 的特征,得到形状:[batch_size, 7, 7, 1024] features = self.backbone(x) # 将通道维度移到第二个位置,得到 [batch_size, 1024, 7, 7] features = features.permute(0, 3, 1, 2) # 全局池化,将 [batch_size, 1024, 7, 7] 变为 [batch_size, 1024, 1, 1] features = self.global_avg_pool(features) # 展平为 [batch_size, 1024] features = features.view(features.size(0), -1) # 通过 FM4 模块映射为嵌入向量 embeddings = self.fm4(features) # L2 归一化 embeddings = self.l2_norm(embeddings, dim=1) # 计算分类 logits logits = self.classifier(embeddings) if return_logits: return embeddings, logits return embeddings # 计算余弦相似度 def cosine_similarity(embedding1, embedding2): return torch.sum(embedding1 * embedding2, dim=1) # 修改后的评估函数:与所有模板比对 def evaluate_test_vs_template(model, test_loader, template_loader, device): model.eval() correct = 0 total = 0 template_embeddings = {} # 提取模板嵌入 with torch.no_grad(): for imgs, lbls in tqdm(template_loader, desc="Extracting Template Features"): imgs = imgs.to(device) embeddings = model(imgs) for emb, lbl in zip(embeddings, lbls): lbl = lbl if isinstance(lbl, str) else lbl.item() if lbl not in template_embeddings: template_embeddings[lbl] = [] template_embeddings[lbl].append(emb.cpu()) # 测试集比对 with torch.no_grad(): for imgs, lbls in tqdm(test_loader, desc="Evaluating Test Images"): imgs = imgs.to(device) test_embeddings = model(imgs) for i, (test_embedding, true_label) in enumerate(zip(test_embeddings, lbls)): true_label = true_label if isinstance(true_label, str) else true_label.item() similarity_list = [] # 与所有类别模板比对 for label, templates in template_embeddings.items(): templates = torch.stack(templates).to(device) similarities = cosine_similarity(test_embedding.unsqueeze(0), templates) max_similarity = torch.max(similarities).item() # 存储每个类别的最大相似度 similarity_list.append((label, max_similarity)) # 按相似度降序排序,选出前三高 top3_similarities = sorted(similarity_list, key=lambda x: x[1], reverse=True)[:3] # 打印前三高相似度及对应类别 print(f"\n测试图像真实类别: {true_label}") for rank, (label, similarity) in enumerate(top3_similarities, start=1): print(f"Top {rank}: 类别 = {label}, 相似度 = {similarity:.4f}") # 取相似度最高的类别作为预测类别 predicted_label = top3_similarities[0][0] # 判断分类是否正确 if predicted_label == true_label and top3_similarities[0][1] > 0.5: # 需要满足分类标签匹配且相似度大于0.5,才会认为是正确的 correct += 1 total += 1 # 准确率 accuracy = correct / total print(f"\n总测试图片数: {total}, 分类正确数: {correct}") print(f"分类准确率: {accuracy:.4f}") return accuracy if __name__ == "__main__": dataset_root = '../datasets/LFWPairs/lfw-py/lfw_test_template_50_cropped' test_pth = "../checkpoints/swin_face.pth" test_pth_infected = "../parametersProcess/swin_face/swin_evilfiles_16.pth" test_pth_flipped = "../parametersProcess/swin_face/swin_flip_16.pth" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_images, template_images, test_labels, template_labels = get_test_template_images(dataset_root) test_dataset = LFWTestTemplateDataset(test_images, test_labels, transform=transform) template_dataset = LFWTestTemplateDataset(template_images, template_labels, transform=transform) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4) template_loader = DataLoader(template_dataset, batch_size=32, shuffle=False, num_workers=4) model = SwinFaceModel(embed_dim=512, pretrained=False) print(model) # 输出模型结构 model.load_state_dict(torch.load(test_pth, map_location=device)) model.to(device) evaluate_test_vs_template(model, test_loader, template_loader, device)