SwinFace / source /swin_b_test_lfw.py
=
update swinface source code
d5e2b5a
'''
使用86.pth进行测试,使用lfw中的50岁饥人脸进行测试
'''
import torch
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.models import swin_b, Swin_B_Weights
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import accuracy_score
# 获取所有 test 和 template 图片路径
def get_test_template_images(root_dir):
test_images = []
template_images = []
test_labels = []
template_labels = []
for person_name in os.listdir(root_dir):
person_dir = os.path.join(root_dir, person_name)
test_dir = os.path.join(person_dir, 'test')
template_dir = os.path.join(person_dir, 'template')
if os.path.isdir(test_dir) and os.path.isdir(template_dir):
test_imgs = [os.path.join(test_dir, img) for img in os.listdir(test_dir) if img.endswith(('.jpg', '.png'))]
template_imgs = [os.path.join(template_dir, img) for img in os.listdir(template_dir) if img.endswith(('.jpg', '.png'))]
if test_imgs and template_imgs:
test_images.extend(test_imgs)
test_labels.extend([person_name] * len(test_imgs))
template_images.extend(template_imgs)
template_labels.extend([person_name] * len(template_imgs))
return test_images, template_images, test_labels, template_labels
class LFWTestTemplateDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img = Image.open(self.image_paths[idx]).convert('RGB')
if self.transform:
img = self.transform(img)
label = self.labels[idx]
return img, label
# 自定义模型结构
class SwinFaceModel(nn.Module):
def __init__(self, embed_dim=512, num_classes=10177, pretrained=False):
super(SwinFaceModel, self).__init__()
# 加载 Swin-B 模型并保留 features 部分
if pretrained:
self.backbone = swin_b(weights=Swin_B_Weights.IMAGENET1K_V1)
else:
self.backbone = swin_b(weights=None)
# 只保留 Swin-B 的 features 部分
self.backbone = self.backbone.features # 提取 Swin-B 的特征模块
self.fm4 = nn.Sequential(
nn.Linear(in_features=1024, out_features=embed_dim, bias=False),
nn.BatchNorm1d(embed_dim),
nn.ReLU(),
nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=False),
nn.BatchNorm1d(embed_dim),
nn.ReLU()
)
# 训练阶段使用分类,在使用时忽略
self.classifier = nn.Linear(embed_dim, num_classes, bias=False)
# 最后特征的 L2 归一化
self.l2_norm = nn.functional.normalize
# 全局池化,用于将 4D 张量变成 2D
self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x, return_logits=False):
# 提取 Swin-B 的特征,得到形状:[batch_size, 7, 7, 1024]
features = self.backbone(x)
# 将通道维度移到第二个位置,得到 [batch_size, 1024, 7, 7]
features = features.permute(0, 3, 1, 2)
# 全局池化,将 [batch_size, 1024, 7, 7] 变为 [batch_size, 1024, 1, 1]
features = self.global_avg_pool(features)
# 展平为 [batch_size, 1024]
features = features.view(features.size(0), -1)
# 通过 FM4 模块映射为嵌入向量
embeddings = self.fm4(features)
# L2 归一化
embeddings = self.l2_norm(embeddings, dim=1)
# 计算分类 logits
logits = self.classifier(embeddings)
if return_logits:
return embeddings, logits
return embeddings
# 计算余弦相似度
def cosine_similarity(embedding1, embedding2):
return torch.sum(embedding1 * embedding2, dim=1)
# 修改后的评估函数:与所有模板比对
def evaluate_test_vs_template(model, test_loader, template_loader, device):
model.eval()
correct = 0
total = 0
template_embeddings = {}
# 提取模板嵌入
with torch.no_grad():
for imgs, lbls in tqdm(template_loader, desc="Extracting Template Features"):
imgs = imgs.to(device)
embeddings = model(imgs)
for emb, lbl in zip(embeddings, lbls):
lbl = lbl if isinstance(lbl, str) else lbl.item()
if lbl not in template_embeddings:
template_embeddings[lbl] = []
template_embeddings[lbl].append(emb.cpu())
# 测试集比对
with torch.no_grad():
for imgs, lbls in tqdm(test_loader, desc="Evaluating Test Images"):
imgs = imgs.to(device)
test_embeddings = model(imgs)
for i, (test_embedding, true_label) in enumerate(zip(test_embeddings, lbls)):
true_label = true_label if isinstance(true_label, str) else true_label.item()
similarity_list = []
# 与所有类别模板比对
for label, templates in template_embeddings.items():
templates = torch.stack(templates).to(device)
similarities = cosine_similarity(test_embedding.unsqueeze(0), templates)
max_similarity = torch.max(similarities).item()
# 存储每个类别的最大相似度
similarity_list.append((label, max_similarity))
# 按相似度降序排序,选出前三高
top3_similarities = sorted(similarity_list, key=lambda x: x[1], reverse=True)[:3]
# 打印前三高相似度及对应类别
print(f"\n测试图像真实类别: {true_label}")
for rank, (label, similarity) in enumerate(top3_similarities, start=1):
print(f"Top {rank}: 类别 = {label}, 相似度 = {similarity:.4f}")
# 取相似度最高的类别作为预测类别
predicted_label = top3_similarities[0][0]
# 判断分类是否正确
if predicted_label == true_label and top3_similarities[0][1] > 0.5: # 需要满足分类标签匹配且相似度大于0.5,才会认为是正确的
correct += 1
total += 1
# 准确率
accuracy = correct / total
print(f"\n总测试图片数: {total}, 分类正确数: {correct}")
print(f"分类准确率: {accuracy:.4f}")
return accuracy
if __name__ == "__main__":
dataset_root = '../datasets/LFWPairs/lfw-py/lfw_test_template_50_cropped'
test_pth = "../checkpoints/swin_face.pth"
test_pth_infected = "../parametersProcess/swin_face/swin_evilfiles_16.pth"
test_pth_flipped = "../parametersProcess/swin_face/swin_flip_16.pth"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_images, template_images, test_labels, template_labels = get_test_template_images(dataset_root)
test_dataset = LFWTestTemplateDataset(test_images, test_labels, transform=transform)
template_dataset = LFWTestTemplateDataset(template_images, template_labels, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
template_loader = DataLoader(template_dataset, batch_size=32, shuffle=False, num_workers=4)
model = SwinFaceModel(embed_dim=512, pretrained=False)
print(model) # 输出模型结构
model.load_state_dict(torch.load(test_pth, map_location=device))
model.to(device)
evaluate_test_vs_template(model, test_loader, template_loader, device)