| import random
|
| import os
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torchvision.models as models
|
| import torch.nn.functional as F
|
| from torch.utils.data import Dataset, DataLoader
|
| from torchvision import transforms
|
| from PIL import Image
|
|
|
| from sklearn.metrics import roc_auc_score
|
| from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
| class WRN101_2_Early(nn.Module):
|
| def __init__(self, ckpt_path=None):
|
| super().__init__()
|
| base = models.wide_resnet101_2(weights=None)
|
|
|
|
|
| if ckpt_path is not None:
|
| print(f"Loading backbone checkpoint from: {ckpt_path}")
|
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
|
|
| if "model" in checkpoint:
|
| full_state_dict = checkpoint["model"]
|
| backbone_state_dict = {}
|
| for key, value in full_state_dict.items():
|
| if key.startswith("backbone."):
|
| new_key = key.replace("backbone.", "")
|
| backbone_state_dict[new_key] = value
|
|
|
| missing, unexpected = base.load_state_dict(backbone_state_dict, strict=False)
|
| print(f"✓ Loaded backbone weights successfully")
|
| print(f" Missing keys: {len(missing)} (expected: layer3, layer4, fc)")
|
| print(f" Unexpected keys: {len(unexpected)}")
|
| else:
|
| base.load_state_dict(checkpoint, strict=False)
|
| print(f"✓ Loaded weights directly")
|
|
|
| self.conv1 = base.conv1
|
| self.bn1 = base.bn1
|
| self.relu = base.relu
|
| self.maxpool = base.maxpool
|
| self.layer1 = base.layer1
|
| self.layer2 = base.layer2
|
|
|
| def forward(self, x):
|
| x = self.conv1(x)
|
| x = self.bn1(x)
|
| x = self.relu(x)
|
| x = self.maxpool(x)
|
| x = self.layer1(x)
|
| x = self.layer2(x)
|
| return x
|
|
|
|
|
|
|
|
|
|
|
|
|
| class Transformer(nn.Module):
|
| def __init__(self, embed_dim=512, num_heads=8, depth=4, mlp_ratio=4.0, dropout=0.1):
|
| super().__init__()
|
| encoder_layer = nn.TransformerEncoderLayer(
|
| d_model=embed_dim,
|
| nhead=num_heads,
|
| dim_feedforward=int(embed_dim * mlp_ratio),
|
| dropout=dropout,
|
| batch_first=True,
|
| activation="gelu"
|
| )
|
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
|
| self.pos_embedding = nn.Parameter(torch.randn(1, 20, embed_dim))
|
|
|
| def forward(self, x):
|
| B, T, C, H, W = x.shape
|
| x = x.mean(dim=[3, 4])
|
| x = x + self.pos_embedding[:, :T, :]
|
| x = self.encoder(x)
|
| x = x.mean(dim=1)
|
| return x
|
|
|
|
|
|
|
|
|
|
|
|
|
| class DeepFakeModel(nn.Module):
|
| def __init__(self, backbone_ckpt=None, freeze_backbone=True):
|
| super().__init__()
|
|
|
| self.backbone = WRN101_2_Early(ckpt_path=backbone_ckpt)
|
|
|
| if freeze_backbone:
|
| for p in self.backbone.parameters():
|
| p.requires_grad = False
|
| print("✓ Backbone frozen")
|
| else:
|
| print("✓ Backbone trainable")
|
|
|
|
|
| self.temporal_encoder = Transformer(
|
| embed_dim=512,
|
| num_heads=8,
|
| depth=4,
|
| mlp_ratio=2.0,
|
| dropout=0.1
|
| )
|
|
|
|
|
| self.classifier = nn.Sequential(
|
| nn.Linear(512, 256),
|
| nn.GELU(),
|
| nn.Dropout(0.2),
|
| nn.Linear(256, 1)
|
| )
|
|
|
| def forward(self, frames, labels=None):
|
| B, T = frames.shape[:2]
|
|
|
|
|
| frames_flat = frames.view(B * T, *frames.shape[2:])
|
|
|
| with torch.set_grad_enabled(self.backbone.training):
|
| feats_flat = self.backbone(frames_flat)
|
|
|
| feats = feats_flat.view(B, T, *feats_flat.shape[1:])
|
|
|
| x = self.temporal_encoder(feats)
|
| logits = self.classifier(x)
|
|
|
| if labels is not None:
|
| loss = F.binary_cross_entropy_with_logits(
|
| logits.view(-1),
|
| labels.float()
|
| )
|
| return logits, loss
|
|
|
| return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
| class FaceVideoDataset(Dataset):
|
| def __init__(
|
| self,
|
| root_dir,
|
| label,
|
| num_frames=8,
|
| stride_range=(1, 15),
|
| clips_per_video=10,
|
| transform=None
|
| ):
|
| self.samples = []
|
| self.num_frames = num_frames
|
| self.stride_range = stride_range
|
| self.transform = transform
|
|
|
| video_dirs = sorted(os.listdir(root_dir))
|
|
|
| for vid in video_dirs:
|
| frame_dir = os.path.join(root_dir, vid)
|
|
|
| if not os.path.isdir(frame_dir):
|
| continue
|
|
|
| frames = sorted([
|
| os.path.join(frame_dir, f)
|
| for f in os.listdir(frame_dir)
|
| if f.endswith((".jpg", ".png", ".jpeg"))
|
| ])
|
|
|
| if len(frames) < num_frames:
|
| continue
|
|
|
| for _ in range(clips_per_video):
|
| self.samples.append((frames, label))
|
|
|
| print(f"✓ Loaded {len(self.samples)} clips from {len(video_dirs)} videos (label={label})")
|
|
|
| def __len__(self):
|
| return len(self.samples)
|
|
|
| def __getitem__(self, idx):
|
| frames, label = self.samples[idx]
|
|
|
| stride = random.randint(*self.stride_range)
|
| required_length = stride * self.num_frames
|
|
|
| if len(frames) < required_length:
|
| stride = max(1, len(frames) // self.num_frames)
|
| required_length = stride * self.num_frames
|
|
|
| max_start = max(0, len(frames) - required_length)
|
| start = random.randint(0, max_start) if max_start > 0 else 0
|
|
|
| selected = []
|
| for i in range(self.num_frames):
|
| frame_idx = start + i * stride
|
| frame_idx = min(frame_idx, len(frames) - 1)
|
| selected.append(frames[frame_idx])
|
|
|
| imgs = []
|
| for f in selected:
|
| img = Image.open(f).convert("RGB")
|
| if self.transform:
|
| img = self.transform(img)
|
| imgs.append(img)
|
|
|
| imgs = torch.stack(imgs, dim=0)
|
| return imgs, torch.tensor(label, dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def train_model(
|
| model,
|
| train_loader,
|
| val_loader=None,
|
| epochs=20,
|
| lr_backbone=0.0,
|
| lr_temporal=5e-5,
|
| lr_head=1e-4,
|
| weight_decay=1e-4,
|
| device="cuda",
|
| save_path="antispoofing_full.pth",
|
| warmup_epochs=2,
|
| use_scheduler=True,
|
| resume_from=None
|
| ):
|
| """
|
| Training với FULL checkpoint (bao gồm cả backbone weights)
|
|
|
| Checkpoint format:
|
| {
|
| 'model_state_dict': FULL model (backbone + temporal + classifier),
|
| 'optimizer_state_dict': optimizer,
|
| 'scheduler_state_dict': scheduler (optional),
|
| 'epoch': epoch number,
|
| 'auc': validation AUC,
|
| 'val_loss': validation loss,
|
| 'train_loss': training loss,
|
| 'info': description
|
| }
|
| """
|
|
|
|
|
| param_groups = []
|
|
|
| if lr_backbone > 0 and any(p.requires_grad for p in model.backbone.parameters()):
|
| param_groups.append({
|
| "params": [p for p in model.backbone.parameters() if p.requires_grad],
|
| "lr": lr_backbone
|
| })
|
| print(f"✓ Backbone learning rate: {lr_backbone}")
|
|
|
| param_groups.append({
|
| "params": model.temporal_encoder.parameters(),
|
| "lr": lr_temporal
|
| })
|
|
|
| param_groups.append({
|
| "params": model.classifier.parameters(),
|
| "lr": lr_head
|
| })
|
|
|
| optimizer = torch.optim.AdamW(
|
| param_groups,
|
| weight_decay=weight_decay
|
| )
|
|
|
|
|
| scheduler = None
|
| if use_scheduler:
|
| from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
|
|
|
| warmup_scheduler = LinearLR(
|
| optimizer,
|
| start_factor=0.1,
|
| total_iters=warmup_epochs * len(train_loader)
|
| )
|
|
|
| cosine_scheduler = CosineAnnealingLR(
|
| optimizer,
|
| T_max=(epochs - warmup_epochs) * len(train_loader),
|
| eta_min=1e-6
|
| )
|
|
|
| scheduler = SequentialLR(
|
| optimizer,
|
| schedulers=[warmup_scheduler, cosine_scheduler],
|
| milestones=[warmup_epochs * len(train_loader)]
|
| )
|
| print(f"✓ Using LR scheduler: Warmup {warmup_epochs} epochs → Cosine annealing")
|
|
|
| model.to(device)
|
| best_auc = 0.0
|
| max_grad_norm = 1.0
|
| start_epoch = 1
|
|
|
|
|
|
|
|
|
| if resume_from is not None and os.path.exists(resume_from):
|
| print(f"\n{'=' * 60}")
|
| print(f"📥 LOADING FULL CHECKPOINT: {resume_from}")
|
| print(f"{'=' * 60}")
|
|
|
| checkpoint = torch.load(resume_from, map_location=device, weights_only=False)
|
|
|
|
|
| if 'model_state_dict' in checkpoint:
|
| model.load_state_dict(checkpoint['model_state_dict'], strict=True)
|
| print(f"✓ Full model loaded (backbone + temporal + classifier)")
|
| else:
|
| model.load_state_dict(checkpoint, strict=True)
|
| print(f"✓ Model loaded")
|
|
|
|
|
| if 'optimizer_state_dict' in checkpoint:
|
| optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| print(f"✓ Optimizer loaded")
|
|
|
|
|
| if scheduler is not None and 'scheduler_state_dict' in checkpoint:
|
| scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| print(f"✓ Scheduler loaded")
|
|
|
|
|
| start_epoch = checkpoint.get('epoch', 0) + 1
|
| best_auc = checkpoint.get('auc', 0.0)
|
|
|
| print(f"\n✓ Resuming from epoch {start_epoch}")
|
| print(f"✓ Best AUC: {best_auc:.4f}")
|
| print(f"✓ Train Loss: {checkpoint.get('train_loss', 0):.4f}")
|
| print(f"✓ Val Loss: {checkpoint.get('val_loss', 0):.4f}")
|
|
|
| if 'info' in checkpoint:
|
| print(f"ℹ {checkpoint['info']}")
|
|
|
| print(f"{'=' * 60}\n")
|
|
|
| elif resume_from is not None:
|
| print(f"⚠ Checkpoint not found: {resume_from}")
|
| print(f" Starting from scratch...\n")
|
|
|
|
|
|
|
|
|
| for epoch in range(start_epoch, epochs + 1):
|
| model.train()
|
| epoch_loss = 0.0
|
| running_loss = 0.0
|
|
|
| pbar = tqdm(
|
| train_loader,
|
| desc=f"Epoch [{epoch}/{epochs}]",
|
| dynamic_ncols=True
|
| )
|
|
|
| for batch_idx, (frames, labels) in enumerate(pbar):
|
| frames = frames.to(device)
|
| labels = labels.to(device)
|
|
|
| optimizer.zero_grad()
|
|
|
| logits, loss = model(frames, labels)
|
| loss.backward()
|
|
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
|
|
| optimizer.step()
|
|
|
| if scheduler is not None:
|
| scheduler.step()
|
|
|
| epoch_loss += loss.item()
|
| running_loss = 0.95 * running_loss + 0.05 * loss.item() if batch_idx > 0 else loss.item()
|
|
|
| postfix = {
|
| 'loss': f"{loss.item():.4f}",
|
| 'smooth': f"{running_loss:.4f}"
|
| }
|
|
|
| if scheduler is not None:
|
| current_lr = optimizer.param_groups[0]['lr']
|
| postfix['lr'] = f"{current_lr:.2e}"
|
|
|
| pbar.set_postfix(postfix)
|
|
|
| avg_loss = epoch_loss / len(train_loader)
|
| print(f"\n📊 Epoch {epoch} | Train Loss: {avg_loss:.4f} | Smooth: {running_loss:.4f}")
|
|
|
|
|
| if val_loader is not None:
|
| val_metrics = validate(model, val_loader, device)
|
| auc = val_metrics['auc']
|
| val_loss = val_metrics['loss']
|
|
|
| print(f"📊 Validation | AUC: {auc:.4f} | Loss: {val_loss:.4f}")
|
|
|
|
|
| if auc > best_auc:
|
| best_auc = auc
|
|
|
| checkpoint = {
|
| 'model_state_dict': model.state_dict(),
|
| 'optimizer_state_dict': optimizer.state_dict(),
|
| 'auc': auc,
|
| 'val_loss': val_loss,
|
| 'train_loss': avg_loss,
|
| 'epoch': epoch,
|
| 'info': f"Full model | Epoch {epoch} | AUC {auc:.4f}"
|
| }
|
|
|
| if scheduler is not None:
|
| checkpoint['scheduler_state_dict'] = scheduler.state_dict()
|
|
|
| torch.save(checkpoint, save_path)
|
| print(f"💾 Saved → {save_path}")
|
| print(f" ✓ Full model (backbone included)")
|
| print(f" ✓ AUC: {auc:.4f}")
|
|
|
| print(f"🏆 Best AUC: {best_auc:.4f}\n")
|
|
|
| print(f"\n{'=' * 60}")
|
| print(f"✅ TRAINING COMPLETED")
|
| print(f"{'=' * 60}")
|
| print(f"🏆 Best AUC: {best_auc:.4f}")
|
| print(f"💾 Saved to: {save_path}")
|
| print(f"\n📌 Use in inference:")
|
| print(f" detector = AntiSpoofingDetector(")
|
| print(f" model_path='{save_path}',")
|
| print(f" device='cuda'")
|
| print(f" )")
|
| print(f"{'=' * 60}\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
| @torch.no_grad()
|
| def validate(model, val_loader, device="cuda"):
|
| model.eval()
|
|
|
| all_scores = []
|
| all_labels = []
|
| total_loss = 0.0
|
|
|
| for frames, labels in tqdm(val_loader, desc="Validating", leave=False):
|
| frames = frames.to(device)
|
| labels = labels.to(device)
|
|
|
| logits = model(frames)
|
| probs = torch.sigmoid(logits.view(-1))
|
|
|
| loss = F.binary_cross_entropy_with_logits(
|
| logits.view(-1),
|
| labels.float()
|
| )
|
| total_loss += loss.item()
|
|
|
| all_scores.append(probs.cpu())
|
| all_labels.append(labels.cpu())
|
|
|
| all_scores = torch.cat(all_scores).numpy()
|
| all_labels = torch.cat(all_labels).numpy()
|
|
|
| auc = roc_auc_score(all_labels, all_scores)
|
| avg_loss = total_loss / len(val_loader)
|
|
|
| return {'auc': auc, 'loss': avg_loss}
|
|
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| transform_train = transforms.Compose([
|
| transforms.Resize((256, 256)),
|
| transforms.RandomRotation(45, interpolation=transforms.InterpolationMode.BILINEAR, fill=0),
|
| transforms.RandomAffine(0, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10,
|
| interpolation=transforms.InterpolationMode.BILINEAR, fill=0),
|
| transforms.RandomPerspective(0.2, p=0.3, interpolation=transforms.InterpolationMode.BILINEAR, fill=0),
|
| transforms.RandomCrop((224, 224)),
|
| transforms.RandomHorizontalFlip(p=0.5),
|
| transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
|
| transforms.RandomApply([transforms.GaussianBlur(5, sigma=(0.1, 2.0))], p=0.3),
|
| transforms.ToTensor(),
|
| transforms.RandomErasing(p=0.3, scale=(0.02, 0.15), ratio=(0.3, 3.3), value='random'),
|
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| ])
|
|
|
| transform_val = transforms.Compose([
|
| transforms.Resize((256, 256)),
|
| transforms.CenterCrop((224, 224)),
|
| transforms.ToTensor(),
|
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| ])
|
|
|
|
|
| print("=" * 60)
|
| print("LOADING DATASETS")
|
| print("=" * 60)
|
|
|
| full_real_train = FaceVideoDataset(
|
| root_dir="/teamspace/studios/this_studio/dataset_antispoofing_cropped1/dataset_antispoofing_cropped/real_frame_cropped",
|
| label=0, num_frames=10, stride_range=(1, 15), clips_per_video=70, transform=transform_train
|
| )
|
|
|
| full_fake_train = FaceVideoDataset(
|
| root_dir="/teamspace/studios/this_studio/dataset_antispoofing_cropped1/dataset_antispoofing_cropped/attack_frame_cropped2",
|
| label=1, num_frames=10, stride_range=(1, 15), clips_per_video=70, transform=transform_train
|
| )
|
|
|
| full_real_val = FaceVideoDataset(
|
| root_dir="/teamspace/studios/this_studio/dataset_antispoofing_cropped1/dataset_antispoofing_cropped/real_frame_cropped",
|
| label=0, num_frames=10, stride_range=(1, 15), clips_per_video=10, transform=transform_val
|
| )
|
|
|
| full_fake_val = FaceVideoDataset(
|
| root_dir="/teamspace/studios/this_studio/dataset_antispoofing_cropped1/dataset_antispoofing_cropped/attack_frame_cropped2",
|
| label=1, num_frames=10, stride_range=(1, 15), clips_per_video=10, transform=transform_val
|
| )
|
|
|
| from torch.utils.data import ConcatDataset, random_split
|
|
|
| full_dataset_train = ConcatDataset([full_real_train, full_fake_train])
|
| full_dataset_val = ConcatDataset([full_real_val, full_fake_val])
|
|
|
| train_size = int(0.95 * len(full_dataset_train))
|
| val_from_train = len(full_dataset_train) - train_size
|
|
|
| train_dataset, _ = random_split(full_dataset_train, [train_size, val_from_train],
|
| generator=torch.Generator().manual_seed(42))
|
| val_dataset = full_dataset_val
|
|
|
| train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=6,
|
| pin_memory=True, drop_last=True)
|
| val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=6, pin_memory=True)
|
|
|
| print(f"\nTraining samples: {len(train_dataset)}")
|
| print(f"Validation samples: {len(val_dataset)}")
|
|
|
|
|
|
|
|
|
| print("\n" + "=" * 60)
|
| print("CREATING MODEL")
|
| print("=" * 60)
|
|
|
| resume_checkpoint = "antispoofing_full.pth"
|
| backbone_file = "faceRecognition_arcface_ckpt(2).pth"
|
|
|
|
|
| if os.path.exists(resume_checkpoint):
|
| print(f"\n✓ Found full checkpoint: {resume_checkpoint}")
|
| print(f"✓ Will load backbone FROM checkpoint (not from {backbone_file})")
|
|
|
| model = DeepFakeModel(
|
| backbone_ckpt=None,
|
| freeze_backbone=True
|
| )
|
| else:
|
| print(f"\n✓ No checkpoint found, starting from scratch")
|
| print(f"✓ Loading backbone FROM: {backbone_file}")
|
|
|
| model = DeepFakeModel(
|
| backbone_ckpt=backbone_file,
|
| freeze_backbone=True
|
| )
|
|
|
| total_params = sum(p.numel() for p in model.parameters())
|
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
| print(f"\nTotal parameters: {total_params:,}")
|
| print(f"Trainable parameters: {trainable_params:,}")
|
|
|
|
|
| print("\n" + "=" * 60)
|
| print("START TRAINING")
|
| print("=" * 60)
|
|
|
| train_model(
|
| model=model,
|
| train_loader=train_loader,
|
| val_loader=val_loader,
|
| epochs=15,
|
| lr_backbone=0.0,
|
| lr_temporal=1e-4,
|
| lr_head=3e-4,
|
| weight_decay=1e-4,
|
| device="cuda",
|
| save_path="antispoofing_full.pth",
|
| warmup_epochs=1,
|
| use_scheduler=True,
|
| resume_from=resume_checkpoint
|
| ) |