| | import os |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| | import random |
| | import ssl |
| | |
| | ssl._create_default_https_context = ssl._create_unverified_context |
| |
|
| | from src.config import Config |
| | from src.models import DeepfakeDetector |
| | from src.dataset import DeepfakeDataset |
| |
|
| | try: |
| | from safetensors.torch import save_file, load_model |
| | SAFETENSORS_AVAILABLE = True |
| | except ImportError: |
| | SAFETENSORS_AVAILABLE = False |
| | print("Warning: safetensors not installed. Checkpoints will be saved as .pt") |
| |
|
| | def finetune(): |
| | |
| | Config.setup() |
| | device = torch.device(Config.DEVICE) |
| | |
| | |
| | FINETUNE_DATA_PATH = "/Users/harshvardhan/Developer/dataset/Dataset c" |
| | |
| | print(f"\n{'='*80}") |
| | print("FINE-TUNING ON DATASET C") |
| | print(f"{'='*80}\n") |
| | |
| | |
| | print(f"Loading data from: {FINETUNE_DATA_PATH}") |
| | all_paths, all_labels = DeepfakeDataset.scan_directory(FINETUNE_DATA_PATH) |
| | |
| | if len(all_paths) == 0: |
| | print(f"No images found in {FINETUNE_DATA_PATH}") |
| | return |
| |
|
| | |
| | combined = list(zip(all_paths, all_labels)) |
| | random.shuffle(combined) |
| | |
| | split_idx = int(len(combined) * 0.8) |
| | train_data = combined[:split_idx] |
| | val_data = combined[split_idx:] |
| | |
| | train_paths, train_labels = zip(*train_data) |
| | val_paths, val_labels = zip(*val_data) |
| | |
| | train_dataset = DeepfakeDataset(file_paths=list(train_paths), labels=list(train_labels), phase='train') |
| | val_dataset = DeepfakeDataset(file_paths=list(val_paths), labels=list(val_labels), phase='val') |
| | |
| | |
| | train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, |
| | num_workers=Config.NUM_WORKERS, |
| | pin_memory=True if device.type=='cuda' else False, |
| | persistent_workers=True if Config.NUM_WORKERS > 0 else False) |
| | val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, |
| | num_workers=Config.NUM_WORKERS, |
| | pin_memory=True if device.type=='cuda' else False, |
| | persistent_workers=True if Config.NUM_WORKERS > 0 else False) |
| | |
| | |
| | print("\nπ Loading pre-trained model from Dataset A...") |
| | model = DeepfakeDetector(pretrained=False).to(device) |
| | |
| | checkpoint_path = "results/checkpoints/best_model.safetensors" |
| | if os.path.exists(checkpoint_path): |
| | load_model(model, checkpoint_path, strict=False) |
| | print(f"β
Loaded checkpoint: {checkpoint_path}") |
| | else: |
| | print("β οΈ No checkpoint found! Starting from random weights.") |
| | |
| | model.to(device) |
| | |
| | |
| | FINETUNE_LR = 1e-5 |
| | FINETUNE_EPOCHS = 2 |
| | |
| | print(f"\nπ Fine-tuning settings:") |
| | print(f" Learning Rate: {FINETUNE_LR} (10x lower for fine-tuning)") |
| | print(f" Epochs: {FINETUNE_EPOCHS}") |
| | print(f" Batch Size: {Config.BATCH_SIZE}") |
| | |
| | criterion = nn.BCEWithLogitsLoss() |
| | optimizer = optim.AdamW(model.parameters(), lr=FINETUNE_LR, weight_decay=Config.WEIGHT_DECAY) |
| | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) |
| | |
| | |
| | best_acc = 0.0 |
| | |
| | for epoch in range(FINETUNE_EPOCHS): |
| | model.train() |
| | train_loss = 0.0 |
| | train_correct = 0 |
| | train_total = 0 |
| | |
| | loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{FINETUNE_EPOCHS}") |
| | for images, labels in loop: |
| | images = images.to(device) |
| | labels = labels.to(device).unsqueeze(1) |
| | |
| | optimizer.zero_grad() |
| | outputs = model(images) |
| | loss = criterion(outputs, labels) |
| | loss.backward() |
| | optimizer.step() |
| | |
| | train_loss += loss.item() |
| | preds = (torch.sigmoid(outputs) > 0.5).float() |
| | correct = (preds == labels).sum().item() |
| | train_correct += correct |
| | train_total += labels.size(0) |
| | |
| | loop.set_postfix(loss=loss.item(), acc=correct/labels.size(0)) |
| | |
| | train_acc = train_correct / train_total if train_total > 0 else 0 |
| | print(f"Epoch {epoch+1} Train Loss: {train_loss/len(train_loader):.4f} Acc: {train_acc:.4f}") |
| | |
| | |
| | save_checkpoint(model, epoch+1, train_acc, name=f"finetuned_datasetC_ep{epoch+1}") |
| | |
| | |
| | if len(val_dataset) > 0: |
| | val_loss, val_acc = validate(model, val_loader, criterion, device) |
| | print(f"Epoch {epoch+1} Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}") |
| | |
| | |
| | if val_acc > best_acc: |
| | best_acc = val_acc |
| | print(f"β New best model! Validation Accuracy: {val_acc:.4f}") |
| | save_checkpoint(model, epoch+1, val_acc, name="best_finetuned_datasetC") |
| | |
| | scheduler.step() |
| | |
| | print(f"\nπ Fine-tuning Complete!") |
| | print(f"Best Validation Accuracy: {best_acc:.4f}") |
| | print(f"\nπΎ Checkpoints saved in: results/checkpoints/") |
| |
|
| | def validate(model, loader, criterion, device): |
| | model.eval() |
| | val_loss = 0.0 |
| | correct = 0 |
| | total = 0 |
| | |
| | with torch.no_grad(): |
| | for images, labels in loader: |
| | images = images.to(device) |
| | labels = labels.to(device).unsqueeze(1) |
| | |
| | outputs = model(images) |
| | loss = criterion(outputs, labels) |
| | |
| | val_loss += loss.item() |
| | preds = (torch.sigmoid(outputs) > 0.5).float() |
| | correct += (preds == labels).sum().item() |
| | total += labels.size(0) |
| | |
| | return val_loss / len(loader), correct / total |
| |
|
| | def save_checkpoint(model, epoch, acc, name="checkpoint"): |
| | state_dict = model.state_dict() |
| | filename = f"{name}.safetensors" |
| | path = os.path.join(Config.CHECKPOINT_DIR, filename) |
| | |
| | if SAFETENSORS_AVAILABLE: |
| | try: |
| | from safetensors.torch import save_model |
| | save_model(model, path) |
| | print(f"β
Saved: {filename}") |
| | except Exception as e: |
| | print(f"SafeTensors save failed, falling back to .pth: {e}") |
| | torch.save(state_dict, path.replace(".safetensors", ".pth")) |
| | else: |
| | torch.save(state_dict, path.replace(".safetensors", ".pth")) |
| |
|
| | if __name__ == "__main__": |
| | finetune() |
| |
|