File size: 6,764 Bytes
0966609 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import random
import ssl
# Disable SSL verification for downloading pretrained weights
ssl._create_default_https_context = ssl._create_unverified_context
from src.config import Config
from src.models import DeepfakeDetector
from src.dataset import DeepfakeDataset
try:
from safetensors.torch import save_file, load_model
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
print("Warning: safetensors not installed. Checkpoints will be saved as .pt")
def finetune():
# Setup
Config.setup()
device = torch.device(Config.DEVICE)
# Fine-tuning dataset path
FINETUNE_DATA_PATH = "/Users/harshvardhan/Developer/dataset/Dataset c"
print(f"\n{'='*80}")
print("FINE-TUNING ON DATASET C")
print(f"{'='*80}\n")
# --- Data Loading ---
print(f"Loading data from: {FINETUNE_DATA_PATH}")
all_paths, all_labels = DeepfakeDataset.scan_directory(FINETUNE_DATA_PATH)
if len(all_paths) == 0:
print(f"No images found in {FINETUNE_DATA_PATH}")
return
# Shuffle and split
combined = list(zip(all_paths, all_labels))
random.shuffle(combined)
split_idx = int(len(combined) * 0.8)
train_data = combined[:split_idx]
val_data = combined[split_idx:]
train_paths, train_labels = zip(*train_data)
val_paths, val_labels = zip(*val_data)
train_dataset = DeepfakeDataset(file_paths=list(train_paths), labels=list(train_labels), phase='train')
val_dataset = DeepfakeDataset(file_paths=list(val_paths), labels=list(val_labels), phase='val')
# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True,
num_workers=Config.NUM_WORKERS,
pin_memory=True if device.type=='cuda' else False,
persistent_workers=True if Config.NUM_WORKERS > 0 else False)
val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False,
num_workers=Config.NUM_WORKERS,
pin_memory=True if device.type=='cuda' else False,
persistent_workers=True if Config.NUM_WORKERS > 0 else False)
# Load pre-trained model from Dataset A
print("\n๐ Loading pre-trained model from Dataset A...")
model = DeepfakeDetector(pretrained=False).to(device)
checkpoint_path = "results/checkpoints/best_model.safetensors"
if os.path.exists(checkpoint_path):
load_model(model, checkpoint_path, strict=False)
print(f"โ
Loaded checkpoint: {checkpoint_path}")
else:
print("โ ๏ธ No checkpoint found! Starting from random weights.")
model.to(device)
# Optimization with LOWER learning rate for fine-tuning
FINETUNE_LR = 1e-5 # 10x lower than original training
FINETUNE_EPOCHS = 2
print(f"\n๐ Fine-tuning settings:")
print(f" Learning Rate: {FINETUNE_LR} (10x lower for fine-tuning)")
print(f" Epochs: {FINETUNE_EPOCHS}")
print(f" Batch Size: {Config.BATCH_SIZE}")
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=FINETUNE_LR, weight_decay=Config.WEIGHT_DECAY)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
# Loop
best_acc = 0.0
for epoch in range(FINETUNE_EPOCHS):
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{FINETUNE_EPOCHS}")
for images, labels in loop:
images = images.to(device)
labels = labels.to(device).unsqueeze(1)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
preds = (torch.sigmoid(outputs) > 0.5).float()
correct = (preds == labels).sum().item()
train_correct += correct
train_total += labels.size(0)
loop.set_postfix(loss=loss.item(), acc=correct/labels.size(0))
train_acc = train_correct / train_total if train_total > 0 else 0
print(f"Epoch {epoch+1} Train Loss: {train_loss/len(train_loader):.4f} Acc: {train_acc:.4f}")
# Save checkpoint after every epoch
save_checkpoint(model, epoch+1, train_acc, name=f"finetuned_datasetC_ep{epoch+1}")
# Validation
if len(val_dataset) > 0:
val_loss, val_acc = validate(model, val_loader, criterion, device)
print(f"Epoch {epoch+1} Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
# Save best model if validation accuracy improved
if val_acc > best_acc:
best_acc = val_acc
print(f"โญ New best model! Validation Accuracy: {val_acc:.4f}")
save_checkpoint(model, epoch+1, val_acc, name="best_finetuned_datasetC")
scheduler.step()
print(f"\n๐ Fine-tuning Complete!")
print(f"Best Validation Accuracy: {best_acc:.4f}")
print(f"\n๐พ Checkpoints saved in: results/checkpoints/")
def validate(model, loader, criterion, device):
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
labels = labels.to(device).unsqueeze(1)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
preds = (torch.sigmoid(outputs) > 0.5).float()
correct += (preds == labels).sum().item()
total += labels.size(0)
return val_loss / len(loader), correct / total
def save_checkpoint(model, epoch, acc, name="checkpoint"):
state_dict = model.state_dict()
filename = f"{name}.safetensors"
path = os.path.join(Config.CHECKPOINT_DIR, filename)
if SAFETENSORS_AVAILABLE:
try:
from safetensors.torch import save_model
save_model(model, path)
print(f"โ
Saved: {filename}")
except Exception as e:
print(f"SafeTensors save failed, falling back to .pth: {e}")
torch.save(state_dict, path.replace(".safetensors", ".pth"))
else:
torch.save(state_dict, path.replace(".safetensors", ".pth"))
if __name__ == "__main__":
finetune()
|