Deepfake_Detection_System_V1 / model /src /finetune_dataset_a.py
Harshasnade's picture
Deploy Backend (No Frontend)
0966609
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import random
import ssl
import platform
# Disable SSL verification for downloading pretrained weights
ssl._create_default_https_context = ssl._create_unverified_context
from src.config import Config
from src.models import DeepfakeDetector
from src.dataset import DeepfakeDataset
try:
from safetensors.torch import save_file, load_model
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
print("Warning: safetensors not installed. Checkpoints will be saved as .pt")
def finetune():
# Setup
Config.setup()
device = torch.device(Config.DEVICE)
# Fine-tuning dataset path - Dataset A
if platform.system() == "Windows":
FINETUNE_DATA_PATH = r"C:\Users\kanna\Downloads\Dataset\Dataset A\Dataset A"
else:
FINETUNE_DATA_PATH = "/Users/harshvardhan/Developer/dataset/Dataset A"
print(f"\n{'='*80}")
print("FINE-TUNING ON DATASET A")
print(f"{'='*80}\n")
# --- Data Loading ---
print(f"Loading data from: {FINETUNE_DATA_PATH}")
if not os.path.exists(FINETUNE_DATA_PATH):
print(f"โŒ Error: Dataset path not found: {FINETUNE_DATA_PATH}")
return
all_paths, all_labels = DeepfakeDataset.scan_directory(FINETUNE_DATA_PATH)
if len(all_paths) == 0:
print(f"No images found in {FINETUNE_DATA_PATH}")
return
# Shuffle and split
combined = list(zip(all_paths, all_labels))
random.shuffle(combined)
# Use 80/20 split for fine-tuning dataset
split_idx = int(len(combined) * 0.8)
train_data = combined[:split_idx]
val_data = combined[split_idx:]
train_paths, train_labels = zip(*train_data)
val_paths, val_labels = zip(*val_data)
train_dataset = DeepfakeDataset(file_paths=list(train_paths), labels=list(train_labels), phase='train')
val_dataset = DeepfakeDataset(file_paths=list(val_paths), labels=list(val_labels), phase='val')
# Dataloaders - Use Config.BATCH_SIZE but ensure it fits GPU
train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True,
num_workers=Config.NUM_WORKERS,
pin_memory=True if device.type=='cuda' else False,
persistent_workers=True if Config.NUM_WORKERS > 0 else False)
val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False,
num_workers=Config.NUM_WORKERS,
pin_memory=True if device.type=='cuda' else False,
persistent_workers=True if Config.NUM_WORKERS > 0 else False)
# Load pre-trained model
print("\n๐Ÿ”„ Loading pre-trained model (best_model)...")
model = DeepfakeDetector(pretrained=False).to(device)
# Try to load the best model found so far
checkpoint_path = os.path.join(Config.CHECKPOINT_DIR, "best_model.safetensors")
if not os.path.exists(checkpoint_path):
# Fallback to .pth if safetensors logic above failed or not used previously
checkpoint_path = os.path.join(Config.CHECKPOINT_DIR, "best_model.pth")
if os.path.exists(checkpoint_path):
try:
if checkpoint_path.endswith(".safetensors"):
load_model(model, checkpoint_path, strict=False)
else:
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
print(f"โœ… Loaded checkpoint: {checkpoint_path}")
except Exception as e:
print(f"โš ๏ธ Error loading checkpoint: {e}")
print("Starting from random weights (not ideal for fine-tuning!)")
else:
print("โš ๏ธ No checkpoint found! Starting from random weights.")
model.to(device)
# Optimization with LOWER learning rate for fine-tuning
FINETUNE_LR = 1e-5 # 10x lower than original training
FINETUNE_EPOCHS = 5 # Give it a few epochs to adapt
print(f"\n๐Ÿ“ Fine-tuning settings:")
print(f" Learning Rate: {FINETUNE_LR} (Low LR for fine-tuning)")
print(f" Epochs: {FINETUNE_EPOCHS}")
print(f" Batch Size: {Config.BATCH_SIZE}")
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=FINETUNE_LR, weight_decay=Config.WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
# Loop
best_acc = 0.0
for epoch in range(FINETUNE_EPOCHS):
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{FINETUNE_EPOCHS}")
for images, labels in loop:
images = images.to(device)
labels = labels.to(device).unsqueeze(1)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
preds = (torch.sigmoid(outputs) > 0.5).float()
correct = (preds == labels).sum().item()
train_correct += correct
train_total += labels.size(0)
loop.set_postfix(loss=loss.item(), acc=correct/labels.size(0) if labels.size(0) > 0 else 0)
train_acc = train_correct / train_total if train_total > 0 else 0
print(f"Epoch {epoch+1} Train Loss: {train_loss/len(train_loader):.4f} Acc: {train_acc:.4f}")
# Save checkpoint after every epoch
save_checkpoint(model, epoch+1, train_acc, name=f"finetuned_datasetA_ep{epoch+1}")
# Validation
if len(val_dataset) > 0:
val_loss, val_acc = validate(model, val_loader, criterion, device)
print(f"Epoch {epoch+1} Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
scheduler.step(val_acc)
# Save best model if validation accuracy improved
if val_acc > best_acc:
best_acc = val_acc
print(f"โญ New best model! Validation Accuracy: {val_acc:.4f}")
save_checkpoint(model, epoch+1, val_acc, name="best_finetuned_datasetA")
print(f"\n๐ŸŽ‰ Fine-tuning Complete!")
print(f"Best Validation Accuracy: {best_acc:.4f}")
print(f"\n๐Ÿ’พ Checkpoints saved in: {Config.CHECKPOINT_DIR}")
def validate(model, loader, criterion, device):
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
labels = labels.to(device).unsqueeze(1)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
preds = (torch.sigmoid(outputs) > 0.5).float()
correct += (preds == labels).sum().item()
total += labels.size(0)
return val_loss / len(loader), correct / total
def save_checkpoint(model, epoch, acc, name="checkpoint"):
state_dict = model.state_dict()
filename = f"{name}.safetensors"
path = os.path.join(Config.CHECKPOINT_DIR, filename)
if SAFETENSORS_AVAILABLE:
try:
from safetensors.torch import save_model
save_model(model, path)
print(f"โœ… Saved: {filename}")
except Exception as e:
print(f"SafeTensors save failed, falling back to .pth: {e}")
torch.save(state_dict, path.replace(".safetensors", ".pth"))
else:
torch.save(state_dict, path.replace(".safetensors", ".pth"))
if __name__ == "__main__":
finetune()