| | import os |
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import DataLoader |
| | from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_fscore_support, accuracy_score |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | import numpy as np |
| | from tqdm import tqdm |
| | import sys |
| |
|
| | |
| | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| | sys.path.insert(0, os.path.dirname(CURRENT_DIR)) |
| |
|
| | from src.config import Config |
| | from src.models import DeepfakeDetector |
| | from src.dataset import DeepfakeDataset |
| | from safetensors.torch import load_model |
| |
|
| |
|
| | def generate_report(model_filename="Mark-III.safetensors", val_loader=None, device_str=None, output_dir=None): |
| | if device_str: |
| | device = torch.device(device_str) |
| | else: |
| | Config.setup() |
| | device = torch.device(Config.DEVICE) |
| | |
| | |
| | if output_dir is None: |
| | report_plots_dir = os.path.join(Config.RESULTS_DIR, "plots") |
| | else: |
| | report_plots_dir = output_dir |
| | os.makedirs(report_plots_dir, exist_ok=True) |
| |
|
| | |
| | |
| | |
| | print(f"๐ Loading {model_filename}...") |
| | model = DeepfakeDetector(pretrained=False).to(device) |
| | |
| | |
| | if os.path.isabs(model_filename): |
| | checkpoint_path = model_filename |
| | else: |
| | checkpoint_path = os.path.join(Config.CHECKPOINT_DIR, model_filename) |
| | |
| | if not os.path.exists(checkpoint_path): |
| | print(f"โ ๏ธ Model not found at {checkpoint_path}") |
| | return |
| |
|
| | load_model(model, checkpoint_path, strict=False) |
| | model.eval() |
| | |
| | |
| | |
| | if val_loader is None: |
| | print("๐ Loading Default Validation Dataset (FF++)...") |
| | FF_DATASET_ROOT = "/Users/harshvardhan/Developer/Deepfake Project /DataSet/FaceForencis++ extracted frames" |
| | |
| | train_real_path = os.path.join(FF_DATASET_ROOT, "real") |
| | train_fake_path = os.path.join(FF_DATASET_ROOT, "fake") |
| | |
| | real_files, real_labels = DeepfakeDataset.scan_directory(train_real_path) |
| | fake_files, fake_labels = DeepfakeDataset.scan_directory(train_fake_path) |
| | |
| | all_paths = list(real_files) + list(fake_files) |
| | all_labels = list(real_labels) + list(fake_labels) |
| | |
| | |
| | import random |
| | combined = list(zip(all_paths, all_labels)) |
| | random.shuffle(combined) |
| | val_data = combined[:2000] |
| | val_paths, val_labels = zip(*val_data) |
| | |
| | val_dataset = DeepfakeDataset(file_paths=list(val_paths), labels=list(val_labels), phase='val') |
| | val_loader = DataLoader(val_dataset, batch_size=64, num_workers=4, shuffle=False) |
| | |
| | |
| | all_preds = [] |
| | all_labels_list = [] |
| | all_probs = [] |
| | |
| | print("โก Running Inference for Report...") |
| | with torch.no_grad(): |
| | for images, labels in tqdm(val_loader, desc="Reporting"): |
| | images = images.to(device) |
| | outputs = model(images) |
| | probs = torch.sigmoid(outputs).cpu().numpy() |
| | preds = (probs > 0.5).astype(int) |
| | |
| | all_probs.extend(probs) |
| | all_preds.extend(preds) |
| | all_labels_list.extend(labels.numpy()) |
| | |
| | all_labels_np = np.array(all_labels_list) |
| | all_preds_np = np.array(all_preds).flatten() |
| | all_probs_np = np.array(all_probs).flatten() |
| | |
| | |
| | try: |
| | acc = accuracy_score(all_labels_np, all_preds_np) |
| | precision, recall, f1, _ = precision_recall_fscore_support(all_labels_np, all_preds_np, average='binary', zero_division=0) |
| | |
| | |
| | try: |
| | fpr, tpr, thresholds = roc_curve(all_labels_np, all_probs_np) |
| | roc_auc = auc(fpr, tpr) |
| | except IndexError: |
| | print("โ ๏ธ Warning: ROC Curve generation failed due to sklearn IndexError. Skipping ROC plot.") |
| | fpr, tpr, roc_auc = None, None, 0.0 |
| |
|
| | cm = confusion_matrix(all_labels_np, all_preds_np) |
| | |
| | print(f"\n๐ Report Metrics:") |
| | print(f" Accuracy: {acc:.4f}") |
| | print(f" Precision: {precision:.4f}") |
| | print(f" Recall: {recall:.4f}") |
| | print(f" F1-Score: {f1:.4f}") |
| | print(f" ROC-AUC: {roc_auc:.4f}") |
| | |
| | |
| | |
| | plt.figure(figsize=(8, 6)) |
| | sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Real', 'Fake'], yticklabels=['Real', 'Fake']) |
| | plt.title(f'Confusion Matrix - {os.path.basename(model_filename)}') |
| | plt.ylabel('True Label') |
| | plt.xlabel('Predicted Label') |
| | plt.savefig(os.path.join(report_plots_dir, "confusion_matrix.png")) |
| | plt.close() |
| | |
| | |
| | if fpr is not None: |
| | plt.figure(figsize=(8, 6)) |
| | plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})') |
| | plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') |
| | plt.xlim([0.0, 1.0]) |
| | plt.ylim([0.0, 1.05]) |
| | plt.xlabel('False Positive Rate') |
| | plt.ylabel('True Positive Rate') |
| | plt.title(f'ROC - {os.path.basename(model_filename)}') |
| | plt.legend(loc="lower right") |
| | plt.savefig(os.path.join(report_plots_dir, "roc_curve.png")) |
| | plt.close() |
| | |
| | print(f"\nโ
Visuals saved to {report_plots_dir}") |
| | return acc, roc_auc |
| | |
| | except Exception as e: |
| | print(f"๐จ Error during report generation metrics: {e}") |
| | return 0.0, 0.0 |
| |
|
| | if __name__ == "__main__": |
| | generate_report() |
| |
|
| |
|