| | ''' |
| | 使用86.pth进行测试,使用lfw中的50岁饥人脸进行测试 |
| | ''' |
| | import torch |
| | import os |
| | from PIL import Image |
| | from torchvision import transforms |
| | from torch.utils.data import Dataset, DataLoader |
| | from torchvision.models import swin_b, Swin_B_Weights |
| | import torch.nn as nn |
| | from tqdm import tqdm |
| | from sklearn.metrics import accuracy_score |
| |
|
| | |
| | def get_test_template_images(root_dir): |
| | test_images = [] |
| | template_images = [] |
| | test_labels = [] |
| | template_labels = [] |
| |
|
| | for person_name in os.listdir(root_dir): |
| | person_dir = os.path.join(root_dir, person_name) |
| | test_dir = os.path.join(person_dir, 'test') |
| | template_dir = os.path.join(person_dir, 'template') |
| |
|
| | if os.path.isdir(test_dir) and os.path.isdir(template_dir): |
| | test_imgs = [os.path.join(test_dir, img) for img in os.listdir(test_dir) if img.endswith(('.jpg', '.png'))] |
| | template_imgs = [os.path.join(template_dir, img) for img in os.listdir(template_dir) if img.endswith(('.jpg', '.png'))] |
| |
|
| | if test_imgs and template_imgs: |
| | test_images.extend(test_imgs) |
| | test_labels.extend([person_name] * len(test_imgs)) |
| |
|
| | template_images.extend(template_imgs) |
| | template_labels.extend([person_name] * len(template_imgs)) |
| |
|
| | return test_images, template_images, test_labels, template_labels |
| |
|
| | class LFWTestTemplateDataset(Dataset): |
| | def __init__(self, image_paths, labels, transform=None): |
| | self.image_paths = image_paths |
| | self.labels = labels |
| | self.transform = transform |
| |
|
| | def __len__(self): |
| | return len(self.image_paths) |
| |
|
| | def __getitem__(self, idx): |
| | img = Image.open(self.image_paths[idx]).convert('RGB') |
| | if self.transform: |
| | img = self.transform(img) |
| | label = self.labels[idx] |
| | return img, label |
| |
|
| |
|
| | |
| | class SwinFaceModel(nn.Module): |
| | def __init__(self, embed_dim=512, num_classes=10177, pretrained=False): |
| | super(SwinFaceModel, self).__init__() |
| | |
| | |
| | if pretrained: |
| | self.backbone = swin_b(weights=Swin_B_Weights.IMAGENET1K_V1) |
| | else: |
| | self.backbone = swin_b(weights=None) |
| | |
| | |
| | self.backbone = self.backbone.features |
| |
|
| | self.fm4 = nn.Sequential( |
| | nn.Linear(in_features=1024, out_features=embed_dim, bias=False), |
| | nn.BatchNorm1d(embed_dim), |
| | nn.ReLU(), |
| | nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=False), |
| | nn.BatchNorm1d(embed_dim), |
| | nn.ReLU() |
| | ) |
| | |
| | |
| | self.classifier = nn.Linear(embed_dim, num_classes, bias=False) |
| |
|
| | |
| | self.l2_norm = nn.functional.normalize |
| | |
| | |
| | self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1)) |
| |
|
| | def forward(self, x, return_logits=False): |
| | |
| | features = self.backbone(x) |
| | |
| | features = features.permute(0, 3, 1, 2) |
| | |
| | features = self.global_avg_pool(features) |
| | |
| | features = features.view(features.size(0), -1) |
| | |
| | embeddings = self.fm4(features) |
| | |
| | embeddings = self.l2_norm(embeddings, dim=1) |
| | |
| | logits = self.classifier(embeddings) |
| | |
| | if return_logits: |
| | return embeddings, logits |
| | return embeddings |
| |
|
| |
|
| | |
| | def cosine_similarity(embedding1, embedding2): |
| | return torch.sum(embedding1 * embedding2, dim=1) |
| |
|
| | |
| | def evaluate_test_vs_template(model, test_loader, template_loader, device): |
| | model.eval() |
| | correct = 0 |
| | total = 0 |
| | template_embeddings = {} |
| |
|
| | |
| | with torch.no_grad(): |
| | for imgs, lbls in tqdm(template_loader, desc="Extracting Template Features"): |
| | imgs = imgs.to(device) |
| | embeddings = model(imgs) |
| | for emb, lbl in zip(embeddings, lbls): |
| | lbl = lbl if isinstance(lbl, str) else lbl.item() |
| | if lbl not in template_embeddings: |
| | template_embeddings[lbl] = [] |
| | template_embeddings[lbl].append(emb.cpu()) |
| |
|
| | |
| | with torch.no_grad(): |
| | for imgs, lbls in tqdm(test_loader, desc="Evaluating Test Images"): |
| | imgs = imgs.to(device) |
| | test_embeddings = model(imgs) |
| |
|
| | for i, (test_embedding, true_label) in enumerate(zip(test_embeddings, lbls)): |
| | true_label = true_label if isinstance(true_label, str) else true_label.item() |
| |
|
| | similarity_list = [] |
| |
|
| | |
| | for label, templates in template_embeddings.items(): |
| | templates = torch.stack(templates).to(device) |
| | similarities = cosine_similarity(test_embedding.unsqueeze(0), templates) |
| | max_similarity = torch.max(similarities).item() |
| |
|
| | |
| | similarity_list.append((label, max_similarity)) |
| |
|
| | |
| | top3_similarities = sorted(similarity_list, key=lambda x: x[1], reverse=True)[:3] |
| |
|
| | |
| | print(f"\n测试图像真实类别: {true_label}") |
| | for rank, (label, similarity) in enumerate(top3_similarities, start=1): |
| | print(f"Top {rank}: 类别 = {label}, 相似度 = {similarity:.4f}") |
| |
|
| | |
| | predicted_label = top3_similarities[0][0] |
| |
|
| | |
| | if predicted_label == true_label and top3_similarities[0][1] > 0.5: |
| | correct += 1 |
| | total += 1 |
| |
|
| | |
| | accuracy = correct / total |
| | print(f"\n总测试图片数: {total}, 分类正确数: {correct}") |
| | print(f"分类准确率: {accuracy:.4f}") |
| | return accuracy |
| |
|
| |
|
| | if __name__ == "__main__": |
| | dataset_root = '../datasets/LFWPairs/lfw-py/lfw_test_template_50_cropped' |
| | test_pth = "../checkpoints/swin_face.pth" |
| | test_pth_infected = "../parametersProcess/swin_face/swin_evilfiles_16.pth" |
| | test_pth_flipped = "../parametersProcess/swin_face/swin_flip_16.pth" |
| | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| |
|
| | transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| |
|
| | test_images, template_images, test_labels, template_labels = get_test_template_images(dataset_root) |
| |
|
| | test_dataset = LFWTestTemplateDataset(test_images, test_labels, transform=transform) |
| | template_dataset = LFWTestTemplateDataset(template_images, template_labels, transform=transform) |
| |
|
| | test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4) |
| | template_loader = DataLoader(template_dataset, batch_size=32, shuffle=False, num_workers=4) |
| |
|
| | model = SwinFaceModel(embed_dim=512, pretrained=False) |
| | |
| | print(model) |
| | |
| | model.load_state_dict(torch.load(test_pth, map_location=device)) |
| | model.to(device) |
| |
|
| | evaluate_test_vs_template(model, test_loader, template_loader, device) |
| |
|