Commit ·
0966609
1
Parent(s): 3dccfa9
Deploy Backend (No Frontend)
Browse files- .dockerignore +10 -0
- Dockerfile +31 -0
- backend/.DS_Store +0 -0
- backend/app.py +368 -0
- backend/database.db +0 -0
- backend/database.py +106 -0
- backend/requirements_web.txt +9 -0
- model/.DS_Store +0 -0
- model/results/.DS_Store +0 -0
- model/results/checkpoints/.DS_Store +0 -0
- model/results/checkpoints/best_model.safetensors +3 -0
- model/src/__init__.py +0 -0
- model/src/__pycache__/__init__.cpython-314.pyc +0 -0
- model/src/__pycache__/config.cpython-314.pyc +0 -0
- model/src/__pycache__/dataset.cpython-314.pyc +0 -0
- model/src/__pycache__/models.cpython-314.pyc +0 -0
- model/src/__pycache__/train.cpython-314.pyc +0 -0
- model/src/__pycache__/utils.cpython-314.pyc +0 -0
- model/src/__pycache__/video_inference.cpython-314.pyc +0 -0
- model/src/config.py +50 -0
- model/src/dataset.py +109 -0
- model/src/finetune.py +182 -0
- model/src/finetune_dataset_a.py +204 -0
- model/src/inference.py +139 -0
- model/src/models.py +197 -0
- model/src/test_dataloading.py +47 -0
- model/src/test_dryrun.py +56 -0
- model/src/train.py +271 -0
- model/src/utils.py +37 -0
- model/src/video_inference.py +172 -0
.dockerignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.gitignore
|
| 3 |
+
__pycache__
|
| 4 |
+
*.pyc
|
| 5 |
+
.DS_Store
|
| 6 |
+
model/test_images
|
| 7 |
+
venv
|
| 8 |
+
env
|
| 9 |
+
node_modules
|
| 10 |
+
frontend/tests
|
Dockerfile
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies for OpenCV and GLib
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
libgl1-mesa-glx \
|
| 8 |
+
libglib2.0-0 \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
# Copy requirements first to leverage Docker cache
|
| 12 |
+
COPY backend/requirements_web.txt .
|
| 13 |
+
RUN pip install --no-cache-dir -r requirements_web.txt
|
| 14 |
+
|
| 15 |
+
# Copy the rest of the application
|
| 16 |
+
# Copy the rest of the application
|
| 17 |
+
COPY backend/ backend/
|
| 18 |
+
COPY model/ model/
|
| 19 |
+
|
| 20 |
+
# Set working directory to backend
|
| 21 |
+
WORKDIR /app/backend
|
| 22 |
+
|
| 23 |
+
# Create necessary directories
|
| 24 |
+
RUN mkdir -p uploads
|
| 25 |
+
RUN mkdir -p history_uploads
|
| 26 |
+
|
| 27 |
+
# Expose Hugging Face default port
|
| 28 |
+
EXPOSE 7860
|
| 29 |
+
|
| 30 |
+
# Run the application
|
| 31 |
+
CMD ["python", "app.py"]
|
backend/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
backend/app.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import Flask, request, jsonify, send_from_directory
|
| 2 |
+
from flask_cors import CORS
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# Add model directory to path
|
| 7 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'model')))
|
| 8 |
+
import datetime
|
| 9 |
+
import torch
|
| 10 |
+
import cv2
|
| 11 |
+
import os
|
| 12 |
+
import numpy as np
|
| 13 |
+
import ssl
|
| 14 |
+
import base64
|
| 15 |
+
from werkzeug.utils import secure_filename
|
| 16 |
+
import io
|
| 17 |
+
from PIL import Image
|
| 18 |
+
from src import video_inference
|
| 19 |
+
|
| 20 |
+
# Disable SSL verification
|
| 21 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
| 22 |
+
import albumentations as A
|
| 23 |
+
from albumentations.pytorch import ToTensorV2
|
| 24 |
+
from albumentations.pytorch import ToTensorV2
|
| 25 |
+
from src.models import DeepfakeDetector
|
| 26 |
+
from src.config import Config
|
| 27 |
+
import database
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from safetensors.torch import load_file
|
| 31 |
+
SAFETENSORS_AVAILABLE = True
|
| 32 |
+
except ImportError:
|
| 33 |
+
SAFETENSORS_AVAILABLE = False
|
| 34 |
+
|
| 35 |
+
app = Flask(__name__, static_folder='../frontend', static_url_path='')
|
| 36 |
+
CORS(app)
|
| 37 |
+
|
| 38 |
+
# Configuration
|
| 39 |
+
UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), 'uploads')
|
| 40 |
+
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'webp', 'mp4', 'avi', 'mov', 'webm'}
|
| 41 |
+
HISTORY_FOLDER = os.path.join(os.path.dirname(__file__), 'history_uploads')
|
| 42 |
+
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
| 43 |
+
os.makedirs(HISTORY_FOLDER, exist_ok=True)
|
| 44 |
+
|
| 45 |
+
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
| 46 |
+
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
|
| 47 |
+
|
| 48 |
+
# Global model and transform
|
| 49 |
+
device = torch.device(Config.DEVICE)
|
| 50 |
+
model = None
|
| 51 |
+
transform = None
|
| 52 |
+
|
| 53 |
+
def get_transform():
|
| 54 |
+
return A.Compose([
|
| 55 |
+
A.Resize(Config.IMAGE_SIZE, Config.IMAGE_SIZE),
|
| 56 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 57 |
+
ToTensorV2(),
|
| 58 |
+
])
|
| 59 |
+
|
| 60 |
+
def load_model():
|
| 61 |
+
"""Load the trained deepfake detection model"""
|
| 62 |
+
global model, transform
|
| 63 |
+
|
| 64 |
+
checkpoint_dir = Config.CHECKPOINT_DIR
|
| 65 |
+
# Explicitly target the model requested by the user
|
| 66 |
+
target_model_name = "best_model.safetensors"
|
| 67 |
+
checkpoint_path = os.path.join(checkpoint_dir, target_model_name)
|
| 68 |
+
|
| 69 |
+
print(f"Using device: {device}")
|
| 70 |
+
|
| 71 |
+
# Initialize with pretrained=True to ensure missing keys (frozen layers) have valid ImageNet weights
|
| 72 |
+
# instead of random noise. This fixes the "random prediction" issue when the checkpoint
|
| 73 |
+
# only contains finetuned layers.
|
| 74 |
+
model = DeepfakeDetector(pretrained=True)
|
| 75 |
+
model.to(device)
|
| 76 |
+
model.eval()
|
| 77 |
+
|
| 78 |
+
# Check if file exists first
|
| 79 |
+
if not os.path.exists(checkpoint_path):
|
| 80 |
+
print(f"❌ CRITICAL ERROR: Model file not found at: {checkpoint_path}")
|
| 81 |
+
print(f"Please ensure '{target_model_name}' exists in '{checkpoint_dir}'")
|
| 82 |
+
model = None
|
| 83 |
+
transform = get_transform()
|
| 84 |
+
return model, transform
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
print(f"Loading checkpoint: {checkpoint_path}")
|
| 88 |
+
if checkpoint_path.endswith(".safetensors") and SAFETENSORS_AVAILABLE:
|
| 89 |
+
state_dict = load_file(checkpoint_path)
|
| 90 |
+
else:
|
| 91 |
+
state_dict = torch.load(checkpoint_path, map_location=device)
|
| 92 |
+
|
| 93 |
+
# Use strict=False because the checkpoint might be a partial save (e.g. only finetuned layers)
|
| 94 |
+
# or there might be minor architecture mismatches.
|
| 95 |
+
# Since we use pretrained=True, the missing keys will remain as ImageNet weights (valid features).
|
| 96 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
| 97 |
+
|
| 98 |
+
print(f"✅ Model loaded successfully!")
|
| 99 |
+
if missing_keys:
|
| 100 |
+
print(f"ℹ️ {len(missing_keys)} keys missing from checkpoint (using pretrained defaults).")
|
| 101 |
+
if unexpected_keys:
|
| 102 |
+
print(f"ℹ️ {len(unexpected_keys)} unexpected keys in checkpoint.")
|
| 103 |
+
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"❌ Error loading checkpoint: {e}")
|
| 106 |
+
print("Predictions will fail until this is resolved.")
|
| 107 |
+
model = None
|
| 108 |
+
|
| 109 |
+
transform = get_transform()
|
| 110 |
+
return model, transform
|
| 111 |
+
|
| 112 |
+
def allowed_file(filename):
|
| 113 |
+
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
| 114 |
+
|
| 115 |
+
def predict_image(image_path):
|
| 116 |
+
"""Make prediction on a single image"""
|
| 117 |
+
if model is None:
|
| 118 |
+
return None, "Error: Model not loaded. Check backend logs for 'best_finetuned_datasetB.safetensors' error."
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
# Read and preprocess image
|
| 122 |
+
image = cv2.imread(image_path)
|
| 123 |
+
if image is None:
|
| 124 |
+
return None, "Error: Could not read image"
|
| 125 |
+
|
| 126 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 127 |
+
augmented = transform(image=image)
|
| 128 |
+
image_tensor = augmented['image'].unsqueeze(0).to(device)
|
| 129 |
+
|
| 130 |
+
# Make prediction
|
| 131 |
+
logits = model(image_tensor)
|
| 132 |
+
prob = torch.sigmoid(logits).item()
|
| 133 |
+
|
| 134 |
+
# Generate Heatmap
|
| 135 |
+
heatmap = model.get_heatmap(image_tensor)
|
| 136 |
+
|
| 137 |
+
# Process Heatmap for Visualization
|
| 138 |
+
# Resize to original image size
|
| 139 |
+
heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
|
| 140 |
+
heatmap = np.uint8(255 * heatmap)
|
| 141 |
+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
| 142 |
+
|
| 143 |
+
# Superimpose
|
| 144 |
+
# Heatmap is BGR (from cv2), Image is RGB. Convert Image to BGR.
|
| 145 |
+
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 146 |
+
superimposed_img = heatmap * 0.4 + image_bgr * 0.6
|
| 147 |
+
superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
|
| 148 |
+
|
| 149 |
+
# Encode to Base64
|
| 150 |
+
_, buffer = cv2.imencode('.jpg', superimposed_img)
|
| 151 |
+
heatmap_b64 = base64.b64encode(buffer).decode('utf-8')
|
| 152 |
+
|
| 153 |
+
is_fake = prob > 0.5
|
| 154 |
+
label = "FAKE" if is_fake else "REAL"
|
| 155 |
+
confidence = prob if is_fake else 1 - prob
|
| 156 |
+
|
| 157 |
+
return {
|
| 158 |
+
'prediction': label,
|
| 159 |
+
'confidence': float(confidence),
|
| 160 |
+
'fake_probability': float(prob),
|
| 161 |
+
'real_probability': float(1 - prob),
|
| 162 |
+
'heatmap': heatmap_b64
|
| 163 |
+
}, None
|
| 164 |
+
except Exception as e:
|
| 165 |
+
return None, str(e)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
@app.route('/')
|
| 169 |
+
def index():
|
| 170 |
+
"""Backend Root"""
|
| 171 |
+
return jsonify({
|
| 172 |
+
"status": "online",
|
| 173 |
+
"message": "DeepGuard Backend is Running",
|
| 174 |
+
"endpoints": ["/api/predict", "/api/history", "/api/health"]
|
| 175 |
+
})
|
| 176 |
+
|
| 177 |
+
@app.route('/history_uploads/<path:filename>')
|
| 178 |
+
def serve_history_image(filename):
|
| 179 |
+
"""Serve history images"""
|
| 180 |
+
return send_from_directory(HISTORY_FOLDER, filename)
|
| 181 |
+
|
| 182 |
+
@app.route('/api/health', methods=['GET'])
|
| 183 |
+
def health_check():
|
| 184 |
+
"""Health check endpoint"""
|
| 185 |
+
return jsonify({
|
| 186 |
+
'status': 'healthy',
|
| 187 |
+
'model_loaded': model is not None,
|
| 188 |
+
'device': str(device)
|
| 189 |
+
})
|
| 190 |
+
|
| 191 |
+
@app.route('/api/predict', methods=['POST'])
|
| 192 |
+
def predict():
|
| 193 |
+
"""Handle image upload and prediction"""
|
| 194 |
+
try:
|
| 195 |
+
# Check if file is present
|
| 196 |
+
if 'file' not in request.files:
|
| 197 |
+
return jsonify({'error': 'No file provided'}), 400
|
| 198 |
+
|
| 199 |
+
file = request.files['file']
|
| 200 |
+
|
| 201 |
+
if file.filename == '':
|
| 202 |
+
return jsonify({'error': 'No file selected'}), 400
|
| 203 |
+
|
| 204 |
+
if not allowed_file(file.filename):
|
| 205 |
+
return jsonify({'error': 'Invalid file type. Allowed types: png, jpg, jpeg, webp'}), 400
|
| 206 |
+
|
| 207 |
+
# Save file
|
| 208 |
+
filename = secure_filename(file.filename)
|
| 209 |
+
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 210 |
+
file.save(filepath)
|
| 211 |
+
|
| 212 |
+
# Make prediction
|
| 213 |
+
result, error = predict_image(filepath)
|
| 214 |
+
|
| 215 |
+
# Save to History
|
| 216 |
+
import shutil
|
| 217 |
+
history_filename = f"scan_{int(datetime.datetime.now().timestamp())}_{filename}"
|
| 218 |
+
history_path = os.path.join(HISTORY_FOLDER, history_filename)
|
| 219 |
+
|
| 220 |
+
# Copy original file to history folder
|
| 221 |
+
# We need to read the file again or just copy if we haven't deleted it?
|
| 222 |
+
# We read via cv2, the file is still at filepath.
|
| 223 |
+
shutil.copy(filepath, history_path)
|
| 224 |
+
|
| 225 |
+
# Relative path for frontend
|
| 226 |
+
relative_path = f"history_uploads/{history_filename}"
|
| 227 |
+
|
| 228 |
+
database.add_scan(
|
| 229 |
+
filename=filename,
|
| 230 |
+
prediction=result['prediction'],
|
| 231 |
+
confidence=result['confidence'],
|
| 232 |
+
fake_prob=result['fake_probability'],
|
| 233 |
+
real_prob=result['real_probability'],
|
| 234 |
+
image_path=relative_path
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Clean up uploaded file
|
| 238 |
+
try:
|
| 239 |
+
os.remove(filepath)
|
| 240 |
+
except:
|
| 241 |
+
pass
|
| 242 |
+
|
| 243 |
+
return jsonify(result)
|
| 244 |
+
|
| 245 |
+
except Exception as e:
|
| 246 |
+
return jsonify({'error': str(e)}), 500
|
| 247 |
+
|
| 248 |
+
@app.route('/api/predict_video', methods=['POST'])
|
| 249 |
+
def predict_video():
|
| 250 |
+
"""Handle video upload and prediction"""
|
| 251 |
+
try:
|
| 252 |
+
if 'file' not in request.files:
|
| 253 |
+
return jsonify({'error': 'No file provided'}), 400
|
| 254 |
+
|
| 255 |
+
file = request.files['file']
|
| 256 |
+
|
| 257 |
+
if file.filename == '':
|
| 258 |
+
return jsonify({'error': 'No file selected'}), 400
|
| 259 |
+
|
| 260 |
+
if not allowed_file(file.filename):
|
| 261 |
+
return jsonify({'error': 'Invalid file type'}), 400
|
| 262 |
+
|
| 263 |
+
# Save file
|
| 264 |
+
filename = secure_filename(file.filename)
|
| 265 |
+
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 266 |
+
file.save(filepath)
|
| 267 |
+
|
| 268 |
+
# Process Video
|
| 269 |
+
# Note: process_video needs sys.path to be correct to import models inside it if it was standalone,
|
| 270 |
+
# but here we pass the already loaded 'model' object.
|
| 271 |
+
if model is None:
|
| 272 |
+
return jsonify({'error': 'Model not loaded'}), 500
|
| 273 |
+
|
| 274 |
+
result = video_inference.process_video(filepath, model, transform, device)
|
| 275 |
+
|
| 276 |
+
if "error" in result:
|
| 277 |
+
return jsonify(result), 500
|
| 278 |
+
|
| 279 |
+
# Save to History (Using the first frame or a placeholder icon for now?)
|
| 280 |
+
# For video, we might want to save the video file itself to history_uploads
|
| 281 |
+
# or just a thumbnail. Let's save the video for now.
|
| 282 |
+
import shutil
|
| 283 |
+
history_filename = f"scan_{int(datetime.datetime.now().timestamp())}_{filename}"
|
| 284 |
+
history_path = os.path.join(HISTORY_FOLDER, history_filename)
|
| 285 |
+
shutil.copy(filepath, history_path)
|
| 286 |
+
|
| 287 |
+
relative_path = f"history_uploads/{history_filename}"
|
| 288 |
+
|
| 289 |
+
# Add to database
|
| 290 |
+
# Note: The database 'add_scan' might expect image-specific fields.
|
| 291 |
+
# We'll re-use 'fake_prob' as 'avg_fake_prob'
|
| 292 |
+
database.add_scan(
|
| 293 |
+
filename=filename,
|
| 294 |
+
prediction=result['prediction'],
|
| 295 |
+
confidence=result['confidence'],
|
| 296 |
+
fake_prob=result['avg_fake_prob'],
|
| 297 |
+
real_prob=1 - result['avg_fake_prob'],
|
| 298 |
+
image_path=relative_path
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Clean up
|
| 302 |
+
try:
|
| 303 |
+
os.remove(filepath)
|
| 304 |
+
except:
|
| 305 |
+
pass
|
| 306 |
+
|
| 307 |
+
# Add video URL for frontend playback
|
| 308 |
+
result['video_url'] = relative_path
|
| 309 |
+
|
| 310 |
+
return jsonify(result)
|
| 311 |
+
|
| 312 |
+
except Exception as e:
|
| 313 |
+
print(f"Video Error: {e}")
|
| 314 |
+
return jsonify({'error': str(e)}), 500
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
@app.route('/api/history', methods=['GET'])
|
| 318 |
+
def get_history():
|
| 319 |
+
"""Get all past scans"""
|
| 320 |
+
history = database.get_history()
|
| 321 |
+
history = database.get_history()
|
| 322 |
+
return jsonify(history)
|
| 323 |
+
|
| 324 |
+
@app.route('/api/history/<int:scan_id>', methods=['DELETE'])
|
| 325 |
+
def delete_scan(scan_id):
|
| 326 |
+
"""Delete a specific scan"""
|
| 327 |
+
if database.delete_scan(scan_id):
|
| 328 |
+
return jsonify({'message': 'Scan deleted'})
|
| 329 |
+
return jsonify({'error': 'Failed to delete scan'}), 500
|
| 330 |
+
|
| 331 |
+
@app.route('/api/history', methods=['DELETE'])
|
| 332 |
+
def clear_history():
|
| 333 |
+
"""Clear all history"""
|
| 334 |
+
if database.clear_history():
|
| 335 |
+
return jsonify({'message': 'History cleared'})
|
| 336 |
+
return jsonify({'error': 'Failed to clear history'}), 500
|
| 337 |
+
|
| 338 |
+
@app.route('/api/model-info', methods=['GET'])
|
| 339 |
+
def model_info():
|
| 340 |
+
"""Return model information"""
|
| 341 |
+
return jsonify({
|
| 342 |
+
'model_name': 'DeepGuard: Advanced Deepfake Detector',
|
| 343 |
+
'architecture': 'Hybrid CNN-ViT',
|
| 344 |
+
'components': {
|
| 345 |
+
'RGB Analysis': Config.USE_RGB,
|
| 346 |
+
'Frequency Domain': Config.USE_FREQ,
|
| 347 |
+
'Patch-based Detection': Config.USE_PATCH,
|
| 348 |
+
'Vision Transformer': Config.USE_VIT
|
| 349 |
+
},
|
| 350 |
+
'image_size': Config.IMAGE_SIZE,
|
| 351 |
+
'device': str(device),
|
| 352 |
+
'threshold': 0.5
|
| 353 |
+
})
|
| 354 |
+
|
| 355 |
+
if __name__ == '__main__':
|
| 356 |
+
print("=" * 60)
|
| 357 |
+
print("🚀 DeepGuard - Deepfake Detection System")
|
| 358 |
+
print("=" * 60)
|
| 359 |
+
|
| 360 |
+
# Load model
|
| 361 |
+
load_model()
|
| 362 |
+
|
| 363 |
+
print("=" * 60)
|
| 364 |
+
port = int(os.environ.get("PORT", 7860))
|
| 365 |
+
print(f"🌐 Starting server on http://0.0.0.0:{port}")
|
| 366 |
+
print("=" * 60)
|
| 367 |
+
|
| 368 |
+
app.run(debug=False, host='0.0.0.0', port=port)
|
backend/database.db
ADDED
|
Binary file (20.5 kB). View file
|
|
|
backend/database.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import datetime
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
DB_NAME = os.path.join(os.path.dirname(__file__), 'database.db')
|
| 6 |
+
|
| 7 |
+
def get_db_connection():
|
| 8 |
+
try:
|
| 9 |
+
conn = sqlite3.connect(DB_NAME)
|
| 10 |
+
conn.row_factory = sqlite3.Row
|
| 11 |
+
return conn
|
| 12 |
+
except sqlite3.Error as e:
|
| 13 |
+
print(f"Database error: {e}")
|
| 14 |
+
return None
|
| 15 |
+
|
| 16 |
+
def init_db():
|
| 17 |
+
conn = get_db_connection()
|
| 18 |
+
if conn:
|
| 19 |
+
try:
|
| 20 |
+
conn.execute('''
|
| 21 |
+
CREATE TABLE IF NOT EXISTS history (
|
| 22 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 23 |
+
filename TEXT NOT NULL,
|
| 24 |
+
prediction TEXT NOT NULL,
|
| 25 |
+
confidence REAL NOT NULL,
|
| 26 |
+
fake_probability REAL NOT NULL,
|
| 27 |
+
real_probability REAL NOT NULL,
|
| 28 |
+
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
| 29 |
+
)
|
| 30 |
+
''')
|
| 31 |
+
conn.commit()
|
| 32 |
+
print("✅ Database initialized successfully.")
|
| 33 |
+
except sqlite3.Error as e:
|
| 34 |
+
print(f"Error initializing database: {e}")
|
| 35 |
+
|
| 36 |
+
# Migration: Add image_path if not exists
|
| 37 |
+
try:
|
| 38 |
+
conn.execute('ALTER TABLE history ADD COLUMN image_path TEXT')
|
| 39 |
+
print("✅ Added image_path column.")
|
| 40 |
+
except sqlite3.Error:
|
| 41 |
+
pass # Column likely exists
|
| 42 |
+
|
| 43 |
+
finally:
|
| 44 |
+
conn.close()
|
| 45 |
+
|
| 46 |
+
def add_scan(filename, prediction, confidence, fake_prob, real_prob, image_path=""):
|
| 47 |
+
conn = get_db_connection()
|
| 48 |
+
if conn:
|
| 49 |
+
try:
|
| 50 |
+
conn.execute('''
|
| 51 |
+
INSERT INTO history (filename, prediction, confidence, fake_probability, real_probability, image_path)
|
| 52 |
+
VALUES (?, ?, ?, ?, ?, ?)
|
| 53 |
+
''', (filename, prediction, confidence, fake_prob, real_prob, image_path))
|
| 54 |
+
conn.commit()
|
| 55 |
+
return True
|
| 56 |
+
except sqlite3.Error as e:
|
| 57 |
+
print(f"Error adding scan: {e}")
|
| 58 |
+
return False
|
| 59 |
+
finally:
|
| 60 |
+
conn.close()
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
def get_history():
|
| 64 |
+
conn = get_db_connection()
|
| 65 |
+
if conn:
|
| 66 |
+
try:
|
| 67 |
+
cursor = conn.execute('SELECT * FROM history ORDER BY timestamp DESC')
|
| 68 |
+
history = [dict(row) for row in cursor.fetchall()]
|
| 69 |
+
return history
|
| 70 |
+
except sqlite3.Error as e:
|
| 71 |
+
print(f"Error retrieving history: {e}")
|
| 72 |
+
return []
|
| 73 |
+
finally:
|
| 74 |
+
conn.close()
|
| 75 |
+
return []
|
| 76 |
+
|
| 77 |
+
def clear_history():
|
| 78 |
+
conn = get_db_connection()
|
| 79 |
+
if conn:
|
| 80 |
+
try:
|
| 81 |
+
conn.execute('DELETE FROM history')
|
| 82 |
+
conn.commit()
|
| 83 |
+
return True
|
| 84 |
+
except sqlite3.Error as e:
|
| 85 |
+
print(f"Error clearing history: {e}")
|
| 86 |
+
return False
|
| 87 |
+
finally:
|
| 88 |
+
conn.close()
|
| 89 |
+
return False
|
| 90 |
+
|
| 91 |
+
def delete_scan(scan_id):
|
| 92 |
+
conn = get_db_connection()
|
| 93 |
+
if conn:
|
| 94 |
+
try:
|
| 95 |
+
conn.execute('DELETE FROM history WHERE id = ?', (scan_id,))
|
| 96 |
+
conn.commit()
|
| 97 |
+
return True
|
| 98 |
+
except sqlite3.Error as e:
|
| 99 |
+
print(f"Error deleting scan: {e}")
|
| 100 |
+
return False
|
| 101 |
+
finally:
|
| 102 |
+
conn.close()
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
# Initialize DB on module load
|
| 106 |
+
init_db()
|
backend/requirements_web.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flask==3.0.0
|
| 2 |
+
flask-cors==4.0.0
|
| 3 |
+
torch
|
| 4 |
+
torchvision
|
| 5 |
+
opencv-python
|
| 6 |
+
albumentations
|
| 7 |
+
Pillow
|
| 8 |
+
numpy
|
| 9 |
+
safetensors
|
model/.DS_Store
ADDED
|
Binary file (10.2 kB). View file
|
|
|
model/results/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
model/results/checkpoints/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
model/results/checkpoints/best_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a58fd840e8ebab964a3021acb9e365fe445a86dbcb8af93d808f68eb3a254ad4
|
| 3 |
+
size 202457588
|
model/src/__init__.py
ADDED
|
File without changes
|
model/src/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (175 Bytes). View file
|
|
|
model/src/__pycache__/config.cpython-314.pyc
ADDED
|
Binary file (2.58 kB). View file
|
|
|
model/src/__pycache__/dataset.cpython-314.pyc
ADDED
|
Binary file (5.79 kB). View file
|
|
|
model/src/__pycache__/models.cpython-314.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
model/src/__pycache__/train.cpython-314.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
model/src/__pycache__/utils.cpython-314.pyc
ADDED
|
Binary file (1.46 kB). View file
|
|
|
model/src/__pycache__/video_inference.cpython-314.pyc
ADDED
|
Binary file (6.7 kB). View file
|
|
|
model/src/config.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import platform
|
| 4 |
+
|
| 5 |
+
class Config:
|
| 6 |
+
# System
|
| 7 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 8 |
+
DATA_DIR = os.path.join(PROJECT_ROOT, "data")
|
| 9 |
+
RESULTS_DIR = os.path.join(PROJECT_ROOT, "results")
|
| 10 |
+
|
| 11 |
+
# Model Architecture
|
| 12 |
+
IMAGE_SIZE = 256
|
| 13 |
+
NUM_CLASSES = 1 # Logic: 0=Real, 1=Fake (Sigmoid output)
|
| 14 |
+
|
| 15 |
+
# Component Flags
|
| 16 |
+
USE_RGB = True
|
| 17 |
+
USE_FREQ = True
|
| 18 |
+
USE_PATCH = True
|
| 19 |
+
USE_VIT = True
|
| 20 |
+
|
| 21 |
+
# Training Hyperparameters
|
| 22 |
+
BATCH_SIZE = 32 # Optimized for Mac M4 (Unified Memory)
|
| 23 |
+
EPOCHS = 3
|
| 24 |
+
LEARNING_RATE = 1e-4
|
| 25 |
+
WEIGHT_DECAY = 1e-5
|
| 26 |
+
NUM_WORKERS = 8 # Leverage M4 Performance Cores
|
| 27 |
+
|
| 28 |
+
# Hardware
|
| 29 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 30 |
+
|
| 31 |
+
# Paths
|
| 32 |
+
# Docker Deployment: Use relative paths
|
| 33 |
+
DATA_DIR = os.path.join(PROJECT_ROOT, "data")
|
| 34 |
+
|
| 35 |
+
# Since we are using the root folder, the script will recursively find ALL images
|
| 36 |
+
# in all sub-datasets and split them 80/20 for training/validation.
|
| 37 |
+
TRAIN_DATA_PATH = DATA_DIR
|
| 38 |
+
TEST_DATA_PATH = DATA_DIR
|
| 39 |
+
CHECKPOINT_DIR = os.path.join(RESULTS_DIR, "checkpoints")
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def setup(cls):
|
| 43 |
+
os.makedirs(cls.RESULTS_DIR, exist_ok=True)
|
| 44 |
+
os.makedirs(cls.CHECKPOINT_DIR, exist_ok=True)
|
| 45 |
+
os.makedirs(cls.DATA_DIR, exist_ok=True)
|
| 46 |
+
print(f"Project initialized at {cls.PROJECT_ROOT}")
|
| 47 |
+
print(f"Using device: {cls.DEVICE}")
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
Config.setup()
|
model/src/dataset.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
import albumentations as A
|
| 7 |
+
from albumentations.pytorch import ToTensorV2
|
| 8 |
+
from src.config import Config
|
| 9 |
+
|
| 10 |
+
class DeepfakeDataset(Dataset):
|
| 11 |
+
def __init__(self, root_dir=None, file_paths=None, labels=None, phase='train', max_samples=None):
|
| 12 |
+
"""
|
| 13 |
+
Args:
|
| 14 |
+
root_dir (str): Directory with subfolders containing images. (Optional if file_paths provided)
|
| 15 |
+
file_paths (list): List of absolute paths to images.
|
| 16 |
+
labels (list): List of labels corresponding to file_paths.
|
| 17 |
+
phase (str): 'train' or 'val'.
|
| 18 |
+
max_samples (int): Optional limit for quick debugging.
|
| 19 |
+
"""
|
| 20 |
+
self.phase = phase
|
| 21 |
+
|
| 22 |
+
if file_paths is not None and labels is not None:
|
| 23 |
+
self.image_paths = file_paths
|
| 24 |
+
self.labels = labels
|
| 25 |
+
elif root_dir is not None:
|
| 26 |
+
self.image_paths, self.labels = self.scan_directory(root_dir)
|
| 27 |
+
else:
|
| 28 |
+
raise ValueError("Either root_dir or (file_paths, labels) must be provided.")
|
| 29 |
+
|
| 30 |
+
if max_samples:
|
| 31 |
+
self.image_paths = self.image_paths[:max_samples]
|
| 32 |
+
self.labels = self.labels[:max_samples]
|
| 33 |
+
|
| 34 |
+
self.transform = self._get_transforms()
|
| 35 |
+
|
| 36 |
+
print(f"Initialized {self.phase} dataset with {len(self.image_paths)} samples.")
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def scan_directory(root_dir):
|
| 40 |
+
image_paths = []
|
| 41 |
+
labels = []
|
| 42 |
+
print(f"Scanning dataset at {root_dir}...")
|
| 43 |
+
|
| 44 |
+
# Valid extensions
|
| 45 |
+
exts = ('.png', '.jpg', '.jpeg', '.webp', '.bmp', '.tif')
|
| 46 |
+
|
| 47 |
+
for root, dirs, files in os.walk(root_dir):
|
| 48 |
+
for file in files:
|
| 49 |
+
if file.lower().endswith(exts):
|
| 50 |
+
path = os.path.join(root, file)
|
| 51 |
+
# Label inference based on full path
|
| 52 |
+
path_lower = path.lower()
|
| 53 |
+
|
| 54 |
+
label = None
|
| 55 |
+
# Prioritize explicit folder names
|
| 56 |
+
if "real" in path_lower:
|
| 57 |
+
label = 0.0
|
| 58 |
+
elif any(x in path_lower for x in ["fake", "df", "synthesis", "generated", "ai"]):
|
| 59 |
+
label = 1.0
|
| 60 |
+
|
| 61 |
+
if label is not None:
|
| 62 |
+
image_paths.append(path)
|
| 63 |
+
labels.append(label)
|
| 64 |
+
|
| 65 |
+
return image_paths, labels
|
| 66 |
+
|
| 67 |
+
def _get_transforms(self):
|
| 68 |
+
size = Config.IMAGE_SIZE
|
| 69 |
+
if self.phase == 'train':
|
| 70 |
+
return A.Compose([
|
| 71 |
+
A.Resize(size, size),
|
| 72 |
+
A.HorizontalFlip(p=0.5),
|
| 73 |
+
A.RandomBrightnessContrast(p=0.2),
|
| 74 |
+
A.GaussNoise(p=0.2),
|
| 75 |
+
# A.GaussianBlur(p=0.1),
|
| 76 |
+
# Fixed for newer albumentations versions
|
| 77 |
+
A.ImageCompression(quality_lower=60, quality_upper=100, p=0.3),
|
| 78 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 79 |
+
ToTensorV2(),
|
| 80 |
+
])
|
| 81 |
+
else:
|
| 82 |
+
return A.Compose([
|
| 83 |
+
A.Resize(size, size),
|
| 84 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 85 |
+
ToTensorV2(),
|
| 86 |
+
])
|
| 87 |
+
|
| 88 |
+
def __len__(self):
|
| 89 |
+
return len(self.image_paths)
|
| 90 |
+
|
| 91 |
+
def __getitem__(self, idx):
|
| 92 |
+
path = self.image_paths[idx]
|
| 93 |
+
label = self.labels[idx]
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
image = cv2.imread(path)
|
| 97 |
+
if image is None:
|
| 98 |
+
raise ValueError("Image not found or corrupt")
|
| 99 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
# print(f"Error loading {path}: {e}")
|
| 102 |
+
# Fallback to next image
|
| 103 |
+
return self.__getitem__((idx + 1) % len(self))
|
| 104 |
+
|
| 105 |
+
if self.transform:
|
| 106 |
+
augmented = self.transform(image=image)
|
| 107 |
+
image = augmented['image']
|
| 108 |
+
|
| 109 |
+
return image, torch.tensor(label, dtype=torch.float32)
|
model/src/finetune.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import random
|
| 8 |
+
import ssl
|
| 9 |
+
# Disable SSL verification for downloading pretrained weights
|
| 10 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
| 11 |
+
|
| 12 |
+
from src.config import Config
|
| 13 |
+
from src.models import DeepfakeDetector
|
| 14 |
+
from src.dataset import DeepfakeDataset
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from safetensors.torch import save_file, load_model
|
| 18 |
+
SAFETENSORS_AVAILABLE = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
SAFETENSORS_AVAILABLE = False
|
| 21 |
+
print("Warning: safetensors not installed. Checkpoints will be saved as .pt")
|
| 22 |
+
|
| 23 |
+
def finetune():
|
| 24 |
+
# Setup
|
| 25 |
+
Config.setup()
|
| 26 |
+
device = torch.device(Config.DEVICE)
|
| 27 |
+
|
| 28 |
+
# Fine-tuning dataset path
|
| 29 |
+
FINETUNE_DATA_PATH = "/Users/harshvardhan/Developer/dataset/Dataset c"
|
| 30 |
+
|
| 31 |
+
print(f"\n{'='*80}")
|
| 32 |
+
print("FINE-TUNING ON DATASET C")
|
| 33 |
+
print(f"{'='*80}\n")
|
| 34 |
+
|
| 35 |
+
# --- Data Loading ---
|
| 36 |
+
print(f"Loading data from: {FINETUNE_DATA_PATH}")
|
| 37 |
+
all_paths, all_labels = DeepfakeDataset.scan_directory(FINETUNE_DATA_PATH)
|
| 38 |
+
|
| 39 |
+
if len(all_paths) == 0:
|
| 40 |
+
print(f"No images found in {FINETUNE_DATA_PATH}")
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
# Shuffle and split
|
| 44 |
+
combined = list(zip(all_paths, all_labels))
|
| 45 |
+
random.shuffle(combined)
|
| 46 |
+
|
| 47 |
+
split_idx = int(len(combined) * 0.8)
|
| 48 |
+
train_data = combined[:split_idx]
|
| 49 |
+
val_data = combined[split_idx:]
|
| 50 |
+
|
| 51 |
+
train_paths, train_labels = zip(*train_data)
|
| 52 |
+
val_paths, val_labels = zip(*val_data)
|
| 53 |
+
|
| 54 |
+
train_dataset = DeepfakeDataset(file_paths=list(train_paths), labels=list(train_labels), phase='train')
|
| 55 |
+
val_dataset = DeepfakeDataset(file_paths=list(val_paths), labels=list(val_labels), phase='val')
|
| 56 |
+
|
| 57 |
+
# Dataloaders
|
| 58 |
+
train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True,
|
| 59 |
+
num_workers=Config.NUM_WORKERS,
|
| 60 |
+
pin_memory=True if device.type=='cuda' else False,
|
| 61 |
+
persistent_workers=True if Config.NUM_WORKERS > 0 else False)
|
| 62 |
+
val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False,
|
| 63 |
+
num_workers=Config.NUM_WORKERS,
|
| 64 |
+
pin_memory=True if device.type=='cuda' else False,
|
| 65 |
+
persistent_workers=True if Config.NUM_WORKERS > 0 else False)
|
| 66 |
+
|
| 67 |
+
# Load pre-trained model from Dataset A
|
| 68 |
+
print("\n🔄 Loading pre-trained model from Dataset A...")
|
| 69 |
+
model = DeepfakeDetector(pretrained=False).to(device)
|
| 70 |
+
|
| 71 |
+
checkpoint_path = "results/checkpoints/best_model.safetensors"
|
| 72 |
+
if os.path.exists(checkpoint_path):
|
| 73 |
+
load_model(model, checkpoint_path, strict=False)
|
| 74 |
+
print(f"✅ Loaded checkpoint: {checkpoint_path}")
|
| 75 |
+
else:
|
| 76 |
+
print("⚠️ No checkpoint found! Starting from random weights.")
|
| 77 |
+
|
| 78 |
+
model.to(device)
|
| 79 |
+
|
| 80 |
+
# Optimization with LOWER learning rate for fine-tuning
|
| 81 |
+
FINETUNE_LR = 1e-5 # 10x lower than original training
|
| 82 |
+
FINETUNE_EPOCHS = 2
|
| 83 |
+
|
| 84 |
+
print(f"\n📝 Fine-tuning settings:")
|
| 85 |
+
print(f" Learning Rate: {FINETUNE_LR} (10x lower for fine-tuning)")
|
| 86 |
+
print(f" Epochs: {FINETUNE_EPOCHS}")
|
| 87 |
+
print(f" Batch Size: {Config.BATCH_SIZE}")
|
| 88 |
+
|
| 89 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 90 |
+
optimizer = optim.AdamW(model.parameters(), lr=FINETUNE_LR, weight_decay=Config.WEIGHT_DECAY)
|
| 91 |
+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
|
| 92 |
+
|
| 93 |
+
# Loop
|
| 94 |
+
best_acc = 0.0
|
| 95 |
+
|
| 96 |
+
for epoch in range(FINETUNE_EPOCHS):
|
| 97 |
+
model.train()
|
| 98 |
+
train_loss = 0.0
|
| 99 |
+
train_correct = 0
|
| 100 |
+
train_total = 0
|
| 101 |
+
|
| 102 |
+
loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{FINETUNE_EPOCHS}")
|
| 103 |
+
for images, labels in loop:
|
| 104 |
+
images = images.to(device)
|
| 105 |
+
labels = labels.to(device).unsqueeze(1)
|
| 106 |
+
|
| 107 |
+
optimizer.zero_grad()
|
| 108 |
+
outputs = model(images)
|
| 109 |
+
loss = criterion(outputs, labels)
|
| 110 |
+
loss.backward()
|
| 111 |
+
optimizer.step()
|
| 112 |
+
|
| 113 |
+
train_loss += loss.item()
|
| 114 |
+
preds = (torch.sigmoid(outputs) > 0.5).float()
|
| 115 |
+
correct = (preds == labels).sum().item()
|
| 116 |
+
train_correct += correct
|
| 117 |
+
train_total += labels.size(0)
|
| 118 |
+
|
| 119 |
+
loop.set_postfix(loss=loss.item(), acc=correct/labels.size(0))
|
| 120 |
+
|
| 121 |
+
train_acc = train_correct / train_total if train_total > 0 else 0
|
| 122 |
+
print(f"Epoch {epoch+1} Train Loss: {train_loss/len(train_loader):.4f} Acc: {train_acc:.4f}")
|
| 123 |
+
|
| 124 |
+
# Save checkpoint after every epoch
|
| 125 |
+
save_checkpoint(model, epoch+1, train_acc, name=f"finetuned_datasetC_ep{epoch+1}")
|
| 126 |
+
|
| 127 |
+
# Validation
|
| 128 |
+
if len(val_dataset) > 0:
|
| 129 |
+
val_loss, val_acc = validate(model, val_loader, criterion, device)
|
| 130 |
+
print(f"Epoch {epoch+1} Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
|
| 131 |
+
|
| 132 |
+
# Save best model if validation accuracy improved
|
| 133 |
+
if val_acc > best_acc:
|
| 134 |
+
best_acc = val_acc
|
| 135 |
+
print(f"⭐ New best model! Validation Accuracy: {val_acc:.4f}")
|
| 136 |
+
save_checkpoint(model, epoch+1, val_acc, name="best_finetuned_datasetC")
|
| 137 |
+
|
| 138 |
+
scheduler.step()
|
| 139 |
+
|
| 140 |
+
print(f"\n🎉 Fine-tuning Complete!")
|
| 141 |
+
print(f"Best Validation Accuracy: {best_acc:.4f}")
|
| 142 |
+
print(f"\n💾 Checkpoints saved in: results/checkpoints/")
|
| 143 |
+
|
| 144 |
+
def validate(model, loader, criterion, device):
|
| 145 |
+
model.eval()
|
| 146 |
+
val_loss = 0.0
|
| 147 |
+
correct = 0
|
| 148 |
+
total = 0
|
| 149 |
+
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
for images, labels in loader:
|
| 152 |
+
images = images.to(device)
|
| 153 |
+
labels = labels.to(device).unsqueeze(1)
|
| 154 |
+
|
| 155 |
+
outputs = model(images)
|
| 156 |
+
loss = criterion(outputs, labels)
|
| 157 |
+
|
| 158 |
+
val_loss += loss.item()
|
| 159 |
+
preds = (torch.sigmoid(outputs) > 0.5).float()
|
| 160 |
+
correct += (preds == labels).sum().item()
|
| 161 |
+
total += labels.size(0)
|
| 162 |
+
|
| 163 |
+
return val_loss / len(loader), correct / total
|
| 164 |
+
|
| 165 |
+
def save_checkpoint(model, epoch, acc, name="checkpoint"):
|
| 166 |
+
state_dict = model.state_dict()
|
| 167 |
+
filename = f"{name}.safetensors"
|
| 168 |
+
path = os.path.join(Config.CHECKPOINT_DIR, filename)
|
| 169 |
+
|
| 170 |
+
if SAFETENSORS_AVAILABLE:
|
| 171 |
+
try:
|
| 172 |
+
from safetensors.torch import save_model
|
| 173 |
+
save_model(model, path)
|
| 174 |
+
print(f"✅ Saved: {filename}")
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"SafeTensors save failed, falling back to .pth: {e}")
|
| 177 |
+
torch.save(state_dict, path.replace(".safetensors", ".pth"))
|
| 178 |
+
else:
|
| 179 |
+
torch.save(state_dict, path.replace(".safetensors", ".pth"))
|
| 180 |
+
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
finetune()
|
model/src/finetune_dataset_a.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import random
|
| 8 |
+
import ssl
|
| 9 |
+
import platform
|
| 10 |
+
|
| 11 |
+
# Disable SSL verification for downloading pretrained weights
|
| 12 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
| 13 |
+
|
| 14 |
+
from src.config import Config
|
| 15 |
+
from src.models import DeepfakeDetector
|
| 16 |
+
from src.dataset import DeepfakeDataset
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from safetensors.torch import save_file, load_model
|
| 20 |
+
SAFETENSORS_AVAILABLE = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
SAFETENSORS_AVAILABLE = False
|
| 23 |
+
print("Warning: safetensors not installed. Checkpoints will be saved as .pt")
|
| 24 |
+
|
| 25 |
+
def finetune():
|
| 26 |
+
# Setup
|
| 27 |
+
Config.setup()
|
| 28 |
+
device = torch.device(Config.DEVICE)
|
| 29 |
+
|
| 30 |
+
# Fine-tuning dataset path - Dataset A
|
| 31 |
+
if platform.system() == "Windows":
|
| 32 |
+
FINETUNE_DATA_PATH = r"C:\Users\kanna\Downloads\Dataset\Dataset A\Dataset A"
|
| 33 |
+
else:
|
| 34 |
+
FINETUNE_DATA_PATH = "/Users/harshvardhan/Developer/dataset/Dataset A"
|
| 35 |
+
|
| 36 |
+
print(f"\n{'='*80}")
|
| 37 |
+
print("FINE-TUNING ON DATASET A")
|
| 38 |
+
print(f"{'='*80}\n")
|
| 39 |
+
|
| 40 |
+
# --- Data Loading ---
|
| 41 |
+
print(f"Loading data from: {FINETUNE_DATA_PATH}")
|
| 42 |
+
if not os.path.exists(FINETUNE_DATA_PATH):
|
| 43 |
+
print(f"❌ Error: Dataset path not found: {FINETUNE_DATA_PATH}")
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
all_paths, all_labels = DeepfakeDataset.scan_directory(FINETUNE_DATA_PATH)
|
| 47 |
+
|
| 48 |
+
if len(all_paths) == 0:
|
| 49 |
+
print(f"No images found in {FINETUNE_DATA_PATH}")
|
| 50 |
+
return
|
| 51 |
+
|
| 52 |
+
# Shuffle and split
|
| 53 |
+
combined = list(zip(all_paths, all_labels))
|
| 54 |
+
random.shuffle(combined)
|
| 55 |
+
|
| 56 |
+
# Use 80/20 split for fine-tuning dataset
|
| 57 |
+
split_idx = int(len(combined) * 0.8)
|
| 58 |
+
train_data = combined[:split_idx]
|
| 59 |
+
val_data = combined[split_idx:]
|
| 60 |
+
|
| 61 |
+
train_paths, train_labels = zip(*train_data)
|
| 62 |
+
val_paths, val_labels = zip(*val_data)
|
| 63 |
+
|
| 64 |
+
train_dataset = DeepfakeDataset(file_paths=list(train_paths), labels=list(train_labels), phase='train')
|
| 65 |
+
val_dataset = DeepfakeDataset(file_paths=list(val_paths), labels=list(val_labels), phase='val')
|
| 66 |
+
|
| 67 |
+
# Dataloaders - Use Config.BATCH_SIZE but ensure it fits GPU
|
| 68 |
+
train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True,
|
| 69 |
+
num_workers=Config.NUM_WORKERS,
|
| 70 |
+
pin_memory=True if device.type=='cuda' else False,
|
| 71 |
+
persistent_workers=True if Config.NUM_WORKERS > 0 else False)
|
| 72 |
+
val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False,
|
| 73 |
+
num_workers=Config.NUM_WORKERS,
|
| 74 |
+
pin_memory=True if device.type=='cuda' else False,
|
| 75 |
+
persistent_workers=True if Config.NUM_WORKERS > 0 else False)
|
| 76 |
+
|
| 77 |
+
# Load pre-trained model
|
| 78 |
+
print("\n🔄 Loading pre-trained model (best_model)...")
|
| 79 |
+
model = DeepfakeDetector(pretrained=False).to(device)
|
| 80 |
+
|
| 81 |
+
# Try to load the best model found so far
|
| 82 |
+
checkpoint_path = os.path.join(Config.CHECKPOINT_DIR, "best_model.safetensors")
|
| 83 |
+
if not os.path.exists(checkpoint_path):
|
| 84 |
+
# Fallback to .pth if safetensors logic above failed or not used previously
|
| 85 |
+
checkpoint_path = os.path.join(Config.CHECKPOINT_DIR, "best_model.pth")
|
| 86 |
+
|
| 87 |
+
if os.path.exists(checkpoint_path):
|
| 88 |
+
try:
|
| 89 |
+
if checkpoint_path.endswith(".safetensors"):
|
| 90 |
+
load_model(model, checkpoint_path, strict=False)
|
| 91 |
+
else:
|
| 92 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
| 93 |
+
print(f"✅ Loaded checkpoint: {checkpoint_path}")
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"⚠️ Error loading checkpoint: {e}")
|
| 96 |
+
print("Starting from random weights (not ideal for fine-tuning!)")
|
| 97 |
+
else:
|
| 98 |
+
print("⚠️ No checkpoint found! Starting from random weights.")
|
| 99 |
+
|
| 100 |
+
model.to(device)
|
| 101 |
+
|
| 102 |
+
# Optimization with LOWER learning rate for fine-tuning
|
| 103 |
+
FINETUNE_LR = 1e-5 # 10x lower than original training
|
| 104 |
+
FINETUNE_EPOCHS = 5 # Give it a few epochs to adapt
|
| 105 |
+
|
| 106 |
+
print(f"\n📝 Fine-tuning settings:")
|
| 107 |
+
print(f" Learning Rate: {FINETUNE_LR} (Low LR for fine-tuning)")
|
| 108 |
+
print(f" Epochs: {FINETUNE_EPOCHS}")
|
| 109 |
+
print(f" Batch Size: {Config.BATCH_SIZE}")
|
| 110 |
+
|
| 111 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 112 |
+
optimizer = optim.AdamW(model.parameters(), lr=FINETUNE_LR, weight_decay=Config.WEIGHT_DECAY)
|
| 113 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
|
| 114 |
+
|
| 115 |
+
# Loop
|
| 116 |
+
best_acc = 0.0
|
| 117 |
+
|
| 118 |
+
for epoch in range(FINETUNE_EPOCHS):
|
| 119 |
+
model.train()
|
| 120 |
+
train_loss = 0.0
|
| 121 |
+
train_correct = 0
|
| 122 |
+
train_total = 0
|
| 123 |
+
|
| 124 |
+
loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{FINETUNE_EPOCHS}")
|
| 125 |
+
for images, labels in loop:
|
| 126 |
+
images = images.to(device)
|
| 127 |
+
labels = labels.to(device).unsqueeze(1)
|
| 128 |
+
|
| 129 |
+
optimizer.zero_grad()
|
| 130 |
+
outputs = model(images)
|
| 131 |
+
loss = criterion(outputs, labels)
|
| 132 |
+
loss.backward()
|
| 133 |
+
optimizer.step()
|
| 134 |
+
|
| 135 |
+
train_loss += loss.item()
|
| 136 |
+
preds = (torch.sigmoid(outputs) > 0.5).float()
|
| 137 |
+
correct = (preds == labels).sum().item()
|
| 138 |
+
train_correct += correct
|
| 139 |
+
train_total += labels.size(0)
|
| 140 |
+
|
| 141 |
+
loop.set_postfix(loss=loss.item(), acc=correct/labels.size(0) if labels.size(0) > 0 else 0)
|
| 142 |
+
|
| 143 |
+
train_acc = train_correct / train_total if train_total > 0 else 0
|
| 144 |
+
print(f"Epoch {epoch+1} Train Loss: {train_loss/len(train_loader):.4f} Acc: {train_acc:.4f}")
|
| 145 |
+
|
| 146 |
+
# Save checkpoint after every epoch
|
| 147 |
+
save_checkpoint(model, epoch+1, train_acc, name=f"finetuned_datasetA_ep{epoch+1}")
|
| 148 |
+
|
| 149 |
+
# Validation
|
| 150 |
+
if len(val_dataset) > 0:
|
| 151 |
+
val_loss, val_acc = validate(model, val_loader, criterion, device)
|
| 152 |
+
print(f"Epoch {epoch+1} Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
|
| 153 |
+
|
| 154 |
+
scheduler.step(val_acc)
|
| 155 |
+
|
| 156 |
+
# Save best model if validation accuracy improved
|
| 157 |
+
if val_acc > best_acc:
|
| 158 |
+
best_acc = val_acc
|
| 159 |
+
print(f"⭐ New best model! Validation Accuracy: {val_acc:.4f}")
|
| 160 |
+
save_checkpoint(model, epoch+1, val_acc, name="best_finetuned_datasetA")
|
| 161 |
+
|
| 162 |
+
print(f"\n🎉 Fine-tuning Complete!")
|
| 163 |
+
print(f"Best Validation Accuracy: {best_acc:.4f}")
|
| 164 |
+
print(f"\n💾 Checkpoints saved in: {Config.CHECKPOINT_DIR}")
|
| 165 |
+
|
| 166 |
+
def validate(model, loader, criterion, device):
|
| 167 |
+
model.eval()
|
| 168 |
+
val_loss = 0.0
|
| 169 |
+
correct = 0
|
| 170 |
+
total = 0
|
| 171 |
+
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
for images, labels in loader:
|
| 174 |
+
images = images.to(device)
|
| 175 |
+
labels = labels.to(device).unsqueeze(1)
|
| 176 |
+
|
| 177 |
+
outputs = model(images)
|
| 178 |
+
loss = criterion(outputs, labels)
|
| 179 |
+
|
| 180 |
+
val_loss += loss.item()
|
| 181 |
+
preds = (torch.sigmoid(outputs) > 0.5).float()
|
| 182 |
+
correct += (preds == labels).sum().item()
|
| 183 |
+
total += labels.size(0)
|
| 184 |
+
|
| 185 |
+
return val_loss / len(loader), correct / total
|
| 186 |
+
|
| 187 |
+
def save_checkpoint(model, epoch, acc, name="checkpoint"):
|
| 188 |
+
state_dict = model.state_dict()
|
| 189 |
+
filename = f"{name}.safetensors"
|
| 190 |
+
path = os.path.join(Config.CHECKPOINT_DIR, filename)
|
| 191 |
+
|
| 192 |
+
if SAFETENSORS_AVAILABLE:
|
| 193 |
+
try:
|
| 194 |
+
from safetensors.torch import save_model
|
| 195 |
+
save_model(model, path)
|
| 196 |
+
print(f"✅ Saved: {filename}")
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"SafeTensors save failed, falling back to .pth: {e}")
|
| 199 |
+
torch.save(state_dict, path.replace(".safetensors", ".pth"))
|
| 200 |
+
else:
|
| 201 |
+
torch.save(state_dict, path.replace(".safetensors", ".pth"))
|
| 202 |
+
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
finetune()
|
model/src/inference.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
import cv2
|
| 4 |
+
import os
|
| 5 |
+
import glob
|
| 6 |
+
import numpy as np
|
| 7 |
+
import ssl
|
| 8 |
+
# Disable SSL verification
|
| 9 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
| 10 |
+
|
| 11 |
+
import albumentations as A
|
| 12 |
+
from albumentations.pytorch import ToTensorV2
|
| 13 |
+
from src.models import DeepfakeDetector
|
| 14 |
+
from src.config import Config
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from safetensors.torch import load_file
|
| 18 |
+
SAFETENSORS_AVAILABLE = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
SAFETENSORS_AVAILABLE = False
|
| 21 |
+
|
| 22 |
+
def get_transform():
|
| 23 |
+
return A.Compose([
|
| 24 |
+
A.Resize(Config.IMAGE_SIZE, Config.IMAGE_SIZE),
|
| 25 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 26 |
+
ToTensorV2(),
|
| 27 |
+
])
|
| 28 |
+
|
| 29 |
+
def load_models(checkpoints_arg, device):
|
| 30 |
+
"""
|
| 31 |
+
Load one or multiple models for ensemble inference.
|
| 32 |
+
checkpoints_arg: Comma-separated list of paths, or single path, or directory.
|
| 33 |
+
"""
|
| 34 |
+
paths = []
|
| 35 |
+
if os.path.isdir(checkpoints_arg):
|
| 36 |
+
paths = glob.glob(os.path.join(checkpoints_arg, "*.safetensors"))
|
| 37 |
+
if not paths:
|
| 38 |
+
paths = glob.glob(os.path.join(checkpoints_arg, "*.pth"))
|
| 39 |
+
else:
|
| 40 |
+
paths = checkpoints_arg.split(',')
|
| 41 |
+
|
| 42 |
+
models = []
|
| 43 |
+
print(f"Loading {len(paths)} model(s) for ensemble inference...")
|
| 44 |
+
|
| 45 |
+
for path in paths:
|
| 46 |
+
path = path.strip()
|
| 47 |
+
if not path: continue
|
| 48 |
+
|
| 49 |
+
print(f"Loading: {path}")
|
| 50 |
+
model = DeepfakeDetector(pretrained=False) # Structure only
|
| 51 |
+
model.to(device)
|
| 52 |
+
model.eval()
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
if path.endswith(".safetensors") and SAFETENSORS_AVAILABLE:
|
| 56 |
+
state_dict = load_file(path)
|
| 57 |
+
else:
|
| 58 |
+
state_dict = torch.load(path, map_location=device)
|
| 59 |
+
model.load_state_dict(state_dict)
|
| 60 |
+
models.append(model)
|
| 61 |
+
print(f"✅ Successfully loaded: {os.path.basename(path)}")
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"❌ Failed to load {path}: {e}")
|
| 64 |
+
import traceback
|
| 65 |
+
traceback.print_exc()
|
| 66 |
+
|
| 67 |
+
if not models:
|
| 68 |
+
# Fallback for testing if no checkpoint exists yet
|
| 69 |
+
print("Warning: No valid checkoints loaded. Using random initialization for testing flow.")
|
| 70 |
+
model = DeepfakeDetector(pretrained=False).to(device)
|
| 71 |
+
model.eval()
|
| 72 |
+
models.append(model)
|
| 73 |
+
|
| 74 |
+
return models
|
| 75 |
+
|
| 76 |
+
def predict_ensemble(models, image_path, device, transform):
|
| 77 |
+
try:
|
| 78 |
+
image = cv2.imread(image_path)
|
| 79 |
+
if image is None:
|
| 80 |
+
return None, "Error: Could not read image"
|
| 81 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 82 |
+
except Exception as e:
|
| 83 |
+
return None, str(e)
|
| 84 |
+
|
| 85 |
+
augmented = transform(image=image)
|
| 86 |
+
image_tensor = augmented['image'].unsqueeze(0).to(device)
|
| 87 |
+
|
| 88 |
+
probs = []
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
for model in models:
|
| 91 |
+
logits = model(image_tensor)
|
| 92 |
+
prob = torch.sigmoid(logits).item()
|
| 93 |
+
probs.append(prob)
|
| 94 |
+
|
| 95 |
+
# Ensemble Strategy: Average Probability
|
| 96 |
+
avg_prob = sum(probs) / len(probs)
|
| 97 |
+
return avg_prob, None
|
| 98 |
+
|
| 99 |
+
def main():
|
| 100 |
+
parser = argparse.ArgumentParser(description="Deepfake Detection Inference (Ensemble Support)")
|
| 101 |
+
parser.add_argument("--source", type=str, required=True, help="Path to image or directory")
|
| 102 |
+
parser.add_argument("--checkpoints", type=str, default="results/checkpoints", help="Path to checkpoint file, list of files (comma-separated), or directory")
|
| 103 |
+
parser.add_argument("--device", type=str, default=Config.DEVICE, help="Device to use (cuda/mps/cpu)")
|
| 104 |
+
args = parser.parse_args()
|
| 105 |
+
|
| 106 |
+
device = torch.device(args.device)
|
| 107 |
+
print(f"Using device: {device}")
|
| 108 |
+
|
| 109 |
+
# Load Models
|
| 110 |
+
models = load_models(args.checkpoints, device)
|
| 111 |
+
transform = get_transform()
|
| 112 |
+
|
| 113 |
+
# Process Source
|
| 114 |
+
if os.path.isdir(args.source):
|
| 115 |
+
files = glob.glob(os.path.join(args.source, "*.*"))
|
| 116 |
+
# Filter images
|
| 117 |
+
files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
|
| 118 |
+
else:
|
| 119 |
+
files = [args.source]
|
| 120 |
+
|
| 121 |
+
print(f"Processing {len(files)} images with {len(models)} model(s)...")
|
| 122 |
+
print("-" * 65)
|
| 123 |
+
print(f"{'Image Name':<40} | {'Prediction':<10} | {'Confidence':<10}")
|
| 124 |
+
print("-" * 65)
|
| 125 |
+
|
| 126 |
+
for file_path in files:
|
| 127 |
+
prob, error = predict_ensemble(models, file_path, device, transform)
|
| 128 |
+
if error:
|
| 129 |
+
print(f"{os.path.basename(file_path):<40} | ERROR: {error}")
|
| 130 |
+
continue
|
| 131 |
+
|
| 132 |
+
is_fake = prob > 0.5
|
| 133 |
+
label = "FAKE" if is_fake else "REAL"
|
| 134 |
+
confidence = prob if is_fake else 1 - prob
|
| 135 |
+
|
| 136 |
+
print(f"{os.path.basename(file_path):<40} | {label:<10} | {confidence:.2%}")
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
main()
|
model/src/models.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.models as models
|
| 5 |
+
import numpy as np
|
| 6 |
+
from src.utils import get_fft_feature
|
| 7 |
+
|
| 8 |
+
class RGBBranch(nn.Module):
|
| 9 |
+
def __init__(self, pretrained=True):
|
| 10 |
+
super().__init__()
|
| 11 |
+
# EfficientNet V2 Small: Robust and efficient spatial features
|
| 12 |
+
weights = models.EfficientNet_V2_S_Weights.DEFAULT if pretrained else None
|
| 13 |
+
self.net = models.efficientnet_v2_s(weights=weights)
|
| 14 |
+
# Extract features before classification head
|
| 15 |
+
self.features = self.net.features
|
| 16 |
+
self.avgpool = self.net.avgpool
|
| 17 |
+
self.out_dim = 1280
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
x = self.features(x)
|
| 21 |
+
x = self.avgpool(x)
|
| 22 |
+
x = torch.flatten(x, 1)
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
class FreqBranch(nn.Module):
|
| 26 |
+
def __init__(self):
|
| 27 |
+
super().__init__()
|
| 28 |
+
# Simple CNN to analyze frequency domain patterns
|
| 29 |
+
self.net = nn.Sequential(
|
| 30 |
+
nn.Conv2d(3, 32, kernel_size=3, padding=1),
|
| 31 |
+
nn.BatchNorm2d(32),
|
| 32 |
+
nn.ReLU(),
|
| 33 |
+
nn.MaxPool2d(2),
|
| 34 |
+
|
| 35 |
+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
| 36 |
+
nn.BatchNorm2d(64),
|
| 37 |
+
nn.ReLU(),
|
| 38 |
+
nn.MaxPool2d(2),
|
| 39 |
+
|
| 40 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
| 41 |
+
nn.BatchNorm2d(128),
|
| 42 |
+
nn.ReLU(),
|
| 43 |
+
nn.AdaptiveAvgPool2d((1,1))
|
| 44 |
+
)
|
| 45 |
+
self.out_dim = 128
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
return torch.flatten(self.net(x), 1)
|
| 49 |
+
|
| 50 |
+
class PatchBranch(nn.Module):
|
| 51 |
+
def __init__(self):
|
| 52 |
+
super().__init__()
|
| 53 |
+
# Analyzes local patches for inconsistencies
|
| 54 |
+
# Shared lightweight CNN for each patch
|
| 55 |
+
self.patch_encoder = nn.Sequential(
|
| 56 |
+
nn.Conv2d(3, 16, kernel_size=3, padding=1),
|
| 57 |
+
nn.ReLU(),
|
| 58 |
+
nn.MaxPool2d(2), # 64 -> 32
|
| 59 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
| 60 |
+
nn.ReLU(),
|
| 61 |
+
nn.MaxPool2d(2), # 32 -> 16
|
| 62 |
+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
| 63 |
+
nn.ReLU(),
|
| 64 |
+
nn.AdaptiveAvgPool2d((1,1))
|
| 65 |
+
)
|
| 66 |
+
self.out_dim = 64
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
# x: (B, 3, 256, 256)
|
| 70 |
+
# Create 4x4=16 patches of size 64x64
|
| 71 |
+
# Unfold logic: kernel_size=64, stride=64
|
| 72 |
+
patches = x.unfold(2, 64, 64).unfold(3, 64, 64)
|
| 73 |
+
# patches shape: (B, 3, 4, 4, 64, 64)
|
| 74 |
+
B, C, H_grid, W_grid, H_patch, W_patch = patches.shape
|
| 75 |
+
|
| 76 |
+
# Merge batch and grid dimensions for parallel processing
|
| 77 |
+
patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
|
| 78 |
+
patches = patches.view(B * H_grid * W_grid, C, H_patch, W_patch)
|
| 79 |
+
|
| 80 |
+
# Encode
|
| 81 |
+
feats = self.patch_encoder(patches) # (B*16, 64, 1, 1)
|
| 82 |
+
feats = torch.flatten(feats, 1) # (B*16, 64)
|
| 83 |
+
|
| 84 |
+
# Aggregate back to B
|
| 85 |
+
feats = feats.view(B, H_grid * W_grid, -1) # (B, 16, 64)
|
| 86 |
+
|
| 87 |
+
# Max pool over patches to capture the "most fake" patch signal
|
| 88 |
+
feats_max, _ = torch.max(feats, dim=1) # (B, 64)
|
| 89 |
+
|
| 90 |
+
return feats_max
|
| 91 |
+
|
| 92 |
+
class ViTBranch(nn.Module):
|
| 93 |
+
def __init__(self, pretrained=True):
|
| 94 |
+
super().__init__()
|
| 95 |
+
# Swin Transformer Tiny: Capture long-range dependencies
|
| 96 |
+
weights = models.Swin_V2_T_Weights.DEFAULT if pretrained else None
|
| 97 |
+
self.net = models.swin_v2_t(weights=weights)
|
| 98 |
+
|
| 99 |
+
# Replace head with Identity to get features
|
| 100 |
+
self.out_dim = self.net.head.in_features
|
| 101 |
+
self.net.head = nn.Identity()
|
| 102 |
+
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
return self.net(x)
|
| 105 |
+
|
| 106 |
+
class DeepfakeDetector(nn.Module):
|
| 107 |
+
def __init__(self, pretrained=True):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.rgb_branch = RGBBranch(pretrained)
|
| 110 |
+
self.freq_branch = FreqBranch()
|
| 111 |
+
self.patch_branch = PatchBranch()
|
| 112 |
+
self.vit_branch = ViTBranch(pretrained)
|
| 113 |
+
|
| 114 |
+
input_dim = (self.rgb_branch.out_dim +
|
| 115 |
+
self.freq_branch.out_dim +
|
| 116 |
+
self.patch_branch.out_dim +
|
| 117 |
+
self.vit_branch.out_dim)
|
| 118 |
+
|
| 119 |
+
# Confidence-based fusion head
|
| 120 |
+
self.classifier = nn.Sequential(
|
| 121 |
+
nn.Linear(input_dim, 512),
|
| 122 |
+
nn.BatchNorm1d(512),
|
| 123 |
+
nn.ReLU(),
|
| 124 |
+
nn.Dropout(0.5),
|
| 125 |
+
nn.Linear(512, 1)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def forward(self, x):
|
| 129 |
+
# 1. Spatial Analysis
|
| 130 |
+
rgb_feat = self.rgb_branch(x)
|
| 131 |
+
|
| 132 |
+
# 2. Frequency Analysis
|
| 133 |
+
freq_img = get_fft_feature(x)
|
| 134 |
+
freq_feat = self.freq_branch(freq_img)
|
| 135 |
+
|
| 136 |
+
# 3. Patch Analysis (Local Inconsistencies)
|
| 137 |
+
patch_feat = self.patch_branch(x)
|
| 138 |
+
|
| 139 |
+
# 4. Global Consistency (ViT)
|
| 140 |
+
vit_feat = self.vit_branch(x)
|
| 141 |
+
|
| 142 |
+
# 5. Feature Fusion
|
| 143 |
+
combined = torch.cat([rgb_feat, freq_feat, patch_feat, vit_feat], dim=1)
|
| 144 |
+
|
| 145 |
+
return self.classifier(combined)
|
| 146 |
+
|
| 147 |
+
def get_heatmap(self, x):
|
| 148 |
+
"""Generate Grad-CAM heatmap for the input image"""
|
| 149 |
+
# We'll use the RGB branch for visualization as it contains spatial features
|
| 150 |
+
# Enable gradients for the input if needed, though typically we hook into layers
|
| 151 |
+
|
| 152 |
+
# 1. Forward pass through RGB branch
|
| 153 |
+
# We need to register a hook on the last conv layer of the efficientnet features
|
| 154 |
+
# Target layer: self.rgb_branch.features[-1] (the last block)
|
| 155 |
+
|
| 156 |
+
gradients = []
|
| 157 |
+
activations = []
|
| 158 |
+
|
| 159 |
+
def backward_hook(module, grad_input, grad_output):
|
| 160 |
+
gradients.append(grad_output[0])
|
| 161 |
+
|
| 162 |
+
def forward_hook(module, input, output):
|
| 163 |
+
activations.append(output)
|
| 164 |
+
|
| 165 |
+
# Register hooks on the last convolutional layer of RGB branch
|
| 166 |
+
target_layer = self.rgb_branch.features[-1]
|
| 167 |
+
hook_b = target_layer.register_full_backward_hook(backward_hook)
|
| 168 |
+
hook_f = target_layer.register_forward_hook(forward_hook)
|
| 169 |
+
|
| 170 |
+
# Forward pass
|
| 171 |
+
logits = self(x)
|
| 172 |
+
pred_idx = 0 # Binary classification, output is scalar logic
|
| 173 |
+
|
| 174 |
+
# Backward pass
|
| 175 |
+
self.zero_grad()
|
| 176 |
+
logits.backward(retain_graph=True)
|
| 177 |
+
|
| 178 |
+
# Get gradients and activations
|
| 179 |
+
pooled_gradients = torch.mean(gradients[0], dim=[0, 2, 3])
|
| 180 |
+
activation = activations[0][0]
|
| 181 |
+
|
| 182 |
+
# Weight activations by gradients (Grad-CAM)
|
| 183 |
+
for i in range(activation.shape[0]):
|
| 184 |
+
activation[i, :, :] *= pooled_gradients[i]
|
| 185 |
+
|
| 186 |
+
heatmap = torch.mean(activation, dim=0).cpu().detach().numpy()
|
| 187 |
+
heatmap = np.maximum(heatmap, 0) # ReLU
|
| 188 |
+
|
| 189 |
+
# Normalize
|
| 190 |
+
if np.max(heatmap) != 0:
|
| 191 |
+
heatmap /= np.max(heatmap)
|
| 192 |
+
|
| 193 |
+
# Remove hooks
|
| 194 |
+
hook_b.remove()
|
| 195 |
+
hook_f.remove()
|
| 196 |
+
|
| 197 |
+
return heatmap
|
model/src/test_dataloading.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from src.config import Config
|
| 4 |
+
from src.dataset import DeepfakeDataset
|
| 5 |
+
|
| 6 |
+
def test_dataloading():
|
| 7 |
+
print("Testing Data Loading & Splitting Logic...")
|
| 8 |
+
Config.setup()
|
| 9 |
+
|
| 10 |
+
print(f"Data Path: {Config.TRAIN_DATA_PATH}")
|
| 11 |
+
|
| 12 |
+
# 1. Test Scan
|
| 13 |
+
paths, labels = DeepfakeDataset.scan_directory(Config.TRAIN_DATA_PATH)
|
| 14 |
+
total_files = len(paths)
|
| 15 |
+
print(f"Total images found: {total_files}")
|
| 16 |
+
|
| 17 |
+
if total_files == 0:
|
| 18 |
+
print("[FAIL] No images found! Check path.")
|
| 19 |
+
return
|
| 20 |
+
|
| 21 |
+
# 2. Simulate Split Logic
|
| 22 |
+
combined = list(zip(paths, labels))
|
| 23 |
+
random.shuffle(combined)
|
| 24 |
+
split_idx = int(len(combined) * 0.8)
|
| 25 |
+
train_data = combined[:split_idx]
|
| 26 |
+
val_data = combined[split_idx:]
|
| 27 |
+
|
| 28 |
+
print(f"Train Split: {len(train_data)} images")
|
| 29 |
+
print(f"Val Split: {len(val_data)} images")
|
| 30 |
+
|
| 31 |
+
# 3. Test Dataset Initialization
|
| 32 |
+
try:
|
| 33 |
+
train_paths, train_labels = zip(*train_data)
|
| 34 |
+
ds = DeepfakeDataset(file_paths=list(train_paths), labels=list(train_labels), phase='train')
|
| 35 |
+
print(f"[Pass] Dataset initialized with {len(ds)} samples.")
|
| 36 |
+
|
| 37 |
+
# Test Get Item
|
| 38 |
+
img, lbl = ds[0]
|
| 39 |
+
print(f"[Pass] Loaded sample image. Shape: {img.shape}, Label: {lbl}")
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(f"[FAIL] Dataset initialization or loading error: {e}")
|
| 42 |
+
return
|
| 43 |
+
|
| 44 |
+
print("\nSUCCESS: Data loading verification passed!")
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
test_dataloading()
|
model/src/test_dryrun.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from src.models import DeepfakeDetector
|
| 4 |
+
from src.config import Config
|
| 5 |
+
|
| 6 |
+
def test_model_architecture():
|
| 7 |
+
print("Testing DeepfakeDetector Architecture...")
|
| 8 |
+
|
| 9 |
+
# Check device
|
| 10 |
+
device = torch.device("cpu") # Test on CPU for simplicity or Config.DEVICE
|
| 11 |
+
print(f"Device: {device}")
|
| 12 |
+
|
| 13 |
+
# Initialize Model
|
| 14 |
+
try:
|
| 15 |
+
model = DeepfakeDetector(pretrained=False).to(device)
|
| 16 |
+
print("[Pass] Model Initialization")
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f"[Fail] Model Initialization: {e}")
|
| 19 |
+
return
|
| 20 |
+
|
| 21 |
+
# Create dummy input
|
| 22 |
+
batch_size = 2
|
| 23 |
+
x = torch.randn(batch_size, 3, Config.IMAGE_SIZE, Config.IMAGE_SIZE).to(device)
|
| 24 |
+
print(f"Input Shape: {x.shape}")
|
| 25 |
+
|
| 26 |
+
# Forward Pass
|
| 27 |
+
try:
|
| 28 |
+
out = model(x)
|
| 29 |
+
print(f"Output Shape: {out.shape}")
|
| 30 |
+
|
| 31 |
+
if out.shape == (batch_size, 1):
|
| 32 |
+
print("[Pass] Output Shape Correct")
|
| 33 |
+
else:
|
| 34 |
+
print(f"[Fail] Output Shape Incorrect. Expected ({batch_size}, 1), got {out.shape}")
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"[Fail] Forward Pass: {e}")
|
| 37 |
+
# Debug trace
|
| 38 |
+
import traceback
|
| 39 |
+
traceback.print_exc()
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
# Loss and Backward
|
| 43 |
+
try:
|
| 44 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 45 |
+
target = torch.ones(batch_size, 1).to(device)
|
| 46 |
+
loss = criterion(out, target)
|
| 47 |
+
loss.backward()
|
| 48 |
+
print(f"[Pass] Backward Pass (Loss: {loss.item():.4f})")
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f"[Fail] Backward Pass: {e}")
|
| 51 |
+
return
|
| 52 |
+
|
| 53 |
+
print("\nSUCCESS: Model architecture verification passed!")
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
test_model_architecture()
|
model/src/train.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import random
|
| 8 |
+
import ssl
|
| 9 |
+
# Disable SSL verification for downloading pretrained weights
|
| 10 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
| 11 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 12 |
+
|
| 13 |
+
from src.config import Config
|
| 14 |
+
from src.models import DeepfakeDetector
|
| 15 |
+
from src.dataset import DeepfakeDataset
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from safetensors.torch import save_file, load_file
|
| 19 |
+
SAFETENSORS_AVAILABLE = True
|
| 20 |
+
except ImportError:
|
| 21 |
+
SAFETENSORS_AVAILABLE = False
|
| 22 |
+
print("Warning: safetensors not installed. Checkpoints will be saved as .pt")
|
| 23 |
+
|
| 24 |
+
def train():
|
| 25 |
+
# Setup
|
| 26 |
+
Config.setup()
|
| 27 |
+
device = torch.device(Config.DEVICE)
|
| 28 |
+
|
| 29 |
+
# --- Data Loading with Automatic Split ---
|
| 30 |
+
if Config.TRAIN_DATA_PATH == Config.TEST_DATA_PATH:
|
| 31 |
+
print("Train and Test paths are identical. Performing automatic 80/20 shuffle split...")
|
| 32 |
+
all_paths, all_labels = DeepfakeDataset.scan_directory(Config.TRAIN_DATA_PATH)
|
| 33 |
+
|
| 34 |
+
if len(all_paths) == 0:
|
| 35 |
+
print(f"No images found in {Config.TRAIN_DATA_PATH}")
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
# Combine and shuffle
|
| 39 |
+
combined = list(zip(all_paths, all_labels))
|
| 40 |
+
random.shuffle(combined)
|
| 41 |
+
|
| 42 |
+
split_idx = int(len(combined) * 0.8)
|
| 43 |
+
train_data = combined[:split_idx]
|
| 44 |
+
val_data = combined[split_idx:]
|
| 45 |
+
|
| 46 |
+
train_paths, train_labels = zip(*train_data)
|
| 47 |
+
val_paths, val_labels = zip(*val_data)
|
| 48 |
+
|
| 49 |
+
train_dataset = DeepfakeDataset(file_paths=list(train_paths), labels=list(train_labels), phase='train')
|
| 50 |
+
val_dataset = DeepfakeDataset(file_paths=list(val_paths), labels=list(val_labels), phase='val')
|
| 51 |
+
else:
|
| 52 |
+
# Standard folder-based loading
|
| 53 |
+
train_dataset = DeepfakeDataset(root_dir=Config.TRAIN_DATA_PATH, phase='train')
|
| 54 |
+
val_dataset = DeepfakeDataset(root_dir=Config.TEST_DATA_PATH, phase='val')
|
| 55 |
+
|
| 56 |
+
# Dataloaders
|
| 57 |
+
train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True,
|
| 58 |
+
num_workers=Config.NUM_WORKERS,
|
| 59 |
+
pin_memory=True if device.type=='cuda' else False,
|
| 60 |
+
persistent_workers=True if Config.NUM_WORKERS > 0 else False)
|
| 61 |
+
val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False,
|
| 62 |
+
num_workers=Config.NUM_WORKERS,
|
| 63 |
+
pin_memory=True if device.type=='cuda' else False,
|
| 64 |
+
persistent_workers=True if Config.NUM_WORKERS > 0 else False)
|
| 65 |
+
|
| 66 |
+
# Model
|
| 67 |
+
print("Initializing Multi-Branch DeepfakeDetector...")
|
| 68 |
+
model = DeepfakeDetector(pretrained=True).to(device)
|
| 69 |
+
|
| 70 |
+
# Optimization
|
| 71 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 72 |
+
optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY)
|
| 73 |
+
# Optimization
|
| 74 |
+
criterion = nn.BCEWithLogitsLoss()
|
| 75 |
+
optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=Config.WEIGHT_DECAY)
|
| 76 |
+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
|
| 77 |
+
|
| 78 |
+
# Enable AMP only for CUDA (Windows NVIDIA)
|
| 79 |
+
use_amp = (Config.DEVICE == 'cuda')
|
| 80 |
+
scaler = GradScaler() if use_amp else None
|
| 81 |
+
if use_amp:
|
| 82 |
+
print("🚀 Mixed Precision (AMP) Enabled for RTX GPU")
|
| 83 |
+
else:
|
| 84 |
+
print("🐌 Standard Precision (No AMP) for CPU/MPS")
|
| 85 |
+
|
| 86 |
+
# Resume from checkpoint if exists
|
| 87 |
+
start_epoch = 0
|
| 88 |
+
best_acc = 0.0
|
| 89 |
+
|
| 90 |
+
# Priority:
|
| 91 |
+
# 1. best_model.safetensors (if we crashed mid-training)
|
| 92 |
+
# 2. patched_model.safetensors (the model we want to improve)
|
| 93 |
+
|
| 94 |
+
resume_path = os.path.join(Config.CHECKPOINT_DIR, "best_model.safetensors")
|
| 95 |
+
if not os.path.exists(resume_path):
|
| 96 |
+
# Look for latest epoch checkpoint
|
| 97 |
+
import glob
|
| 98 |
+
import re
|
| 99 |
+
checkpoints = glob.glob(os.path.join(Config.CHECKPOINT_DIR, "checkpoint_ep*.safetensors"))
|
| 100 |
+
if checkpoints:
|
| 101 |
+
# Sort by epoch number
|
| 102 |
+
def get_epoch(p):
|
| 103 |
+
match = re.search(r"checkpoint_ep(\d+)", p)
|
| 104 |
+
return int(match.group(1)) if match else 0
|
| 105 |
+
|
| 106 |
+
latest_ckpt = max(checkpoints, key=get_epoch)
|
| 107 |
+
resume_path = latest_ckpt
|
| 108 |
+
start_epoch = get_epoch(latest_ckpt)
|
| 109 |
+
print(f"🔄 Auto-Resuming from latest epoch: {start_epoch}")
|
| 110 |
+
else:
|
| 111 |
+
resume_path = os.path.join(Config.CHECKPOINT_DIR, "patched_model.safetensors")
|
| 112 |
+
|
| 113 |
+
if os.path.exists(resume_path):
|
| 114 |
+
print(f"\n🔄 Found existing checkpoint: {resume_path}")
|
| 115 |
+
print("Auto-resuming to FINETUNE this model...")
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
if resume_path.endswith(".safetensors") and SAFETENSORS_AVAILABLE:
|
| 119 |
+
state_dict = load_file(resume_path)
|
| 120 |
+
else:
|
| 121 |
+
state_dict = torch.load(resume_path, map_location=device)
|
| 122 |
+
|
| 123 |
+
# Use strict=False to allow for minor architecture changes or missing keys
|
| 124 |
+
model.load_state_dict(state_dict, strict=False)
|
| 125 |
+
print("✅ Weights loaded. Starting Fine-Tuning.")
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"⚠ Failed to load checkpoint: {e}")
|
| 128 |
+
print("Starting from ImageNet weights.")
|
| 129 |
+
|
| 130 |
+
# Loop
|
| 131 |
+
|
| 132 |
+
for epoch in range(start_epoch, Config.EPOCHS):
|
| 133 |
+
model.train()
|
| 134 |
+
train_loss = 0.0
|
| 135 |
+
train_correct = 0
|
| 136 |
+
train_total = 0
|
| 137 |
+
|
| 138 |
+
loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{Config.EPOCHS}")
|
| 139 |
+
for images, labels in loop:
|
| 140 |
+
images = images.to(device)
|
| 141 |
+
labels = labels.to(device).unsqueeze(1)
|
| 142 |
+
|
| 143 |
+
optimizer.zero_grad()
|
| 144 |
+
|
| 145 |
+
if use_amp:
|
| 146 |
+
with autocast():
|
| 147 |
+
outputs = model(images)
|
| 148 |
+
loss = criterion(outputs, labels)
|
| 149 |
+
|
| 150 |
+
scaler.scale(loss).backward()
|
| 151 |
+
scaler.step(optimizer)
|
| 152 |
+
scaler.update()
|
| 153 |
+
else:
|
| 154 |
+
# Standard training for Mac/CPU
|
| 155 |
+
outputs = model(images)
|
| 156 |
+
loss = criterion(outputs, labels)
|
| 157 |
+
loss.backward()
|
| 158 |
+
optimizer.step()
|
| 159 |
+
|
| 160 |
+
train_loss += loss.item()
|
| 161 |
+
preds = (torch.sigmoid(outputs) > 0.5).float()
|
| 162 |
+
correct = (preds == labels).sum().item()
|
| 163 |
+
train_correct += correct
|
| 164 |
+
train_total += labels.size(0)
|
| 165 |
+
|
| 166 |
+
loop.set_postfix(loss=loss.item(), acc=correct/labels.size(0))
|
| 167 |
+
|
| 168 |
+
train_acc = train_correct / train_total if train_total > 0 else 0
|
| 169 |
+
print(f"Epoch {epoch+1} Train Loss: {train_loss/len(train_loader):.4f} Acc: {train_acc:.4f}")
|
| 170 |
+
|
| 171 |
+
# Save checkpoint after every epoch
|
| 172 |
+
save_checkpoint(model, epoch+1, train_acc, best=False)
|
| 173 |
+
|
| 174 |
+
# Validation
|
| 175 |
+
if len(val_dataset) > 0:
|
| 176 |
+
val_loss, val_acc = validate(model, val_loader, criterion, device)
|
| 177 |
+
print(f"Epoch {epoch+1} Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")
|
| 178 |
+
|
| 179 |
+
# Save best model if validation accuracy improved
|
| 180 |
+
if val_acc > best_acc:
|
| 181 |
+
best_acc = val_acc
|
| 182 |
+
print(f"⭐ New best model! Validation Accuracy: {val_acc:.4f}")
|
| 183 |
+
save_checkpoint(model, epoch+1, val_acc, best=True)
|
| 184 |
+
|
| 185 |
+
scheduler.step()
|
| 186 |
+
|
| 187 |
+
print(f"\n🎉 Training Complete!")
|
| 188 |
+
print(f"Best Validation Accuracy: {best_acc:.4f}")
|
| 189 |
+
|
| 190 |
+
def validate(model, loader, criterion, device):
|
| 191 |
+
model.eval()
|
| 192 |
+
val_loss = 0.0
|
| 193 |
+
correct = 0
|
| 194 |
+
total = 0
|
| 195 |
+
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
for images, labels in loader:
|
| 198 |
+
images = images.to(device)
|
| 199 |
+
labels = labels.to(device).unsqueeze(1)
|
| 200 |
+
|
| 201 |
+
outputs = model(images)
|
| 202 |
+
loss = criterion(outputs, labels)
|
| 203 |
+
|
| 204 |
+
val_loss += loss.item()
|
| 205 |
+
preds = (torch.sigmoid(outputs) > 0.5).float()
|
| 206 |
+
correct += (preds == labels).sum().item()
|
| 207 |
+
total += labels.size(0)
|
| 208 |
+
|
| 209 |
+
return val_loss / len(loader), correct / total
|
| 210 |
+
|
| 211 |
+
def save_checkpoint(model, epoch, acc, best=False):
|
| 212 |
+
state_dict = model.state_dict()
|
| 213 |
+
name = "best_model.safetensors" if best else f"checkpoint_ep{epoch}.safetensors"
|
| 214 |
+
path = os.path.join(Config.CHECKPOINT_DIR, name)
|
| 215 |
+
|
| 216 |
+
if SAFETENSORS_AVAILABLE:
|
| 217 |
+
try:
|
| 218 |
+
# Try with shared tensors support
|
| 219 |
+
from safetensors.torch import save_model
|
| 220 |
+
save_model(model, path)
|
| 221 |
+
print(f"Saved Checkpoint: {path}")
|
| 222 |
+
|
| 223 |
+
# 📝 Auto-Log to History
|
| 224 |
+
try:
|
| 225 |
+
from datetime import datetime
|
| 226 |
+
log_path = os.path.join(Config.PROJECT_ROOT, "TRAINING_HISTORY.md")
|
| 227 |
+
timestamp = datetime.now().strftime("%Y-%m-%d | %I:%M %p")
|
| 228 |
+
|
| 229 |
+
# Create file with header if doesn't exist
|
| 230 |
+
if not os.path.exists(log_path):
|
| 231 |
+
with open(log_path, "w", encoding="utf-8") as f:
|
| 232 |
+
f.write("# 📜 Training History Log\n\n")
|
| 233 |
+
f.write("| Date | Time | Model Name | Dataset | Epochs | Accuracy | Loss | Status |\n")
|
| 234 |
+
f.write("| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |\n")
|
| 235 |
+
|
| 236 |
+
# Append Entry to Summary Log
|
| 237 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 238 |
+
# Format: Date | Time | Name | Dataset | Epoch | Acc | Loss | Status
|
| 239 |
+
dataset_name = os.path.basename(Config.DATA_DIR)
|
| 240 |
+
entry = f"| **{timestamp.split(' | ')[0]}** | {timestamp.split(' | ')[1]} | {name} | {dataset_name} | {epoch} | {acc*100:.2f}% | N/A | ✅ Saved |\n"
|
| 241 |
+
f.write(entry)
|
| 242 |
+
print(f"📝 Logged to TRAINING_HISTORY.md")
|
| 243 |
+
|
| 244 |
+
# 📝 Detailed Lab Notebook Logging
|
| 245 |
+
detail_path = os.path.join(Config.PROJECT_ROOT, "DETAILED_HISTORY.md")
|
| 246 |
+
with open(detail_path, "a", encoding="utf-8") as f:
|
| 247 |
+
f.write(f"\n## Model: {name} (Epoch {epoch})\n")
|
| 248 |
+
f.write(f"| Feature | Detail |\n| :--- | :--- |\n")
|
| 249 |
+
f.write(f"| **Date** | {timestamp} |\n")
|
| 250 |
+
f.write(f"| **Training Accuracy** | {acc*100:.2f}% |\n")
|
| 251 |
+
f.write(f"| **Dataset** | {Config.DATA_DIR} |\n")
|
| 252 |
+
f.write(f"| **Batch Size** | {Config.BATCH_SIZE} |\n")
|
| 253 |
+
f.write(f"| **Optimizer** | AdamW (lr={Config.LEARNING_RATE}) |\n")
|
| 254 |
+
f.write(f"| **Device** | {Config.DEVICE.upper()} |\n")
|
| 255 |
+
f.write("\n---\n")
|
| 256 |
+
print(f"📘 Detailed log written to DETAILED_HISTORY.md")
|
| 257 |
+
|
| 258 |
+
except Exception as e:
|
| 259 |
+
print(f"⚠️ Failed to write log: {e}")
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
# Fallback to regular torch save if safetensors fails
|
| 263 |
+
print(f"SafeTensors save failed ({e}), falling back to .pth format")
|
| 264 |
+
torch.save(state_dict, path.replace(".safetensors", ".pth"))
|
| 265 |
+
print(f"Saved Checkpoint (Legacy): {path.replace('.safetensors', '.pth')}")
|
| 266 |
+
else:
|
| 267 |
+
torch.save(state_dict, path.replace(".safetensors", ".pth"))
|
| 268 |
+
print(f"Saved Checkpoint (Legacy): {path}")
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
train()
|
model/src/utils.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
def get_fft_feature(x):
|
| 6 |
+
"""
|
| 7 |
+
Computes the Log-Magnitude Spectrum of the input images.
|
| 8 |
+
Args:
|
| 9 |
+
x (torch.Tensor): Input images of shape (B, C, H, W)
|
| 10 |
+
Returns:
|
| 11 |
+
torch.Tensor: Log-magnitude spectrum of shape (B, C, H, W)
|
| 12 |
+
"""
|
| 13 |
+
if x.dim() == 3:
|
| 14 |
+
x = x.unsqueeze(0)
|
| 15 |
+
|
| 16 |
+
# Compute 2D FFT
|
| 17 |
+
fft = torch.fft.fft2(x, norm='ortho')
|
| 18 |
+
|
| 19 |
+
# Compute magnitude
|
| 20 |
+
mag = torch.abs(fft)
|
| 21 |
+
|
| 22 |
+
# Apply log scale (add epsilon for stability)
|
| 23 |
+
mag = torch.log(mag + 1e-6)
|
| 24 |
+
|
| 25 |
+
# Shift zero-frequency component to the center of the spectrum
|
| 26 |
+
mag = torch.fft.fftshift(mag, dim=(-2, -1))
|
| 27 |
+
|
| 28 |
+
return mag
|
| 29 |
+
|
| 30 |
+
def min_max_normalize(tensor):
|
| 31 |
+
"""
|
| 32 |
+
Min-max normalization for visualization or stable training provided tensor.
|
| 33 |
+
"""
|
| 34 |
+
min_val = tensor.min()
|
| 35 |
+
max_val = tensor.max()
|
| 36 |
+
return (tensor - min_val) / (max_val - min_val + 1e-8)
|
| 37 |
+
|
model/src/video_inference.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
def process_video(video_path, model, transform, device, frames_per_second=1):
|
| 8 |
+
"""
|
| 9 |
+
Process a video file frame-by-frame using the deepfake detection model.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
video_path (str): Path to the video file.
|
| 13 |
+
model (torch.nn.Module): Loaded PyTorch model.
|
| 14 |
+
transform (callable): Albumentations transform pipeline.
|
| 15 |
+
device (torch.device): Device to run inference on.
|
| 16 |
+
frames_per_second (int): Number of frames to sample per second of video.
|
| 17 |
+
Default is 1 to keep processing fast.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
dict: Aggregated results including verdict, average confidence, and frame-level details.
|
| 21 |
+
"""
|
| 22 |
+
if model is None:
|
| 23 |
+
return {"error": "Model not loaded"}
|
| 24 |
+
|
| 25 |
+
cap = cv2.VideoCapture(video_path)
|
| 26 |
+
if not cap.isOpened():
|
| 27 |
+
return {"error": "Could not open video file"}
|
| 28 |
+
|
| 29 |
+
# specific video properties
|
| 30 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 31 |
+
if fps <= 0: fps = 30 # Fallback
|
| 32 |
+
|
| 33 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 34 |
+
duration = total_frames / fps
|
| 35 |
+
|
| 36 |
+
# Calculate sampling interval (step size)
|
| 37 |
+
# If we want 1 frame per second, we step by 'fps' frames
|
| 38 |
+
step = int(fps / frames_per_second)
|
| 39 |
+
if step < 1: step = 1
|
| 40 |
+
|
| 41 |
+
frame_indices = []
|
| 42 |
+
probs = []
|
| 43 |
+
|
| 44 |
+
print(f"Processing video: {video_path}")
|
| 45 |
+
print(f"Duration: {duration:.2f}s, FPS: {fps}, Total Frames: {total_frames}")
|
| 46 |
+
print(f"Sampling every {step} frames...")
|
| 47 |
+
|
| 48 |
+
count = 0
|
| 49 |
+
processed_count = 0
|
| 50 |
+
|
| 51 |
+
suspicious_frames = [] # Store frames with high fake probability
|
| 52 |
+
|
| 53 |
+
while cap.isOpened():
|
| 54 |
+
ret, frame = cap.read()
|
| 55 |
+
if not ret:
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
if count % step == 0:
|
| 59 |
+
# Process this frame
|
| 60 |
+
try:
|
| 61 |
+
# Convert BGR (OpenCV) to RGB
|
| 62 |
+
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 63 |
+
|
| 64 |
+
# --- Face Extraction ---
|
| 65 |
+
# Load Haar Cascade (lazy load)
|
| 66 |
+
if not hasattr(process_video, "face_cascade"):
|
| 67 |
+
try:
|
| 68 |
+
cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
|
| 69 |
+
process_video.face_cascade = cv2.CascadeClassifier(cascade_path)
|
| 70 |
+
except:
|
| 71 |
+
process_video.face_cascade = None
|
| 72 |
+
|
| 73 |
+
face_crop = None
|
| 74 |
+
if process_video.face_cascade:
|
| 75 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
| 76 |
+
faces = process_video.face_cascade.detectMultiScale(
|
| 77 |
+
gray, scaleFactor=1.1, minNeighbors=5, minSize=(60, 60)
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
if len(faces) > 0:
|
| 81 |
+
# Find largest face
|
| 82 |
+
largest_face = max(faces, key=lambda rect: rect[2] * rect[3])
|
| 83 |
+
x, y, w, h = largest_face
|
| 84 |
+
|
| 85 |
+
# Add margin (20%)
|
| 86 |
+
margin = int(max(w, h) * 0.2)
|
| 87 |
+
x_start = max(x - margin, 0)
|
| 88 |
+
y_start = max(y - margin, 0)
|
| 89 |
+
x_end = min(x + w + margin, frame.shape[1])
|
| 90 |
+
y_end = min(y + h + margin, frame.shape[0])
|
| 91 |
+
|
| 92 |
+
face_crop = image[y_start:y_end, x_start:x_end]
|
| 93 |
+
|
| 94 |
+
# Use face crop if found, otherwise use full image
|
| 95 |
+
input_image = face_crop if face_crop is not None else image
|
| 96 |
+
|
| 97 |
+
# Apply transforms
|
| 98 |
+
augmented = transform(image=input_image)
|
| 99 |
+
image_tensor = augmented['image'].unsqueeze(0).to(device)
|
| 100 |
+
|
| 101 |
+
# Inference
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
logits = model(image_tensor)
|
| 104 |
+
prob = torch.sigmoid(logits).item()
|
| 105 |
+
|
| 106 |
+
probs.append(prob)
|
| 107 |
+
frame_indices.append(count)
|
| 108 |
+
processed_count += 1
|
| 109 |
+
|
| 110 |
+
# If highly fake, store metadata (timestamp)
|
| 111 |
+
if prob > 0.5:
|
| 112 |
+
timestamp = count / fps
|
| 113 |
+
suspicious_frames.append({
|
| 114 |
+
"timestamp": round(timestamp, 2),
|
| 115 |
+
"frame_index": count,
|
| 116 |
+
"fake_prob": round(prob, 4)
|
| 117 |
+
})
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"Error processing frame {count}: {e}")
|
| 121 |
+
|
| 122 |
+
count += 1
|
| 123 |
+
|
| 124 |
+
cap.release()
|
| 125 |
+
|
| 126 |
+
if processed_count == 0:
|
| 127 |
+
return {"error": "No frames processed"}
|
| 128 |
+
|
| 129 |
+
# Aggregation
|
| 130 |
+
avg_prob = sum(probs) / len(probs)
|
| 131 |
+
max_prob = max(probs)
|
| 132 |
+
fake_frame_count = len([p for p in probs if p > 0.6]) # Stricter frame threshold
|
| 133 |
+
fake_ratio = fake_frame_count / processed_count
|
| 134 |
+
|
| 135 |
+
# Verdict Logic (Tuned for High Efficiency Model)
|
| 136 |
+
# The new model is detecting everything as fake, so we need stricter rules.
|
| 137 |
+
|
| 138 |
+
# 1. Standard Average Check (shifted)
|
| 139 |
+
cond1 = avg_prob > 0.65
|
| 140 |
+
|
| 141 |
+
# 2. Density Check: Require at least 15% of frames to be strictly fake
|
| 142 |
+
# Was 5%, which is too low for a sensitive model
|
| 143 |
+
cond2 = fake_ratio > 0.15 and max_prob > 0.7
|
| 144 |
+
|
| 145 |
+
# 3. Peak Check: Only flag single-frame anomalies if EXTREMELY suspicious
|
| 146 |
+
cond3 = max_prob > 0.95
|
| 147 |
+
|
| 148 |
+
is_fake = cond1 or cond2 or cond3
|
| 149 |
+
|
| 150 |
+
verdict = "FAKE" if is_fake else "REAL"
|
| 151 |
+
|
| 152 |
+
# Confidence Calculation
|
| 153 |
+
if is_fake:
|
| 154 |
+
confidence = max(max_prob, 0.6)
|
| 155 |
+
else:
|
| 156 |
+
confidence = 1 - avg_prob
|
| 157 |
+
|
| 158 |
+
return {
|
| 159 |
+
"type": "video",
|
| 160 |
+
"prediction": verdict,
|
| 161 |
+
"confidence": float(confidence),
|
| 162 |
+
"avg_fake_prob": float(avg_prob),
|
| 163 |
+
"max_fake_prob": float(max_prob),
|
| 164 |
+
"fake_frame_ratio": float(fake_ratio),
|
| 165 |
+
"processed_frames": processed_count,
|
| 166 |
+
"duration": float(duration),
|
| 167 |
+
"timeline": [
|
| 168 |
+
{"time": round(i / fps, 2), "prob": round(p, 3)}
|
| 169 |
+
for i, p in zip(frame_indices, probs)
|
| 170 |
+
],
|
| 171 |
+
"suspicious_frames": suspicious_frames[:10] # Top 10 suspicious moments
|
| 172 |
+
}
|