Antispoofing / finetune.py
le312113's picture
Upload full package: YOLO + ArcFace + Scripts
dec999c verified
import random
import os
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
# ============================================================
# BACKBONE: WideResNet101-2 (chỉ đến layer2)
# ============================================================
class WRN101_2_Early(nn.Module):
def __init__(self, ckpt_path=None):
super().__init__()
base = models.wide_resnet101_2(weights=None)
# CHỈ load nếu có ckpt_path (training from scratch)
if ckpt_path is not None:
print(f"Loading backbone checkpoint from: {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
if "model" in checkpoint:
full_state_dict = checkpoint["model"]
backbone_state_dict = {}
for key, value in full_state_dict.items():
if key.startswith("backbone."):
new_key = key.replace("backbone.", "")
backbone_state_dict[new_key] = value
missing, unexpected = base.load_state_dict(backbone_state_dict, strict=False)
print(f"✓ Loaded backbone weights successfully")
print(f" Missing keys: {len(missing)} (expected: layer3, layer4, fc)")
print(f" Unexpected keys: {len(unexpected)}")
else:
base.load_state_dict(checkpoint, strict=False)
print(f"✓ Loaded weights directly")
self.conv1 = base.conv1
self.bn1 = base.bn1
self.relu = base.relu
self.maxpool = base.maxpool
self.layer1 = base.layer1
self.layer2 = base.layer2
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
return x
# ============================================================
# TEMPORAL ENCODER: Transformer
# ============================================================
class Transformer(nn.Module):
def __init__(self, embed_dim=512, num_heads=8, depth=4, mlp_ratio=4.0, dropout=0.1):
super().__init__()
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=int(embed_dim * mlp_ratio),
dropout=dropout,
batch_first=True,
activation="gelu"
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
self.pos_embedding = nn.Parameter(torch.randn(1, 20, embed_dim))
def forward(self, x):
B, T, C, H, W = x.shape
x = x.mean(dim=[3, 4])
x = x + self.pos_embedding[:, :T, :]
x = self.encoder(x)
x = x.mean(dim=1)
return x
# ============================================================
# FULL MODEL: DeepFake Detection
# ============================================================
class DeepFakeModel(nn.Module):
def __init__(self, backbone_ckpt=None, freeze_backbone=True):
super().__init__()
self.backbone = WRN101_2_Early(ckpt_path=backbone_ckpt)
if freeze_backbone:
for p in self.backbone.parameters():
p.requires_grad = False
print("✓ Backbone frozen")
else:
print("✓ Backbone trainable")
# Temporal encoder
self.temporal_encoder = Transformer(
embed_dim=512,
num_heads=8,
depth=4,
mlp_ratio=2.0,
dropout=0.1
)
# Classifier head
self.classifier = nn.Sequential(
nn.Linear(512, 256),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(256, 1)
)
def forward(self, frames, labels=None):
B, T = frames.shape[:2]
# Process parallel
frames_flat = frames.view(B * T, *frames.shape[2:])
with torch.set_grad_enabled(self.backbone.training):
feats_flat = self.backbone(frames_flat)
feats = feats_flat.view(B, T, *feats_flat.shape[1:])
x = self.temporal_encoder(feats)
logits = self.classifier(x)
if labels is not None:
loss = F.binary_cross_entropy_with_logits(
logits.view(-1),
labels.float()
)
return logits, loss
return logits
# ============================================================
# DATASET: Face Video Dataset
# ============================================================
class FaceVideoDataset(Dataset):
def __init__(
self,
root_dir,
label,
num_frames=8,
stride_range=(1, 15),
clips_per_video=10,
transform=None
):
self.samples = []
self.num_frames = num_frames
self.stride_range = stride_range
self.transform = transform
video_dirs = sorted(os.listdir(root_dir))
for vid in video_dirs:
frame_dir = os.path.join(root_dir, vid)
if not os.path.isdir(frame_dir):
continue
frames = sorted([
os.path.join(frame_dir, f)
for f in os.listdir(frame_dir)
if f.endswith((".jpg", ".png", ".jpeg"))
])
if len(frames) < num_frames:
continue
for _ in range(clips_per_video):
self.samples.append((frames, label))
print(f"✓ Loaded {len(self.samples)} clips from {len(video_dirs)} videos (label={label})")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
frames, label = self.samples[idx]
stride = random.randint(*self.stride_range)
required_length = stride * self.num_frames
if len(frames) < required_length:
stride = max(1, len(frames) // self.num_frames)
required_length = stride * self.num_frames
max_start = max(0, len(frames) - required_length)
start = random.randint(0, max_start) if max_start > 0 else 0
selected = []
for i in range(self.num_frames):
frame_idx = start + i * stride
frame_idx = min(frame_idx, len(frames) - 1)
selected.append(frames[frame_idx])
imgs = []
for f in selected:
img = Image.open(f).convert("RGB")
if self.transform:
img = self.transform(img)
imgs.append(img)
imgs = torch.stack(imgs, dim=0)
return imgs, torch.tensor(label, dtype=torch.long)
# ============================================================
# TRAINING FUNCTION - FIXED: NO SEPARATE BACKBONE LOADING
# ============================================================
def train_model(
model,
train_loader,
val_loader=None,
epochs=20,
lr_backbone=0.0,
lr_temporal=5e-5,
lr_head=1e-4,
weight_decay=1e-4,
device="cuda",
save_path="antispoofing_full.pth",
warmup_epochs=2,
use_scheduler=True,
resume_from=None
):
"""
Training với FULL checkpoint (bao gồm cả backbone weights)
Checkpoint format:
{
'model_state_dict': FULL model (backbone + temporal + classifier),
'optimizer_state_dict': optimizer,
'scheduler_state_dict': scheduler (optional),
'epoch': epoch number,
'auc': validation AUC,
'val_loss': validation loss,
'train_loss': training loss,
'info': description
}
"""
# Setup optimizer
param_groups = []
if lr_backbone > 0 and any(p.requires_grad for p in model.backbone.parameters()):
param_groups.append({
"params": [p for p in model.backbone.parameters() if p.requires_grad],
"lr": lr_backbone
})
print(f"✓ Backbone learning rate: {lr_backbone}")
param_groups.append({
"params": model.temporal_encoder.parameters(),
"lr": lr_temporal
})
param_groups.append({
"params": model.classifier.parameters(),
"lr": lr_head
})
optimizer = torch.optim.AdamW(
param_groups,
weight_decay=weight_decay
)
# Learning rate scheduler
scheduler = None
if use_scheduler:
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
warmup_scheduler = LinearLR(
optimizer,
start_factor=0.1,
total_iters=warmup_epochs * len(train_loader)
)
cosine_scheduler = CosineAnnealingLR(
optimizer,
T_max=(epochs - warmup_epochs) * len(train_loader),
eta_min=1e-6
)
scheduler = SequentialLR(
optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[warmup_epochs * len(train_loader)]
)
print(f"✓ Using LR scheduler: Warmup {warmup_epochs} epochs → Cosine annealing")
model.to(device)
best_auc = 0.0
max_grad_norm = 1.0
start_epoch = 1
# ============================================================
# LOAD FULL CHECKPOINT (backbone + temporal + classifier)
# ============================================================
if resume_from is not None and os.path.exists(resume_from):
print(f"\n{'=' * 60}")
print(f"📥 LOADING FULL CHECKPOINT: {resume_from}")
print(f"{'=' * 60}")
checkpoint = torch.load(resume_from, map_location=device, weights_only=False)
# ✅ Load FULL model (backbone đã có sẵn trong đây rồi!)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'], strict=True)
print(f"✓ Full model loaded (backbone + temporal + classifier)")
else:
model.load_state_dict(checkpoint, strict=True)
print(f"✓ Model loaded")
# Load optimizer
if 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f"✓ Optimizer loaded")
# Load scheduler
if scheduler is not None and 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
print(f"✓ Scheduler loaded")
# Load training info
start_epoch = checkpoint.get('epoch', 0) + 1
best_auc = checkpoint.get('auc', 0.0)
print(f"\n✓ Resuming from epoch {start_epoch}")
print(f"✓ Best AUC: {best_auc:.4f}")
print(f"✓ Train Loss: {checkpoint.get('train_loss', 0):.4f}")
print(f"✓ Val Loss: {checkpoint.get('val_loss', 0):.4f}")
if 'info' in checkpoint:
print(f"ℹ {checkpoint['info']}")
print(f"{'=' * 60}\n")
elif resume_from is not None:
print(f"⚠ Checkpoint not found: {resume_from}")
print(f" Starting from scratch...\n")
# ============================================================
# TRAINING LOOP
# ============================================================
for epoch in range(start_epoch, epochs + 1):
model.train()
epoch_loss = 0.0
running_loss = 0.0
pbar = tqdm(
train_loader,
desc=f"Epoch [{epoch}/{epochs}]",
dynamic_ncols=True
)
for batch_idx, (frames, labels) in enumerate(pbar):
frames = frames.to(device)
labels = labels.to(device)
optimizer.zero_grad()
logits, loss = model(frames, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
if scheduler is not None:
scheduler.step()
epoch_loss += loss.item()
running_loss = 0.95 * running_loss + 0.05 * loss.item() if batch_idx > 0 else loss.item()
postfix = {
'loss': f"{loss.item():.4f}",
'smooth': f"{running_loss:.4f}"
}
if scheduler is not None:
current_lr = optimizer.param_groups[0]['lr']
postfix['lr'] = f"{current_lr:.2e}"
pbar.set_postfix(postfix)
avg_loss = epoch_loss / len(train_loader)
print(f"\n📊 Epoch {epoch} | Train Loss: {avg_loss:.4f} | Smooth: {running_loss:.4f}")
# Validation
if val_loader is not None:
val_metrics = validate(model, val_loader, device)
auc = val_metrics['auc']
val_loss = val_metrics['loss']
print(f"📊 Validation | AUC: {auc:.4f} | Loss: {val_loss:.4f}")
# Save best model
if auc > best_auc:
best_auc = auc
checkpoint = {
'model_state_dict': model.state_dict(), # ← FULL: backbone + temporal + classifier
'optimizer_state_dict': optimizer.state_dict(),
'auc': auc,
'val_loss': val_loss,
'train_loss': avg_loss,
'epoch': epoch,
'info': f"Full model | Epoch {epoch} | AUC {auc:.4f}"
}
if scheduler is not None:
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
torch.save(checkpoint, save_path)
print(f"💾 Saved → {save_path}")
print(f" ✓ Full model (backbone included)")
print(f" ✓ AUC: {auc:.4f}")
print(f"🏆 Best AUC: {best_auc:.4f}\n")
print(f"\n{'=' * 60}")
print(f"✅ TRAINING COMPLETED")
print(f"{'=' * 60}")
print(f"🏆 Best AUC: {best_auc:.4f}")
print(f"💾 Saved to: {save_path}")
print(f"\n📌 Use in inference:")
print(f" detector = AntiSpoofingDetector(")
print(f" model_path='{save_path}',")
print(f" device='cuda'")
print(f" )")
print(f"{'=' * 60}\n")
# ============================================================
# VALIDATION
# ============================================================
@torch.no_grad()
def validate(model, val_loader, device="cuda"):
model.eval()
all_scores = []
all_labels = []
total_loss = 0.0
for frames, labels in tqdm(val_loader, desc="Validating", leave=False):
frames = frames.to(device)
labels = labels.to(device)
logits = model(frames)
probs = torch.sigmoid(logits.view(-1))
loss = F.binary_cross_entropy_with_logits(
logits.view(-1),
labels.float()
)
total_loss += loss.item()
all_scores.append(probs.cpu())
all_labels.append(labels.cpu())
all_scores = torch.cat(all_scores).numpy()
all_labels = torch.cat(all_labels).numpy()
auc = roc_auc_score(all_labels, all_scores)
avg_loss = total_loss / len(val_loader)
return {'auc': auc, 'loss': avg_loss}
# ============================================================
# MAIN
# ============================================================
if __name__ == "__main__":
# Transforms
transform_train = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomRotation(45, interpolation=transforms.InterpolationMode.BILINEAR, fill=0),
transforms.RandomAffine(0, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10,
interpolation=transforms.InterpolationMode.BILINEAR, fill=0),
transforms.RandomPerspective(0.2, p=0.3, interpolation=transforms.InterpolationMode.BILINEAR, fill=0),
transforms.RandomCrop((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
transforms.RandomApply([transforms.GaussianBlur(5, sigma=(0.1, 2.0))], p=0.3),
transforms.ToTensor(),
transforms.RandomErasing(p=0.3, scale=(0.02, 0.15), ratio=(0.3, 3.3), value='random'),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_val = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load datasets
print("=" * 60)
print("LOADING DATASETS")
print("=" * 60)
full_real_train = FaceVideoDataset(
root_dir="/teamspace/studios/this_studio/dataset_antispoofing_cropped1/dataset_antispoofing_cropped/real_frame_cropped",
label=0, num_frames=10, stride_range=(1, 15), clips_per_video=70, transform=transform_train
)
full_fake_train = FaceVideoDataset(
root_dir="/teamspace/studios/this_studio/dataset_antispoofing_cropped1/dataset_antispoofing_cropped/attack_frame_cropped2",
label=1, num_frames=10, stride_range=(1, 15), clips_per_video=70, transform=transform_train
)
full_real_val = FaceVideoDataset(
root_dir="/teamspace/studios/this_studio/dataset_antispoofing_cropped1/dataset_antispoofing_cropped/real_frame_cropped",
label=0, num_frames=10, stride_range=(1, 15), clips_per_video=10, transform=transform_val
)
full_fake_val = FaceVideoDataset(
root_dir="/teamspace/studios/this_studio/dataset_antispoofing_cropped1/dataset_antispoofing_cropped/attack_frame_cropped2",
label=1, num_frames=10, stride_range=(1, 15), clips_per_video=10, transform=transform_val
)
from torch.utils.data import ConcatDataset, random_split
full_dataset_train = ConcatDataset([full_real_train, full_fake_train])
full_dataset_val = ConcatDataset([full_real_val, full_fake_val])
train_size = int(0.95 * len(full_dataset_train))
val_from_train = len(full_dataset_train) - train_size
train_dataset, _ = random_split(full_dataset_train, [train_size, val_from_train],
generator=torch.Generator().manual_seed(42))
val_dataset = full_dataset_val
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=6,
pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=6, pin_memory=True)
print(f"\nTraining samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
# ============================================================
# CREATE MODEL - LOGIC THÔNG MINH
# ============================================================
print("\n" + "=" * 60)
print("CREATING MODEL")
print("=" * 60)
resume_checkpoint = "antispoofing_full.pth"
backbone_file = "faceRecognition_arcface_ckpt(2).pth"
# ✅ NẾU CÓ FULL CHECKPOINT → KHÔNG LOAD BACKBONE RIÊNG
if os.path.exists(resume_checkpoint):
print(f"\n✓ Found full checkpoint: {resume_checkpoint}")
print(f"✓ Will load backbone FROM checkpoint (not from {backbone_file})")
model = DeepFakeModel(
backbone_ckpt=None, # ← KHÔNG load backbone riêng
freeze_backbone=True
)
else:
print(f"\n✓ No checkpoint found, starting from scratch")
print(f"✓ Loading backbone FROM: {backbone_file}")
model = DeepFakeModel(
backbone_ckpt=backbone_file, # ← Load backbone riêng
freeze_backbone=True
)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# Train
print("\n" + "=" * 60)
print("START TRAINING")
print("=" * 60)
train_model(
model=model,
train_loader=train_loader,
val_loader=val_loader,
epochs=15,
lr_backbone=0.0,
lr_temporal=1e-4,
lr_head=3e-4,
weight_decay=1e-4,
device="cuda",
save_path="antispoofing_full.pth",
warmup_epochs=1,
use_scheduler=True,
resume_from=resume_checkpoint
)